Various improvements (#47)
* Update ggml * Pack only rwkv.dll for Windows releases Test executables would not be packed anymore. * Move test code into a separate file * Remove redundant zeroing * Refactor chat script
This commit is contained in:
		
							parent
							
								
									3621172428
								
							
						
					
					
						commit
						5eb8f09c14
					
				|  | @ -230,7 +230,7 @@ jobs: | |||
|         id: pack_artifacts | ||||
|         if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} | ||||
|         run: | | ||||
|           7z a rwkv-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip .\build\bin\Release\* | ||||
|           7z a rwkv-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip .\build\bin\Release\rwkv.dll | ||||
| 
 | ||||
|       - name: Upload artifacts | ||||
|         if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} | ||||
|  |  | |||
							
								
								
									
										2
									
								
								ggml
								
								
								
								
							
							
								
								
								
								
								
								
							
						
						
									
										2
									
								
								ggml
								
								
								
								
							|  | @ -1 +1 @@ | |||
| Subproject commit b237714db49cc09b63a372aeb33ca83bc56b3977 | ||||
| Subproject commit 9d7974c3cf1284b4ddb926d94552e9fe4c4ad483 | ||||
							
								
								
									
										1
									
								
								rwkv.cpp
								
								
								
								
							
							
						
						
									
										1
									
								
								rwkv.cpp
								
								
								
								
							|  | @ -568,7 +568,6 @@ bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float | |||
| 
 | ||||
|     RWKV_ASSERT_FALSE(token >= 0 && token < n_vocab, "Token is out of range 0..%d", n_vocab - 1); | ||||
| 
 | ||||
|     ggml_set_i32(ctx->token_index, 0); | ||||
|     ggml_set_i32_1d(ctx->token_index, 0, token); | ||||
| 
 | ||||
|     if (state_in == NULL) { | ||||
|  |  | |||
|  | @ -12,22 +12,15 @@ import tokenizers | |||
| import rwkv_cpp_model | ||||
| import rwkv_cpp_shared_library | ||||
| import json | ||||
| from typing import Optional | ||||
| 
 | ||||
| # ======================================== Script settings ======================================== | ||||
| 
 | ||||
| # English, Chinese, Japanese | ||||
| LANGUAGE: str = 'English' | ||||
| # QA: Question and Answer prompt  | ||||
| # Chat: chat prompt (you need a large model for adequate quality, 7B+) | ||||
| PROMPT_TYPE: str = "Chat" | ||||
| 
 | ||||
| PROMPT_FILE: str = f'./rwkv/prompt/{LANGUAGE}-{PROMPT_TYPE}.json' | ||||
| 
 | ||||
| def load_prompt(PROMPT_FILE: str): | ||||
|     with open(PROMPT_FILE, 'r') as json_file: | ||||
|         variables = json.load(json_file) | ||||
|         user, bot, separator, prompt = variables['user'], variables['bot'], variables['separator'], variables['prompt'] | ||||
|         return user, bot, separator, prompt | ||||
| # QA: Question and Answer prompt to talk to an AI assistant. | ||||
| # Chat: chat prompt (need a large model for adequate quality, 7B+). | ||||
| PROMPT_TYPE: str = 'QA' | ||||
| 
 | ||||
| MAX_GENERATION_LENGTH: int = 250 | ||||
| 
 | ||||
|  | @ -39,6 +32,7 @@ TOP_P: float = 0.5 | |||
| PRESENCE_PENALTY: float = 0.2 | ||||
| # Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. | ||||
| FREQUENCY_PENALTY: float = 0.2 | ||||
| 
 | ||||
| END_OF_LINE_TOKEN: int = 187 | ||||
| END_OF_TEXT_TOKEN: int = 0 | ||||
| 
 | ||||
|  | @ -48,11 +42,17 @@ parser = argparse.ArgumentParser(description='Provide terminal-based chat interf | |||
| parser.add_argument('model_path', help='Path to RWKV model in ggml format') | ||||
| args = parser.parse_args() | ||||
| 
 | ||||
| user, bot, separator, init_prompt = load_prompt(PROMPT_FILE) | ||||
| script_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent | ||||
| 
 | ||||
| with open(script_dir / 'prompt' / f'{LANGUAGE}-{PROMPT_TYPE}.json', 'r') as json_file: | ||||
|     prompt_data = json.load(json_file) | ||||
| 
 | ||||
|     user, bot, separator, init_prompt = prompt_data['user'], prompt_data['bot'], prompt_data['separator'], prompt_data['prompt'] | ||||
| 
 | ||||
| assert init_prompt != '', 'Prompt must not be empty' | ||||
| 
 | ||||
| print('Loading 20B tokenizer') | ||||
| tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json' | ||||
| tokenizer_path = script_dir / '20B_tokenizer.json' | ||||
| tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path)) | ||||
| 
 | ||||
