Move library wrapper to separate file, refactor code

This commit is contained in:
saharNooby 2023-04-02 12:24:40 +04:00
parent 38f9d02d52
commit 935d16f5db
11 changed files with 364 additions and 202 deletions

View File

@ -47,10 +47,14 @@ python rwkv\convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pt
#### 3. Use the model in Python: #### 3. Use the model in Python:
```python ```python
# This file is located at rwkv/rwkv_cpp.py # These files are located in rwkv directory
import rwkv_cpp import rwkv_cpp_model
import rwkv_cpp_shared_library
model = rwkv_cpp.RWKVModel(r'bin\Release\rwkv.dll', r'C:\rwkv.cpp-169M.bin') model = rwkv_cpp_model.RWKVModel(
rwkv_cpp_shared_library.load_rwkv_shared_library(),
r'C:\rwkv.cpp-169M.bin'
)
logits, state = None, None logits, state = None, None
@ -59,7 +63,7 @@ for token in [1, 2, 3]:
print(f'Output logits: {logits}') print(f'Output logits: {logits}')
# Don't forget to free memory after you've done working with the model # Don't forget to free the memory after you've done working with the model
model.free() model.free()
``` ```

View File

@ -163,7 +163,7 @@ struct rwkv_context {
bool freed; bool freed;
}; };
struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads) { struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_threads) {
FILE * file = fopen(file_path, "rb"); FILE * file = fopen(file_path, "rb");
RWKV_ASSERT_NULL(file != NULL, "Failed to open file %s", file_path); RWKV_ASSERT_NULL(file != NULL, "Failed to open file %s", file_path);
@ -505,15 +505,15 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads)
return rwkv_ctx; return rwkv_ctx;
} }
size_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx) { uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx) {
return ctx->model->n_layer * 5 * ctx->model->n_embed; return ctx->model->n_layer * 5 * ctx->model->n_embed;
} }
size_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx) { uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx) {
return ctx->model->n_vocab; return ctx->model->n_vocab;
} }
bool rwkv_eval(struct rwkv_context * ctx, long int token, float * state_in, float * state_out, float * logits_out) { bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out) {
RWKV_ASSERT_FALSE(state_out != NULL, "state_out is NULL"); RWKV_ASSERT_FALSE(state_out != NULL, "state_out is NULL");
RWKV_ASSERT_FALSE(logits_out != NULL, "logits_out is NULL"); RWKV_ASSERT_FALSE(logits_out != NULL, "logits_out is NULL");
@ -564,7 +564,7 @@ void rwkv_free(struct rwkv_context * ctx) {
delete ctx; delete ctx;
} }
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type) { bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) {
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type); RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type);
ggml_type type; ggml_type type;

10
rwkv.h
View File

@ -33,7 +33,7 @@ extern "C" {
// Returns NULL on any error. Error messages would be printed to stderr. // Returns NULL on any error. Error messages would be printed to stderr.
// - model_file_path: path to model file in ggml format. // - model_file_path: path to model file in ggml format.
// - n_threads: count of threads to use, must be positive. // - n_threads: count of threads to use, must be positive.
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads); RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, uint32_t n_threads);
// Evaluates the model for a single token. // Evaluates the model for a single token.
// Returns false on any error. Error messages would be printed to stderr. // Returns false on any error. Error messages would be printed to stderr.
@ -41,13 +41,13 @@ extern "C" {
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass. // - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to. // - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
// - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. // - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
RWKV_API bool rwkv_eval(struct rwkv_context * ctx, long int token, float * state_in, float * state_out, float * logits_out); RWKV_API bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out);
// Returns count of FP32 elements in state buffer. // Returns count of FP32 elements in state buffer.
RWKV_API size_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx); RWKV_API uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx);
// Returns count of FP32 elements in logits buffer. // Returns count of FP32 elements in logits buffer.
RWKV_API size_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx); RWKV_API uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx);
// Frees all allocated memory and the context. // Frees all allocated memory and the context.
RWKV_API void rwkv_free(struct rwkv_context * ctx); RWKV_API void rwkv_free(struct rwkv_context * ctx);
@ -57,7 +57,7 @@ extern "C" {
// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16. // - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
// - model_file_path_out: quantized model will be written here. // - model_file_path_out: quantized model will be written here.
// - q_type: set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1. // - q_type: set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1.
RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type); RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type);
// Returns system information string. // Returns system information string.
RWKV_API const char * rwkv_get_system_info_string(void); RWKV_API const char * rwkv_get_system_info_string(void);

