Compare commits
5 Commits
3f8bb2c080
...
b88ae59604
Author | SHA1 | Date |
---|---|---|
|
b88ae59604 | |
|
82c4ac78f4 | |
|
09ec3145b3 | |
|
5b41cd7e5d | |
|
fb6708b555 |
|
@ -286,3 +286,4 @@ endif()
|
|||
|
||||
enable_testing()
|
||||
add_subdirectory(tests)
|
||||
add_subdirectory(extras)
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
function(rwkv_add_extra source)
|
||||
get_filename_component(EXTRA_TARGET ${source} NAME_WE)
|
||||
add_executable(rwkv_${EXTRA_TARGET} ${source})
|
||||
target_link_libraries(rwkv_${EXTRA_TARGET} PRIVATE ggml rwkv)
|
||||
endfunction()
|
||||
|
||||
file(GLOB extras *.c)
|
||||
foreach (extra ${extras})
|
||||
rwkv_add_extra(${extra})
|
||||
endforeach()
|
|
@ -0,0 +1,7 @@
|
|||
#include "rwkv.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
int main() {
|
||||
printf("%s", rwkv_get_system_info_string());
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
#include "ggml.h"
|
||||
#include "rwkv.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
enum ggml_type type_from_string(const char* string) {
|
||||
if (strcmp(string, "Q4_0") == 0) return GGML_TYPE_Q4_0;
|
||||
if (strcmp(string, "Q4_1") == 0) return GGML_TYPE_Q4_1;
|
||||
if (strcmp(string, "Q5_0") == 0) return GGML_TYPE_Q5_0;
|
||||
if (strcmp(string, "Q5_1") == 0) return GGML_TYPE_Q5_1;
|
||||
if (strcmp(string, "Q8_0") == 0) return GGML_TYPE_Q8_0;
|
||||
return GGML_TYPE_COUNT;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
bool QueryPerformanceFrequency(uint64_t* lpFrequency);
|
||||
bool QueryPerformanceCounter(uint64_t* lpPerformanceCount);
|
||||
|
||||
#define time_t uint64_t
|
||||
#define time_calibrate(freq) do { QueryPerformanceFrequency(&freq); freq /= 1000; } while (0)
|
||||
#define time_measure(x) QueryPerformanceCounter(&x)
|
||||
#define TIME_DIFF(freq, start, end) (double) ((end - start) / freq) / 1000.
|
||||
#else
|
||||
#include <time.h>
|
||||
|
||||
#define time_t struct timespec
|
||||
#define time_calibrate(freq) (void) freq
|
||||
#define time_measure(x) clock_gettime(CLOCK_MONOTONIC, &x)
|
||||
#define TIME_DIFF(freq, start, end) (double) ((end.tv_nsec - start.tv_nsec) / 1000000) / 1000
|
||||
#endif
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc != 4 || type_from_string(argv[3]) == GGML_TYPE_COUNT) {
|
||||
fprintf(stderr, "Usage: %s INPUT OUTPUT FORMAT\n\nAvailable formats: Q4_0 Q4_1 Q5_0 Q5_1 Q8_0\n", argv[0]);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
time_t freq, start, end;
|
||||
time_calibrate(freq);
|
||||
|
||||
fprintf(stderr, "Quantizing ...\n");
|
||||
|
||||
time_measure(start);
|
||||
bool success = rwkv_quantize_model_file(argv[1], argv[2], argv[3]);
|
||||
time_measure(end);
|
||||
|
||||
double diff = TIME_DIFF(freq, start, end);
|
||||
|
||||
if (success) {
|
||||
fprintf(stderr, "Succeeded in %.3fs\n", diff);
|
||||
return EXIT_SUCCESS;
|
||||
} else {
|
||||
fprintf(stderr, "Error in %.3fs: 0x%.8X\n", diff, rwkv_get_last_error(NULL));
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
2
rwkv.cpp
2
rwkv.cpp
|
@ -1264,7 +1264,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const
|
|||
|
||||
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_tensor(out_file, tensor), "Failed to write tensor %s", name_str);
|
||||
orig_total_size += orig_size;
|
||||
new_total_size += orig_size;
|
||||
new_total_size += new_size;
|
||||
}
|
||||
|
||||
RWKV_MSG("original size = %8.2f MB\n", orig_total_size / 1024.0 / 1024.0);
|
||||
|
|
|
@ -8,9 +8,9 @@ import pathlib
|
|||
import copy
|
||||
import torch
|
||||
import sampling
|
||||
import tokenizers
|
||||
import rwkv_cpp_model
|
||||
import rwkv_cpp_shared_library
|
||||
from rwkv_tokenizer import get_tokenizer
|
||||
import json
|
||||
from typing import List, Dict, Optional
|
||||
import time
|
||||
|
@ -42,6 +42,7 @@ END_OF_TEXT_TOKEN: int = 0
|
|||
|
||||
parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model')
|
||||
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
|
||||
parser.add_argument('tokenizer', help='Which tokenizer to use', nargs='?', type=str, default="20B")
|
||||
args = parser.parse_args()
|
||||
|
||||
script_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent
|
||||
|
@ -53,9 +54,7 @@ with open(script_dir / 'prompt' / f'{LANGUAGE}-{PROMPT_TYPE}.json', 'r', encodin
|
|||
|
||||
assert init_prompt != '', 'Prompt must not be empty'
|
||||
|
||||
print('Loading 20B tokenizer')
|
||||
tokenizer_path = script_dir / '20B_tokenizer.json'
|
||||
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
|
||||
tokenizer, tokenizer_encode = get_tokenizer(args.tokenizer)
|
||||
|
||||
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||
print(f'System info: {library.rwkv_get_system_info_string()}')
|
||||
|
@ -63,9 +62,6 @@ print(f'System info: {library.rwkv_get_system_info_string()}')
|
|||
print('Loading RWKV model')
|
||||
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
|
||||
|
||||
prompt_tokens = tokenizer.encode(init_prompt).ids
|
||||
prompt_token_count = len(prompt_tokens)
|
||||
|
||||
# =================================================================================================
|
||||
|
||||
processed_tokens: List[int] = []
|
||||
|
@ -110,9 +106,11 @@ def split_last_end_of_line(tokens):
|
|||
|
||||
# =================================================================================================
|
||||
T1 = time.time()
|
||||
prompt_tokens = tokenizer_encode(init_prompt)
|
||||
prompt_token_count = len(prompt_tokens)
|
||||
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
|
||||
|
||||
process_tokens(split_last_end_of_line(tokenizer.encode(init_prompt).ids))
|
||||
process_tokens(split_last_end_of_line(prompt_tokens))
|
||||
T2 = time.time()
|
||||
print(f'Process time :{((T2 - T1)*1000)} ms')
|
||||
print(f'Process time per token :{(((T2 - T1)*1000)) / prompt_token_count} ms')
|
||||
|
@ -164,7 +162,7 @@ while True:
|
|||
new = '\n' + msg[5:].strip()
|
||||
state = None
|
||||
processed_tokens = []
|
||||
process_tokens(tokenizer.encode(new).ids)
|
||||
process_tokens(tokenizer_encode(new))
|
||||
save_thread_state('gen_0')
|
||||
|
||||
# +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model.
|
||||
|
@ -179,7 +177,7 @@ Below is an instruction that describes a task. Write a response that appropriate
|
|||
'''
|
||||
state = None
|
||||
processed_tokens = []
|
||||
process_tokens(tokenizer.encode(new).ids)
|
||||
process_tokens(tokenizer_encode(new))
|
||||
save_thread_state('gen_0')
|
||||
|
||||
# +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context).
|
||||
|
@ -187,7 +185,7 @@ Below is an instruction that describes a task. Write a response that appropriate
|
|||
new = '\nQ: ' + msg[4:].strip() + '\nA:'
|
||||
state = None
|
||||
processed_tokens = []
|
||||
process_tokens(tokenizer.encode(new).ids)
|
||||
process_tokens(tokenizer_encode(new))
|
||||
save_thread_state('gen_0')
|
||||
|
||||
# +qa YOUR QUESTION --> answer an independent question (regardless of context).
|
||||
|
@ -197,7 +195,7 @@ Below is an instruction that describes a task. Write a response that appropriate
|
|||
real_msg = msg[4:].strip()
|
||||
new = f'{user}{separator} {real_msg}\n\n{bot}{separator}'
|
||||
|
||||
process_tokens(tokenizer.encode(new).ids)
|
||||
process_tokens(tokenizer_encode(new))
|
||||
save_thread_state('gen_0')
|
||||
|
||||
# +++ --> continue last free generation (only for +gen / +i)
|
||||
|
@ -230,7 +228,7 @@ Below is an instruction that describes a task. Write a response that appropriate
|
|||
else:
|
||||
load_thread_state('chat')
|
||||
new = f'{user}{separator} {msg}\n\n{bot}{separator}'
|
||||
process_tokens(tokenizer.encode(new).ids, new_line_logit_bias=-999999999)
|
||||
process_tokens(tokenizer_encode(new), new_line_logit_bias=-999999999)
|
||||
save_thread_state('chat_pre')
|
||||
|
||||
thread = 'chat'
|
||||
|
|
|
@ -2,13 +2,12 @@
|
|||
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
import time
|
||||
import sampling
|
||||
import tokenizers
|
||||
import rwkv_cpp_model
|
||||
import rwkv_cpp_shared_library
|
||||
|
||||
from rwkv_tokenizer import get_tokenizer
|
||||
from typing import List
|
||||
|
||||
# ======================================== Script settings ========================================
|
||||
|
||||
|
@ -31,13 +30,14 @@ top_p: float = 0.5
|
|||
|
||||
parser = argparse.ArgumentParser(description='Generate completions from RWKV model based on a prompt')
|
||||
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
|
||||
parser.add_argument('tokenizer', help='Which tokenizer to use', nargs='?', type=str, default="20B")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert prompt != '', 'Prompt must not be empty'
|
||||
|
||||
print('Loading 20B tokenizer')
|
||||
tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
|
||||
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
|
||||
tokenizer, tokenizer_encode = get_tokenizer(args.tokenizer)
|
||||
|
||||
prompt_tokens = tokenizer_encode(prompt)
|
||||
|
||||
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||
print(f'System info: {library.rwkv_get_system_info_string()}')
|
||||
|
@ -45,7 +45,6 @@ print(f'System info: {library.rwkv_get_system_info_string()}')
|
|||
print('Loading RWKV model')
|
||||
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
|
||||
|
||||
prompt_tokens = tokenizer.encode(prompt).ids
|
||||
prompt_token_count = len(prompt_tokens)
|
||||
print(f'{prompt_token_count} tokens in prompt')
|
||||
|
||||
|
|
|
@ -4,13 +4,11 @@
|
|||
|
||||
import os
|
||||
import time
|
||||
import pathlib
|
||||
import argparse
|
||||
import tokenizers
|
||||
import torch
|
||||
import rwkv_cpp_model
|
||||
import rwkv_cpp_shared_library
|
||||
from typing import List
|
||||
from rwkv_tokenizer import get_tokenizer
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file')
|
||||
|
@ -18,19 +16,18 @@ def parse_args():
|
|||
parser.add_argument('text_path', help='Path to text file in UTF-8 encoding', type=str)
|
||||
parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int)
|
||||
parser.add_argument('token_limit', help='How many tokens to process; set to -1 to process all text', nargs='?', type=int, default=-1)
|
||||
parser.add_argument('tokenizer', help='Which tokenizer to use', nargs='?', type=str, default="20B")
|
||||
return parser.parse_args()
|
||||
|
||||
args = parse_args()
|
||||
|
||||
# ---
|
||||
|
||||
print('Loading 20B tokenizer')
|
||||
tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
|
||||
tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
|
||||
|
||||
print('Loading text')
|
||||
text: str = open(args.text_path, encoding='utf-8').read()
|
||||
tokens: List[int] = tokenizer.encode(text).ids
|
||||
|
||||
tokenizer, tokenizer_encode = get_tokenizer(args.tokenizer)
|
||||
|
||||
tokens = tokenizer_encode(text)
|
||||
|
||||
token_count: int = len(tokens)
|
||||
print(f'{token_count} tokens in the text')
|
||||
|
||||
|
|
|
@ -84,7 +84,7 @@ class RWKVModel:
|
|||
if state_in is not None:
|
||||
validate_buffer(state_in, 'state_in', self._state_buffer_element_count)
|
||||
|
||||
state_in_ptr = state_in.untyped_storage().data_ptr()
|
||||
state_in_ptr = state_in.data_ptr()
|
||||
else:
|
||||
state_in_ptr = 0
|
||||
|
||||
|
@ -102,8 +102,8 @@ class RWKVModel:
|
|||
self._ctx,
|
||||
token,
|
||||
state_in_ptr,
|
||||
state_out.untyped_storage().data_ptr(),
|
||||
logits_out.untyped_storage().data_ptr()
|
||||
state_out.data_ptr(),
|
||||
logits_out.data_ptr()
|
||||
)
|
||||
|
||||
return logits_out, state_out
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
import os
|
||||
import tokenizers
|
||||
import pathlib
|
||||
|
||||
########################################################################################################
|
||||
# Taken from https://github.com/BlinkDL/ChatRWKV/tree/main/tokenizer/rwkv_tokenizer.py
|
||||
########################################################################################################
|
||||
|
||||
class TRIE:
|
||||
__slots__ = tuple("ch,to,values,front".split(","))
|
||||
to:list
|
||||
values:set
|
||||
def __init__(self, front=None, ch=None):
|
||||
self.ch = ch
|
||||
self.to = [None for ch in range(256)]
|
||||
self.values = set()
|
||||
self.front = front
|
||||
|
||||
def __repr__(self):
|
||||
fr = self
|
||||
ret = []
|
||||
while(fr!=None):
|
||||
if(fr.ch!=None):
|
||||
ret.append(fr.ch)
|
||||
fr = fr.front
|
||||
return "<TRIE %s %s>"%(ret[::-1], self.values)
|
||||
|
||||
def add(self, key:bytes, idx:int=0, val=None):
|
||||
if(idx == len(key)):
|
||||
if(val is None):
|
||||
val = key
|
||||
self.values.add(val)
|
||||
return self
|
||||
ch = key[idx]
|
||||
if(self.to[ch] is None):
|
||||
self.to[ch] = TRIE(front=self, ch=ch)
|
||||
return self.to[ch].add(key, idx=idx+1, val=val)
|
||||
|
||||
def find_longest(self, key:bytes, idx:int=0):
|
||||
u:TRIE = self
|
||||
ch:int = key[idx]
|
||||
|
||||
while(u.to[ch] is not None):
|
||||
u = u.to[ch]
|
||||
idx += 1
|
||||
if(u.values):
|
||||
ret = idx, u, u.values
|
||||
if(idx==len(key)):
|
||||
break
|
||||
ch = key[idx]
|
||||
return ret
|
||||
|
||||
class TRIE_TOKENIZER():
|
||||
def __init__(self, file_name):
|
||||
self.idx2token = {}
|
||||
sorted = [] # must be already sorted
|
||||
with open(file_name, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for l in lines:
|
||||
idx = int(l[:l.index(' ')])
|
||||
x = eval(l[l.index(' '):l.rindex(' ')])
|
||||
x = x.encode("utf-8") if isinstance(x, str) else x
|
||||
assert isinstance(x, bytes)
|
||||
assert len(x) == int(l[l.rindex(' '):])
|
||||
sorted += [x]
|
||||
self.idx2token[idx] = x
|
||||
|
||||
self.token2idx = {}
|
||||
for k,v in self.idx2token.items():
|
||||
self.token2idx[v] = int(k)
|
||||
|
||||
self.root = TRIE()
|
||||
for t, i in self.token2idx.items():
|
||||
_ = self.root.add(t, val=(t, i))
|
||||
|
||||
def encodeBytes(self, src:bytes) -> list[int]:
|
||||
idx:int = 0
|
||||
tokens:list[int] = []
|
||||
while (idx < len(src)):
|
||||
_idx:int = idx
|
||||
idx, _, values = self.root.find_longest(src, idx)
|
||||
assert(idx != _idx)
|
||||
_, token = next(iter(values))
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
def decodeBytes(self, tokens):
|
||||
return b''.join(map(lambda i: self.idx2token[i], tokens))
|
||||
|
||||
def encode(self, src):
|
||||
return self.encodeBytes(src.encode("utf-8"))
|
||||
|
||||
def decode(self, tokens):
|
||||
return self.decodeBytes(tokens).decode('utf-8')
|
||||
|
||||
def printTokens(self, tokens):
|
||||
for i in tokens:
|
||||
s = self.idx2token[i]
|
||||
try:
|
||||
s = s.decode('utf-8')
|
||||
except:
|
||||
pass
|
||||
print(f'{repr(s)}{i}', end=' ')
|
||||
print()
|
||||
|
||||
def get_tokenizer(tokenizer="20B"):
|
||||
if tokenizer == "world":
|
||||
print('Loading world tokenizer')
|
||||
tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / 'rwkv_vocab_v20230424.txt'
|
||||
tokenizer = TRIE_TOKENIZER(tokenizer_path)
|
||||
tokenizer_encode = lambda prompt: tokenizer.encode(prompt)
|
||||
elif tokenizer == "20B":
|
||||
print('Loading 20B tokenizer')
|
||||
tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
|
||||
tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
|
||||
tokenizer_encode = lambda prompt: tokenizer.encode(prompt).ids
|
||||
else:
|
||||
print(f"Unknown tokenizer: {args.tokenizer}")
|
||||
quit()
|
||||
return tokenizer, tokenizer_encode
|
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue