Move library wrapper to separate file, refactor code
This commit is contained in:
		
							parent
							
								
									38f9d02d52
								
							
						
					
					
						commit
						935d16f5db
					
				
							
								
								
									
										12
									
								
								README.md
								
								
								
								
							
							
						
						
									
										12
									
								
								README.md
								
								
								
								
							|  | @ -47,10 +47,14 @@ python rwkv\convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pt | |||
| #### 3. Use the model in Python: | ||||
| 
 | ||||
| ```python | ||||
| # This file is located at rwkv/rwkv_cpp.py | ||||
| import rwkv_cpp | ||||
| # These files are located in rwkv directory | ||||
| import rwkv_cpp_model | ||||
| import rwkv_cpp_shared_library | ||||
| 
 | ||||
| model = rwkv_cpp.RWKVModel(r'bin\Release\rwkv.dll', r'C:\rwkv.cpp-169M.bin') | ||||
| model = rwkv_cpp_model.RWKVModel( | ||||
|     rwkv_cpp_shared_library.load_rwkv_shared_library(), | ||||
|     r'C:\rwkv.cpp-169M.bin' | ||||
| ) | ||||
| 
 | ||||
| logits, state = None, None | ||||
| 
 | ||||
|  | @ -59,7 +63,7 @@ for token in [1, 2, 3]: | |||
|      | ||||
|     print(f'Output logits: {logits}') | ||||
| 
 | ||||
| # Don't forget to free memory after you've done working with the model | ||||
| # Don't forget to free the memory after you've done working with the model | ||||
| model.free() | ||||
| 
 | ||||
| ``` | ||||
|  |  | |||
							
								
								
									
										10
									
								
								rwkv.cpp
								
								
								
								
							
							
						
						
									
										10
									
								
								rwkv.cpp
								
								
								
								
							|  | @ -163,7 +163,7 @@ struct rwkv_context { | |||
|     bool freed; | ||||
| }; | ||||
| 
 | ||||
| struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads) { | ||||
| struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_threads) { | ||||
|     FILE * file = fopen(file_path, "rb"); | ||||
|     RWKV_ASSERT_NULL(file != NULL, "Failed to open file %s", file_path); | ||||
| 
 | ||||
|  | @ -505,15 +505,15 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads) | |||
|     return rwkv_ctx; | ||||
| } | ||||
| 
 | ||||
| size_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx) { | ||||
| uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx) { | ||||
|     return ctx->model->n_layer * 5 * ctx->model->n_embed; | ||||
| } | ||||
| 
 | ||||
| size_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx) { | ||||
| uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx) { | ||||
|     return ctx->model->n_vocab; | ||||
| } | ||||
| 
 | ||||
| bool rwkv_eval(struct rwkv_context * ctx, long int token, float * state_in, float * state_out, float * logits_out) { | ||||
| bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out) { | ||||
|     RWKV_ASSERT_FALSE(state_out != NULL, "state_out is NULL"); | ||||
|     RWKV_ASSERT_FALSE(logits_out != NULL, "logits_out is NULL"); | ||||
| 
 | ||||
|  | @ -564,7 +564,7 @@ void rwkv_free(struct rwkv_context * ctx) { | |||
|     delete ctx; | ||||
| } | ||||
| 
 | ||||
| bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type) { | ||||
| bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) { | ||||
|     RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type); | ||||
| 
 | ||||
|     ggml_type type; | ||||
|  |  | |||
							
								
								
									
										10
									
								
								rwkv.h
								
								
								
								
							
							
						
						
									
										10
									
								
								rwkv.h
								
								
								
								
							|  | @ -33,7 +33,7 @@ extern "C" { | |||
|     // Returns NULL on any error. Error messages would be printed to stderr.
 | ||||
|     // - model_file_path: path to model file in ggml format.
 | ||||
|     // - n_threads: count of threads to use, must be positive.
 | ||||
|     RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads); | ||||
|     RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, uint32_t n_threads); | ||||
| 
 | ||||
|     // Evaluates the model for a single token.
 | ||||
