Move library wrapper to separate file, refactor code
This commit is contained in:
		
							parent
							
								
									38f9d02d52
								
							
						
					
					
						commit
						935d16f5db
					
				
							
								
								
									
										12
									
								
								README.md
								
								
								
								
							
							
						
						
									
										12
									
								
								README.md
								
								
								
								
							|  | @ -47,10 +47,14 @@ python rwkv\convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pt | ||||||
| #### 3. Use the model in Python: | #### 3. Use the model in Python: | ||||||
| 
 | 
 | ||||||
| ```python | ```python | ||||||
| # This file is located at rwkv/rwkv_cpp.py | # These files are located in rwkv directory | ||||||
| import rwkv_cpp | import rwkv_cpp_model | ||||||
|  | import rwkv_cpp_shared_library | ||||||
| 
 | 
 | ||||||
| model = rwkv_cpp.RWKVModel(r'bin\Release\rwkv.dll', r'C:\rwkv.cpp-169M.bin') | model = rwkv_cpp_model.RWKVModel( | ||||||
|  |     rwkv_cpp_shared_library.load_rwkv_shared_library(), | ||||||
|  |     r'C:\rwkv.cpp-169M.bin' | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| logits, state = None, None | logits, state = None, None | ||||||
| 
 | 
 | ||||||
|  | @ -59,7 +63,7 @@ for token in [1, 2, 3]: | ||||||
|      |      | ||||||
|     print(f'Output logits: {logits}') |     print(f'Output logits: {logits}') | ||||||
| 
 | 
 | ||||||
| # Don't forget to free memory after you've done working with the model | # Don't forget to free the memory after you've done working with the model | ||||||
| model.free() | model.free() | ||||||
| 
 | 
 | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
							
								
								
									
										10
									
								
								rwkv.cpp
								
								
								
								
							
							
						
						
									
										10
									
								
								rwkv.cpp
								
								
								
								
							|  | @ -163,7 +163,7 @@ struct rwkv_context { | ||||||
|     bool freed; |     bool freed; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads) { | struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_threads) { | ||||||
|     FILE * file = fopen(file_path, "rb"); |     FILE * file = fopen(file_path, "rb"); | ||||||
|     RWKV_ASSERT_NULL(file != NULL, "Failed to open file %s", file_path); |     RWKV_ASSERT_NULL(file != NULL, "Failed to open file %s", file_path); | ||||||
| 
 | 
 | ||||||
|  | @ -505,15 +505,15 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads) | ||||||
|     return rwkv_ctx; |     return rwkv_ctx; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| size_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx) { | uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx) { | ||||||
|     return ctx->model->n_layer * 5 * ctx->model->n_embed; |     return ctx->model->n_layer * 5 * ctx->model->n_embed; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| size_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx) { | uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx) { | ||||||
|     return ctx->model->n_vocab; |     return ctx->model->n_vocab; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| bool rwkv_eval(struct rwkv_context * ctx, long int token, float * state_in, float * state_out, float * logits_out) { | bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out) { | ||||||
|     RWKV_ASSERT_FALSE(state_out != NULL, "state_out is NULL"); |     RWKV_ASSERT_FALSE(state_out != NULL, "state_out is NULL"); | ||||||
|     RWKV_ASSERT_FALSE(logits_out != NULL, "logits_out is NULL"); |     RWKV_ASSERT_FALSE(logits_out != NULL, "logits_out is NULL"); | ||||||
| 
 | 
 | ||||||
|  | @ -564,7 +564,7 @@ void rwkv_free(struct rwkv_context * ctx) { | ||||||
|     delete ctx; |     delete ctx; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type) { | bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) { | ||||||
|     RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type); |     RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type); | ||||||
| 
 | 
 | ||||||
|     ggml_type type; |     ggml_type type; | ||||||
|  |  | ||||||
							
								
								
									
										10
									
								
								rwkv.h
								
								
								
								
							
							
						
						
									
										10
									
								
								rwkv.h
								
								
								
								
							|  | @ -33,7 +33,7 @@ extern "C" { | ||||||
|     // Returns NULL on any error. Error messages would be printed to stderr.
 |     // Returns NULL on any error. Error messages would be printed to stderr.
 | ||||||
|     // - model_file_path: path to model file in ggml format.
 |     // - model_file_path: path to model file in ggml format.
 | ||||||
|     // - n_threads: count of threads to use, must be positive.
 |     // - n_threads: count of threads to use, must be positive.
 | ||||||
|     RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads); |     RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, uint32_t n_threads); | ||||||
| 
 | 
 | ||||||
|     // Evaluates the model for a single token.
 |     // Evaluates the model for a single token.
 | ||||||
