Compare commits

..

5 Commits

Author SHA1 Message Date
Mathmagician8191 b88ae59604
Fix bug in world tokenizer (#93) 2023-06-11 11:46:54 +05:00
Mathmagician8191 82c4ac78f4
Add support for the world tokenizer (#86)
* Add support for the world tokenizer

* Move tokenizer logic to rwkv_tokenizer.py

* Added test for the tokenizer
2023-06-08 16:37:18 +05:00
LoganDark 09ec3145b3
Fix visual bug in quantization (#92)
It didn't calculate the compression ratio properly because of a
copy/paste error :(
2023-06-07 16:45:21 +05:00
LoganDark 5b41cd7e5d
Add capability for extra binaries to be built with rwkv.cpp (#87)
* Add capability for examples

This also adds a quantizer that works without python.
in the future, we might be able to convert from pytorch as well,
without python.

* example implied code style

* rename examples to tools

* rename cpuinfo.c to cpu_info.c

* include ggml header again

* Return EXIT_FAILURE on help

* done with this

* final name: extras

* going To have a seizure

* wait literal double n
2023-06-03 15:44:50 +05:00
LoganDark fb6708b555
Fix pytorch storage warnings, fixes #80 (#88)
we seriously don't care what type of storage we get, pytorch sucks
2023-06-03 15:09:51 +05:00
12 changed files with 66475 additions and 34 deletions

View File

@ -286,3 +286,4 @@ endif()
enable_testing()
add_subdirectory(tests)
add_subdirectory(extras)

10
extras/CMakeLists.txt Normal file
View File

@ -0,0 +1,10 @@
function(rwkv_add_extra source)
get_filename_component(EXTRA_TARGET ${source} NAME_WE)
add_executable(rwkv_${EXTRA_TARGET} ${source})
target_link_libraries(rwkv_${EXTRA_TARGET} PRIVATE ggml rwkv)
endfunction()
file(GLOB extras *.c)
foreach (extra ${extras})
rwkv_add_extra(${extra})
endforeach()

7
extras/cpu_info.c Normal file
View File

@ -0,0 +1,7 @@
#include "rwkv.h"
#include <stdio.h>
int main() {
printf("%s", rwkv_get_system_info_string());
}

58
extras/quantize.c Normal file
View File

@ -0,0 +1,58 @@
#include "ggml.h"
#include "rwkv.h"
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
enum ggml_type type_from_string(const char* string) {
if (strcmp(string, "Q4_0") == 0) return GGML_TYPE_Q4_0;
if (strcmp(string, "Q4_1") == 0) return GGML_TYPE_Q4_1;
if (strcmp(string, "Q5_0") == 0) return GGML_TYPE_Q5_0;
if (strcmp(string, "Q5_1") == 0) return GGML_TYPE_Q5_1;
if (strcmp(string, "Q8_0") == 0) return GGML_TYPE_Q8_0;
return GGML_TYPE_COUNT;
}
#ifdef _WIN32
bool QueryPerformanceFrequency(uint64_t* lpFrequency);
bool QueryPerformanceCounter(uint64_t* lpPerformanceCount);
#define time_t uint64_t
#define time_calibrate(freq) do { QueryPerformanceFrequency(&freq); freq /= 1000; } while (0)
#define time_measure(x) QueryPerformanceCounter(&x)
#define TIME_DIFF(freq, start, end) (double) ((end - start) / freq) / 1000.
#else
#include <time.h>
#define time_t struct timespec
#define time_calibrate(freq) (void) freq
#define time_measure(x) clock_gettime(CLOCK_MONOTONIC, &x)
#define TIME_DIFF(freq, start, end) (double) ((end.tv_nsec - start.tv_nsec) / 1000000) / 1000
#endif
int main(int argc, char* argv[]) {
if (argc != 4 || type_from_string(argv[3]) == GGML_TYPE_COUNT) {
fprintf(stderr, "Usage: %s INPUT OUTPUT FORMAT\n\nAvailable formats: Q4_0 Q4_1 Q5_0 Q5_1 Q8_0\n", argv[0]);
return EXIT_FAILURE;
}
time_t freq, start, end;
time_calibrate(freq);
fprintf(stderr, "Quantizing ...\n");
time_measure(start);
bool success = rwkv_quantize_model_file(argv[1], argv[2], argv[3]);
time_measure(end);
double diff = TIME_DIFF(freq, start, end);
if (success) {
fprintf(stderr, "Succeeded in %.3fs\n", diff);
return EXIT_SUCCESS;
} else {
fprintf(stderr, "Error in %.3fs: 0x%.8X\n", diff, rwkv_get_last_error(NULL));
return EXIT_FAILURE;
}
}

View File

@ -1264,7 +1264,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_tensor(out_file, tensor), "Failed to write tensor %s", name_str);
orig_total_size += orig_size;
new_total_size += orig_size;
new_total_size += new_size;
}
RWKV_MSG("original size = %8.2f MB\n", orig_total_size / 1024.0 / 1024.0);

View File

@ -8,9 +8,9 @@ import pathlib
import copy
import torch
import sampling
import tokenizers
import rwkv_cpp_model
import rwkv_cpp_shared_library
from rwkv_tokenizer import get_tokenizer
import json
from typing import List, Dict, Optional
import time
@ -42,6 +42,7 @@ END_OF_TEXT_TOKEN: int = 0
parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
parser.add_argument('tokenizer', help='Which tokenizer to use', nargs='?', type=str, default="20B")
args = parser.parse_args()
script_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent
@ -53,9 +54,7 @@ with open(script_dir / 'prompt' / f'{LANGUAGE}-{PROMPT_TYPE}.json', 'r', encodin
assert init_prompt != '', 'Prompt must not be empty'
print('Loading 20B tokenizer')
tokenizer_path = script_dir / '20B_tokenizer.json'
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
tokenizer, tokenizer_encode = get_tokenizer(args.tokenizer)
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')
@ -63,9 +62,6 @@ print(f'System info: {library.rwkv_get_system_info_string()}')
print('Loading RWKV model')
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
prompt_tokens = tokenizer.encode(init_prompt).ids
prompt_token_count = len(prompt_tokens)
# =================================================================================================
processed_tokens: List[int] = []
@ -110,9 +106,11 @@ def split_last_end_of_line(tokens):
# =================================================================================================
T1 = time.time()
prompt_tokens = tokenizer_encode(init_prompt)
prompt_token_count = len(prompt_tokens)
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
process_tokens(split_last_end_of_line(tokenizer.encode(init_prompt).ids))
process_tokens(split_last_end_of_line(prompt_tokens))
T2 = time.time()
print(f'Process time :{((T2 - T1)*1000)} ms')
print(f'Process time per token :{(((T2 - T1)*1000)) / prompt_token_count} ms')
@ -164,7 +162,7 @@ while True:
new = '\n' + msg[5:].strip()
state = None
processed_tokens = []
process_tokens(tokenizer.encode(new).ids)
process_tokens(tokenizer_encode(new))
save_thread_state('gen_0')
# +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model.
@ -179,7 +177,7 @@ Below is an instruction that describes a task. Write a response that appropriate
'''
state = None
processed_tokens = []
process_tokens(tokenizer.encode(new).ids)
process_tokens(tokenizer_encode(new))
save_thread_state('gen_0')
# +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context).
@ -187,7 +185,7 @@ Below is an instruction that describes a task. Write a response that appropriate
new = '\nQ: ' + msg[4:].strip() + '\nA:'
state = None
processed_tokens = []
process_tokens(tokenizer.encode(new).ids)
process_tokens(tokenizer_encode(new))
save_thread_state('gen_0')
# +qa YOUR QUESTION --> answer an independent question (regardless of context).
@ -197,7 +195,7 @@ Below is an instruction that describes a task. Write a response that appropriate
real_msg = msg[4:].strip()
new = f'{user}{separator} {real_msg}\n\n{bot}{separator}'
process_tokens(tokenizer.encode(new).ids)
process_tokens(tokenizer_encode(new))
save_thread_state('gen_0')
# +++ --> continue last free generation (only for +gen / +i)
@ -230,7 +228,7 @@ Below is an instruction that describes a task. Write a response that appropriate
else:
load_thread_state('chat')
new = f'{user}{separator} {msg}\n\n{bot}{separator}'
process_tokens(tokenizer.encode(new).ids, new_line_logit_bias=-999999999)
process_tokens(tokenizer_encode(new), new_line_logit_bias=-999999999)
save_thread_state('chat_pre')
thread = 'chat'

View File

@ -2,13 +2,12 @@
import argparse
import os
import pathlib
import time
import sampling
import tokenizers
import rwkv_cpp_model
import rwkv_cpp_shared_library
from rwkv_tokenizer import get_tokenizer
from typing import List
# ======================================== Script settings ========================================
@ -31,13 +30,14 @@ top_p: float = 0.5
parser = argparse.ArgumentParser(description='Generate completions from RWKV model based on a prompt')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
parser.add_argument('tokenizer', help='Which tokenizer to use', nargs='?', type=str, default="20B")
args = parser.parse_args()
assert prompt != '', 'Prompt must not be empty'
print('Loading 20B tokenizer')
tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
tokenizer, tokenizer_encode = get_tokenizer(args.tokenizer)
prompt_tokens = tokenizer_encode(prompt)
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')
@ -45,7 +45,6 @@ print(f'System info: {library.rwkv_get_system_info_string()}')
print('Loading RWKV model')
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
prompt_tokens = tokenizer.encode(prompt).ids
prompt_token_count = len(prompt_tokens)
print(f'{prompt_token_count} tokens in prompt')

View File

@ -4,13 +4,11 @@
import os
import time
import pathlib
import argparse
import tokenizers
import torch
import rwkv_cpp_model
import rwkv_cpp_shared_library
from typing import List
from rwkv_tokenizer import get_tokenizer
def parse_args():
parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file')
@ -18,19 +16,18 @@ def parse_args():
parser.add_argument('text_path', help='Path to text file in UTF-8 encoding', type=str)
parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int)
parser.add_argument('token_limit', help='How many tokens to process; set to -1 to process all text', nargs='?', type=int, default=-1)
parser.add_argument('tokenizer', help='Which tokenizer to use', nargs='?', type=str, default="20B")
return parser.parse_args()
args = parse_args()
# ---
print('Loading 20B tokenizer')
tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
print('Loading text')
text: str = open(args.text_path, encoding='utf-8').read()
tokens: List[int] = tokenizer.encode(text).ids
tokenizer, tokenizer_encode = get_tokenizer(args.tokenizer)
tokens = tokenizer_encode(text)
token_count: int = len(tokens)
print(f'{token_count} tokens in the text')

View File

@ -84,7 +84,7 @@ class RWKVModel:
if state_in is not None:
validate_buffer(state_in, 'state_in', self._state_buffer_element_count)
state_in_ptr = state_in.untyped_storage().data_ptr()
state_in_ptr = state_in.data_ptr()
else:
state_in_ptr = 0
@ -102,8 +102,8 @@ class RWKVModel:
self._ctx,
token,
state_in_ptr,
state_out.untyped_storage().data_ptr(),
logits_out.untyped_storage().data_ptr()
state_out.data_ptr(),
logits_out.data_ptr()
)
return logits_out, state_out

120
rwkv/rwkv_tokenizer.py Normal file
View File

@ -0,0 +1,120 @@
import os
import tokenizers
import pathlib
########################################################################################################
# Taken from https://github.com/BlinkDL/ChatRWKV/tree/main/tokenizer/rwkv_tokenizer.py
########################################################################################################
class TRIE:
__slots__ = tuple("ch,to,values,front".split(","))
to:list
values:set
def __init__(self, front=None, ch=None):
self.ch = ch
self.to = [None for ch in range(256)]
self.values = set()
self.front = front
def __repr__(self):
fr = self
ret = []
while(fr!=None):
if(fr.ch!=None):
ret.append(fr.ch)
fr = fr.front
return "<TRIE %s %s>"%(ret[::-1], self.values)
def add(self, key:bytes, idx:int=0, val=None):
if(idx == len(key)):
if(val is None):
val = key
self.values.add(val)
return self
ch = key[idx]
if(self.to[ch] is None):
self.to[ch] = TRIE(front=self, ch=ch)
return self.to[ch].add(key, idx=idx+1, val=val)
def find_longest(self, key:bytes, idx:int=0):
u:TRIE = self
ch:int = key[idx]
while(u.to[ch] is not None):
u = u.to[ch]
idx += 1
if(u.values):
ret = idx, u, u.values
if(idx==len(key)):
break
ch = key[idx]
return ret
class TRIE_TOKENIZER():
def __init__(self, file_name):
self.idx2token = {}
sorted = [] # must be already sorted
with open(file_name, "r", encoding="utf-8") as f:
lines = f.readlines()
for l in lines:
idx = int(l[:l.index(' ')])
x = eval(l[l.index(' '):l.rindex(' ')])
x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes)
assert len(x) == int(l[l.rindex(' '):])
sorted += [x]
self.idx2token[idx] = x
self.token2idx = {}
for k,v in self.idx2token.items():
self.token2idx[v] = int(k)
self.root = TRIE()
for t, i in self.token2idx.items():
_ = self.root.add(t, val=(t, i))
def encodeBytes(self, src:bytes) -> list[int]:
idx:int = 0
tokens:list[int] = []
while (idx < len(src)):
_idx:int = idx
idx, _, values = self.root.find_longest(src, idx)
assert(idx != _idx)
_, token = next(iter(values))
tokens.append(token)
return tokens
def decodeBytes(self, tokens):
return b''.join(map(lambda i: self.idx2token[i], tokens))
def encode(self, src):
return self.encodeBytes(src.encode("utf-8"))
def decode(self, tokens):
return self.decodeBytes(tokens).decode('utf-8')
def printTokens(self, tokens):
for i in tokens:
s = self.idx2token[i]
try:
s = s.decode('utf-8')
except:
pass
print(f'{repr(s)}{i}', end=' ')
print()
def get_tokenizer(tokenizer="20B"):
if tokenizer == "world":
print('Loading world tokenizer')
tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / 'rwkv_vocab_v20230424.txt'
tokenizer = TRIE_TOKENIZER(tokenizer_path)
tokenizer_encode = lambda prompt: tokenizer.encode(prompt)
elif tokenizer == "20B":
print('Loading 20B tokenizer')
tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
tokenizer_encode = lambda prompt: tokenizer.encode(prompt).ids
else:
print(f"Unknown tokenizer: {args.tokenizer}")
quit()
return tokenizer, tokenizer_encode

722
rwkv/rwkv_tokenizer_test.py Normal file

File diff suppressed because one or more lines are too long

65529
rwkv/rwkv_vocab_v20230424.txt Normal file

File diff suppressed because it is too large Load Diff