|     // Returns false on any error. Error messages would be printed to stderr.
 | ||||
|  | @ -41,13 +41,13 @@ extern "C" { | |||
|     // - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
 | ||||
|     // - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
 | ||||
|     // - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
 | ||||
|     RWKV_API bool rwkv_eval(struct rwkv_context * ctx, long int token, float * state_in, float * state_out, float * logits_out); | ||||
|     RWKV_API bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out); | ||||
| 
 | ||||
|     // Returns count of FP32 elements in state buffer.
 | ||||
|     RWKV_API size_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx); | ||||
|     RWKV_API uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx); | ||||
| 
 | ||||
|     // Returns count of FP32 elements in logits buffer.
 | ||||
|     RWKV_API size_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx); | ||||
|     RWKV_API uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx); | ||||
| 
 | ||||
|     // Frees all allocated memory and the context.
 | ||||
|     RWKV_API void rwkv_free(struct rwkv_context * ctx); | ||||
|  | @ -57,7 +57,7 @@ extern "C" { | |||
|     // - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
 | ||||
|     // - model_file_path_out: quantized model will be written here.
 | ||||
|     // - q_type: set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1.
 | ||||
|     RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type); | ||||
|     RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type); | ||||
| 
 | ||||
|     // Returns system information string.
 | ||||
|     RWKV_API const char * rwkv_get_system_info_string(void); | ||||
|  |  | |||
|  | @ -1,20 +1,19 @@ | |||
| # Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV. | ||||
| # Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch. | ||||
| # Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py | ||||
| # Usage: python compare_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin | ||||
| # Usage: python compare_with_reference_implementation.py C:\rwkv.cpp-169M.bin | ||||
| 
 | ||||
| import os | ||||
| import struct | ||||
| import argparse | ||||
| import subprocess | ||||
| import torch | ||||
| import numpy as np | ||||
| import rwkv_cpp | ||||
| import rwkv_cpp_model | ||||
| import rwkv_cpp_shared_library | ||||
| from typing import List, Tuple, Any | ||||
| 
 | ||||
| def parse_args(): | ||||
|     parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV') | ||||
|     parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file or shared library') | ||||
|     parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file') | ||||
|     return parser.parse_args() | ||||
| 
 | ||||
|  | @ -22,17 +21,12 @@ def main() -> None: | |||
|     args = parse_args() | ||||
| 
 | ||||
|     # Don't want to depend on tokenizer here. | ||||
|     # Exact string is: | ||||
|     # context = "1 In the beginning God created the heaven and the earth. " \ | ||||
|     #           "2 And the earth was without form, and void; and darkness was upon the face of the deep. And the Spirit of God moved upon the face of the waters. " \ | ||||
|     #           "3 And God said, Let there be light: and there was light. " \ | ||||
|     #           "4 And God saw the light, that it was good: and God divided the light from the darkness." | ||||
|     # The Bible was the first non-copyrighted public domain text that came to my mind. | ||||
|     tokens: List[int] = [18, 496, 253, 5068, 2656, 3562, 253, 13926, 285, 253, 6149, 15, 374, 1244, 253, 6149, 369, 1293, 830, | ||||
|                          13, 285, 2991, 28, 285, 13862, 369, 2220, 253, 2454, 273, 253, 3676, 15, 1244, 253, 14559, 273, 2656, | ||||
|                          4395, 2220, 253, 2454, 273, 253, 12685, 15, 495, 1244, 2656, 753, 13, 1281, 627, 320, 1708, 27, 285, | ||||
|                          627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272, | ||||
|                          253, 1708, 432, 253, 13862, 15] | ||||
|     tokens: List[int] = [4, 3631, 4420, 2412, 953, 432, 391, 30567, 87, 15, 14161, 7092, 273, 416, 27767, 55, 342, | ||||
|                          2412, 953, 432, 3806, 7092, 273, 416, 27767, 55, 15, 187, 4, 19039, 2412, 953, 497, 4561, | ||||
|                          342, 416, 27767, 55, 14, 21, 14, 49, 587, 14, 17809, 46, 14, 938, 14256, 28950, 14, 1438, | ||||
|                          1508, 15, 81, 394, 1566, 275, 8462, 22097, 348, 15, 187, 4, 43825, 27, 15548, 7277, 64, | ||||
|                          3113, 64, 14005, 64, 39595, 15, 4789, 10269, 61, 18992, 61, 7265, 64, 30217, 39297, 15, | ||||
|                          20963, 330, 27, 190, 30567, 87, 15, 14161, 14, 17809, 46, 15, 4805] | ||||
| 
 | ||||
|     threshold: float | ||||
| 
 | ||||
|  | @ -50,7 +44,7 @@ def main() -> None: | |||
|             threshold = 0.000005 | ||||
|         elif data_type == 1: | ||||
|             # FP16, lower precision, so higher threshold | ||||
|             threshold = 0.003 | ||||
|             threshold = 0.0032 | ||||
|         elif data_type == 2: | ||||
|             # INT4 quantized, even lower precision, so even higher threshold | ||||
|             # This threshold will let some bugs pass | ||||
|  | @ -59,42 +53,24 @@ def main() -> None: | |||
|             # This format stores more data, so error would be lower | ||||
|             threshold = 1.2 | ||||
| 
 | ||||
|     model = None | ||||
| 
 | ||||
|     if args.main_executable_path.lower().endswith('.dll') or args.main_executable_path.lower().endswith('.so'): | ||||
|         print('Testing shared library') | ||||
| 
 | ||||
|         model = rwkv_cpp.RWKVModel(args.main_executable_path, args.ggml_model_path) | ||||
|     else: | ||||
|         print('Testing main_rwkv executable') | ||||
|     model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path) | ||||
| 
 | ||||
|     def compare_logits(tokens_subset: List[int]) -> None: | ||||
|         token_count: int = len(tokens_subset) | ||||
|         state_path: str = './state.bin' | ||||
|         logits_path: str = './logits.bin' | ||||
| 
 | ||||
|         logits, state = None, None | ||||
| 
 | ||||
|         for i in range(token_count): | ||||
|             token: int = tokens_subset[i] | ||||
| 
 | ||||
|             print(f'{i + 1}/{token_count}') | ||||
|             if token_count <= 10 or i % (token_count // 10) == 0: | ||||
|                 print(f'{i + 1}/{token_count}') | ||||
| 
 | ||||
|             if model is not None: | ||||
|                 logits, state = model.eval(token, state) | ||||
|             else: | ||||
|                 subprocess.run( | ||||
|                     [ | ||||
|                         args.main_executable_path, | ||||
|                         args.ggml_model_path, | ||||
|                         str(token), | ||||
|                         # If this is the first token, let the script create a new state. | ||||
|                         '' if i == 0 else state_path, | ||||
|                         state_path, | ||||
|                         logits_path | ||||
|                     ], | ||||
|                     check=True | ||||
|                 ) | ||||
|             logits, state = model.eval(token, state, state, logits) | ||||
| 
 | ||||
|         actual_logits = logits | ||||
| 
 | ||||
|         # --- | ||||
| 
 | ||||
|         expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin' | ||||
| 
 | ||||
|  | @ -104,11 +80,7 @@ def main() -> None: | |||
|         with open(expected_logits_path, 'rb') as logits_file: | ||||
|             expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) | ||||
| 
 | ||||
|         if model is not None: | ||||
|             actual_logits = logits | ||||
|         else: | ||||
|             with open(logits_path, 'rb') as logits_file: | ||||
|                 actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) | ||||
|         # --- | ||||
| 
 | ||||
|         difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item() | ||||
| 
 | ||||
|  | @ -118,8 +90,6 @@ def main() -> None: | |||
| 
 | ||||
|         assert abs(difference) <= threshold, 'Difference is too big' | ||||
| 
 | ||||
|     # Check small token amount first to avoid waiting too long before seeing that model is broken | ||||
|     compare_logits(tokens[:4]) | ||||
|     compare_logits(tokens) | ||||
| 
 | ||||