View File

@ -1,20 +1,19 @@
# Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV. # Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV.
# Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch. # Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch.
# Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py # Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py
# Usage: python compare_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin # Usage: python compare_with_reference_implementation.py C:\rwkv.cpp-169M.bin
import os import os
import struct import struct
import argparse import argparse
import subprocess
import torch import torch
import numpy as np import numpy as np
import rwkv_cpp import rwkv_cpp_model
import rwkv_cpp_shared_library
from typing import List, Tuple, Any from typing import List, Tuple, Any
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV') parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file or shared library')
parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file') parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file')
return parser.parse_args() return parser.parse_args()
@ -22,17 +21,12 @@ def main() -> None:
args = parse_args() args = parse_args()
# Don't want to depend on tokenizer here. # Don't want to depend on tokenizer here.
# Exact string is: tokens: List[int] = [4, 3631, 4420, 2412, 953, 432, 391, 30567, 87, 15, 14161, 7092, 273, 416, 27767, 55, 342,
# context = "1 In the beginning God created the heaven and the earth. " \ 2412, 953, 432, 3806, 7092, 273, 416, 27767, 55, 15, 187, 4, 19039, 2412, 953, 497, 4561,
# "2 And the earth was without form, and void; and darkness was upon the face of the deep. And the Spirit of God moved upon the face of the waters. " \ 342, 416, 27767, 55, 14, 21, 14, 49, 587, 14, 17809, 46, 14, 938, 14256, 28950, 14, 1438,
# "3 And God said, Let there be light: and there was light. " \ 1508, 15, 81, 394, 1566, 275, 8462, 22097, 348, 15, 187, 4, 43825, 27, 15548, 7277, 64,
# "4 And God saw the light, that it was good: and God divided the light from the darkness." 3113, 64, 14005, 64, 39595, 15, 4789, 10269, 61, 18992, 61, 7265, 64, 30217, 39297, 15,
# The Bible was the first non-copyrighted public domain text that came to my mind. 20963, 330, 27, 190, 30567, 87, 15, 14161, 14, 17809, 46, 15, 4805]
tokens: List[int] = [18, 496, 253, 5068, 2656, 3562, 253, 13926, 285, 253, 6149, 15, 374, 1244, 253, 6149, 369, 1293, 830,
13, 285, 2991, 28, 285, 13862, 369, 2220, 253, 2454, 273, 253, 3676, 15, 1244, 253, 14559, 273, 2656,
4395, 2220, 253, 2454, 273, 253, 12685, 15, 495, 1244, 2656, 753, 13, 1281, 627, 320, 1708, 27, 285,
627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272,
253, 1708, 432, 253, 13862, 15]
threshold: float threshold: float
@ -50,7 +44,7 @@ def main() -> None:
threshold = 0.000005 threshold = 0.000005
elif data_type == 1: elif data_type == 1:
# FP16, lower precision, so higher threshold # FP16, lower precision, so higher threshold
threshold = 0.003 threshold = 0.0032
elif data_type == 2: elif data_type == 2:
# INT4 quantized, even lower precision, so even higher threshold # INT4 quantized, even lower precision, so even higher threshold
# This threshold will let some bugs pass # This threshold will let some bugs pass
@ -59,42 +53,24 @@ def main() -> None:
# This format stores more data, so error would be lower # This format stores more data, so error would be lower
threshold = 1.2 threshold = 1.2
model = None model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path)
if args.main_executable_path.lower().endswith('.dll') or args.main_executable_path.lower().endswith('.so'):
print('Testing shared library')
model = rwkv_cpp.RWKVModel(args.main_executable_path, args.ggml_model_path)
else:
print('Testing main_rwkv executable')
def compare_logits(tokens_subset: List[int]) -> None: def compare_logits(tokens_subset: List[int]) -> None:
token_count: int = len(tokens_subset) token_count: int = len(tokens_subset)
state_path: str = './state.bin'
logits_path: str = './logits.bin'
logits, state = None, None logits, state = None, None
for i in range(token_count): for i in range(token_count):
token: int = tokens_subset[i] token: int = tokens_subset[i]
print(f'{i + 1}/{token_count}') if token_count <= 10 or i % (token_count // 10) == 0:
print(f'{i + 1}/{token_count}')
if model is not None: logits, state = model.eval(token, state, state, logits)
logits, state = model.eval(token, state)
else: actual_logits = logits
subprocess.run(
[ # ---
args.main_executable_path,
args.ggml_model_path,
str(token),
# If this is the first token, let the script create a new state.
'' if i == 0 else state_path,
state_path,
logits_path
],
check=True
)
expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin' expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin'
@ -104,11 +80,7 @@ def main() -> None:
with open(expected_logits_path, 'rb') as logits_file: with open(expected_logits_path, 'rb') as logits_file:
expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
if model is not None: # ---
actual_logits = logits
else:
with open(logits_path, 'rb') as logits_file:
actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item() difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item()
@ -118,8 +90,6 @@ def main() -> None:
assert abs(difference) <= threshold, 'Difference is too big' assert abs(difference) <= threshold, 'Difference is too big'
# Check small token amount first to avoid waiting too long before seeing that model is broken
compare_logits(tokens[:4])
compare_logits(tokens) compare_logits(tokens)
print() print()

Binary file not shown.

View File

@ -1,12 +1,11 @@
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1. # Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1.
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3 # Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3
import ctypes
import argparse import argparse
import rwkv_cpp_shared_library
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1') parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1')
parser.add_argument('shared_library_path', help='Path to rwkv.cpp shared library')
parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file') parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten') parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3) parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3)
@ -15,19 +14,14 @@ def parse_args():
def main() -> None: def main() -> None:
args = parse_args() args = parse_args()
library = ctypes.cdll.LoadLibrary(args.shared_library_path) library = rwkv_cpp_shared_library.load_rwkv_shared_library()
library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_int] library.rwkv_quantize_model_file(
library.rwkv_quantize_model_file.restype = ctypes.c_bool args.src_path,
args.dest_path,
result: bool = library.rwkv_quantize_model_file( args.data_type
args.src_path.encode('utf-8'),
args.dest_path.encode('utf-8'),
ctypes.c_int(args.data_type)
) )
assert result, 'Failed to quantize, check stderr'
print('Done') print('Done')
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,127 +0,0 @@
import os
import ctypes
import torch
import multiprocessing
from typing import Tuple, Optional
P_FLOAT = ctypes.POINTER(ctypes.c_float)
class RWKVModel:
"""
PyTorch wrapper around rwkv.cpp shared library.
"""
def __init__(
self,
shared_library_path: str,
model_path: str,
thread_count: int = max(1, multiprocessing.cpu_count() // 2)
):
"""
Loads the model and prepares it for inference.
In case of any error, this method will throw an exception.
Parameters
----------
shared_library_path : str
Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
model_path : str
Path to RWKV model file in ggml format.
thread_count : int
Thread count to use. If not set, defaults to CPU count / 2.
"""
assert os.path.isfile(shared_library_path), f'{shared_library_path} is not a file'
assert os.path.isfile(model_path), f'{model_path} is not a file'
assert thread_count > 0, 'Thread count must be positive'
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_int]
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
self.library.rwkv_eval.argtypes = [
ctypes.c_void_p, # ctx
ctypes.c_long, # token
P_FLOAT, # state_in
P_FLOAT, # state_out
P_FLOAT # logits_out
]
self.library.rwkv_eval.restype = ctypes.c_bool
self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_size_t
self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_size_t
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
self.library.rwkv_free.restype = None
self.ctx = self.library.rwkv_init_from_file(model_path.encode('utf-8'), ctypes.c_int(thread_count))
assert self.ctx is not None, 'Failed to load the model, see stderr'
self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx)
self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx)
self.valid = True
def eval(self, token: int, state_in: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Evaluates the model for a single token.
In case of any error, this method will throw an exception.
Parameters
----------
token : int
Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
state_in : Optional[torch.Tensor]
State from previous call of this method. If this is a first pass, set it to None.
Returns
-------
logits, state
Logits vector of shape (n_vocab); state for the next step.
"""
assert self.valid, 'Model was freed'
if state_in is None:
state_in_ptr = 0
else:
expected_shape = (self.state_buffer_element_count,)
assert state_in.is_contiguous(), 'State tensor is not contiguous'
assert state_in.shape == expected_shape, f'Invalid state shape {state_in.shape}, expected {expected_shape}'
state_in_ptr = state_in.storage().data_ptr()
# TODO Probably these allocations can be optimized away
state_out: torch.Tensor = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu')
logits_out: torch.Tensor = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu')
result = self.library.rwkv_eval(
self.ctx,
ctypes.c_long(token),
ctypes.cast(state_in_ptr, P_FLOAT),
ctypes.cast(state_out.storage().data_ptr(), P_FLOAT),
ctypes.cast(logits_out.storage().data_ptr(), P_FLOAT)
)
assert result, 'Inference failed, see stderr'
return logits_out, state_out
def free(self):
"""
Frees all allocated resources.
In case of any error, this method will throw an exception.
The object must not be used anymore after calling this method.
"""
assert self.valid, 'Already freed'
self.valid = False
self.library.rwkv_free(self.ctx)

117
rwkv/rwkv_cpp_model.py Normal file
View File

@ -0,0 +1,117 @@
import os
import torch
import multiprocessing
import rwkv_cpp_shared_library
from typing import Tuple, Optional
class RWKVModel:
"""
PyTorch wrapper around rwkv.cpp model.
"""
def __init__(
self,
shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary,
model_path: str,
thread_count: int = max(1, multiprocessing.cpu_count() // 2)
):
"""
Loads the model and prepares it for inference.
In case of any error, this method will throw an exception.
Parameters
----------
shared_library : RWKVSharedLibrary
rwkv.cpp shared library.
model_path : str
Path to RWKV model file in ggml format.
thread_count : int
Thread count to use. If not set, defaults to CPU count / 2.
"""
assert os.path.isfile(model_path), f'{model_path} is not a file'
assert thread_count > 0, 'Thread count must be positive'
self.library = shared_library
self.ctx = self.library.rwkv_init_from_file(model_path, thread_count)
self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx)
self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx)
self.valid = True
def eval(
self,
token: int,
state_in: Optional[torch.Tensor],
state_out: Optional[torch.Tensor] = None,
logits_out: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Evaluates the model for a single token.
In case of any error, this method will throw an exception.
Parameters
----------
token : int
Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
state_in : Optional[torch.Tensor]
State from previous call of this method. If this is a first pass, set it to None.
state_out : Optional[torch.Tensor]
Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
logits_out : Optional[torch.Tensor]
Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
Returns
-------
logits, state
Logits vector of shape (n_vocab); state for the next step.
"""
assert self.valid, 'Model was freed'
def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None:
assert buf.dtype == torch.float32, f'{name} is not of type float32'
assert buf.is_contiguous(), f'{name} is not contiguous'
assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})'
if state_in is not None:
validate_buffer(state_in, 'state_in', self.state_buffer_element_count)
state_in_ptr = state_in.storage().data_ptr()
else:
state_in_ptr = 0
if state_out is not None:
validate_buffer(state_out, 'state_out', self.state_buffer_element_count)
else:
state_out = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu')
if logits_out is not None:
validate_buffer(logits_out, 'logits_out', self.logits_buffer_element_count)
else:
logits_out = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu')
self.library.rwkv_eval(
self.ctx,
token,
state_in_ptr,
state_out.storage().data_ptr(),
logits_out.storage().data_ptr()
)
return logits_out, state_out
def free(self):
"""
Frees all allocated resources.
In case of any error, this method will throw an exception.
The object must not be used anymore after calling this method.
"""
assert self.valid, 'Already freed'
self.valid = False
self.library.rwkv_free(self.ctx)

