Various improvements (#47)
* Update ggml * Pack only rwkv.dll for Windows releases Test executables would not be packed anymore. * Move test code into a separate file * Remove redundant zeroing * Refactor chat script
This commit is contained in:
parent
3621172428
commit
5eb8f09c14
|
@ -230,7 +230,7 @@ jobs:
|
||||||
id: pack_artifacts
|
id: pack_artifacts
|
||||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||||
run: |
|
run: |
|
||||||
7z a rwkv-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip .\build\bin\Release\*
|
7z a rwkv-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip .\build\bin\Release\rwkv.dll
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||||
|
|
2
ggml
2
ggml
|
@ -1 +1 @@
|
||||||
Subproject commit b237714db49cc09b63a372aeb33ca83bc56b3977
|
Subproject commit 9d7974c3cf1284b4ddb926d94552e9fe4c4ad483
|
1
rwkv.cpp
1
rwkv.cpp
|
@ -568,7 +568,6 @@ bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float
|
||||||
|
|
||||||
RWKV_ASSERT_FALSE(token >= 0 && token < n_vocab, "Token is out of range 0..%d", n_vocab - 1);
|
RWKV_ASSERT_FALSE(token >= 0 && token < n_vocab, "Token is out of range 0..%d", n_vocab - 1);
|
||||||
|
|
||||||
ggml_set_i32(ctx->token_index, 0);
|
|
||||||
ggml_set_i32_1d(ctx->token_index, 0, token);
|
ggml_set_i32_1d(ctx->token_index, 0, token);
|
||||||
|
|
||||||
if (state_in == NULL) {
|
if (state_in == NULL) {
|
||||||
|
|
|
@ -12,22 +12,15 @@ import tokenizers
|
||||||
import rwkv_cpp_model
|
import rwkv_cpp_model
|
||||||
import rwkv_cpp_shared_library
|
import rwkv_cpp_shared_library
|
||||||
import json
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
# ======================================== Script settings ========================================
|
# ======================================== Script settings ========================================
|
||||||
|
|
||||||
# English, Chinese, Japanese
|
# English, Chinese, Japanese
|
||||||
LANGUAGE: str = 'English'
|
LANGUAGE: str = 'English'
|
||||||
# QA: Question and Answer prompt
|
# QA: Question and Answer prompt to talk to an AI assistant.
|
||||||
# Chat: chat prompt (you need a large model for adequate quality, 7B+)
|
# Chat: chat prompt (need a large model for adequate quality, 7B+).
|
||||||
PROMPT_TYPE: str = "Chat"
|
PROMPT_TYPE: str = 'QA'
|
||||||
|
|
||||||
PROMPT_FILE: str = f'./rwkv/prompt/{LANGUAGE}-{PROMPT_TYPE}.json'
|
|
||||||
|
|
||||||
def load_prompt(PROMPT_FILE: str):
|
|
||||||
with open(PROMPT_FILE, 'r') as json_file:
|
|
||||||
variables = json.load(json_file)
|
|
||||||
user, bot, separator, prompt = variables['user'], variables['bot'], variables['separator'], variables['prompt']
|
|
||||||
return user, bot, separator, prompt
|
|
||||||
|
|
||||||
MAX_GENERATION_LENGTH: int = 250
|
MAX_GENERATION_LENGTH: int = 250
|
||||||
|
|
||||||
|
@ -39,6 +32,7 @@ TOP_P: float = 0.5
|
||||||
PRESENCE_PENALTY: float = 0.2
|
PRESENCE_PENALTY: float = 0.2
|
||||||
# Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
|
# Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
|
||||||
FREQUENCY_PENALTY: float = 0.2
|
FREQUENCY_PENALTY: float = 0.2
|
||||||
|
|
||||||
END_OF_LINE_TOKEN: int = 187
|
END_OF_LINE_TOKEN: int = 187
|
||||||
END_OF_TEXT_TOKEN: int = 0
|
END_OF_TEXT_TOKEN: int = 0
|
||||||
|
|
||||||
|
@ -48,11 +42,17 @@ parser = argparse.ArgumentParser(description='Provide terminal-based chat interf
|
||||||
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
|
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
user, bot, separator, init_prompt = load_prompt(PROMPT_FILE)
|
script_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent
|
||||||
|
|
||||||
|
with open(script_dir / 'prompt' / f'{LANGUAGE}-{PROMPT_TYPE}.json', 'r') as json_file:
|
||||||
|
prompt_data = json.load(json_file)
|
||||||
|
|
||||||
|
user, bot, separator, init_prompt = prompt_data['user'], prompt_data['bot'], prompt_data['separator'], prompt_data['prompt']
|
||||||
|
|
||||||
assert init_prompt != '', 'Prompt must not be empty'
|
assert init_prompt != '', 'Prompt must not be empty'
|
||||||
|
|
||||||
print('Loading 20B tokenizer')
|
print('Loading 20B tokenizer')
|
||||||
tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
|
tokenizer_path = script_dir / '20B_tokenizer.json'
|
||||||
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
|
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
|
||||||
|
|
||||||
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||||
|
@ -64,48 +64,48 @@ model = rwkv_cpp_model.RWKVModel(library, args.model_path)
|
||||||
prompt_tokens = tokenizer.encode(init_prompt).ids
|
prompt_tokens = tokenizer.encode(init_prompt).ids
|
||||||
prompt_token_count = len(prompt_tokens)
|
prompt_token_count = len(prompt_tokens)
|
||||||
|
|
||||||
########################################################################################################
|
# =================================================================================================
|
||||||
|
|
||||||
model_tokens: list[int] = []
|
processed_tokens: list[int] = []
|
||||||
|
logits: Optional[torch.Tensor] = None
|
||||||
|
state: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
logits, model_state = None, None
|
def process_tokens(_tokens: list[int], new_line_logit_bias: float = 0.0) -> None:
|
||||||
|
global processed_tokens, logits, state
|
||||||
|
|
||||||
def process_tokens(_tokens: list[int], newline_adj: int = 0) -> torch.Tensor:
|
processed_tokens += _tokens
|
||||||
global model_tokens, model_state, logits
|
|
||||||
|
|
||||||
_tokens = [int(x) for x in _tokens]
|
|
||||||
|
|
||||||
model_tokens += _tokens
|
|
||||||
|
|
||||||
for _token in _tokens:
|
for _token in _tokens:
|
||||||
logits, model_state = model.eval(_token, model_state, model_state, logits)
|
logits, state = model.eval(_token, state, state, logits)
|
||||||
|
|
||||||
logits[END_OF_LINE_TOKEN] += newline_adj # adjust \n probability
|
logits[END_OF_LINE_TOKEN] += new_line_logit_bias
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
state_by_thread: dict[str, dict] = {}
|
state_by_thread: dict[str, dict] = {}
|
||||||
|
|
||||||
def save_thread_state(_thread: str, _logits: torch.Tensor) -> None:
|
def save_thread_state(_thread: str) -> None:
|
||||||
state_by_thread[_thread] = {}
|
state_by_thread[_thread] = {
|
||||||
state_by_thread[_thread]['logits'] = copy.deepcopy(_logits)
|
'tokens': copy.deepcopy(processed_tokens),
|
||||||
state_by_thread[_thread]['rnn'] = copy.deepcopy(model_state)
|
'logits': copy.deepcopy(logits),
|
||||||
state_by_thread[_thread]['token'] = copy.deepcopy(model_tokens)
|
'state': copy.deepcopy(state)
|
||||||
|
}
|
||||||
|
|
||||||
def load_thread_state(_thread: str) -> torch.Tensor:
|
def load_thread_state(_thread: str) -> None:
|
||||||
global model_tokens, model_state
|
global processed_tokens, logits, state
|
||||||
model_state = copy.deepcopy(state_by_thread[_thread]['rnn'])
|
|
||||||
model_tokens = copy.deepcopy(state_by_thread[_thread]['token'])
|
|
||||||
return copy.deepcopy(state_by_thread[_thread]['logits'])
|
|
||||||
|
|
||||||
########################################################################################################
|
thread_state = state_by_thread[_thread]
|
||||||
|
|
||||||
|
processed_tokens = copy.deepcopy(thread_state['tokens'])
|
||||||
|
logits = copy.deepcopy(thread_state['logits'])
|
||||||
|
state = copy.deepcopy(thread_state['state'])
|
||||||
|
|
||||||
|
# =================================================================================================
|
||||||
|
|
||||||
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
|
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
|
||||||
|
|
||||||
logits = process_tokens(tokenizer.encode(init_prompt).ids)
|
process_tokens(tokenizer.encode(init_prompt).ids)
|
||||||
|
|
||||||
save_thread_state('chat_init', logits)
|
save_thread_state('chat_init')
|
||||||
save_thread_state('chat', logits)
|
save_thread_state('chat')
|
||||||
|
|
||||||
print(f'\nChat initialized! Your name is {user}. Write something and press Enter. Use \\n to add line breaks to your message.')
|
print(f'\nChat initialized! Your name is {user}. Write something and press Enter. Use \\n to add line breaks to your message.')
|
||||||
|
|
||||||
|
@ -117,7 +117,7 @@ while True:
|
||||||
temperature = TEMPERATURE
|
temperature = TEMPERATURE
|
||||||
top_p = TOP_P
|
top_p = TOP_P
|
||||||
|
|
||||||
if "-temp=" in msg:
|
if '-temp=' in msg:
|
||||||
temperature = float(msg.split('-temp=')[1].split(' ')[0])
|
temperature = float(msg.split('-temp=')[1].split(' ')[0])
|
||||||
|
|
||||||
msg = msg.replace('-temp='+f'{temperature:g}', '')
|
msg = msg.replace('-temp='+f'{temperature:g}', '')
|
||||||
|
@ -128,7 +128,7 @@ while True:
|
||||||
if temperature >= 5:
|
if temperature >= 5:
|
||||||
temperature = 5
|
temperature = 5
|
||||||
|
|
||||||
if "-top_p=" in msg:
|
if '-top_p=' in msg:
|
||||||
top_p = float(msg.split('-top_p=')[1].split(' ')[0])
|
top_p = float(msg.split('-top_p=')[1].split(' ')[0])
|
||||||
|
|
||||||
msg = msg.replace('-top_p='+f'{top_p:g}', '')
|
msg = msg.replace('-top_p='+f'{top_p:g}', '')
|
||||||
|
@ -140,8 +140,8 @@ while True:
|
||||||
|
|
||||||
# + reset --> reset chat
|
# + reset --> reset chat
|
||||||
if msg == '+reset':
|
if msg == '+reset':
|
||||||
logits = load_thread_state('chat_init')
|
load_thread_state('chat_init')
|
||||||
save_thread_state('chat', logits)
|
save_thread_state('chat')
|
||||||
print(f'{bot}{separator} Chat reset.\n')
|
print(f'{bot}{separator} Chat reset.\n')
|
||||||
continue
|
continue
|
||||||
elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++':
|
elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++':
|
||||||
|
@ -149,11 +149,10 @@ while True:
|
||||||
# +gen YOUR PROMPT --> free single-round generation with any prompt. Requires Novel model.
|
# +gen YOUR PROMPT --> free single-round generation with any prompt. Requires Novel model.
|
||||||
if msg[:5].lower() == '+gen ':
|
if msg[:5].lower() == '+gen ':
|
||||||
new = '\n' + msg[5:].strip()
|
new = '\n' + msg[5:].strip()
|
||||||
# print(f'### prompt ###\n[{new}]')
|
state = None
|
||||||
model_state = None
|
processed_tokens = []
|
||||||
model_tokens = []
|
process_tokens(tokenizer.encode(new).ids)
|
||||||
logits = process_tokens(tokenizer.encode(new).ids)
|
save_thread_state('gen_0')
|
||||||
save_thread_state('gen_0', logits)
|
|
||||||
|
|
||||||
# +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model.
|
# +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model.
|
||||||
elif msg[:3].lower() == '+i ':
|
elif msg[:3].lower() == '+i ':
|
||||||
|
@ -165,37 +164,34 @@ Below is an instruction that describes a task. Write a response that appropriate
|
||||||
|
|
||||||
# Response:
|
# Response:
|
||||||
'''
|
'''
|
||||||
# print(f'### prompt ###\n[{new}]')
|
state = None
|
||||||
model_state = None
|
processed_tokens = []
|
||||||
model_tokens = []
|
process_tokens(tokenizer.encode(new).ids)
|
||||||
logits = process_tokens(tokenizer.encode(new).ids)
|
save_thread_state('gen_0')
|
||||||
save_thread_state('gen_0', logits)
|
|
||||||
|
|
||||||
# +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context).
|
# +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context).
|
||||||
elif msg[:4].lower() == '+qq ':
|
elif msg[:4].lower() == '+qq ':
|
||||||
new = '\nQ: ' + msg[4:].strip() + '\nA:'
|
new = '\nQ: ' + msg[4:].strip() + '\nA:'
|
||||||
# print(f'### prompt ###\n[{new}]')
|
state = None
|
||||||
model_state = None
|
processed_tokens = []
|
||||||
model_tokens = []
|
process_tokens(tokenizer.encode(new).ids)
|
||||||
logits = process_tokens(tokenizer.encode(new).ids)
|
save_thread_state('gen_0')
|
||||||
save_thread_state('gen_0', logits)
|
|
||||||
|
|
||||||
# +qa YOUR QUESTION --> answer an independent question (regardless of context).
|
# +qa YOUR QUESTION --> answer an independent question (regardless of context).
|
||||||
elif msg[:4].lower() == '+qa ':
|
elif msg[:4].lower() == '+qa ':
|
||||||
logits = load_thread_state('chat_init')
|
load_thread_state('chat_init')
|
||||||
|
|
||||||
real_msg = msg[4:].strip()
|
real_msg = msg[4:].strip()
|
||||||
new = f"{user}{separator} {real_msg}\n\n{bot}{separator}"
|
new = f'{user}{separator} {real_msg}\n\n{bot}{separator}'
|
||||||
# print(f'### qa ###\n[{new}]')
|
|
||||||
|
|
||||||
logits = process_tokens(tokenizer.encode(new).ids)
|
process_tokens(tokenizer.encode(new).ids)
|
||||||
save_thread_state('gen_0', logits)
|
save_thread_state('gen_0')
|
||||||
|
|
||||||
# +++ --> continue last free generation (only for +gen / +i)
|
# +++ --> continue last free generation (only for +gen / +i)
|
||||||
elif msg.lower() == '+++':
|
elif msg.lower() == '+++':
|
||||||
try:
|
try:
|
||||||
logits = load_thread_state('gen_1')
|
load_thread_state('gen_1')
|
||||||
save_thread_state('gen_0', logits)
|
save_thread_state('gen_0')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
continue
|
continue
|
||||||
|
@ -203,49 +199,52 @@ Below is an instruction that describes a task. Write a response that appropriate
|
||||||
# ++ --> retry last free generation (only for +gen / +i)
|
# ++ --> retry last free generation (only for +gen / +i)
|
||||||
elif msg.lower() == '++':
|
elif msg.lower() == '++':
|
||||||
try:
|
try:
|
||||||
logits = load_thread_state('gen_0')
|
load_thread_state('gen_0')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
continue
|
continue
|
||||||
thread = "gen_1"
|
thread = 'gen_1'
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# + --> alternate chat reply
|
# + --> alternate chat reply
|
||||||
if msg.lower() == '+':
|
if msg.lower() == '+':
|
||||||
try:
|
try:
|
||||||
logits = load_thread_state('chat_pre')
|
load_thread_state('chat_pre')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
continue
|
continue
|
||||||
# chat with bot
|
# chat with bot
|
||||||
else:
|
else:
|
||||||
logits = load_thread_state('chat')
|
load_thread_state('chat')
|
||||||
new = f"{user}{separator} {msg}\n\n{bot}{separator}"
|
new = f'{user}{separator} {msg}\n\n{bot}{separator}'
|
||||||
# print(f'### add ###\n[{new}]')
|
process_tokens(tokenizer.encode(new).ids, new_line_logit_bias=-999999999)
|
||||||
logits = process_tokens(tokenizer.encode(new).ids, newline_adj=-999999999)
|
save_thread_state('chat_pre')
|
||||||
save_thread_state('chat_pre', logits)
|
|
||||||
|
|
||||||
thread = 'chat'
|
thread = 'chat'
|
||||||
|
|
||||||
# Print bot response
|
# Print bot response
|
||||||
print(f"> {bot}{separator}", end='')
|
print(f'> {bot}{separator}', end='')
|
||||||
|
|
||||||
start_index: int = len(model_tokens)
|
start_index: int = len(processed_tokens)
|
||||||
accumulated_tokens: list[int] = []
|
accumulated_tokens: list[int] = []
|
||||||
occurrence: dict[int, int] = {}
|
token_counts: dict[int, int] = {}
|
||||||
|
|
||||||
for i in range(MAX_GENERATION_LENGTH):
|
for i in range(MAX_GENERATION_LENGTH):
|
||||||
for n in occurrence:
|
for n in token_counts:
|
||||||
logits[n] -= (PRESENCE_PENALTY + occurrence[n] * FREQUENCY_PENALTY)
|
logits[n] -= PRESENCE_PENALTY + token_counts[n] * FREQUENCY_PENALTY
|
||||||
|
|
||||||
token: int = sampling.sample_logits(logits, temperature, top_p)
|
token: int = sampling.sample_logits(logits, temperature, top_p)
|
||||||
|
|
||||||
if token == END_OF_TEXT_TOKEN:
|
if token == END_OF_TEXT_TOKEN:
|
||||||
print()
|
print()
|
||||||
break
|
break
|
||||||
if token not in occurrence:
|
|
||||||
occurrence[token] = 1
|
if token not in token_counts:
|
||||||
|
token_counts[token] = 1
|
||||||
else:
|
else:
|
||||||
occurrence[token] += 1
|
token_counts[token] += 1
|
||||||
logits: torch.Tensor = process_tokens([token])
|
|
||||||
|
process_tokens([token])
|
||||||
|
|
||||||
# Avoid UTF-8 display issues
|
# Avoid UTF-8 display issues
|
||||||
accumulated_tokens += [token]
|
accumulated_tokens += [token]
|
||||||
|
@ -258,10 +257,10 @@ Below is an instruction that describes a task. Write a response that appropriate
|
||||||
accumulated_tokens = []
|
accumulated_tokens = []
|
||||||
|
|
||||||
if thread == 'chat':
|
if thread == 'chat':
|
||||||
if '\n\n' in tokenizer.decode(model_tokens[start_index:]):
|
if '\n\n' in tokenizer.decode(processed_tokens[start_index:]):
|
||||||
break
|
break
|
||||||
|
|
||||||
if i == MAX_GENERATION_LENGTH - 1:
|
if i == MAX_GENERATION_LENGTH - 1:
|
||||||
print()
|
print()
|
||||||
|
|
||||||
save_thread_state(thread, logits)
|
save_thread_state(thread)
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
# Get model checkpoints from https://huggingface.co/BlinkDL
|
# Get model checkpoints from https://huggingface.co/BlinkDL
|
||||||
# See FILE_FORMAT.md for the documentation on the file format.
|
# See FILE_FORMAT.md for the documentation on the file format.
|
||||||
|
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
import struct
|
import struct
|
||||||
import torch
|
import torch
|
||||||
|
@ -97,53 +96,5 @@ def main() -> None:
|
||||||
|
|
||||||
print('Done')
|
print('Done')
|
||||||
|
|
||||||
# --- Tests ---
|
|
||||||
|
|
||||||
def test() -> None:
|
|
||||||
test_file_path = 'convert_pytorch_rwkv_to_ggml_test.tmp'
|
|
||||||
|
|
||||||
try:
|
|
||||||
state_dict: Dict[str, torch.Tensor] = {
|
|
||||||
'emb.weight': torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32),
|
|
||||||
'blocks.0.ln1.weight': torch.tensor([1], dtype=torch.float32)
|
|
||||||
}
|
|
||||||
|
|
||||||
write_state_dict(state_dict, dest_path=test_file_path, data_type='float32')
|
|
||||||
|
|
||||||
with open(test_file_path, 'rb') as input:
|
|
||||||
actual_bytes: bytes = input.read()
|
|
||||||
|
|
||||||
expected_bytes: bytes = struct.pack(
|
|
||||||
'=iiiiii' + 'iiiii10sffffff' + 'iiii19sf',
|
|
||||||
0x67676d66,
|
|
||||||
100,
|
|
||||||
3,
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
# emb.weight
|
|
||||||
2,
|
|
||||||
10,
|
|
||||||
0,
|
|
||||||
2, 3,
|
|
||||||
'emb.weight'.encode('utf-8'),
|
|
||||||
1.0, 2.0, 3.0,
|
|
||||||
4.0, 5.0, 6.0,
|
|
||||||
# blocks.0.ln1.weight
|
|
||||||
1,
|
|
||||||
19,
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
'blocks.0.ln1.weight'.encode('utf-8'),
|
|
||||||
1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
assert list(actual_bytes) == list(expected_bytes), f'\nActual: {list(actual_bytes)}\nExpected: {list(expected_bytes)}'
|
|
||||||
|
|
||||||
print('All tests pass')
|
|
||||||
finally:
|
|
||||||
if os.path.isfile(test_file_path):
|
|
||||||
os.remove(test_file_path)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
import os
|
||||||
|
import struct
|
||||||
|
import torch
|
||||||
|
import convert_pytorch_to_ggml
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
def test() -> None:
|
||||||
|
test_file_path = 'convert_pytorch_rwkv_to_ggml_test.tmp'
|
||||||
|
|
||||||
|
try:
|
||||||
|
state_dict: Dict[str, torch.Tensor] = {
|
||||||
|
'emb.weight': torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32),
|
||||||
|
'blocks.0.ln1.weight': torch.tensor([1], dtype=torch.float32)
|
||||||
|
}
|
||||||
|
|
||||||
|
convert_pytorch_to_ggml.write_state_dict(state_dict, dest_path=test_file_path, data_type='float32')
|
||||||
|
|
||||||
|
with open(test_file_path, 'rb') as input:
|
||||||
|
actual_bytes: bytes = input.read()
|
||||||
|
|
||||||
|
expected_bytes: bytes = struct.pack(
|
||||||
|
'=iiiiii' + 'iiiii10sffffff' + 'iiii19sf',
|
||||||
|
0x67676d66,
|
||||||
|
100,
|
||||||
|
3,
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
# emb.weight
|
||||||
|
2,
|
||||||
|
10,
|
||||||
|
0,
|
||||||
|
2, 3,
|
||||||
|
'emb.weight'.encode('utf-8'),
|
||||||
|
1.0, 2.0, 3.0,
|
||||||
|
4.0, 5.0, 6.0,
|
||||||
|
# blocks.0.ln1.weight
|
||||||
|
1,
|
||||||
|
19,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
'blocks.0.ln1.weight'.encode('utf-8'),
|
||||||
|
1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
assert list(actual_bytes) == list(expected_bytes), f'\nActual: {list(actual_bytes)}\nExpected: {list(expected_bytes)}'
|
||||||
|
|
||||||
|
print('All tests pass')
|
||||||
|
finally:
|
||||||
|
if os.path.isfile(test_file_path):
|
||||||
|
os.remove(test_file_path)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test()
|
Loading…
Reference in New Issue