Various improvements (#47)
* Update ggml * Pack only rwkv.dll for Windows releases Test executables would not be packed anymore. * Move test code into a separate file * Remove redundant zeroing * Refactor chat script
This commit is contained in:
		
							parent
							
								
									3621172428
								
							
						
					
					
						commit
						5eb8f09c14
					
				|  | @ -230,7 +230,7 @@ jobs: | ||||||
|         id: pack_artifacts |         id: pack_artifacts | ||||||
|         if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} |         if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} | ||||||
|         run: | |         run: | | ||||||
|           7z a rwkv-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip .\build\bin\Release\* |           7z a rwkv-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip .\build\bin\Release\rwkv.dll | ||||||
| 
 | 
 | ||||||
|       - name: Upload artifacts |       - name: Upload artifacts | ||||||
|         if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} |         if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} | ||||||
|  |  | ||||||
							
								
								
									
										2
									
								
								ggml
								
								
								
								
							
							
								
								
								
								
								
								
							
						
						
									
										2
									
								
								ggml
								
								
								
								
							|  | @ -1 +1 @@ | ||||||
| Subproject commit b237714db49cc09b63a372aeb33ca83bc56b3977 | Subproject commit 9d7974c3cf1284b4ddb926d94552e9fe4c4ad483 | ||||||
							
								
								
									
										1
									
								
								rwkv.cpp
								
								
								
								
							
							
						
						
									
										1
									
								
								rwkv.cpp
								
								
								
								
							|  | @ -568,7 +568,6 @@ bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float | ||||||
| 
 | 
 | ||||||
|     RWKV_ASSERT_FALSE(token >= 0 && token < n_vocab, "Token is out of range 0..%d", n_vocab - 1); |     RWKV_ASSERT_FALSE(token >= 0 && token < n_vocab, "Token is out of range 0..%d", n_vocab - 1); | ||||||
| 
 | 
 | ||||||
|     ggml_set_i32(ctx->token_index, 0); |  | ||||||
|     ggml_set_i32_1d(ctx->token_index, 0, token); |     ggml_set_i32_1d(ctx->token_index, 0, token); | ||||||
| 
 | 
 | ||||||