|     print() | ||||
|  |  | |||
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							|  | @ -1,12 +1,11 @@ | |||
| # Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1. | ||||
| # Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3 | ||||
| 
 | ||||
| import ctypes | ||||
| import argparse | ||||
| import rwkv_cpp_shared_library | ||||
| 
 | ||||
| def parse_args(): | ||||
|     parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1') | ||||
|     parser.add_argument('shared_library_path', help='Path to rwkv.cpp shared library') | ||||
|     parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file') | ||||
|     parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten') | ||||
|     parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3) | ||||
|  | @ -15,19 +14,14 @@ def parse_args(): | |||
| def main() -> None: | ||||
|     args = parse_args() | ||||
| 
 | ||||
|     library = ctypes.cdll.LoadLibrary(args.shared_library_path) | ||||
|     library = rwkv_cpp_shared_library.load_rwkv_shared_library() | ||||
| 
 | ||||
|     library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_int] | ||||
|     library.rwkv_quantize_model_file.restype = ctypes.c_bool | ||||
| 
 | ||||
|     result: bool = library.rwkv_quantize_model_file( | ||||
|         args.src_path.encode('utf-8'), | ||||
|         args.dest_path.encode('utf-8'), | ||||
|         ctypes.c_int(args.data_type) | ||||
|     library.rwkv_quantize_model_file( | ||||
|         args.src_path, | ||||
|         args.dest_path, | ||||
|         args.data_type | ||||
|     ) | ||||
| 
 | ||||
|     assert result, 'Failed to quantize, check stderr' | ||||
| 
 | ||||
|     print('Done') | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|  |  | |||
							
								
								
									
										127
									
								
								rwkv/rwkv_cpp.py
								
								
								
								
							
							
						
						
									
										127
									
								
								rwkv/rwkv_cpp.py
								
								
								
								
							|  | @ -1,127 +0,0 @@ | |||
| import os | ||||
| import ctypes | ||||
| import torch | ||||
| import multiprocessing | ||||
| from typing import Tuple, Optional | ||||
| 
 | ||||
| P_FLOAT = ctypes.POINTER(ctypes.c_float) | ||||
| 
 | ||||
| class RWKVModel: | ||||
|     """ | ||||
|     PyTorch wrapper around rwkv.cpp shared library. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|             self, | ||||
|             shared_library_path: str, | ||||
|             model_path: str, | ||||
|             thread_count: int = max(1, multiprocessing.cpu_count() // 2) | ||||
|     ): | ||||
|         """ | ||||
|         Loads the model and prepares it for inference. | ||||
|         In case of any error, this method will throw an exception. | ||||
| 
 | ||||
|         Parameters | ||||
|         ---------- | ||||
|         shared_library_path : str | ||||
|             Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'. | ||||
|         model_path : str | ||||
|             Path to RWKV model file in ggml format. | ||||
|         thread_count : int | ||||
|             Thread count to use. If not set, defaults to CPU count / 2. | ||||
|         """ | ||||
| 
 | ||||
|         assert os.path.isfile(shared_library_path), f'{shared_library_path} is not a file' | ||||
|         assert os.path.isfile(model_path), f'{model_path} is not a file' | ||||
|         assert thread_count > 0, 'Thread count must be positive' | ||||
| 
 | ||||
|         self.library = ctypes.cdll.LoadLibrary(shared_library_path) | ||||
| 
 | ||||
|         self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_int] | ||||
|         self.library.rwkv_init_from_file.restype = ctypes.c_void_p | ||||
| 
 | ||||
|         self.library.rwkv_eval.argtypes = [ | ||||
|             ctypes.c_void_p, # ctx | ||||
|             ctypes.c_long, # token | ||||
|             P_FLOAT, # state_in | ||||
|             P_FLOAT, # state_out | ||||
|             P_FLOAT  # logits_out | ||||
|         ] | ||||
|         self.library.rwkv_eval.restype = ctypes.c_bool | ||||
| 
 | ||||
|         self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p] | ||||
|         self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_size_t | ||||
| 
 | ||||
|         self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p] | ||||
|         self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_size_t | ||||
| 
 | ||||
|         self.library.rwkv_free.argtypes = [ctypes.c_void_p] | ||||
|         self.library.rwkv_free.restype = None | ||||
| 
 | ||||
|         self.ctx = self.library.rwkv_init_from_file(model_path.encode('utf-8'), ctypes.c_int(thread_count)) | ||||
| 
 | ||||
|         assert self.ctx is not None, 'Failed to load the model, see stderr' | ||||
| 
 | ||||
|         self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx) | ||||
|         self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx) | ||||
| 
 | ||||
|         self.valid = True | ||||
| 
 | ||||
|     def eval(self, token: int, state_in: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: | ||||
|         """ | ||||
|         Evaluates the model for a single token. | ||||
|         In case of any error, this method will throw an exception. | ||||
| 
 | ||||
|         Parameters | ||||
|         ---------- | ||||
|         token : int | ||||
|             Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab. | ||||
|         state_in : Optional[torch.Tensor] | ||||
|             State from previous call of this method. If this is a first pass, set it to None. | ||||
| 
 | ||||
|         Returns | ||||
|         ------- | ||||
|         logits, state | ||||
|             Logits vector of shape (n_vocab); state for the next step. | ||||
|         """ | ||||
| 
 | ||||
|         assert self.valid, 'Model was freed' | ||||
| 
 | ||||
|         if state_in is None: | ||||
|             state_in_ptr = 0 | ||||
|         else: | ||||
|             expected_shape = (self.state_buffer_element_count,) | ||||
| 
 | ||||
|             assert state_in.is_contiguous(), 'State tensor is not contiguous' | ||||
|             assert state_in.shape == expected_shape, f'Invalid state shape {state_in.shape}, expected {expected_shape}' | ||||
| 
 | ||||
|             state_in_ptr = state_in.storage().data_ptr() | ||||
| 
 | ||||
|         # TODO Probably these allocations can be optimized away | ||||
|         state_out: torch.Tensor = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu') | ||||
|         logits_out: torch.Tensor = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu') | ||||
| 
 | ||||
|         result = self.library.rwkv_eval( | ||||
|             self.ctx, | ||||
|             ctypes.c_long(token), | ||||
|             ctypes.cast(state_in_ptr, P_FLOAT), | ||||
|             ctypes.cast(state_out.storage().data_ptr(), P_FLOAT), | ||||
|             ctypes.cast(logits_out.storage().data_ptr(), P_FLOAT) | ||||
|         ) | ||||
| 
 | ||||
|         assert result, 'Inference failed, see stderr' | ||||
| 
 | ||||
|         return logits_out, state_out | ||||
| 
 | ||||
|     def free(self): | ||||
|         """ | ||||
|         Frees all allocated resources. | ||||
|         In case of any error, this method will throw an exception. | ||||
|         The object must not be used anymore after calling this method. | ||||
|         """ | ||||
| 
 | ||||
|         assert self.valid, 'Already freed' | ||||
| 
 | ||||
|         self.valid = False | ||||
| 
 | ||||
|         self.library.rwkv_free(self.ctx) | ||||
|  | @ -0,0 +1,117 @@ | |||
| import os | ||||
| import torch | ||||
| import multiprocessing | ||||
| import rwkv_cpp_shared_library | ||||
| from typing import Tuple, Optional | ||||
| 
 | ||||
| class RWKVModel: | ||||
|     """ | ||||
|     PyTorch wrapper around rwkv.cpp model. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|             self, | ||||
|             shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary, | ||||
|             model_path: str, | ||||
|             thread_count: int = max(1, multiprocessing.cpu_count() // 2) | ||||
|     ): | ||||
|         """ | ||||
|         Loads the model and prepares it for inference. | ||||
|         In case of any error, this method will throw an exception. | ||||
| 
 | ||||
|         Parameters | ||||
|         ---------- | ||||
|         shared_library : RWKVSharedLibrary | ||||
|             rwkv.cpp shared library. | ||||
|         model_path : str | ||||
|             Path to RWKV model file in ggml format. | ||||
|         thread_count : int | ||||
|             Thread count to use. If not set, defaults to CPU count / 2. | ||||
|         """ | ||||
| 
 | ||||
|         assert os.path.isfile(model_path), f'{model_path} is not a file' | ||||
|         assert thread_count > 0, 'Thread count must be positive' | ||||
| 
 | ||||
|         self.library = shared_library | ||||
| 
 | ||||
|         self.ctx = self.library.rwkv_init_from_file(model_path, thread_count) | ||||
| 
 | ||||
|         self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx) | ||||
|         self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx) | ||||
| 
 | ||||
|         self.valid = True | ||||
| 
 | ||||
|     def eval( | ||||
|             self, | ||||
|             token: int, | ||||
|             state_in: Optional[torch.Tensor], | ||||
|             state_out: Optional[torch.Tensor] = None, | ||||
|             logits_out: Optional[torch.Tensor] = None | ||||
|     ) -> Tuple[torch.Tensor, torch.Tensor]: | ||||
|         """ | ||||
|         Evaluates the model for a single token. | ||||
|         In case of any error, this method will throw an exception. | ||||
| 
 | ||||
|         Parameters | ||||
|         ---------- | ||||
|         token : int | ||||
|             Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab. | ||||
|         state_in : Optional[torch.Tensor] | ||||
|             State from previous call of this method. If this is a first pass, set it to None. | ||||
|         state_out : Optional[torch.Tensor] | ||||
|             Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count). | ||||
|         logits_out : Optional[torch.Tensor] | ||||
|             Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count). | ||||
| 
 | ||||
|         Returns | ||||
|         ------- | ||||
|         logits, state | ||||
|             Logits vector of shape (n_vocab); state for the next step. | ||||
|         """ | ||||
| 
 | ||||
|         assert self.valid, 'Model was freed' | ||||
| 
 | ||||
|         def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None: | ||||
|             assert buf.dtype == torch.float32, f'{name} is not of type float32' | ||||
|             assert buf.is_contiguous(), f'{name} is not contiguous' | ||||
|             assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})' | ||||
| 
 | ||||