|     // Returns false on any error. Error messages would be printed to stderr.
 |     // Returns false on any error. Error messages would be printed to stderr.
 | ||||||
|  | @ -41,13 +41,13 @@ extern "C" { | ||||||
|     // - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
 |     // - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
 | ||||||
|     // - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
 |     // - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
 | ||||||
|     // - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
 |     // - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
 | ||||||
|     RWKV_API bool rwkv_eval(struct rwkv_context * ctx, long int token, float * state_in, float * state_out, float * logits_out); |     RWKV_API bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out); | ||||||
| 
 | 
 | ||||||
|     // Returns count of FP32 elements in state buffer.
 |     // Returns count of FP32 elements in state buffer.
 | ||||||
|     RWKV_API size_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx); |     RWKV_API uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx); | ||||||
| 
 | 
 | ||||||
|     // Returns count of FP32 elements in logits buffer.
 |     // Returns count of FP32 elements in logits buffer.
 | ||||||
|     RWKV_API size_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx); |     RWKV_API uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx); | ||||||
| 
 | 
 | ||||||
|     // Frees all allocated memory and the context.
 |     // Frees all allocated memory and the context.
 | ||||||
|     RWKV_API void rwkv_free(struct rwkv_context * ctx); |     RWKV_API void rwkv_free(struct rwkv_context * ctx); | ||||||
|  | @ -57,7 +57,7 @@ extern "C" { | ||||||
|     // - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
 |     // - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
 | ||||||
|     // - model_file_path_out: quantized model will be written here.
 |     // - model_file_path_out: quantized model will be written here.
 | ||||||
|     // - q_type: set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1.
 |     // - q_type: set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1.
 | ||||||
|     RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type); |     RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type); | ||||||
| 
 | 
 | ||||||
|     // Returns system information string.
 |     // Returns system information string.
 | ||||||
|     RWKV_API const char * rwkv_get_system_info_string(void); |     RWKV_API const char * rwkv_get_system_info_string(void); | ||||||
|  |  | ||||||
|  | @ -1,20 +1,19 @@ | ||||||
| # Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV. | # Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV. | ||||||
| # Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch. | # Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch. | ||||||
| # Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py | # Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py | ||||||
| # Usage: python compare_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin | # Usage: python compare_with_reference_implementation.py C:\rwkv.cpp-169M.bin | ||||||
| 
 | 
 | ||||||
| import os | import os | ||||||
| import struct | import struct | ||||||
| import argparse | import argparse | ||||||
| import subprocess |  | ||||||
| import torch | import torch | ||||||
| import numpy as np | import numpy as np | ||||||
| import rwkv_cpp | import rwkv_cpp_model | ||||||
|  | import rwkv_cpp_shared_library | ||||||
| from typing import List, Tuple, Any | from typing import List, Tuple, Any | ||||||
| 
 | 
 | ||||||
| def parse_args(): | def parse_args(): | ||||||
|     parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV') |     parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV') | ||||||
|     parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file or shared library') |  | ||||||
|     parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file') |     parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file') | ||||||
|     return parser.parse_args() |     return parser.parse_args() | ||||||
| 
 | 
 | ||||||
|  | @ -22,17 +21,12 @@ def main() -> None: | ||||||
|     args = parse_args() |     args = parse_args() | ||||||
| 
 | 
 | ||||||
|     # Don't want to depend on tokenizer here. |     # Don't want to depend on tokenizer here. | ||||||
|     # Exact string is: |     tokens: List[int] = [4, 3631, 4420, 2412, 953, 432, 391, 30567, 87, 15, 14161, 7092, 273, 416, 27767, 55, 342, | ||||||
|     # context = "1 In the beginning God created the heaven and the earth. " \ |                          2412, 953, 432, 3806, 7092, 273, 416, 27767, 55, 15, 187, 4, 19039, 2412, 953, 497, 4561, | ||||||
|     #           "2 And the earth was without form, and void; and darkness was upon the face of the deep. And the Spirit of God moved upon the face of the waters. " \ |                          342, 416, 27767, 55, 14, 21, 14, 49, 587, 14, 17809, 46, 14, 938, 14256, 28950, 14, 1438, | ||||||
|     #           "3 And God said, Let there be light: and there was light. " \ |                          1508, 15, 81, 394, 1566, 275, 8462, 22097, 348, 15, 187, 4, 43825, 27, 15548, 7277, 64, | ||||||
|     #           "4 And God saw the light, that it was good: and God divided the light from the darkness." |                          3113, 64, 14005, 64, 39595, 15, 4789, 10269, 61, 18992, 61, 7265, 64, 30217, 39297, 15, | ||||||
|     # The Bible was the first non-copyrighted public domain text that came to my mind. |                          20963, 330, 27, 190, 30567, 87, 15, 14161, 14, 17809, 46, 15, 4805] | ||||||
|     tokens: List[int] = [18, 496, 253, 5068, 2656, 3562, 253, 13926, 285, 253, 6149, 15, 374, 1244, 253, 6149, 369, 1293, 830, |  | ||||||
|                          13, 285, 2991, 28, 285, 13862, 369, 2220, 253, 2454, 273, 253, 3676, 15, 1244, 253, 14559, 273, 2656, |  | ||||||
|                          4395, 2220, 253, 2454, 273, 253, 12685, 15, 495, 1244, 2656, 753, 13, 1281, 627, 320, 1708, 27, 285, |  | ||||||
|                          627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272, |  | ||||||
|                          253, 1708, 432, 253, 13862, 15] |  | ||||||
| 
 | 
 | ||||||
|     threshold: float |     threshold: float | ||||||
| 
 | 
 | ||||||
|  | @ -50,7 +44,7 @@ def main() -> None: | ||||||
|             threshold = 0.000005 |             threshold = 0.000005 | ||||||
|         elif data_type == 1: |         elif data_type == 1: | ||||||
|             # FP16, lower precision, so higher threshold |             # FP16, lower precision, so higher threshold | ||||||
|             threshold = 0.003 |             threshold = 0.0032 | ||||||
|         elif data_type == 2: |         elif data_type == 2: | ||||||
|             # INT4 quantized, even lower precision, so even higher threshold |             # INT4 quantized, even lower precision, so even higher threshold | ||||||
|             # This threshold will let some bugs pass |             # This threshold will let some bugs pass | ||||||
|  | @ -59,42 +53,24 @@ def main() -> None: | ||||||
|             # This format stores more data, so error would be lower |             # This format stores more data, so error would be lower | ||||||
|             threshold = 1.2 |             threshold = 1.2 | ||||||
| 
 | 
 | ||||||
|     model = None |     model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path) | ||||||
| 
 |  | ||||||
|     if args.main_executable_path.lower().endswith('.dll') or args.main_executable_path.lower().endswith('.so'): |  | ||||||
|         print('Testing shared library') |  | ||||||
| 
 |  | ||||||
|         model = rwkv_cpp.RWKVModel(args.main_executable_path, args.ggml_model_path) |  | ||||||
|     else: |  | ||||||
|         print('Testing main_rwkv executable') |  | ||||||
| 
 | 
 | ||||||
|     def compare_logits(tokens_subset: List[int]) -> None: |     def compare_logits(tokens_subset: List[int]) -> None: | ||||||
|         token_count: int = len(tokens_subset) |         token_count: int = len(tokens_subset) | ||||||
|         state_path: str = './state.bin' |  | ||||||
|         logits_path: str = './logits.bin' |  | ||||||
| 
 | 
 | ||||||
|         logits, state = None, None |         logits, state = None, None | ||||||
| 
 | 
 | ||||||
|         for i in range(token_count): |         for i in range(token_count): | ||||||
|             token: int = tokens_subset[i] |             token: int = tokens_subset[i] | ||||||
| 
 | 
 | ||||||
|             print(f'{i + 1}/{token_count}') |             if token_count <= 10 or i % (token_count // 10) == 0: | ||||||
|  |                 print(f'{i + 1}/{token_count}') | ||||||
| 
 | 
 | ||||||
|             if model is not None: |             logits, state = model.eval(token, state, state, logits) | ||||||
|                 logits, state = model.eval(token, state) | 
 | ||||||
|             else: |         actual_logits = logits | ||||||
|                 subprocess.run( | 
 | ||||||
|                     [ |         # --- | ||||||
|                         args.main_executable_path, |  | ||||||
|                         args.ggml_model_path, |  | ||||||
|                         str(token), |  | ||||||
|                         # If this is the first token, let the script create a new state. |  | ||||||
|                         '' if i == 0 else state_path, |  | ||||||
|                         state_path, |  | ||||||
|                         logits_path |  | ||||||
|                     ], |  | ||||||
|                     check=True |  | ||||||
|                 ) |  | ||||||
| 
 | 
 | ||||||
|         expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin' |         expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin' | ||||||
| 
 | 
 | ||||||
|  | @ -104,11 +80,7 @@ def main() -> None: | ||||||
|         with open(expected_logits_path, 'rb') as logits_file: |         with open(expected_logits_path, 'rb') as logits_file: | ||||||
|             expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) |             expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) | ||||||
| 
 | 
 | ||||||
|         if model is not None: |         # --- | ||||||
|             actual_logits = logits |  | ||||||
|         else: |  | ||||||
|             with open(logits_path, 'rb') as logits_file: |  | ||||||
|                 actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) |  | ||||||
| 
 | 
 | ||||||
|         difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item() |         difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item() | ||||||
| 
 | 
 | ||||||
|  | @ -118,8 +90,6 @@ def main() -> None: | ||||||
| 
 | 
 | ||||||
|         assert abs(difference) <= threshold, 'Difference is too big' |         assert abs(difference) <= threshold, 'Difference is too big' | ||||||
| 
 | 
 | ||||||
|     # Check small token amount first to avoid waiting too long before seeing that model is broken |  | ||||||
|     compare_logits(tokens[:4]) |  | ||||||
|     compare_logits(tokens) |     compare_logits(tokens) | ||||||
| 
 | 
 | ||||||
|     print() |     print() | ||||||
|  |  | ||||||
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							|  | @ -1,12 +1,11 @@ | ||||||
| # Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1. | # Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1. | ||||||
| # Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3 | # Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3 | ||||||
| 
 | 
 | ||||||
| import ctypes |  | ||||||
| import argparse | import argparse | ||||||
|  | import rwkv_cpp_shared_library | ||||||
| 
 | 
 | ||||||
| def parse_args(): | def parse_args(): | ||||||
|     parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1') |     parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1') | ||||||
|     parser.add_argument('shared_library_path', help='Path to rwkv.cpp shared library') |  | ||||||
|     parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file') |     parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file') | ||||||
|     parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten') |     parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten') | ||||||
|     parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3) |     parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3) | ||||||
|  | @ -15,19 +14,14 @@ def parse_args(): | ||||||
| def main() -> None: | def main() -> None: | ||||||
|     args = parse_args() |     args = parse_args() | ||||||
| 
 | 
 | ||||||
|     library = ctypes.cdll.LoadLibrary(args.shared_library_path) |     library = rwkv_cpp_shared_library.load_rwkv_shared_library() | ||||||
| 
 | 
 | ||||||
|     library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_int] |     library.rwkv_quantize_model_file( | ||||||
|     library.rwkv_quantize_model_file.restype = ctypes.c_bool |         args.src_path, | ||||||
| 
 |         args.dest_path, | ||||||
|     result: bool = library.rwkv_quantize_model_file( |         args.data_type | ||||||
|         args.src_path.encode('utf-8'), |  | ||||||
|         args.dest_path.encode('utf-8'), |  | ||||||
|         ctypes.c_int(args.data_type) |  | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     assert result, 'Failed to quantize, check stderr' |  | ||||||
| 
 |  | ||||||
|     print('Done') |     print('Done') | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|  |  | ||||||
							
								
								
									
										127
									
								
								rwkv/rwkv_cpp.py
								
								
								
								
							
							
						
						
									
										127
									
								
								rwkv/rwkv_cpp.py
								
								
								
								
							|  | @ -1,127 +0,0 @@ | ||||||
| import os |  | ||||||
| import ctypes |  | ||||||
| import torch |  | ||||||
| import multiprocessing |  | ||||||
| from typing import Tuple, Optional |  | ||||||
| 
 |  | ||||||
| P_FLOAT = ctypes.POINTER(ctypes.c_float) |  | ||||||
| 
 |  | ||||||
| class RWKVModel: |  | ||||||
|     """ |  | ||||||
|     PyTorch wrapper around rwkv.cpp shared library. |  | ||||||
|     """ |  | ||||||
| 
 |  | ||||||
|     def __init__( |  | ||||||
|             self, |  | ||||||
|             shared_library_path: str, |  | ||||||
|             model_path: str, |  | ||||||
|             thread_count: int = max(1, multiprocessing.cpu_count() // 2) |  | ||||||
|     ): |  | ||||||
|         """ |  | ||||||
|         Loads the model and prepares it for inference. |  | ||||||
|         In case of any error, this method will throw an exception. |  | ||||||
| 
 |  | ||||||
|         Parameters |  | ||||||
|         ---------- |  | ||||||
|         shared_library_path : str |  | ||||||
|             Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'. |  | ||||||
|         model_path : str |  | ||||||
|             Path to RWKV model file in ggml format. |  | ||||||
|         thread_count : int |  | ||||||
|             Thread count to use. If not set, defaults to CPU count / 2. |  | ||||||
|         """ |  | ||||||
| 
 |  | ||||||
|         assert os.path.isfile(shared_library_path), f'{shared_library_path} is not a file' |  | ||||||
|         assert os.path.isfile(model_path), f'{model_path} is not a file' |  | ||||||
|         assert thread_count > 0, 'Thread count must be positive' |  | ||||||
| 
 |  | ||||||
|         self.library = ctypes.cdll.LoadLibrary(shared_library_path) |  | ||||||
| 
 |  | ||||||
|         self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_int] |  | ||||||
|         self.library.rwkv_init_from_file.restype = ctypes.c_void_p |  | ||||||
| 
 |  | ||||||
|         self.library.rwkv_eval.argtypes = [ |  | ||||||
|             ctypes.c_void_p, # ctx |  | ||||||
|             ctypes.c_long, # token |  | ||||||
|             P_FLOAT, # state_in |  | ||||||
|             P_FLOAT, # state_out |  | ||||||
|             P_FLOAT  # logits_out |  | ||||||
|         ] |  | ||||||
|         self.library.rwkv_eval.restype = ctypes.c_bool |  | ||||||
| 
 |  | ||||||
|         self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p] |  | ||||||
|         self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_size_t |  | ||||||
| 
 |  | ||||||
|         self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p] |  | ||||||
|         self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_size_t |  | ||||||
| 
 |  | ||||||
|         self.library.rwkv_free.argtypes = [ctypes.c_void_p] |  | ||||||
|         self.library.rwkv_free.restype = None |  | ||||||
| 
 |  | ||||||
|         self.ctx = self.library.rwkv_init_from_file(model_path.encode('utf-8'), ctypes.c_int(thread_count)) |  | ||||||
| 
 |  | ||||||
|         assert self.ctx is not None, 'Failed to load the model, see stderr' |  | ||||||
| 
 |  | ||||||
|         self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx) |  | ||||||
|         self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx) |  | ||||||
| 
 |  | ||||||
|         self.valid = True |  | ||||||
| 
 |  | ||||||
|     def eval(self, token: int, state_in: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: |  | ||||||
|         """ |  | ||||||
|         Evaluates the model for a single token. |  | ||||||
|         In case of any error, this method will throw an exception. |  | ||||||
| 
 |  | ||||||
|         Parameters |  | ||||||
|         ---------- |  | ||||||
|         token : int |  | ||||||
|             Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab. |  | ||||||
|         state_in : Optional[torch.Tensor] |  | ||||||
|             State from previous call of this method. If this is a first pass, set it to None. |  | ||||||
| 
 |  | ||||||
|         Returns |  | ||||||
|         ------- |  | ||||||
|         logits, state |  | ||||||
|             Logits vector of shape (n_vocab); state for the next step. |  | ||||||
|         """ |  | ||||||
| 
 |  | ||||||
|         assert self.valid, 'Model was freed' |  | ||||||
| 
 |  | ||||||
|         if state_in is None: |  | ||||||
|             state_in_ptr = 0 |  | ||||||
|         else: |  | ||||||
|             expected_shape = (self.state_buffer_element_count,) |  | ||||||
| 
 |  | ||||||
|             assert state_in.is_contiguous(), 'State tensor is not contiguous' |  | ||||||
|             assert state_in.shape == expected_shape, f'Invalid state shape {state_in.shape}, expected {expected_shape}' |  | ||||||
| 
 |  | ||||||
|             state_in_ptr = state_in.storage().data_ptr() |  | ||||||
| 
 |  | ||||||
|         # TODO Probably these allocations can be optimized away |  | ||||||
|         state_out: torch.Tensor = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu') |  | ||||||
|         logits_out: torch.Tensor = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu') |  | ||||||
| 
 |  | ||||||
|         result = self.library.rwkv_eval( |  | ||||||
|             self.ctx, |  | ||||||
|             ctypes.c_long(token), |  | ||||||
|             ctypes.cast(state_in_ptr, P_FLOAT), |  | ||||||
|             ctypes.cast(state_out.storage().data_ptr(), P_FLOAT), |  | ||||||
|             ctypes.cast(logits_out.storage().data_ptr(), P_FLOAT) |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         assert result, 'Inference failed, see stderr' |  | ||||||
| 
 |  | ||||||
|         return logits_out, state_out |  | ||||||
| 
 |  | ||||||
|     def free(self): |  | ||||||
|         """ |  | ||||||
|         Frees all allocated resources. |  | ||||||
|         In case of any error, this method will throw an exception. |  | ||||||
|         The object must not be used anymore after calling this method. |  | ||||||
|         """ |  | ||||||
| 
 |  | ||||||
|         assert self.valid, 'Already freed' |  | ||||||
| 
 |  | ||||||
|         self.valid = False |  | ||||||
| 
 |  | ||||||
|         self.library.rwkv_free(self.ctx) |  | ||||||
|  | @ -0,0 +1,117 @@ | ||||||
|  | import os | ||||||
|  | import torch | ||||||
|  | import multiprocessing | ||||||
|  | import rwkv_cpp_shared_library | ||||||
|  | from typing import Tuple, Optional | ||||||
|  | 
 | ||||||
|  | class RWKVModel: | ||||||
|  |     """ | ||||||
|  |     PyTorch wrapper around rwkv.cpp model. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |             self, | ||||||
|  |             shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary, | ||||||
|  |             model_path: str, | ||||||
|  |             thread_count: int = max(1, multiprocessing.cpu_count() // 2) | ||||||
|  |     ): | ||||||
|  |         """ | ||||||
|  |         Loads the model and prepares it for inference. | ||||||
|  |         In case of any error, this method will throw an exception. | ||||||
|  | 
 | ||||||
|  |         Parameters | ||||||
|  |         ---------- | ||||||
|  |         shared_library : RWKVSharedLibrary | ||||||
|  |             rwkv.cpp shared library. | ||||||
|  |         model_path : str | ||||||
|  |             Path to RWKV model file in ggml format. | ||||||
|  |         thread_count : int | ||||||
|  |             Thread count to use. If not set, defaults to CPU count / 2. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         assert os.path.isfile(model_path), f'{model_path} is not a file' | ||||||
|  |         assert thread_count > 0, 'Thread count must be positive' | ||||||
|  | 
 | ||||||
|  |         self.library = shared_library | ||||||
|  | 
 | ||||||
|  |         self.ctx = self.library.rwkv_init_from_file(model_path, thread_count) | ||||||
|  | 
 | ||||||
|  |         self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx) | ||||||
|  |         self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx) | ||||||
|  | 
 | ||||||
|  |         self.valid = True | ||||||
|  | 
 | ||||||
|  |     def eval( | ||||||
|  |             self, | ||||||
|  |             token: int, | ||||||
|  |             state_in: Optional[torch.Tensor], | ||||||
|  |             state_out: Optional[torch.Tensor] = None, | ||||||
|  |             logits_out: Optional[torch.Tensor] = None | ||||||
|  |     ) -> Tuple[torch.Tensor, torch.Tensor]: | ||||||
|  |         """ | ||||||
|  |         Evaluates the model for a single token. | ||||||
|  |         In case of any error, this method will throw an exception. | ||||||
|  | 
 | ||||||
|  |         Parameters | ||||||
|  |         ---------- | ||||||
|  |         token : int | ||||||
|  |             Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab. | ||||||
|  |         state_in : Optional[torch.Tensor] | ||||||
|  |             State from previous call of this method. If this is a first pass, set it to None. | ||||||
|  |         state_out : Optional[torch.Tensor] | ||||||
|  |             Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count). | ||||||
|  |         logits_out : Optional[torch.Tensor] | ||||||
|  |             Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count). | ||||||
|  | 
 | ||||||
|  |         Returns | ||||||
|  |         ------- | ||||||
|  |         logits, state | ||||||
|  |             Logits vector of shape (n_vocab); state for the next step. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         assert self.valid, 'Model was freed' | ||||||
|  | 
 | ||||||
|  |         def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None: | ||||||
|  |             assert buf.dtype == torch.float32, f'{name} is not of type float32' | ||||||
|  |             assert buf.is_contiguous(), f'{name} is not contiguous' | ||||||
|  |             assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})' | ||||||
|  | 
 | ||||||
|  |         if state_in is not None: | ||||||
|  |             validate_buffer(state_in, 'state_in', self.state_buffer_element_count) | ||||||
|  | 
 | ||||||
|  |             state_in_ptr = state_in.storage().data_ptr() | ||||||
|  |         else: | ||||||
|  |             state_in_ptr = 0 | ||||||
|  | 
 | ||||||
|  |         if state_out is not None: | ||||||
|  |             validate_buffer(state_out, 'state_out', self.state_buffer_element_count) | ||||||
|  |         else: | ||||||
|  |             state_out = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu') | ||||||
|  | 
 | ||||||
|  |         if logits_out is not None: | ||||||
|  |             validate_buffer(logits_out, 'logits_out', self.logits_buffer_element_count) | ||||||
|  |         else: | ||||||
|  |             logits_out = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu') | ||||||
|  | 
 | ||||||
|  |         self.library.rwkv_eval( | ||||||
|  |             self.ctx, | ||||||
|  |             token, | ||||||
|  |             state_in_ptr, | ||||||
|  |             state_out.storage().data_ptr(), | ||||||
|  |             logits_out.storage().data_ptr() | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         return logits_out, state_out | ||||||
|  | 
 | ||||||
|  |     def free(self): | ||||||
|  |         """ | ||||||
|  |         Frees all allocated resources. | ||||||
|  |         In case of any error, this method will throw an exception. | ||||||
|  |         The object must not be used anymore after calling this method. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         assert self.valid, 'Already freed' | ||||||
|  | 
 | ||||||
|  |         self.valid = False | ||||||
|  | 
 | ||||||
|  |         self.library.rwkv_free(self.ctx) | ||||||
|  | @ -0,0 +1,204 @@ | ||||||
|  | import os | ||||||
|  | import sys | ||||||
|  | import ctypes | ||||||
|  | from typing import Optional | ||||||
|  | 
 | ||||||
|  | P_FLOAT = ctypes.POINTER(ctypes.c_float) | ||||||
|  | 
 | ||||||
|  | class RWKVContext: | ||||||
|  | 
 | ||||||
|  |     def __init__(self, ptr: ctypes.pointer): | ||||||
|  |         self.ptr = ptr | ||||||
|  | 
 | ||||||
|  | class RWKVSharedLibrary: | ||||||
|  |     """ | ||||||
|  |     Python wrapper around rwkv.cpp shared library. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, shared_library_path: str): | ||||||
|  |         """ | ||||||
|  |         Loads the shared library from specified file. | ||||||
|  |         In case of any error, this method will throw an exception. | ||||||
|  | 
 | ||||||
|  |         Parameters | ||||||
|  |         ---------- | ||||||
|  |         shared_library_path : str | ||||||
|  |             Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         self.library = ctypes.cdll.LoadLibrary(shared_library_path) | ||||||
|  | 
 | ||||||
|  |         self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32] | ||||||
|  |         self.library.rwkv_init_from_file.restype = ctypes.c_void_p | ||||||
|  | 
 | ||||||
|  |         self.library.rwkv_eval.argtypes = [ | ||||||
|  |             ctypes.c_void_p, # ctx | ||||||
|  |             ctypes.c_int32, # token | ||||||
|  |             P_FLOAT, # state_in | ||||||
|  |             P_FLOAT, # state_out | ||||||
|  |             P_FLOAT  # logits_out | ||||||
|  |         ] | ||||||
|  |         self.library.rwkv_eval.restype = ctypes.c_bool | ||||||
|  | 
 | ||||||
|  |         self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p] | ||||||
|  |         self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_uint32 | ||||||
|  | 
 | ||||||
|  |         self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p] | ||||||
|  |         self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_uint32 | ||||||
|  | 
 | ||||||
|  |         self.library.rwkv_free.argtypes = [ctypes.c_void_p] | ||||||
|  |         self.library.rwkv_free.restype = None | ||||||
|  | 
 | ||||||
|  |         self.library.rwkv_free.argtypes = [ctypes.c_void_p] | ||||||
|  |         self.library.rwkv_free.restype = None | ||||||
|  | 
 | ||||||
|  |         self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_uint32] | ||||||
|  |         self.library.rwkv_quantize_model_file.restype = ctypes.c_bool | ||||||
|  | 
 | ||||||
|  |         self.library.rwkv_get_system_info_string.argtypes = [] | ||||||
|  |         self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p | ||||||
|  | 
 | ||||||
|  |     def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext: | ||||||
|  |         """ | ||||||
|  |         Loads the model from a file and prepares it for inference. | ||||||
|  |         Throws an exception in case of any error. Error messages would be printed to stderr. | ||||||
|  | 
 | ||||||
|  |         Parameters | ||||||
|  |         ---------- | ||||||
|  |         model_file_path : str | ||||||
|  |             Path to model file in ggml format. | ||||||
|  |         thread_count : int | ||||||
|  |             Count of threads to use, must be positive. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count)) | ||||||
|  |         assert ptr is not None, 'rwkv_init_from_file failed, check stderr' | ||||||
|  |         return RWKVContext(ptr) | ||||||
|  | 
 | ||||||
|  |     def rwkv_eval( | ||||||
|  |             self, | ||||||
|  |             ctx: RWKVContext, | ||||||
|  |             token: int, | ||||||
|  |             state_in_address: Optional[int], | ||||||
|  |             state_out_address: int, | ||||||
|  |             logits_out_address: int | ||||||
|  |     ) -> None: | ||||||
|  |         """ | ||||||
|  |         Evaluates the model for a single token. | ||||||
|  |         Throws an exception in case of any error. Error messages would be printed to stderr. | ||||||
|  | 
 | ||||||
|  |         Parameters | ||||||
|  |         ---------- | ||||||
|  |         ctx : RWKVContext | ||||||
|  |             RWKV context obtained from rwkv_init_from_file. | ||||||
|  |         token : int | ||||||
|  |             Next token index, in range 0 <= token < n_vocab. | ||||||
|  |         state_in_address : int | ||||||
|  |             Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass. | ||||||
|  |         state_out_address : int | ||||||
|  |             Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to. | ||||||
|  |         logits_out_address : int | ||||||
|  |             Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         assert self.library.rwkv_eval( | ||||||
|  |             ctx.ptr, | ||||||
|  |             ctypes.c_int32(token), | ||||||
|  |             ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), | ||||||
|  |             ctypes.cast(state_out_address, P_FLOAT), | ||||||
|  |             ctypes.cast(logits_out_address, P_FLOAT) | ||||||
|  |         ), 'rwkv_eval failed, check stderr' | ||||||
|  | 
 | ||||||
|  |     def rwkv_get_state_buffer_element_count(self, ctx: RWKVContext) -> int: | ||||||
|  |         """ | ||||||
|  |         Returns count of FP32 elements in state buffer. | ||||||
|  | 
 | ||||||
|  |         Parameters | ||||||
|  |         ---------- | ||||||
|  |         ctx : RWKVContext | ||||||
|  |             RWKV context obtained from rwkv_init_from_file. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         return self.library.rwkv_get_state_buffer_element_count(ctx.ptr) | ||||||
|  | 
 | ||||||
|  |     def rwkv_get_logits_buffer_element_count(self, ctx: RWKVContext) -> int: | ||||||
|  |         """ | ||||||
|  |         Returns count of FP32 elements in logits buffer. | ||||||
|  | 
 | ||||||
|  |         Parameters | ||||||
|  |         ---------- | ||||||
|  |         ctx : RWKVContext | ||||||
|  |             RWKV context obtained from rwkv_init_from_file. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         return self.library.rwkv_get_logits_buffer_element_count(ctx.ptr) | ||||||
|  | 
 | ||||||
|  |     def rwkv_free(self, ctx: RWKVContext) -> None: | ||||||
|  |         """ | ||||||
|  |         Frees all allocated memory and the context. | ||||||
|  | 
 | ||||||
|  |         Parameters | ||||||
|  |         ---------- | ||||||
|  |         ctx : RWKVContext | ||||||
|  |             RWKV context obtained from rwkv_init_from_file. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         self.library.rwkv_free(ctx.ptr) | ||||||
|  | 
 | ||||||
|  |         ctx.ptr = ctypes.cast(0, ctypes.c_void_p) | ||||||
|  | 
 | ||||||
|  |     def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, q_type: int) -> None: | ||||||
|  |         """ | ||||||
|  |         Quantizes FP32 or FP16 model to one of INT4 formats. | ||||||
|  |         Throws an exception in case of any error. Error messages would be printed to stderr. | ||||||
|  | 
 | ||||||
|  |         Parameters | ||||||
|  |         ---------- | ||||||
|  |         model_file_path_in : str | ||||||
|  |             Path to model file in ggml format, must be either FP32 or FP16. | ||||||
|  |         model_file_path_out : str | ||||||
|  |             Quantized model will be written here. | ||||||
|  |         q_type : int | ||||||
|  |             Set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         assert self.library.rwkv_quantize_model_file( | ||||||
|  |             model_file_path_in.encode('utf-8'), | ||||||
|  |             model_file_path_out.encode('utf-8'), | ||||||
|  |             ctypes.c_uint32(q_type) | ||||||
|  |         ), 'rwkv_quantize_model_file failed, check stderr' | ||||||
|  | 
 | ||||||
|  |     def rwkv_get_system_info_string(self) -> str: | ||||||
|  |         """ | ||||||
|  |         Returns system information string. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         return self.library.rwkv_get_system_info_string() | ||||||
|  | 
 | ||||||
|  | def load_rwkv_shared_library() -> RWKVSharedLibrary: | ||||||
|  |     """ | ||||||
|  |     Attempts to find rwkv.cpp shared library and load it. | ||||||
|  |     To specify exact path to the library, create an instance of RWKVSharedLibrary explicitly. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     file_name: str | ||||||
|  | 
 | ||||||
|  |     if 'win32' in sys.platform or 'cygwin' in sys.platform: | ||||||
|  |         file_name = 'rwkv.dll' | ||||||
|  |     else: | ||||||
|  |         file_name = 'rwkv.so' | ||||||
|  | 
 | ||||||
|  |     paths = [ | ||||||
|  |         # If we are in "rwkv" directory | ||||||
|  |         f'../bin/Release/{file_name}', | ||||||
|  |         # If we are in repo root directory | ||||||
|  |         f'bin/Release/{file_name}', | ||||||
|  |         # Fallback | ||||||
|  |         file_name | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|  |     for path in paths: | ||||||
|  |         if os.path.isfile(path): | ||||||
|  |             return RWKVSharedLibrary(path) | ||||||
|  | 
 | ||||||
|  |     return RWKVSharedLibrary(paths[-1]) | ||||||
		Loading…
	
		Reference in New Issue