|     if (state_in == NULL) { |     if (state_in == NULL) { | ||||||
|  |  | ||||||
|  | @ -12,22 +12,15 @@ import tokenizers | ||||||
| import rwkv_cpp_model | import rwkv_cpp_model | ||||||
| import rwkv_cpp_shared_library | import rwkv_cpp_shared_library | ||||||
| import json | import json | ||||||
|  | from typing import Optional | ||||||
| 
 | 
 | ||||||
| # ======================================== Script settings ======================================== | # ======================================== Script settings ======================================== | ||||||
| 
 | 
 | ||||||
| # English, Chinese, Japanese | # English, Chinese, Japanese | ||||||
| LANGUAGE: str = 'English' | LANGUAGE: str = 'English' | ||||||
| # QA: Question and Answer prompt  | # QA: Question and Answer prompt to talk to an AI assistant. | ||||||
| # Chat: chat prompt (you need a large model for adequate quality, 7B+) | # Chat: chat prompt (need a large model for adequate quality, 7B+). | ||||||
| PROMPT_TYPE: str = "Chat" | PROMPT_TYPE: str = 'QA' | ||||||
| 
 |  | ||||||
| PROMPT_FILE: str = f'./rwkv/prompt/{LANGUAGE}-{PROMPT_TYPE}.json' |  | ||||||
| 
 |  | ||||||
| def load_prompt(PROMPT_FILE: str): |  | ||||||
|     with open(PROMPT_FILE, 'r') as json_file: |  | ||||||
|         variables = json.load(json_file) |  | ||||||
|         user, bot, separator, prompt = variables['user'], variables['bot'], variables['separator'], variables['prompt'] |  | ||||||
|         return user, bot, separator, prompt |  | ||||||
| 
 | 
 | ||||||
| MAX_GENERATION_LENGTH: int = 250 | MAX_GENERATION_LENGTH: int = 250 | ||||||
| 
 | 
 | ||||||
|  | @ -39,6 +32,7 @@ TOP_P: float = 0.5 | ||||||
| PRESENCE_PENALTY: float = 0.2 | PRESENCE_PENALTY: float = 0.2 | ||||||
| # Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. | # Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. | ||||||
| FREQUENCY_PENALTY: float = 0.2 | FREQUENCY_PENALTY: float = 0.2 | ||||||
|  | 
 | ||||||
| END_OF_LINE_TOKEN: int = 187 | END_OF_LINE_TOKEN: int = 187 | ||||||
| END_OF_TEXT_TOKEN: int = 0 | END_OF_TEXT_TOKEN: int = 0 | ||||||
| 
 | 
 | ||||||
|  | @ -48,11 +42,17 @@ parser = argparse.ArgumentParser(description='Provide terminal-based chat interf | ||||||
| parser.add_argument('model_path', help='Path to RWKV model in ggml format') | parser.add_argument('model_path', help='Path to RWKV model in ggml format') | ||||||
| args = parser.parse_args() | args = parser.parse_args() | ||||||
| 
 | 
 | ||||||
| user, bot, separator, init_prompt = load_prompt(PROMPT_FILE) | script_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent | ||||||
|  | 
 | ||||||
|  | with open(script_dir / 'prompt' / f'{LANGUAGE}-{PROMPT_TYPE}.json', 'r') as json_file: | ||||||
|  |     prompt_data = json.load(json_file) | ||||||
|  | 
 | ||||||
|  |     user, bot, separator, init_prompt = prompt_data['user'], prompt_data['bot'], prompt_data['separator'], prompt_data['prompt'] | ||||||
|  | 
 | ||||||
| assert init_prompt != '', 'Prompt must not be empty' | assert init_prompt != '', 'Prompt must not be empty' | ||||||
| 
 | 
 | ||||||
| print('Loading 20B tokenizer') | print('Loading 20B tokenizer') | ||||||
| tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json' | tokenizer_path = script_dir / '20B_tokenizer.json' | ||||||
| tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path)) | tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path)) | ||||||
| 
 | 
 | ||||||
| library = rwkv_cpp_shared_library.load_rwkv_shared_library() | library = rwkv_cpp_shared_library.load_rwkv_shared_library() | ||||||
|  | @ -64,48 +64,48 @@ model = rwkv_cpp_model.RWKVModel(library, args.model_path) | ||||||
| prompt_tokens = tokenizer.encode(init_prompt).ids | prompt_tokens = tokenizer.encode(init_prompt).ids | ||||||
| prompt_token_count = len(prompt_tokens) | prompt_token_count = len(prompt_tokens) | ||||||
| 
 | 
 | ||||||
| ######################################################################################################## | # ================================================================================================= | ||||||
| 
 | 
 | ||||||
| model_tokens: list[int] = [] | processed_tokens: list[int] = [] | ||||||
|  | logits: Optional[torch.Tensor] = None | ||||||
|  | state: Optional[torch.Tensor] = None | ||||||
| 
 | 
 | ||||||
| logits, model_state = None, None | def process_tokens(_tokens: list[int], new_line_logit_bias: float = 0.0) -> None: | ||||||
|  |     global processed_tokens, logits, state | ||||||
| 
 | 
 | ||||||
| def process_tokens(_tokens: list[int], newline_adj: int = 0) -> torch.Tensor: |     processed_tokens += _tokens | ||||||
|     global model_tokens, model_state, logits |  | ||||||
| 
 |  | ||||||
|     _tokens = [int(x) for x in _tokens] |  | ||||||
| 
 |  | ||||||
|     model_tokens += _tokens |  | ||||||
| 
 | 
 | ||||||
