Add support for the world tokenizer (#86)

* Add support for the world tokenizer

* Move tokenizer logic to rwkv_tokenizer.py

* Added test for the tokenizer
This commit is contained in:
Mathmagician8191 2023-06-08 23:37:18 +12:00 committed by GitHub
parent 09ec3145b3
commit 82c4ac78f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 66394 additions and 30 deletions

View File

@ -8,9 +8,9 @@ import pathlib
import copy
import torch
import sampling
import tokenizers
import rwkv_cpp_model
import rwkv_cpp_shared_library
from rwkv_tokenizer import get_tokenizer
import json
from typing import List, Dict, Optional
import time
@ -42,6 +42,7 @@ END_OF_TEXT_TOKEN: int = 0
parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
parser.add_argument('tokenizer', help='Which tokenizer to use', nargs='?', type=str, default="20B")
args = parser.parse_args()
script_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent
@ -53,9 +54,7 @@ with open(script_dir / 'prompt' / f'{LANGUAGE}-{PROMPT_TYPE}.json', 'r', encodin
assert init_prompt != '', 'Prompt must not be empty'
print('Loading 20B tokenizer')
tokenizer_path = script_dir / '20B_tokenizer.json'
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
tokenizer, tokenizer_encode = get_tokenizer(args.tokenizer)
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')
@ -63,9 +62,6 @@ print(f'System info: {library.rwkv_get_system_info_string()}')
print('Loading RWKV model')
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
prompt_tokens = tokenizer.encode(init_prompt).ids
prompt_token_count = len(prompt_tokens)
# =================================================================================================
processed_tokens: List[int] = []
@ -110,9 +106,11 @@ def split_last_end_of_line(tokens):
# =================================================================================================
T1 = time.time()
prompt_tokens = tokenizer_encode(init_prompt)
prompt_token_count = len(prompt_tokens)
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
process_tokens(split_last_end_of_line(tokenizer.encode(init_prompt).ids))
process_tokens(split_last_end_of_line(prompt_tokens))
T2 = time.time()
print(f'Process time :{((T2 - T1)*1000)} ms')
print(f'Process time per token :{(((T2 - T1)*1000)) / prompt_token_count} ms')
@ -164,7 +162,7 @@ while True:
new = '\n' + msg[5:].strip()
state = None
processed_tokens = []
process_tokens(tokenizer.encode(new).ids)
process_tokens(tokenizer_encode(new))
save_thread_state('gen_0')
# +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model.
@ -179,7 +177,7 @@ Below is an instruction that describes a task. Write a response that appropriate
'''
state = None
processed_tokens = []
process_tokens(tokenizer.encode(new).ids)
process_tokens(tokenizer_encode(new))
save_thread_state('gen_0')
# +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context).
@ -187,7 +185,7 @@ Below is an instruction that describes a task. Write a response that appropriate
new = '\nQ: ' + msg[4:].strip() + '\nA:'
state = None
processed_tokens = []
process_tokens(tokenizer.encode(new).ids)
process_tokens(tokenizer_encode(new))
save_thread_state('gen_0')
# +qa YOUR QUESTION --> answer an independent question (regardless of context).
@ -197,7 +195,7 @@ Below is an instruction that describes a task. Write a response that appropriate
real_msg = msg[4:].strip()
new = f'{user}{separator} {real_msg}\n\n{bot}{separator}'
process_tokens(tokenizer.encode(new).ids)
process_tokens(tokenizer_encode(new))
save_thread_state('gen_0')
# +++ --> continue last free generation (only for +gen / +i)
@ -230,7 +228,7 @@ Below is an instruction that describes a task. Write a response that appropriate
else:
load_thread_state('chat')
new = f'{user}{separator} {msg}\n\n{bot}{separator}'
process_tokens(tokenizer.encode(new).ids, new_line_logit_bias=-999999999)
process_tokens(tokenizer_encode(new), new_line_logit_bias=-999999999)
save_thread_state('chat_pre')
thread = 'chat'

View File

@ -2,13 +2,12 @@
import argparse
import os
import pathlib
import time
import sampling
import tokenizers
import rwkv_cpp_model
import rwkv_cpp_shared_library
from rwkv_tokenizer import get_tokenizer
from typing import List
# ======================================== Script settings ========================================
@ -31,13 +30,14 @@ top_p: float = 0.5
parser = argparse.ArgumentParser(description='Generate completions from RWKV model based on a prompt')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
parser.add_argument('tokenizer', help='Which tokenizer to use', nargs='?', type=str, default="20B")
args = parser.parse_args()
assert prompt != '', 'Prompt must not be empty'
print('Loading 20B tokenizer')
tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
tokenizer, tokenizer_encode = get_tokenizer(args.tokenizer)
prompt_tokens = tokenizer_encode(prompt)
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')
@ -45,7 +45,6 @@ print(f'System info: {library.rwkv_get_system_info_string()}')
print('Loading RWKV model')
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
prompt_tokens = tokenizer.encode(prompt).ids
prompt_token_count = len(prompt_tokens)
print(f'{prompt_token_count} tokens in prompt')

View File

@ -4,13 +4,11 @@
import os
import time
import pathlib
import argparse
import tokenizers
import torch
import rwkv_cpp_model
import rwkv_cpp_shared_library
from typing import List
from rwkv_tokenizer import get_tokenizer
def parse_args():
parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file')
@ -18,19 +16,18 @@ def parse_args():
parser.add_argument('text_path', help='Path to text file in UTF-8 encoding', type=str)
parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int)
parser.add_argument('token_limit', help='How many tokens to process; set to -1 to process all text', nargs='?', type=int, default=-1)
parser.add_argument('tokenizer', help='Which tokenizer to use', nargs='?', type=str, default="20B")
return parser.parse_args()
args = parse_args()
# ---
print('Loading 20B tokenizer')
tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
print('Loading text')
text: str = open(args.text_path, encoding='utf-8').read()
tokens: List[int] = tokenizer.encode(text).ids
tokenizer, tokenizer_encode = get_tokenizer(args.tokenizer)
tokens = tokenizer_encode(text)
token_count: int = len(tokens)
print(f'{token_count} tokens in the text')

119
rwkv/rwkv_tokenizer.py Normal file
View File

@ -0,0 +1,119 @@
import os
import tokenizers
import pathlib
########################################################################################################
# Taken from https://github.com/BlinkDL/ChatRWKV/tree/main/tokenizer/rwkv_tokenizer.py
########################################################################################################
class TRIE:
__slots__ = tuple("ch,to,values,front".split(","))
to:list
values:set
def __init__(self, front=None, ch=None):
self.ch = ch
self.to = [None for ch in range(256)]
self.values = set()
self.front = front
def __repr__(self):
fr = self
ret = []
while(fr!=None):
if(fr.ch!=None):
ret.append(fr.ch)
fr = fr.front
return "<TRIE %s %s>"%(ret[::-1], self.values)
def add(self, key:bytes, idx:int=0, val=None):
if(idx == len(key)):
if(val is None):
val = key
self.values.add(val)
return self
ch = key[idx]
if(self.to[ch] is None):
self.to[ch] = TRIE(front=self, ch=ch)
return self.to[ch].add(key, idx=idx+1, val=val)
def find_longest(self, key:bytes, idx:int=0):
u:TRIE = self
ch:int = key[idx]
while(u.to[ch] is not None):
u = u.to[ch]
idx += 1
if(u.values):
ret = idx, u, u.values
if(idx==len(key)):
break
ch = key[idx]
return ret
class TRIE_TOKENIZER():
def __init__(self, file_name):
self.idx2token = {}
sorted = [] # must be already sorted
with open(file_name, "r", encoding="utf-8") as f:
lines = f.readlines()
for l in lines:
idx = int(l[:l.index(' ')])
x = eval(l[l.index(' '):l.rindex(' ')])
x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes)
assert len(x) == int(l[l.rindex(' '):])
sorted += [x]
self.idx2token[idx] = x
self.token2idx = {}
for k,v in self.idx2token.items():
self.token2idx[v] = int(k)
self.root = TRIE()
for t, i in self.token2idx.items():
_ = self.root.add(t, val=(t, i))
def encodeBytes(self, src:bytes) -> list[int]:
idx:int = 0
tokens:list[int] = []
while (idx < len(src)):
_idx:int = idx
idx, _, values = self.root.find_longest(src, idx)
assert(idx != _idx)
_, token = next(iter(values))
tokens.append(token)
return tokens
def decodeBytes(self, tokens):
return b''.join(map(lambda i: self.idx2token[i], tokens))
def encode(self, src):
return self.encodeBytes(src.encode("utf-8"))
def decode(self, tokens):
return self.decodeBytes(tokens).decode('utf-8')
def printTokens(self, tokens):
for i in tokens:
s = self.idx2token[i]
try:
s = s.decode('utf-8')
except:
pass
print(f'{repr(s)}{i}', end=' ')
print()
def get_tokenizer(tokenizer="20B"):
if tokenizer == "world":
print('Loading world tokenizer')
tokenizer = TRIE_TOKENIZER('rwkv_vocab_v20230424.txt')
tokenizer_encode = lambda prompt: tokenizer.encode(prompt)
elif tokenizer == "20B":
print('Loading 20B tokenizer')
tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
tokenizer_encode = lambda prompt: tokenizer.encode(prompt).ids
else:
print(f"Unknown tokenizer: {args.tokenizer}")
quit()
return tokenizer, tokenizer_encode

722
rwkv/rwkv_tokenizer_test.py Normal file

File diff suppressed because one or more lines are too long

65529
rwkv/rwkv_vocab_v20230424.txt Normal file

File diff suppressed because it is too large Load Diff