|         if state_in is not None: | ||||
|             validate_buffer(state_in, 'state_in', self.state_buffer_element_count) | ||||
| 
 | ||||
|             state_in_ptr = state_in.storage().data_ptr() | ||||
|         else: | ||||
|             state_in_ptr = 0 | ||||
| 
 | ||||
|         if state_out is not None: | ||||
|             validate_buffer(state_out, 'state_out', self.state_buffer_element_count) | ||||
|         else: | ||||
|             state_out = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu') | ||||
| 
 | ||||
|         if logits_out is not None: | ||||
|             validate_buffer(logits_out, 'logits_out', self.logits_buffer_element_count) | ||||
|         else: | ||||
|             logits_out = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu') | ||||
| 
 | ||||
|         self.library.rwkv_eval( | ||||
|             self.ctx, | ||||
|             token, | ||||
|             state_in_ptr, | ||||
|             state_out.storage().data_ptr(), | ||||
|             logits_out.storage().data_ptr() | ||||
|         ) | ||||
| 
 | ||||
|         return logits_out, state_out | ||||
| 
 | ||||
|     def free(self): | ||||
|         """ | ||||
|         Frees all allocated resources. | ||||
|         In case of any error, this method will throw an exception. | ||||
|         The object must not be used anymore after calling this method. | ||||
|         """ | ||||
| 
 | ||||
|         assert self.valid, 'Already freed' | ||||
| 
 | ||||
|         self.valid = False | ||||
| 
 | ||||
|         self.library.rwkv_free(self.ctx) | ||||
|  | @ -0,0 +1,204 @@ | |||
| import os | ||||
| import sys | ||||
| import ctypes | ||||
| from typing import Optional | ||||
| 
 | ||||
| P_FLOAT = ctypes.POINTER(ctypes.c_float) | ||||
| 
 | ||||
| class RWKVContext: | ||||
| 
 | ||||
|     def __init__(self, ptr: ctypes.pointer): | ||||
|         self.ptr = ptr | ||||
| 
 | ||||
| class RWKVSharedLibrary: | ||||
|     """ | ||||
|     Python wrapper around rwkv.cpp shared library. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, shared_library_path: str): | ||||
|         """ | ||||
|         Loads the shared library from specified file. | ||||
|         In case of any error, this method will throw an exception. | ||||
| 
 | ||||
|         Parameters | ||||
|         ---------- | ||||
|         shared_library_path : str | ||||
|             Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'. | ||||
|         """ | ||||
| 
 | ||||
|         self.library = ctypes.cdll.LoadLibrary(shared_library_path) | ||||
| 
 | ||||
|         self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32] | ||||
|         self.library.rwkv_init_from_file.restype = ctypes.c_void_p | ||||
| 
 | ||||
|         self.library.rwkv_eval.argtypes = [ | ||||
|             ctypes.c_void_p, # ctx | ||||
|             ctypes.c_int32, # token | ||||
|             P_FLOAT, # state_in | ||||
|             P_FLOAT, # state_out | ||||
|             P_FLOAT  # logits_out | ||||
|         ] | ||||
|         self.library.rwkv_eval.restype = ctypes.c_bool | ||||
| 
 | ||||
|         self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p] | ||||
|         self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_uint32 | ||||
| 
 | ||||
|         self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p] | ||||
|         self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_uint32 | ||||
| 
 | ||||
|         self.library.rwkv_free.argtypes = [ctypes.c_void_p] | ||||
|         self.library.rwkv_free.restype = None | ||||
| 
 | ||||
|         self.library.rwkv_free.argtypes = [ctypes.c_void_p] | ||||
|         self.library.rwkv_free.restype = None | ||||
| 
 | ||||
|         self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_uint32] | ||||
|         self.library.rwkv_quantize_model_file.restype = ctypes.c_bool | ||||
| 
 | ||||
|         self.library.rwkv_get_system_info_string.argtypes = [] | ||||
|         self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p | ||||
| 
 | ||||
|     def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext: | ||||
|         """ | ||||
|         Loads the model from a file and prepares it for inference. | ||||
|         Throws an exception in case of any error. Error messages would be printed to stderr. | ||||
| 
 | ||||
|         Parameters | ||||
|         ---------- | ||||
|         model_file_path : str | ||||
|             Path to model file in ggml format. | ||||
|         thread_count : int | ||||
|             Count of threads to use, must be positive. | ||||
|         """ | ||||
| 
 | ||||
|         ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count)) | ||||
|         assert ptr is not None, 'rwkv_init_from_file failed, check stderr' | ||||
|         return RWKVContext(ptr) | ||||
| 
 | ||||
|     def rwkv_eval( | ||||
|             self, | ||||
|             ctx: RWKVContext, | ||||
|             token: int, | ||||
|             state_in_address: Optional[int], | ||||
|             state_out_address: int, | ||||
|             logits_out_address: int | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Evaluates the model for a single token. | ||||
|         Throws an exception in case of any error. Error messages would be printed to stderr. | ||||
| 
 | ||||
|         Parameters | ||||
|         ---------- | ||||
|         ctx : RWKVContext | ||||
|             RWKV context obtained from rwkv_init_from_file. | ||||
|         token : int | ||||
|             Next token index, in range 0 <= token < n_vocab. | ||||
|         state_in_address : int | ||||
|             Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass. | ||||
|         state_out_address : int | ||||
|             Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to. | ||||
|         logits_out_address : int | ||||
|             Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. | ||||
|         """ | ||||
| 
 | ||||
|         assert self.library.rwkv_eval( | ||||
|             ctx.ptr, | ||||
|             ctypes.c_int32(token), | ||||
|             ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), | ||||
|             ctypes.cast(state_out_address, P_FLOAT), | ||||
|             ctypes.cast(logits_out_address, P_FLOAT) | ||||
|         ), 'rwkv_eval failed, check stderr' | ||||
| 
 | ||||