View File

@ -0,0 +1,204 @@
import os
import sys
import ctypes
from typing import Optional
P_FLOAT = ctypes.POINTER(ctypes.c_float)
class RWKVContext:
def __init__(self, ptr: ctypes.pointer):
self.ptr = ptr
class RWKVSharedLibrary:
"""
Python wrapper around rwkv.cpp shared library.
"""
def __init__(self, shared_library_path: str):
"""
Loads the shared library from specified file.
In case of any error, this method will throw an exception.
Parameters
----------
shared_library_path : str
Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
"""
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
self.library.rwkv_eval.argtypes = [
ctypes.c_void_p, # ctx
ctypes.c_int32, # token
P_FLOAT, # state_in
P_FLOAT, # state_out
P_FLOAT # logits_out
]
self.library.rwkv_eval.restype = ctypes.c_bool
self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_uint32
self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_uint32
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
self.library.rwkv_free.restype = None
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
self.library.rwkv_free.restype = None
self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_uint32]
self.library.rwkv_quantize_model_file.restype = ctypes.c_bool
self.library.rwkv_get_system_info_string.argtypes = []
self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p
def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext:
"""
Loads the model from a file and prepares it for inference.
Throws an exception in case of any error. Error messages would be printed to stderr.
Parameters
----------
model_file_path : str
Path to model file in ggml format.
thread_count : int
Count of threads to use, must be positive.
"""
ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count))
assert ptr is not None, 'rwkv_init_from_file failed, check stderr'
return RWKVContext(ptr)
def rwkv_eval(
self,
ctx: RWKVContext,
token: int,
state_in_address: Optional[int],
state_out_address: int,
logits_out_address: int
) -> None:
"""
Evaluates the model for a single token.
Throws an exception in case of any error. Error messages would be printed to stderr.
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
token : int
Next token index, in range 0 <= token < n_vocab.
state_in_address : int
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
state_out_address : int
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
logits_out_address : int
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
"""
assert self.library.rwkv_eval(
ctx.ptr,
ctypes.c_int32(token),
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
ctypes.cast(state_out_address, P_FLOAT),
ctypes.cast(logits_out_address, P_FLOAT)
), 'rwkv_eval failed, check stderr'
def rwkv_get_state_buffer_element_count(self, ctx: RWKVContext) -> int:
"""
Returns count of FP32 elements in state buffer.
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
"""
return self.library.rwkv_get_state_buffer_element_count(ctx.ptr)
def rwkv_get_logits_buffer_element_count(self, ctx: RWKVContext) -> int:
"""
Returns count of FP32 elements in logits buffer.
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
"""
return self.library.rwkv_get_logits_buffer_element_count(ctx.ptr)
def rwkv_free(self, ctx: RWKVContext) -> None:
"""
Frees all allocated memory and the context.
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
"""
self.library.rwkv_free(ctx.ptr)
ctx.ptr = ctypes.cast(0, ctypes.c_void_p)
def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, q_type: int) -> None:
"""
Quantizes FP32 or FP16 model to one of INT4 formats.
Throws an exception in case of any error. Error messages would be printed to stderr.
Parameters
----------
model_file_path_in : str
Path to model file in ggml format, must be either FP32 or FP16.
model_file_path_out : str
Quantized model will be written here.
q_type : int
Set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1.
"""
assert self.library.rwkv_quantize_model_file(
model_file_path_in.encode('utf-8'),
model_file_path_out.encode('utf-8'),
ctypes.c_uint32(q_type)
), 'rwkv_quantize_model_file failed, check stderr'
def rwkv_get_system_info_string(self) -> str:
"""
Returns system information string.
"""
return self.library.rwkv_get_system_info_string()
def load_rwkv_shared_library() -> RWKVSharedLibrary:
"""
Attempts to find rwkv.cpp shared library and load it.
To specify exact path to the library, create an instance of RWKVSharedLibrary explicitly.
"""
file_name: str
if 'win32' in sys.platform or 'cygwin' in sys.platform:
file_name = 'rwkv.dll'
else:
file_name = 'rwkv.so'
paths = [
# If we are in "rwkv" directory
f'../bin/Release/{file_name}',
# If we are in repo root directory
f'bin/Release/{file_name}',
# Fallback
file_name
]
for path in paths:
if os.path.isfile(path):
return RWKVSharedLibrary(path)
return RWKVSharedLibrary(paths[-1])