diff --git a/README.md b/README.md index 6cc84b4..d3c9fd0 100644 --- a/README.md +++ b/README.md @@ -47,10 +47,14 @@ python rwkv\convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pt #### 3. Use the model in Python: ```python -# This file is located at rwkv/rwkv_cpp.py -import rwkv_cpp +# These files are located in rwkv directory +import rwkv_cpp_model +import rwkv_cpp_shared_library -model = rwkv_cpp.RWKVModel(r'bin\Release\rwkv.dll', r'C:\rwkv.cpp-169M.bin') +model = rwkv_cpp_model.RWKVModel( + rwkv_cpp_shared_library.load_rwkv_shared_library(), + r'C:\rwkv.cpp-169M.bin' +) logits, state = None, None @@ -59,7 +63,7 @@ for token in [1, 2, 3]: print(f'Output logits: {logits}') -# Don't forget to free memory after you've done working with the model +# Don't forget to free the memory after you've done working with the model model.free() ``` diff --git a/rwkv.cpp b/rwkv.cpp index 8f36f88..c0140f8 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -163,7 +163,7 @@ struct rwkv_context { bool freed; }; -struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads) { +struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_threads) { FILE * file = fopen(file_path, "rb"); RWKV_ASSERT_NULL(file != NULL, "Failed to open file %s", file_path); @@ -505,15 +505,15 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads) return rwkv_ctx; } -size_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx) { +uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx) { return ctx->model->n_layer * 5 * ctx->model->n_embed; } -size_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx) { +uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx) { return ctx->model->n_vocab; } -bool rwkv_eval(struct rwkv_context * ctx, long int token, float * state_in, float * state_out, float * logits_out) { +bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out) { RWKV_ASSERT_FALSE(state_out != NULL, "state_out is NULL"); RWKV_ASSERT_FALSE(logits_out != NULL, "logits_out is NULL"); @@ -564,7 +564,7 @@ void rwkv_free(struct rwkv_context * ctx) { delete ctx; } -bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type) { +bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) { RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type); ggml_type type; diff --git a/rwkv.h b/rwkv.h index f7dbfb4..e56e290 100644 --- a/rwkv.h +++ b/rwkv.h @@ -33,7 +33,7 @@ extern "C" { // Returns NULL on any error. Error messages would be printed to stderr. // - model_file_path: path to model file in ggml format. // - n_threads: count of threads to use, must be positive. - RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads); + RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, uint32_t n_threads); // Evaluates the model for a single token. // Returns false on any error. Error messages would be printed to stderr. @@ -41,13 +41,13 @@ extern "C" { // - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass. // - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to. // - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. - RWKV_API bool rwkv_eval(struct rwkv_context * ctx, long int token, float * state_in, float * state_out, float * logits_out); + RWKV_API bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out); // Returns count of FP32 elements in state buffer. - RWKV_API size_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx); + RWKV_API uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx); // Returns count of FP32 elements in logits buffer. - RWKV_API size_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx); + RWKV_API uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx); // Frees all allocated memory and the context. RWKV_API void rwkv_free(struct rwkv_context * ctx); @@ -57,7 +57,7 @@ extern "C" { // - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16. // - model_file_path_out: quantized model will be written here. // - q_type: set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1. - RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type); + RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type); // Returns system information string. RWKV_API const char * rwkv_get_system_info_string(void); diff --git a/rwkv/compare_with_reference_implementation.py b/rwkv/compare_with_reference_implementation.py index 7bd3ee8..69a5828 100644 --- a/rwkv/compare_with_reference_implementation.py +++ b/rwkv/compare_with_reference_implementation.py @@ -1,20 +1,19 @@ # Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV. # Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch. # Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py -# Usage: python compare_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin +# Usage: python compare_with_reference_implementation.py C:\rwkv.cpp-169M.bin import os import struct import argparse -import subprocess import torch import numpy as np -import rwkv_cpp +import rwkv_cpp_model +import rwkv_cpp_shared_library from typing import List, Tuple, Any def parse_args(): parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV') - parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file or shared library') parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file') return parser.parse_args() @@ -22,17 +21,12 @@ def main() -> None: args = parse_args() # Don't want to depend on tokenizer here. - # Exact string is: - # context = "1 In the beginning God created the heaven and the earth. " \ - # "2 And the earth was without form, and void; and darkness was upon the face of the deep. And the Spirit of God moved upon the face of the waters. " \ - # "3 And God said, Let there be light: and there was light. " \ - # "4 And God saw the light, that it was good: and God divided the light from the darkness." - # The Bible was the first non-copyrighted public domain text that came to my mind. - tokens: List[int] = [18, 496, 253, 5068, 2656, 3562, 253, 13926, 285, 253, 6149, 15, 374, 1244, 253, 6149, 369, 1293, 830, - 13, 285, 2991, 28, 285, 13862, 369, 2220, 253, 2454, 273, 253, 3676, 15, 1244, 253, 14559, 273, 2656, - 4395, 2220, 253, 2454, 273, 253, 12685, 15, 495, 1244, 2656, 753, 13, 1281, 627, 320, 1708, 27, 285, - 627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272, - 253, 1708, 432, 253, 13862, 15] + tokens: List[int] = [4, 3631, 4420, 2412, 953, 432, 391, 30567, 87, 15, 14161, 7092, 273, 416, 27767, 55, 342, + 2412, 953, 432, 3806, 7092, 273, 416, 27767, 55, 15, 187, 4, 19039, 2412, 953, 497, 4561, + 342, 416, 27767, 55, 14, 21, 14, 49, 587, 14, 17809, 46, 14, 938, 14256, 28950, 14, 1438, + 1508, 15, 81, 394, 1566, 275, 8462, 22097, 348, 15, 187, 4, 43825, 27, 15548, 7277, 64, + 3113, 64, 14005, 64, 39595, 15, 4789, 10269, 61, 18992, 61, 7265, 64, 30217, 39297, 15, + 20963, 330, 27, 190, 30567, 87, 15, 14161, 14, 17809, 46, 15, 4805] threshold: float @@ -50,7 +44,7 @@ def main() -> None: threshold = 0.000005 elif data_type == 1: # FP16, lower precision, so higher threshold - threshold = 0.003 + threshold = 0.0032 elif data_type == 2: # INT4 quantized, even lower precision, so even higher threshold # This threshold will let some bugs pass @@ -59,42 +53,24 @@ def main() -> None: # This format stores more data, so error would be lower threshold = 1.2 - model = None - - if args.main_executable_path.lower().endswith('.dll') or args.main_executable_path.lower().endswith('.so'): - print('Testing shared library') - - model = rwkv_cpp.RWKVModel(args.main_executable_path, args.ggml_model_path) - else: - print('Testing main_rwkv executable') + model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path) def compare_logits(tokens_subset: List[int]) -> None: token_count: int = len(tokens_subset) - state_path: str = './state.bin' - logits_path: str = './logits.bin' logits, state = None, None for i in range(token_count): token: int = tokens_subset[i] - print(f'{i + 1}/{token_count}') + if token_count <= 10 or i % (token_count // 10) == 0: + print(f'{i + 1}/{token_count}') - if model is not None: - logits, state = model.eval(token, state) - else: - subprocess.run( - [ - args.main_executable_path, - args.ggml_model_path, - str(token), - # If this is the first token, let the script create a new state. - '' if i == 0 else state_path, - state_path, - logits_path - ], - check=True - ) + logits, state = model.eval(token, state, state, logits) + + actual_logits = logits + + # --- expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin' @@ -104,11 +80,7 @@ def main() -> None: with open(expected_logits_path, 'rb') as logits_file: expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) - if model is not None: - actual_logits = logits - else: - with open(logits_path, 'rb') as logits_file: - actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) + # --- difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item() @@ -118,8 +90,6 @@ def main() -> None: assert abs(difference) <= threshold, 'Difference is too big' - # Check small token amount first to avoid waiting too long before seeing that model is broken - compare_logits(tokens[:4]) compare_logits(tokens) print() diff --git a/rwkv/expected_logits_169M_20220807_8023_4_tokens.bin b/rwkv/expected_logits_169M_20220807_8023_4_tokens.bin deleted file mode 100644 index e1ddfc0..0000000 Binary files a/rwkv/expected_logits_169M_20220807_8023_4_tokens.bin and /dev/null differ diff --git a/rwkv/expected_logits_169M_20220807_8023_82_tokens.bin b/rwkv/expected_logits_169M_20220807_8023_82_tokens.bin deleted file mode 100644 index 9ce6ca1..0000000 Binary files a/rwkv/expected_logits_169M_20220807_8023_82_tokens.bin and /dev/null differ diff --git a/rwkv/expected_logits_169M_20220807_8023_98_tokens.bin b/rwkv/expected_logits_169M_20220807_8023_98_tokens.bin new file mode 100644 index 0000000..e0409d2 Binary files /dev/null and b/rwkv/expected_logits_169M_20220807_8023_98_tokens.bin differ diff --git a/rwkv/quantize.py b/rwkv/quantize.py index e76359c..e798855 100644 --- a/rwkv/quantize.py +++ b/rwkv/quantize.py @@ -1,12 +1,11 @@ # Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1. # Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3 -import ctypes import argparse +import rwkv_cpp_shared_library def parse_args(): parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1') - parser.add_argument('shared_library_path', help='Path to rwkv.cpp shared library') parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file') parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten') parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3) @@ -15,19 +14,14 @@ def parse_args(): def main() -> None: args = parse_args() - library = ctypes.cdll.LoadLibrary(args.shared_library_path) + library = rwkv_cpp_shared_library.load_rwkv_shared_library() - library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_int] - library.rwkv_quantize_model_file.restype = ctypes.c_bool - - result: bool = library.rwkv_quantize_model_file( - args.src_path.encode('utf-8'), - args.dest_path.encode('utf-8'), - ctypes.c_int(args.data_type) + library.rwkv_quantize_model_file( + args.src_path, + args.dest_path, + args.data_type ) - assert result, 'Failed to quantize, check stderr' - print('Done') if __name__ == "__main__": diff --git a/rwkv/rwkv_cpp.py b/rwkv/rwkv_cpp.py deleted file mode 100644 index 88b422a..0000000 --- a/rwkv/rwkv_cpp.py +++ /dev/null @@ -1,127 +0,0 @@ -import os -import ctypes -import torch -import multiprocessing -from typing import Tuple, Optional - -P_FLOAT = ctypes.POINTER(ctypes.c_float) - -class RWKVModel: - """ - PyTorch wrapper around rwkv.cpp shared library. - """ - - def __init__( - self, - shared_library_path: str, - model_path: str, - thread_count: int = max(1, multiprocessing.cpu_count() // 2) - ): - """ - Loads the model and prepares it for inference. - In case of any error, this method will throw an exception. - - Parameters - ---------- - shared_library_path : str - Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'. - model_path : str - Path to RWKV model file in ggml format. - thread_count : int - Thread count to use. If not set, defaults to CPU count / 2. - """ - - assert os.path.isfile(shared_library_path), f'{shared_library_path} is not a file' - assert os.path.isfile(model_path), f'{model_path} is not a file' - assert thread_count > 0, 'Thread count must be positive' - - self.library = ctypes.cdll.LoadLibrary(shared_library_path) - - self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_int] - self.library.rwkv_init_from_file.restype = ctypes.c_void_p - - self.library.rwkv_eval.argtypes = [ - ctypes.c_void_p, # ctx - ctypes.c_long, # token - P_FLOAT, # state_in - P_FLOAT, # state_out - P_FLOAT # logits_out - ] - self.library.rwkv_eval.restype = ctypes.c_bool - - self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p] - self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_size_t - - self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p] - self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_size_t - - self.library.rwkv_free.argtypes = [ctypes.c_void_p] - self.library.rwkv_free.restype = None - - self.ctx = self.library.rwkv_init_from_file(model_path.encode('utf-8'), ctypes.c_int(thread_count)) - - assert self.ctx is not None, 'Failed to load the model, see stderr' - - self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx) - self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx) - - self.valid = True - - def eval(self, token: int, state_in: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Evaluates the model for a single token. - In case of any error, this method will throw an exception. - - Parameters - ---------- - token : int - Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab. - state_in : Optional[torch.Tensor] - State from previous call of this method. If this is a first pass, set it to None. - - Returns - ------- - logits, state - Logits vector of shape (n_vocab); state for the next step. - """ - - assert self.valid, 'Model was freed' - - if state_in is None: - state_in_ptr = 0 - else: - expected_shape = (self.state_buffer_element_count,) - - assert state_in.is_contiguous(), 'State tensor is not contiguous' - assert state_in.shape == expected_shape, f'Invalid state shape {state_in.shape}, expected {expected_shape}' - - state_in_ptr = state_in.storage().data_ptr() - - # TODO Probably these allocations can be optimized away - state_out: torch.Tensor = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu') - logits_out: torch.Tensor = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu') - - result = self.library.rwkv_eval( - self.ctx, - ctypes.c_long(token), - ctypes.cast(state_in_ptr, P_FLOAT), - ctypes.cast(state_out.storage().data_ptr(), P_FLOAT), - ctypes.cast(logits_out.storage().data_ptr(), P_FLOAT) - ) - - assert result, 'Inference failed, see stderr' - - return logits_out, state_out - - def free(self): - """ - Frees all allocated resources. - In case of any error, this method will throw an exception. - The object must not be used anymore after calling this method. - """ - - assert self.valid, 'Already freed' - - self.valid = False - - self.library.rwkv_free(self.ctx) diff --git a/rwkv/rwkv_cpp_model.py b/rwkv/rwkv_cpp_model.py new file mode 100644 index 0000000..4f089ad --- /dev/null +++ b/rwkv/rwkv_cpp_model.py @@ -0,0 +1,117 @@ +import os +import torch +import multiprocessing +import rwkv_cpp_shared_library +from typing import Tuple, Optional + +class RWKVModel: + """ + PyTorch wrapper around rwkv.cpp model. + """ + + def __init__( + self, + shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary, + model_path: str, + thread_count: int = max(1, multiprocessing.cpu_count() // 2) + ): + """ + Loads the model and prepares it for inference. + In case of any error, this method will throw an exception. + + Parameters + ---------- + shared_library : RWKVSharedLibrary + rwkv.cpp shared library. + model_path : str + Path to RWKV model file in ggml format. + thread_count : int + Thread count to use. If not set, defaults to CPU count / 2. + """ + + assert os.path.isfile(model_path), f'{model_path} is not a file' + assert thread_count > 0, 'Thread count must be positive' + + self.library = shared_library + + self.ctx = self.library.rwkv_init_from_file(model_path, thread_count) + + self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx) + self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx) + + self.valid = True + + def eval( + self, + token: int, + state_in: Optional[torch.Tensor], + state_out: Optional[torch.Tensor] = None, + logits_out: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Evaluates the model for a single token. + In case of any error, this method will throw an exception. + + Parameters + ---------- + token : int + Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab. + state_in : Optional[torch.Tensor] + State from previous call of this method. If this is a first pass, set it to None. + state_out : Optional[torch.Tensor] + Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count). + logits_out : Optional[torch.Tensor] + Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count). + + Returns + ------- + logits, state + Logits vector of shape (n_vocab); state for the next step. + """ + + assert self.valid, 'Model was freed' + + def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None: + assert buf.dtype == torch.float32, f'{name} is not of type float32' + assert buf.is_contiguous(), f'{name} is not contiguous' + assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})' + + if state_in is not None: + validate_buffer(state_in, 'state_in', self.state_buffer_element_count) + + state_in_ptr = state_in.storage().data_ptr() + else: + state_in_ptr = 0 + + if state_out is not None: + validate_buffer(state_out, 'state_out', self.state_buffer_element_count) + else: + state_out = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu') + + if logits_out is not None: + validate_buffer(logits_out, 'logits_out', self.logits_buffer_element_count) + else: + logits_out = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu') + + self.library.rwkv_eval( + self.ctx, + token, + state_in_ptr, + state_out.storage().data_ptr(), + logits_out.storage().data_ptr() + ) + + return logits_out, state_out + + def free(self): + """ + Frees all allocated resources. + In case of any error, this method will throw an exception. + The object must not be used anymore after calling this method. + """ + + assert self.valid, 'Already freed' + + self.valid = False + + self.library.rwkv_free(self.ctx) diff --git a/rwkv/rwkv_cpp_shared_library.py b/rwkv/rwkv_cpp_shared_library.py new file mode 100644 index 0000000..1bacef2 --- /dev/null +++ b/rwkv/rwkv_cpp_shared_library.py @@ -0,0 +1,204 @@ +import os +import sys +import ctypes +from typing import Optional + +P_FLOAT = ctypes.POINTER(ctypes.c_float) + +class RWKVContext: + + def __init__(self, ptr: ctypes.pointer): + self.ptr = ptr + +class RWKVSharedLibrary: + """ + Python wrapper around rwkv.cpp shared library. + """ + + def __init__(self, shared_library_path: str): + """ + Loads the shared library from specified file. + In case of any error, this method will throw an exception. + + Parameters + ---------- + shared_library_path : str + Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'. + """ + + self.library = ctypes.cdll.LoadLibrary(shared_library_path) + + self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32] + self.library.rwkv_init_from_file.restype = ctypes.c_void_p + + self.library.rwkv_eval.argtypes = [ + ctypes.c_void_p, # ctx + ctypes.c_int32, # token + P_FLOAT, # state_in + P_FLOAT, # state_out + P_FLOAT # logits_out + ] + self.library.rwkv_eval.restype = ctypes.c_bool + + self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p] + self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_uint32 + + self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p] + self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_uint32 + + self.library.rwkv_free.argtypes = [ctypes.c_void_p] + self.library.rwkv_free.restype = None + + self.library.rwkv_free.argtypes = [ctypes.c_void_p] + self.library.rwkv_free.restype = None + + self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_uint32] + self.library.rwkv_quantize_model_file.restype = ctypes.c_bool + + self.library.rwkv_get_system_info_string.argtypes = [] + self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p + + def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext: + """ + Loads the model from a file and prepares it for inference. + Throws an exception in case of any error. Error messages would be printed to stderr. + + Parameters + ---------- + model_file_path : str + Path to model file in ggml format. + thread_count : int + Count of threads to use, must be positive. + """ + + ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count)) + assert ptr is not None, 'rwkv_init_from_file failed, check stderr' + return RWKVContext(ptr) + + def rwkv_eval( + self, + ctx: RWKVContext, + token: int, + state_in_address: Optional[int], + state_out_address: int, + logits_out_address: int + ) -> None: + """ + Evaluates the model for a single token. + Throws an exception in case of any error. Error messages would be printed to stderr. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + token : int + Next token index, in range 0 <= token < n_vocab. + state_in_address : int + Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass. + state_out_address : int + Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to. + logits_out_address : int + Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. + """ + + assert self.library.rwkv_eval( + ctx.ptr, + ctypes.c_int32(token), + ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), + ctypes.cast(state_out_address, P_FLOAT), + ctypes.cast(logits_out_address, P_FLOAT) + ), 'rwkv_eval failed, check stderr' + + def rwkv_get_state_buffer_element_count(self, ctx: RWKVContext) -> int: + """ + Returns count of FP32 elements in state buffer. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + """ + + return self.library.rwkv_get_state_buffer_element_count(ctx.ptr) + + def rwkv_get_logits_buffer_element_count(self, ctx: RWKVContext) -> int: + """ + Returns count of FP32 elements in logits buffer. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + """ + + return self.library.rwkv_get_logits_buffer_element_count(ctx.ptr) + + def rwkv_free(self, ctx: RWKVContext) -> None: + """ + Frees all allocated memory and the context. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + """ + + self.library.rwkv_free(ctx.ptr) + + ctx.ptr = ctypes.cast(0, ctypes.c_void_p) + + def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, q_type: int) -> None: + """ + Quantizes FP32 or FP16 model to one of INT4 formats. + Throws an exception in case of any error. Error messages would be printed to stderr. + + Parameters + ---------- + model_file_path_in : str + Path to model file in ggml format, must be either FP32 or FP16. + model_file_path_out : str + Quantized model will be written here. + q_type : int + Set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1. + """ + + assert self.library.rwkv_quantize_model_file( + model_file_path_in.encode('utf-8'), + model_file_path_out.encode('utf-8'), + ctypes.c_uint32(q_type) + ), 'rwkv_quantize_model_file failed, check stderr' + + def rwkv_get_system_info_string(self) -> str: + """ + Returns system information string. + """ + + return self.library.rwkv_get_system_info_string() + +def load_rwkv_shared_library() -> RWKVSharedLibrary: + """ + Attempts to find rwkv.cpp shared library and load it. + To specify exact path to the library, create an instance of RWKVSharedLibrary explicitly. + """ + + file_name: str + + if 'win32' in sys.platform or 'cygwin' in sys.platform: + file_name = 'rwkv.dll' + else: + file_name = 'rwkv.so' + + paths = [ + # If we are in "rwkv" directory + f'../bin/Release/{file_name}', + # If we are in repo root directory + f'bin/Release/{file_name}', + # Fallback + file_name + ] + + for path in paths: + if os.path.isfile(path): + return RWKVSharedLibrary(path) + + return RWKVSharedLibrary(paths[-1])