| library = rwkv_cpp_shared_library.load_rwkv_shared_library() | ||||
|  | @ -64,48 +64,48 @@ model = rwkv_cpp_model.RWKVModel(library, args.model_path) | |||
| prompt_tokens = tokenizer.encode(init_prompt).ids | ||||
| prompt_token_count = len(prompt_tokens) | ||||
| 
 | ||||
| ######################################################################################################## | ||||
| # ================================================================================================= | ||||
| 
 | ||||
| model_tokens: list[int] = [] | ||||
| processed_tokens: list[int] = [] | ||||
| logits: Optional[torch.Tensor] = None | ||||
| state: Optional[torch.Tensor] = None | ||||
| 
 | ||||
| logits, model_state = None, None | ||||
| def process_tokens(_tokens: list[int], new_line_logit_bias: float = 0.0) -> None: | ||||
|     global processed_tokens, logits, state | ||||
| 
 | ||||
| def process_tokens(_tokens: list[int], newline_adj: int = 0) -> torch.Tensor: | ||||
|     global model_tokens, model_state, logits | ||||
| 
 | ||||
|     _tokens = [int(x) for x in _tokens] | ||||
| 
 | ||||
|     model_tokens += _tokens | ||||
|     processed_tokens += _tokens | ||||
| 
 | ||||
|     for _token in _tokens: | ||||
|         logits, model_state = model.eval(_token, model_state, model_state, logits) | ||||
|         logits, state = model.eval(_token, state, state, logits) | ||||
| 
 | ||||
|     logits[END_OF_LINE_TOKEN] += newline_adj # adjust \n probability | ||||
| 
 | ||||
|     return logits | ||||
|     logits[END_OF_LINE_TOKEN] += new_line_logit_bias | ||||
| 
 | ||||
| state_by_thread: dict[str, dict] = {} | ||||
| 
 | ||||
| def save_thread_state(_thread: str, _logits: torch.Tensor) -> None: | ||||
|     state_by_thread[_thread] = {} | ||||
|     state_by_thread[_thread]['logits'] = copy.deepcopy(_logits) | ||||
|     state_by_thread[_thread]['rnn'] = copy.deepcopy(model_state) | ||||
|     state_by_thread[_thread]['token'] = copy.deepcopy(model_tokens) | ||||
| def save_thread_state(_thread: str) -> None: | ||||
|     state_by_thread[_thread] = { | ||||
|         'tokens': copy.deepcopy(processed_tokens), | ||||
|         'logits': copy.deepcopy(logits), | ||||
|         'state': copy.deepcopy(state) | ||||
|     } | ||||
| 
 | ||||
| def load_thread_state(_thread: str) -> torch.Tensor: | ||||
|     global model_tokens, model_state | ||||
|     model_state = copy.deepcopy(state_by_thread[_thread]['rnn']) | ||||
|     model_tokens = copy.deepcopy(state_by_thread[_thread]['token']) | ||||
|     return copy.deepcopy(state_by_thread[_thread]['logits']) | ||||
| def load_thread_state(_thread: str) -> None: | ||||
|     global processed_tokens, logits, state | ||||
| 
 | ||||
| ######################################################################################################## | ||||
|     thread_state = state_by_thread[_thread] | ||||
| 
 | ||||
|     processed_tokens = copy.deepcopy(thread_state['tokens']) | ||||
|     logits = copy.deepcopy(thread_state['logits']) | ||||
|     state = copy.deepcopy(thread_state['state']) | ||||
| 
 | ||||
| # ================================================================================================= | ||||
| 
 | ||||
| print(f'Processing {prompt_token_count} prompt tokens, may take a while') | ||||
| 
 | ||||
| logits = process_tokens(tokenizer.encode(init_prompt).ids) | ||||
| process_tokens(tokenizer.encode(init_prompt).ids) | ||||
| 
 | ||||
| save_thread_state('chat_init', logits) | ||||
| save_thread_state('chat', logits) | ||||
| save_thread_state('chat_init') | ||||
| save_thread_state('chat') | ||||
| 
 | ||||
| print(f'\nChat initialized! Your name is {user}. Write something and press Enter. Use \\n to add line breaks to your message.') | ||||
| 
 | ||||