|     def rwkv_get_state_buffer_element_count(self, ctx: RWKVContext) -> int: | ||||
|         """ | ||||
|         Returns count of FP32 elements in state buffer. | ||||
| 
 | ||||
|         Parameters | ||||
|         ---------- | ||||
|         ctx : RWKVContext | ||||
|             RWKV context obtained from rwkv_init_from_file. | ||||
|         """ | ||||
| 
 | ||||
|         return self.library.rwkv_get_state_buffer_element_count(ctx.ptr) | ||||
| 
 | ||||
|     def rwkv_get_logits_buffer_element_count(self, ctx: RWKVContext) -> int: | ||||
|         """ | ||||
|         Returns count of FP32 elements in logits buffer. | ||||
| 
 | ||||
|         Parameters | ||||
|         ---------- | ||||
|         ctx : RWKVContext | ||||
|             RWKV context obtained from rwkv_init_from_file. | ||||
|         """ | ||||
| 
 | ||||
|         return self.library.rwkv_get_logits_buffer_element_count(ctx.ptr) | ||||
| 
 | ||||
|     def rwkv_free(self, ctx: RWKVContext) -> None: | ||||
|         """ | ||||
|         Frees all allocated memory and the context. | ||||
| 
 | ||||
|         Parameters | ||||
|         ---------- | ||||
|         ctx : RWKVContext | ||||
|             RWKV context obtained from rwkv_init_from_file. | ||||
|         """ | ||||
| 
 | ||||
|         self.library.rwkv_free(ctx.ptr) | ||||
| 
 | ||||
|         ctx.ptr = ctypes.cast(0, ctypes.c_void_p) | ||||
| 
 | ||||
|     def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, q_type: int) -> None: | ||||
|         """ | ||||
|         Quantizes FP32 or FP16 model to one of INT4 formats. | ||||
|         Throws an exception in case of any error. Error messages would be printed to stderr. | ||||
| 
 | ||||
|         Parameters | ||||
|         ---------- | ||||
|         model_file_path_in : str | ||||
|             Path to model file in ggml format, must be either FP32 or FP16. | ||||
|         model_file_path_out : str | ||||
|             Quantized model will be written here. | ||||
|         q_type : int | ||||
|             Set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1. | ||||
|         """ | ||||
| 
 | ||||
|         assert self.library.rwkv_quantize_model_file( | ||||
|             model_file_path_in.encode('utf-8'), | ||||
|             model_file_path_out.encode('utf-8'), | ||||
|             ctypes.c_uint32(q_type) | ||||
|         ), 'rwkv_quantize_model_file failed, check stderr' | ||||
| 
 | ||||
|     def rwkv_get_system_info_string(self) -> str: | ||||
|         """ | ||||
|         Returns system information string. | ||||
|         """ | ||||
| 
 | ||||
|         return self.library.rwkv_get_system_info_string() | ||||
| 
 | ||||
| def load_rwkv_shared_library() -> RWKVSharedLibrary: | ||||
|     """ | ||||
|     Attempts to find rwkv.cpp shared library and load it. | ||||
|     To specify exact path to the library, create an instance of RWKVSharedLibrary explicitly. | ||||
|     """ | ||||
| 
 | ||||
|     file_name: str | ||||
| 
 | ||||
|     if 'win32' in sys.platform or 'cygwin' in sys.platform: | ||||
|         file_name = 'rwkv.dll' | ||||
|     else: | ||||
|         file_name = 'rwkv.so' | ||||
| 
 | ||||
|     paths = [ | ||||
|         # If we are in "rwkv" directory | ||||
|         f'../bin/Release/{file_name}', | ||||
|         # If we are in repo root directory | ||||
|         f'bin/Release/{file_name}', | ||||
|         # Fallback | ||||
|         file_name | ||||
|     ] | ||||
| 
 | ||||
|     for path in paths: | ||||
|         if os.path.isfile(path): | ||||
|             return RWKVSharedLibrary(path) | ||||
| 
 | ||||
|     return RWKVSharedLibrary(paths[-1]) | ||||
		Loading…
	
		Reference in New Issue