|     for _token in _tokens: |     for _token in _tokens: | ||||||
|         logits, model_state = model.eval(_token, model_state, model_state, logits) |         logits, state = model.eval(_token, state, state, logits) | ||||||
| 
 | 
 | ||||||
|     logits[END_OF_LINE_TOKEN] += newline_adj # adjust \n probability |     logits[END_OF_LINE_TOKEN] += new_line_logit_bias | ||||||
| 
 |  | ||||||
|     return logits |  | ||||||
| 
 | 
 | ||||||
| state_by_thread: dict[str, dict] = {} | state_by_thread: dict[str, dict] = {} | ||||||
| 
 | 
 | ||||||
| def save_thread_state(_thread: str, _logits: torch.Tensor) -> None: | def save_thread_state(_thread: str) -> None: | ||||||
|     state_by_thread[_thread] = {} |     state_by_thread[_thread] = { | ||||||
|     state_by_thread[_thread]['logits'] = copy.deepcopy(_logits) |         'tokens': copy.deepcopy(processed_tokens), | ||||||
|     state_by_thread[_thread]['rnn'] = copy.deepcopy(model_state) |         'logits': copy.deepcopy(logits), | ||||||
|     state_by_thread[_thread]['token'] = copy.deepcopy(model_tokens) |         'state': copy.deepcopy(state) | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
| def load_thread_state(_thread: str) -> torch.Tensor: | def load_thread_state(_thread: str) -> None: | ||||||
|     global model_tokens, model_state |     global processed_tokens, logits, state | ||||||
|     model_state = copy.deepcopy(state_by_thread[_thread]['rnn']) |  | ||||||
|     model_tokens = copy.deepcopy(state_by_thread[_thread]['token']) |  | ||||||
|     return copy.deepcopy(state_by_thread[_thread]['logits']) |  | ||||||
| 
 | 
 | ||||||
| ######################################################################################################## |     thread_state = state_by_thread[_thread] | ||||||
|  | 
 | ||||||
|  |     processed_tokens = copy.deepcopy(thread_state['tokens']) | ||||||
|  |     logits = copy.deepcopy(thread_state['logits']) | ||||||
|  |     state = copy.deepcopy(thread_state['state']) | ||||||
|  | 
 | ||||||
|  | # ================================================================================================= | ||||||
| 
 | 
 | ||||||
| print(f'Processing {prompt_token_count} prompt tokens, may take a while') | print(f'Processing {prompt_token_count} prompt tokens, may take a while') | ||||||
| 
 | 
 | ||||||
| logits = process_tokens(tokenizer.encode(init_prompt).ids) | process_tokens(tokenizer.encode(init_prompt).ids) | ||||||
| 
 | 
 | ||||||
| save_thread_state('chat_init', logits) | save_thread_state('chat_init') | ||||||
| save_thread_state('chat', logits) | save_thread_state('chat') | ||||||
| 
 | 
 | ||||||
| print(f'\nChat initialized! Your name is {user}. Write something and press Enter. Use \\n to add line breaks to your message.') | print(f'\nChat initialized! Your name is {user}. Write something and press Enter. Use \\n to add line breaks to your message.') | ||||||
| 
 | 
 | ||||||
|  | @ -117,7 +117,7 @@ while True: | ||||||
|     temperature = TEMPERATURE |     temperature = TEMPERATURE | ||||||
|     top_p = TOP_P |     top_p = TOP_P | ||||||
| 
 | 
 | ||||||
|     if "-temp=" in msg: |     if '-temp=' in msg: | ||||||
|         temperature = float(msg.split('-temp=')[1].split(' ')[0]) |         temperature = float(msg.split('-temp=')[1].split(' ')[0]) | ||||||
| 
 | 
 | ||||||
|         msg = msg.replace('-temp='+f'{temperature:g}', '') |         msg = msg.replace('-temp='+f'{temperature:g}', '') | ||||||
|  | @ -128,7 +128,7 @@ while True: | ||||||
|         if temperature >= 5: |         if temperature >= 5: | ||||||
|             temperature = 5 |             temperature = 5 | ||||||
| 
 | 
 | ||||||
|     if "-top_p=" in msg: |     if '-top_p=' in msg: | ||||||
|         top_p = float(msg.split('-top_p=')[1].split(' ')[0]) |         top_p = float(msg.split('-top_p=')[1].split(' ')[0]) | ||||||
| 
 | 
 | ||||||
|         msg = msg.replace('-top_p='+f'{top_p:g}', '') |         msg = msg.replace('-top_p='+f'{top_p:g}', '') | ||||||
|  | @ -140,8 +140,8 @@ while True: | ||||||
| 
 | 
 | ||||||
|     # + reset --> reset chat |     # + reset --> reset chat | ||||||
|     if msg == '+reset': |     if msg == '+reset': | ||||||
|         logits = load_thread_state('chat_init') |         load_thread_state('chat_init') | ||||||
|         save_thread_state('chat', logits) |         save_thread_state('chat') | ||||||
|         print(f'{bot}{separator} Chat reset.\n') |         print(f'{bot}{separator} Chat reset.\n') | ||||||
|         continue |         continue | ||||||
|     elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++': |     elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++': | ||||||
|  | @ -149,11 +149,10 @@ while True: | ||||||
|         # +gen YOUR PROMPT --> free single-round generation with any prompt. Requires Novel model. |         # +gen YOUR PROMPT --> free single-round generation with any prompt. Requires Novel model. | ||||||
|         if msg[:5].lower() == '+gen ': |         if msg[:5].lower() == '+gen ': | ||||||
|             new = '\n' + msg[5:].strip() |             new = '\n' + msg[5:].strip() | ||||||
|             # print(f'### prompt ###\n[{new}]') |             state = None | ||||||
|             model_state = None |             processed_tokens = [] | ||||||
|             model_tokens = [] |             process_tokens(tokenizer.encode(new).ids) | ||||||
|             logits = process_tokens(tokenizer.encode(new).ids) |             save_thread_state('gen_0') | ||||||
|             save_thread_state('gen_0', logits) |  | ||||||
| 
 | 
 | ||||||
|         # +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model. |         # +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model. | ||||||
|         elif msg[:3].lower() == '+i ': |         elif msg[:3].lower() == '+i ': | ||||||
|  | @ -165,37 +164,34 @@ Below is an instruction that describes a task. Write a response that appropriate | ||||||
| 
 | 
 | ||||||
| # Response: | # Response: | ||||||
| ''' | ''' | ||||||
|             # print(f'### prompt ###\n[{new}]') |             state = None | ||||||
|             model_state = None |             processed_tokens = [] | ||||||
|             model_tokens = [] |             process_tokens(tokenizer.encode(new).ids) | ||||||
|             logits = process_tokens(tokenizer.encode(new).ids) |             save_thread_state('gen_0') | ||||||
|             save_thread_state('gen_0', logits) |  | ||||||
| 
 | 
 | ||||||
|         # +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context). |         # +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context). | ||||||
|         elif msg[:4].lower() == '+qq ': |         elif msg[:4].lower() == '+qq ': | ||||||
|             new = '\nQ: ' + msg[4:].strip() + '\nA:' |             new = '\nQ: ' + msg[4:].strip() + '\nA:' | ||||||
|             # print(f'### prompt ###\n[{new}]') |             state = None | ||||||
|             model_state = None |             processed_tokens = [] | ||||||
|             model_tokens = [] |             process_tokens(tokenizer.encode(new).ids) | ||||||
|             logits = process_tokens(tokenizer.encode(new).ids) |             save_thread_state('gen_0') | ||||||
|             save_thread_state('gen_0', logits) |  | ||||||
| 
 | 
 | ||||||
|         # +qa YOUR QUESTION --> answer an independent question (regardless of context). |         # +qa YOUR QUESTION --> answer an independent question (regardless of context). | ||||||
|         elif msg[:4].lower() == '+qa ': |         elif msg[:4].lower() == '+qa ': | ||||||
|             logits = load_thread_state('chat_init') |             load_thread_state('chat_init') | ||||||
| 
 | 
 | ||||||
|             real_msg = msg[4:].strip() |             real_msg = msg[4:].strip() | ||||||
|             new = f"{user}{separator} {real_msg}\n\n{bot}{separator}" |             new = f'{user}{separator} {real_msg}\n\n{bot}{separator}' | ||||||
|             # print(f'### qa ###\n[{new}]') |  | ||||||
| 
 | 
 | ||||||
|             logits = process_tokens(tokenizer.encode(new).ids) |             process_tokens(tokenizer.encode(new).ids) | ||||||
|             save_thread_state('gen_0', logits) |             save_thread_state('gen_0') | ||||||
| 
 | 
 | ||||||
|         # +++ --> continue last free generation (only for +gen / +i) |         # +++ --> continue last free generation (only for +gen / +i) | ||||||
|         elif msg.lower() == '+++': |         elif msg.lower() == '+++': | ||||||
|             try: |             try: | ||||||
|                 logits = load_thread_state('gen_1') |                 load_thread_state('gen_1') | ||||||
|                 save_thread_state('gen_0', logits) |                 save_thread_state('gen_0') | ||||||
|             except Exception as e: |             except Exception as e: | ||||||
|                 print(e) |                 print(e) | ||||||
|                 continue |                 continue | ||||||
|  | @ -203,49 +199,52 @@ Below is an instruction that describes a task. Write a response that appropriate | ||||||
|         # ++ --> retry last free generation (only for +gen / +i) |         # ++ --> retry last free generation (only for +gen / +i) | ||||||
|         elif msg.lower() == '++': |         elif msg.lower() == '++': | ||||||
|             try: |             try: | ||||||
|                 logits = load_thread_state('gen_0') |                 load_thread_state('gen_0') | ||||||
|             except Exception as e: |             except Exception as e: | ||||||
|                 print(e) |                 print(e) | ||||||
|                 continue |                 continue | ||||||
|         thread = "gen_1" |         thread = 'gen_1' | ||||||
| 
 | 
 | ||||||
|     else: |     else: | ||||||
|         # + --> alternate chat reply |         # + --> alternate chat reply | ||||||
|         if msg.lower() == '+': |         if msg.lower() == '+': | ||||||
|             try: |             try: | ||||||
|                 logits = load_thread_state('chat_pre') |                 load_thread_state('chat_pre') | ||||||
|             except Exception as e: |             except Exception as e: | ||||||
|                 print(e) |                 print(e) | ||||||
|                 continue |                 continue | ||||||
|         # chat with bot |         # chat with bot | ||||||
|         else: |         else: | ||||||
|             logits = load_thread_state('chat') |             load_thread_state('chat') | ||||||
|             new = f"{user}{separator} {msg}\n\n{bot}{separator}" |             new = f'{user}{separator} {msg}\n\n{bot}{separator}' | ||||||
|             # print(f'### add ###\n[{new}]') |             process_tokens(tokenizer.encode(new).ids, new_line_logit_bias=-999999999) | ||||||
|             logits = process_tokens(tokenizer.encode(new).ids, newline_adj=-999999999) |             save_thread_state('chat_pre') | ||||||
|             save_thread_state('chat_pre', logits) |  | ||||||
| 
 | 
 | ||||||
|         thread = 'chat' |         thread = 'chat' | ||||||
| 
 | 
 | ||||||
|         # Print bot response |         # Print bot response | ||||||
|         print(f"> {bot}{separator}", end='') |         print(f'> {bot}{separator}', end='') | ||||||
| 
 | 
 | ||||||
|     start_index: int = len(model_tokens) |     start_index: int = len(processed_tokens) | ||||||
|     accumulated_tokens: list[int] = [] |     accumulated_tokens: list[int] = [] | ||||||
|     occurrence: dict[int, int] = {} |     token_counts: dict[int, int] = {} | ||||||
| 
 | 
 | ||||||
|     for i in range(MAX_GENERATION_LENGTH): |     for i in range(MAX_GENERATION_LENGTH): | ||||||
|         for n in occurrence: |         for n in token_counts: | ||||||
|             logits[n] -= (PRESENCE_PENALTY + occurrence[n] * FREQUENCY_PENALTY) |             logits[n] -= PRESENCE_PENALTY + token_counts[n] * FREQUENCY_PENALTY | ||||||
|  | 
 | ||||||
|         token: int = sampling.sample_logits(logits, temperature, top_p) |         token: int = sampling.sample_logits(logits, temperature, top_p) | ||||||
|  | 
 | ||||||
|         if token == END_OF_TEXT_TOKEN: |         if token == END_OF_TEXT_TOKEN: | ||||||
|             print() |             print() | ||||||
|             break |             break | ||||||
|         if token not in occurrence: | 
 | ||||||
|             occurrence[token] = 1 |         if token not in token_counts: | ||||||
|  |             token_counts[token] = 1 | ||||||
|         else: |         else: | ||||||
|             occurrence[token] += 1 |             token_counts[token] += 1 | ||||||
|         logits: torch.Tensor = process_tokens([token]) | 
 | ||||||
|  |         process_tokens([token]) | ||||||
| 
 | 
 | ||||||
|         # Avoid UTF-8 display issues |         # Avoid UTF-8 display issues | ||||||
|         accumulated_tokens += [token] |         accumulated_tokens += [token] | ||||||
|  | @ -258,10 +257,10 @@ Below is an instruction that describes a task. Write a response that appropriate | ||||||
|             accumulated_tokens = [] |             accumulated_tokens = [] | ||||||
| 
 | 
 | ||||||
|         if thread == 'chat': |         if thread == 'chat': | ||||||
|             if '\n\n' in tokenizer.decode(model_tokens[start_index:]): |             if '\n\n' in tokenizer.decode(processed_tokens[start_index:]): | ||||||
|                 break |                 break | ||||||
| 
 | 
 | ||||||
|         if i == MAX_GENERATION_LENGTH - 1: |         if i == MAX_GENERATION_LENGTH - 1: | ||||||
|             print() |             print() | ||||||
| 
 | 
 | ||||||
|     save_thread_state(thread, logits) |     save_thread_state(thread) | ||||||
|  |  | ||||||
|  | @ -3,7 +3,6 @@ | ||||||
| # Get model checkpoints from https://huggingface.co/BlinkDL | # Get model checkpoints from https://huggingface.co/BlinkDL | ||||||
| # See FILE_FORMAT.md for the documentation on the file format. | # See FILE_FORMAT.md for the documentation on the file format. | ||||||
| 
 | 
 | ||||||
| import os |  | ||||||
| import argparse | import argparse | ||||||
| import struct | import struct | ||||||
| import torch | import torch | ||||||
|  | @ -97,53 +96,5 @@ def main() -> None: | ||||||
| 
 | 
 | ||||||
|     print('Done') |     print('Done') | ||||||
| 
 | 
 | ||||||
| # --- Tests --- |  | ||||||
| 
 |  | ||||||
| def test() -> None: |  | ||||||
|     test_file_path = 'convert_pytorch_rwkv_to_ggml_test.tmp' |  | ||||||
| 
 |  | ||||||
|     try: |  | ||||||
|         state_dict: Dict[str, torch.Tensor] = { |  | ||||||
|             'emb.weight': torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32), |  | ||||||
|             'blocks.0.ln1.weight': torch.tensor([1], dtype=torch.float32) |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         write_state_dict(state_dict, dest_path=test_file_path, data_type='float32') |  | ||||||
| 
 |  | ||||||
|         with open(test_file_path, 'rb') as input: |  | ||||||
|             actual_bytes: bytes = input.read() |  | ||||||
| 
 |  | ||||||
|         expected_bytes: bytes = struct.pack( |  | ||||||
|             '=iiiiii' + 'iiiii10sffffff' + 'iiii19sf', |  | ||||||
|             0x67676d66, |  | ||||||
|             100, |  | ||||||
|             3, |  | ||||||
|             2, |  | ||||||
|             1, |  | ||||||
|             0, |  | ||||||
|             # emb.weight |  | ||||||
|             2, |  | ||||||
|             10, |  | ||||||
|             0, |  | ||||||
|             2, 3, |  | ||||||
|             'emb.weight'.encode('utf-8'), |  | ||||||
|             1.0, 2.0, 3.0, |  | ||||||
|             4.0, 5.0, 6.0, |  | ||||||
|             # blocks.0.ln1.weight |  | ||||||
|             1, |  | ||||||
|             19, |  | ||||||
|             0, |  | ||||||
|             1, |  | ||||||
|             'blocks.0.ln1.weight'.encode('utf-8'), |  | ||||||
|             1.0 |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         assert list(actual_bytes) == list(expected_bytes), f'\nActual: {list(actual_bytes)}\nExpected: {list(expected_bytes)}' |  | ||||||
| 
 |  | ||||||
|         print('All tests pass') |  | ||||||
|     finally: |  | ||||||
|         if os.path.isfile(test_file_path): |  | ||||||
|             os.remove(test_file_path) |  | ||||||
| 
 |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     main() |     main() | ||||||
|  |  | ||||||
|  | @ -0,0 +1,54 @@ | ||||||
|  | import os | ||||||
|  | import struct | ||||||
|  | import torch | ||||||
|  | import convert_pytorch_to_ggml | ||||||
|  | from typing import Dict | ||||||
|  | 
 | ||||||
|  | def test() -> None: | ||||||
|  |     test_file_path = 'convert_pytorch_rwkv_to_ggml_test.tmp' | ||||||
|  | 
 | ||||||
|  |     try: | ||||||
|  |         state_dict: Dict[str, torch.Tensor] = { | ||||||
|  |             'emb.weight': torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32), | ||||||
|  |             'blocks.0.ln1.weight': torch.tensor([1], dtype=torch.float32) | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         convert_pytorch_to_ggml.write_state_dict(state_dict, dest_path=test_file_path, data_type='float32') | ||||||
|  | 
 | ||||||
|  |         with open(test_file_path, 'rb') as input: | ||||||
|  |             actual_bytes: bytes = input.read() | ||||||
|  | 
 | ||||||
|  |         expected_bytes: bytes = struct.pack( | ||||||
|  |             '=iiiiii' + 'iiiii10sffffff' + 'iiii19sf', | ||||||
|  |             0x67676d66, | ||||||
|  |             100, | ||||||
|  |             3, | ||||||
|  |             2, | ||||||
|  |             1, | ||||||
|  |             0, | ||||||
|  |             # emb.weight | ||||||
|  |             2, | ||||||
|  |             10, | ||||||
|  |             0, | ||||||
|  |             2, 3, | ||||||
|  |             'emb.weight'.encode('utf-8'), | ||||||
|  |             1.0, 2.0, 3.0, | ||||||
|  |             4.0, 5.0, 6.0, | ||||||
|  |             # blocks.0.ln1.weight | ||||||
|  |             1, | ||||||
|  |             19, | ||||||
|  |             0, | ||||||
|  |             1, | ||||||
|  |             'blocks.0.ln1.weight'.encode('utf-8'), | ||||||
|  |             1.0 | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         assert list(actual_bytes) == list(expected_bytes), f'\nActual: {list(actual_bytes)}\nExpected: {list(expected_bytes)}' | ||||||
|  | 
 | ||||||
|  |         print('All tests pass') | ||||||
|  |     finally: | ||||||
|  |         if os.path.isfile(test_file_path): | ||||||
|  |             os.remove(test_file_path) | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     test() | ||||||
		Loading…
	
		Reference in New Issue