|  | @ -117,7 +117,7 @@ while True: | |||
|     temperature = TEMPERATURE | ||||
|     top_p = TOP_P | ||||
| 
 | ||||
|     if "-temp=" in msg: | ||||
|     if '-temp=' in msg: | ||||
|         temperature = float(msg.split('-temp=')[1].split(' ')[0]) | ||||
| 
 | ||||
|         msg = msg.replace('-temp='+f'{temperature:g}', '') | ||||
|  | @ -128,7 +128,7 @@ while True: | |||
|         if temperature >= 5: | ||||
|             temperature = 5 | ||||
| 
 | ||||
|     if "-top_p=" in msg: | ||||
|     if '-top_p=' in msg: | ||||
|         top_p = float(msg.split('-top_p=')[1].split(' ')[0]) | ||||
| 
 | ||||
|         msg = msg.replace('-top_p='+f'{top_p:g}', '') | ||||
|  | @ -140,8 +140,8 @@ while True: | |||
| 
 | ||||
|     # + reset --> reset chat | ||||
|     if msg == '+reset': | ||||
|         logits = load_thread_state('chat_init') | ||||
|         save_thread_state('chat', logits) | ||||
|         load_thread_state('chat_init') | ||||
|         save_thread_state('chat') | ||||
|         print(f'{bot}{separator} Chat reset.\n') | ||||
|         continue | ||||
|     elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++': | ||||
|  | @ -149,11 +149,10 @@ while True: | |||
|         # +gen YOUR PROMPT --> free single-round generation with any prompt. Requires Novel model. | ||||
|         if msg[:5].lower() == '+gen ': | ||||
|             new = '\n' + msg[5:].strip() | ||||
|             # print(f'### prompt ###\n[{new}]') | ||||
|             model_state = None | ||||
|             model_tokens = [] | ||||
|             logits = process_tokens(tokenizer.encode(new).ids) | ||||
|             save_thread_state('gen_0', logits) | ||||
|             state = None | ||||
|             processed_tokens = [] | ||||
|             process_tokens(tokenizer.encode(new).ids) | ||||
|             save_thread_state('gen_0') | ||||
| 
 | ||||
|         # +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model. | ||||
|         elif msg[:3].lower() == '+i ': | ||||
|  | @ -165,37 +164,34 @@ Below is an instruction that describes a task. Write a response that appropriate | |||
| 
 | ||||
| # Response: | ||||
| ''' | ||||
|             # print(f'### prompt ###\n[{new}]') | ||||
|             model_state = None | ||||
|             model_tokens = [] | ||||
|             logits = process_tokens(tokenizer.encode(new).ids) | ||||
|             save_thread_state('gen_0', logits) | ||||
|             state = None | ||||
|             processed_tokens = [] | ||||
|             process_tokens(tokenizer.encode(new).ids) | ||||
|             save_thread_state('gen_0') | ||||
| 
 | ||||
|         # +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context). | ||||
|         elif msg[:4].lower() == '+qq ': | ||||
|             new = '\nQ: ' + msg[4:].strip() + '\nA:' | ||||
|             # print(f'### prompt ###\n[{new}]') | ||||
|             model_state = None | ||||
|             model_tokens = [] | ||||
|             logits = process_tokens(tokenizer.encode(new).ids) | ||||
|             save_thread_state('gen_0', logits) | ||||
|             state = None | ||||
|             processed_tokens = [] | ||||
|             process_tokens(tokenizer.encode(new).ids) | ||||
|             save_thread_state('gen_0') | ||||
| 
 | ||||
|         # +qa YOUR QUESTION --> answer an independent question (regardless of context). | ||||
|         elif msg[:4].lower() == '+qa ': | ||||
|             logits = load_thread_state('chat_init') | ||||
|             load_thread_state('chat_init') | ||||
| 
 | ||||
|             real_msg = msg[4:].strip() | ||||
|             new = f"{user}{separator} {real_msg}\n\n{bot}{separator}" | ||||
|             # print(f'### qa ###\n[{new}]') | ||||
|             new = f'{user}{separator} {real_msg}\n\n{bot}{separator}' | ||||
| 
 | ||||
|             logits = process_tokens(tokenizer.encode(new).ids) | ||||
|             save_thread_state('gen_0', logits) | ||||
|             process_tokens(tokenizer.encode(new).ids) | ||||
|             save_thread_state('gen_0') | ||||
| 
 | ||||
|         # +++ --> continue last free generation (only for +gen / +i) | ||||
|         elif msg.lower() == '+++': | ||||
|             try: | ||||
|                 logits = load_thread_state('gen_1') | ||||
|                 save_thread_state('gen_0', logits) | ||||
|                 load_thread_state('gen_1') | ||||
|                 save_thread_state('gen_0') | ||||
|             except Exception as e: | ||||
|                 print(e) | ||||
|                 continue | ||||
|  | @ -203,49 +199,52 @@ Below is an instruction that describes a task. Write a response that appropriate | |||
|         # ++ --> retry last free generation (only for +gen / +i) | ||||
|         elif msg.lower() == '++': | ||||
|             try: | ||||
|                 logits = load_thread_state('gen_0') | ||||
|                 load_thread_state('gen_0') | ||||
|             except Exception as e: | ||||
|                 print(e) | ||||
|                 continue | ||||
|         thread = "gen_1" | ||||
|         thread = 'gen_1' | ||||
| 
 | ||||
|     else: | ||||
|         # + --> alternate chat reply | ||||
|         if msg.lower() == '+': | ||||
|             try: | ||||
|                 logits = load_thread_state('chat_pre') | ||||
|                 load_thread_state('chat_pre') | ||||
|             except Exception as e: | ||||
|                 print(e) | ||||
|                 continue | ||||
|         # chat with bot | ||||
|         else: | ||||
|             logits = load_thread_state('chat') | ||||
|             new = f"{user}{separator} {msg}\n\n{bot}{separator}" | ||||
|             # print(f'### add ###\n[{new}]') | ||||
|             logits = process_tokens(tokenizer.encode(new).ids, newline_adj=-999999999) | ||||
|             save_thread_state('chat_pre', logits) | ||||
|             load_thread_state('chat') | ||||
|             new = f'{user}{separator} {msg}\n\n{bot}{separator}' | ||||
|             process_tokens(tokenizer.encode(new).ids, new_line_logit_bias=-999999999) | ||||
|             save_thread_state('chat_pre') | ||||
| 
 | ||||
|         thread = 'chat' | ||||
| 
 | ||||
|         # Print bot response | ||||
|         print(f"> {bot}{separator}", end='') | ||||
|         print(f'> {bot}{separator}', end='') | ||||
| 
 | ||||
|     start_index: int = len(model_tokens) | ||||
|     start_index: int = len(processed_tokens) | ||||
|     accumulated_tokens: list[int] = [] | ||||
|     occurrence: dict[int, int] = {} | ||||
|     token_counts: dict[int, int] = {} | ||||
| 
 | ||||
|     for i in range(MAX_GENERATION_LENGTH): | ||||
|         for n in occurrence: | ||||
|             logits[n] -= (PRESENCE_PENALTY + occurrence[n] * FREQUENCY_PENALTY) | ||||
|         for n in token_counts: | ||||
|             logits[n] -= PRESENCE_PENALTY + token_counts[n] * FREQUENCY_PENALTY | ||||
| 
 | ||||
|         token: int = sampling.sample_logits(logits, temperature, top_p) | ||||
| 
 | ||||
|         if token == END_OF_TEXT_TOKEN: | ||||
|             print() | ||||
|             break | ||||
|         if token not in occurrence: | ||||
|             occurrence[token] = 1 | ||||
| 
 | ||||
|         if token not in token_counts: | ||||
|             token_counts[token] = 1 | ||||
|         else: | ||||
|             occurrence[token] += 1 | ||||
|         logits: torch.Tensor = process_tokens([token]) | ||||
|             token_counts[token] += 1 | ||||
| 
 | ||||
|         process_tokens([token]) | ||||
| 
 | ||||
|         # Avoid UTF-8 display issues | ||||
|         accumulated_tokens += [token] | ||||
|  | @ -258,10 +257,10 @@ Below is an instruction that describes a task. Write a response that appropriate | |||
|             accumulated_tokens = [] | ||||
| 
 | ||||
|         if thread == 'chat': | ||||
|             if '\n\n' in tokenizer.decode(model_tokens[start_index:]): | ||||
|             if '\n\n' in tokenizer.decode(processed_tokens[start_index:]): | ||||
|                 break | ||||
| 
 | ||||
|         if i == MAX_GENERATION_LENGTH - 1: | ||||
|             print() | ||||
| 
 | ||||
|     save_thread_state(thread, logits) | ||||
|     save_thread_state(thread) | ||||
|  |  | |||
|  | @ -3,7 +3,6 @@ | |||
| # Get model checkpoints from https://huggingface.co/BlinkDL | ||||
| # See FILE_FORMAT.md for the documentation on the file format. | ||||
| 
 | ||||
| import os | ||||
| import argparse | ||||
| import struct | ||||
| import torch | ||||
|  | @ -97,53 +96,5 @@ def main() -> None: | |||
| 
 | ||||
|     print('Done') | ||||
| 
 | ||||
| # --- Tests --- | ||||
| 
 | ||||
| def test() -> None: | ||||
|     test_file_path = 'convert_pytorch_rwkv_to_ggml_test.tmp' | ||||
| 
 | ||||
|     try: | ||||
|         state_dict: Dict[str, torch.Tensor] = { | ||||
|             'emb.weight': torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32), | ||||
|             'blocks.0.ln1.weight': torch.tensor([1], dtype=torch.float32) | ||||
|         } | ||||
| 
 | ||||
|         write_state_dict(state_dict, dest_path=test_file_path, data_type='float32') | ||||
| 
 | ||||
|         with open(test_file_path, 'rb') as input: | ||||
|             actual_bytes: bytes = input.read() | ||||
| 
 | ||||
|         expected_bytes: bytes = struct.pack( | ||||
|             '=iiiiii' + 'iiiii10sffffff' + 'iiii19sf', | ||||
|             0x67676d66, | ||||
|             100, | ||||
|             3, | ||||
|             2, | ||||
|             1, | ||||
|             0, | ||||
|             # emb.weight | ||||
|             2, | ||||
|             10, | ||||
|             0, | ||||
|             2, 3, | ||||
|             'emb.weight'.encode('utf-8'), | ||||
|             1.0, 2.0, 3.0, | ||||
|             4.0, 5.0, 6.0, | ||||
|             # blocks.0.ln1.weight | ||||
|             1, | ||||
|             19, | ||||
|             0, | ||||
|             1, | ||||
|             'blocks.0.ln1.weight'.encode('utf-8'), | ||||
|             1.0 | ||||
|         ) | ||||
| 
 | ||||
|         assert list(actual_bytes) == list(expected_bytes), f'\nActual: {list(actual_bytes)}\nExpected: {list(expected_bytes)}' | ||||
| 
 | ||||
|         print('All tests pass') | ||||
|     finally: | ||||
|         if os.path.isfile(test_file_path): | ||||
|             os.remove(test_file_path) | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
|  |  | |||
|  | @ -0,0 +1,54 @@ | |||
| import os | ||||
| import struct | ||||
| import torch | ||||
| import convert_pytorch_to_ggml | ||||
| from typing import Dict | ||||
| 
 | ||||
| def test() -> None: | ||||
|     test_file_path = 'convert_pytorch_rwkv_to_ggml_test.tmp' | ||||
| 
 | ||||
|     try: | ||||
|         state_dict: Dict[str, torch.Tensor] = { | ||||
|             'emb.weight': torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32), | ||||
|             'blocks.0.ln1.weight': torch.tensor([1], dtype=torch.float32) | ||||
|         } | ||||
| 
 | ||||
|         convert_pytorch_to_ggml.write_state_dict(state_dict, dest_path=test_file_path, data_type='float32') | ||||
| 
 | ||||
|         with open(test_file_path, 'rb') as input: | ||||
|             actual_bytes: bytes = input.read() | ||||
| 
 | ||||
|         expected_bytes: bytes = struct.pack( | ||||
|             '=iiiiii' + 'iiiii10sffffff' + 'iiii19sf', | ||||
|             0x67676d66, | ||||
|             100, | ||||
|             3, | ||||
|             2, | ||||
|             1, | ||||
|             0, | ||||
|             # emb.weight | ||||
|             2, | ||||
|             10, | ||||
|             0, | ||||
|             2, 3, | ||||
|             'emb.weight'.encode('utf-8'), | ||||
|             1.0, 2.0, 3.0, | ||||
|             4.0, 5.0, 6.0, | ||||
|             # blocks.0.ln1.weight | ||||
|             1, | ||||
|             19, | ||||
|             0, | ||||
|             1, | ||||
|             'blocks.0.ln1.weight'.encode('utf-8'), | ||||
|             1.0 | ||||
|         ) | ||||
| 
 | ||||
|         assert list(actual_bytes) == list(expected_bytes), f'\nActual: {list(actual_bytes)}\nExpected: {list(expected_bytes)}' | ||||
| 
 | ||||
|         print('All tests pass') | ||||
|     finally: | ||||
|         if os.path.isfile(test_file_path): | ||||
|             os.remove(test_file_path) | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     test() | ||||
		Loading…
	
		Reference in New Issue