Move library wrapper to separate file, refactor code
This commit is contained in:
parent
38f9d02d52
commit
935d16f5db
12
README.md
12
README.md
|
@ -47,10 +47,14 @@ python rwkv\convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pt
|
|||
#### 3. Use the model in Python:
|
||||
|
||||
```python
|
||||
# This file is located at rwkv/rwkv_cpp.py
|
||||
import rwkv_cpp
|
||||
# These files are located in rwkv directory
|
||||
import rwkv_cpp_model
|
||||
import rwkv_cpp_shared_library
|
||||
|
||||
model = rwkv_cpp.RWKVModel(r'bin\Release\rwkv.dll', r'C:\rwkv.cpp-169M.bin')
|
||||
model = rwkv_cpp_model.RWKVModel(
|
||||
rwkv_cpp_shared_library.load_rwkv_shared_library(),
|
||||
r'C:\rwkv.cpp-169M.bin'
|
||||
)
|
||||
|
||||
logits, state = None, None
|
||||
|
||||
|
@ -59,7 +63,7 @@ for token in [1, 2, 3]:
|
|||
|
||||
print(f'Output logits: {logits}')
|
||||
|
||||
# Don't forget to free memory after you've done working with the model
|
||||
# Don't forget to free the memory after you've done working with the model
|
||||
model.free()
|
||||
|
||||
```
|
||||
|
|
10
rwkv.cpp
10
rwkv.cpp
|
@ -163,7 +163,7 @@ struct rwkv_context {
|
|||
bool freed;
|
||||
};
|
||||
|
||||
struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads) {
|
||||
struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_threads) {
|
||||
FILE * file = fopen(file_path, "rb");
|
||||
RWKV_ASSERT_NULL(file != NULL, "Failed to open file %s", file_path);
|
||||
|
||||
|
@ -505,15 +505,15 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads)
|
|||
return rwkv_ctx;
|
||||
}
|
||||
|
||||
size_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx) {
|
||||
uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx) {
|
||||
return ctx->model->n_layer * 5 * ctx->model->n_embed;
|
||||
}
|
||||
|
||||
size_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx) {
|
||||
uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx) {
|
||||
return ctx->model->n_vocab;
|
||||
}
|
||||
|
||||
bool rwkv_eval(struct rwkv_context * ctx, long int token, float * state_in, float * state_out, float * logits_out) {
|
||||
bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out) {
|
||||
RWKV_ASSERT_FALSE(state_out != NULL, "state_out is NULL");
|
||||
RWKV_ASSERT_FALSE(logits_out != NULL, "logits_out is NULL");
|
||||
|
||||
|
@ -564,7 +564,7 @@ void rwkv_free(struct rwkv_context * ctx) {
|
|||
delete ctx;
|
||||
}
|
||||
|
||||
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type) {
|
||||
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) {
|
||||
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type);
|
||||
|
||||
ggml_type type;
|
||||
|
|
10
rwkv.h
10
rwkv.h
|
@ -33,7 +33,7 @@ extern "C" {
|
|||
// Returns NULL on any error. Error messages would be printed to stderr.
|
||||
// - model_file_path: path to model file in ggml format.
|
||||
// - n_threads: count of threads to use, must be positive.
|
||||
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads);
|
||||
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, uint32_t n_threads);
|
||||
|
||||
// Evaluates the model for a single token.
|
||||
// Returns false on any error. Error messages would be printed to stderr.
|
||||
|
@ -41,13 +41,13 @@ extern "C" {
|
|||
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
|
||||
// - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
||||
// - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
||||
RWKV_API bool rwkv_eval(struct rwkv_context * ctx, long int token, float * state_in, float * state_out, float * logits_out);
|
||||
RWKV_API bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out);
|
||||
|
||||
// Returns count of FP32 elements in state buffer.
|
||||
RWKV_API size_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx);
|
||||
RWKV_API uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx);
|
||||
|
||||
// Returns count of FP32 elements in logits buffer.
|
||||
RWKV_API size_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx);
|
||||
RWKV_API uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx);
|
||||
|
||||
// Frees all allocated memory and the context.
|
||||
RWKV_API void rwkv_free(struct rwkv_context * ctx);
|
||||
|
@ -57,7 +57,7 @@ extern "C" {
|
|||
// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
|
||||
// - model_file_path_out: quantized model will be written here.
|
||||
// - q_type: set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1.
|
||||
RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type);
|
||||
RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type);
|
||||
|
||||
// Returns system information string.
|
||||
RWKV_API const char * rwkv_get_system_info_string(void);
|
||||
|
|
|
@ -1,20 +1,19 @@
|
|||
# Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV.
|
||||
# Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch.
|
||||
# Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py
|
||||
# Usage: python compare_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin
|
||||
# Usage: python compare_with_reference_implementation.py C:\rwkv.cpp-169M.bin
|
||||
|
||||
import os
|
||||
import struct
|
||||
import argparse
|
||||
import subprocess
|
||||
import torch
|
||||
import numpy as np
|
||||
import rwkv_cpp
|
||||
import rwkv_cpp_model
|
||||
import rwkv_cpp_shared_library
|
||||
from typing import List, Tuple, Any
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
|
||||
parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file or shared library')
|
||||
parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file')
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -22,17 +21,12 @@ def main() -> None:
|
|||
args = parse_args()
|
||||
|
||||
# Don't want to depend on tokenizer here.
|
||||
# Exact string is:
|
||||
# context = "1 In the beginning God created the heaven and the earth. " \
|
||||
# "2 And the earth was without form, and void; and darkness was upon the face of the deep. And the Spirit of God moved upon the face of the waters. " \
|
||||
# "3 And God said, Let there be light: and there was light. " \
|
||||
# "4 And God saw the light, that it was good: and God divided the light from the darkness."
|
||||
# The Bible was the first non-copyrighted public domain text that came to my mind.
|
||||
tokens: List[int] = [18, 496, 253, 5068, 2656, 3562, 253, 13926, 285, 253, 6149, 15, 374, 1244, 253, 6149, 369, 1293, 830,
|
||||
13, 285, 2991, 28, 285, 13862, 369, 2220, 253, 2454, 273, 253, 3676, 15, 1244, 253, 14559, 273, 2656,
|
||||
4395, 2220, 253, 2454, 273, 253, 12685, 15, 495, 1244, 2656, 753, 13, 1281, 627, 320, 1708, 27, 285,
|
||||
627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272,
|
||||
253, 1708, 432, 253, 13862, 15]
|
||||
tokens: List[int] = [4, 3631, 4420, 2412, 953, 432, 391, 30567, 87, 15, 14161, 7092, 273, 416, 27767, 55, 342,
|
||||
2412, 953, 432, 3806, 7092, 273, 416, 27767, 55, 15, 187, 4, 19039, 2412, 953, 497, 4561,
|
||||
342, 416, 27767, 55, 14, 21, 14, 49, 587, 14, 17809, 46, 14, 938, 14256, 28950, 14, 1438,
|
||||
1508, 15, 81, 394, 1566, 275, 8462, 22097, 348, 15, 187, 4, 43825, 27, 15548, 7277, 64,
|
||||
3113, 64, 14005, 64, 39595, 15, 4789, 10269, 61, 18992, 61, 7265, 64, 30217, 39297, 15,
|
||||
20963, 330, 27, 190, 30567, 87, 15, 14161, 14, 17809, 46, 15, 4805]
|
||||
|
||||
threshold: float
|
||||
|
||||
|
@ -50,7 +44,7 @@ def main() -> None:
|
|||
threshold = 0.000005
|
||||
elif data_type == 1:
|
||||
# FP16, lower precision, so higher threshold
|
||||
threshold = 0.003
|
||||
threshold = 0.0032
|
||||
elif data_type == 2:
|
||||
# INT4 quantized, even lower precision, so even higher threshold
|
||||
# This threshold will let some bugs pass
|
||||
|
@ -59,42 +53,24 @@ def main() -> None:
|
|||
# This format stores more data, so error would be lower
|
||||
threshold = 1.2
|
||||
|
||||
model = None
|
||||
|
||||
if args.main_executable_path.lower().endswith('.dll') or args.main_executable_path.lower().endswith('.so'):
|
||||
print('Testing shared library')
|
||||
|
||||
model = rwkv_cpp.RWKVModel(args.main_executable_path, args.ggml_model_path)
|
||||
else:
|
||||
print('Testing main_rwkv executable')
|
||||
model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path)
|
||||
|
||||
def compare_logits(tokens_subset: List[int]) -> None:
|
||||
token_count: int = len(tokens_subset)
|
||||
state_path: str = './state.bin'
|
||||
logits_path: str = './logits.bin'
|
||||
|
||||
logits, state = None, None
|
||||
|
||||
for i in range(token_count):
|
||||
token: int = tokens_subset[i]
|
||||
|
||||
print(f'{i + 1}/{token_count}')
|
||||
if token_count <= 10 or i % (token_count // 10) == 0:
|
||||
print(f'{i + 1}/{token_count}')
|
||||
|
||||
if model is not None:
|
||||
logits, state = model.eval(token, state)
|
||||
else:
|
||||
subprocess.run(
|
||||
[
|
||||
args.main_executable_path,
|
||||
args.ggml_model_path,
|
||||
str(token),
|
||||
# If this is the first token, let the script create a new state.
|
||||
'' if i == 0 else state_path,
|
||||
state_path,
|
||||
logits_path
|
||||
],
|
||||
check=True
|
||||
)
|
||||
logits, state = model.eval(token, state, state, logits)
|
||||
|
||||
actual_logits = logits
|
||||
|
||||
# ---
|
||||
|
||||
expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin'
|
||||
|
||||
|
@ -104,11 +80,7 @@ def main() -> None:
|
|||
with open(expected_logits_path, 'rb') as logits_file:
|
||||
expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
||||
|
||||
if model is not None:
|
||||
actual_logits = logits
|
||||
else:
|
||||
with open(logits_path, 'rb') as logits_file:
|
||||
actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
||||
# ---
|
||||
|
||||
difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item()
|
||||
|
||||
|
@ -118,8 +90,6 @@ def main() -> None:
|
|||
|
||||
assert abs(difference) <= threshold, 'Difference is too big'
|
||||
|
||||
# Check small token amount first to avoid waiting too long before seeing that model is broken
|
||||
compare_logits(tokens[:4])
|
||||
compare_logits(tokens)
|
||||
|
||||
print()
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,12 +1,11 @@
|
|||
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1.
|
||||
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3
|
||||
|
||||
import ctypes
|
||||
import argparse
|
||||
import rwkv_cpp_shared_library
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1')
|
||||
parser.add_argument('shared_library_path', help='Path to rwkv.cpp shared library')
|
||||
parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
|
||||
parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
|
||||
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3)
|
||||
|
@ -15,19 +14,14 @@ def parse_args():
|
|||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
library = ctypes.cdll.LoadLibrary(args.shared_library_path)
|
||||
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||
|
||||
library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_int]
|
||||
library.rwkv_quantize_model_file.restype = ctypes.c_bool
|
||||
|
||||
result: bool = library.rwkv_quantize_model_file(
|
||||
args.src_path.encode('utf-8'),
|
||||
args.dest_path.encode('utf-8'),
|
||||
ctypes.c_int(args.data_type)
|
||||
library.rwkv_quantize_model_file(
|
||||
args.src_path,
|
||||
args.dest_path,
|
||||
args.data_type
|
||||
)
|
||||
|
||||
assert result, 'Failed to quantize, check stderr'
|
||||
|
||||
print('Done')
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
127
rwkv/rwkv_cpp.py
127
rwkv/rwkv_cpp.py
|
@ -1,127 +0,0 @@
|
|||
import os
|
||||
import ctypes
|
||||
import torch
|
||||
import multiprocessing
|
||||
from typing import Tuple, Optional
|
||||
|
||||
P_FLOAT = ctypes.POINTER(ctypes.c_float)
|
||||
|
||||
class RWKVModel:
|
||||
"""
|
||||
PyTorch wrapper around rwkv.cpp shared library.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shared_library_path: str,
|
||||
model_path: str,
|
||||
thread_count: int = max(1, multiprocessing.cpu_count() // 2)
|
||||
):
|
||||
"""
|
||||
Loads the model and prepares it for inference.
|
||||
In case of any error, this method will throw an exception.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shared_library_path : str
|
||||
Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
|
||||
model_path : str
|
||||
Path to RWKV model file in ggml format.
|
||||
thread_count : int
|
||||
Thread count to use. If not set, defaults to CPU count / 2.
|
||||
"""
|
||||
|
||||
assert os.path.isfile(shared_library_path), f'{shared_library_path} is not a file'
|
||||
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
||||
assert thread_count > 0, 'Thread count must be positive'
|
||||
|
||||
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
|
||||
|
||||
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_int]
|
||||
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
|
||||
|
||||
self.library.rwkv_eval.argtypes = [
|
||||
ctypes.c_void_p, # ctx
|
||||
ctypes.c_long, # token
|
||||
P_FLOAT, # state_in
|
||||
P_FLOAT, # state_out
|
||||
P_FLOAT # logits_out
|
||||
]
|
||||
self.library.rwkv_eval.restype = ctypes.c_bool
|
||||
|
||||
self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_size_t
|
||||
|
||||
self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_size_t
|
||||
|
||||
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_free.restype = None
|
||||
|
||||
self.ctx = self.library.rwkv_init_from_file(model_path.encode('utf-8'), ctypes.c_int(thread_count))
|
||||
|
||||
assert self.ctx is not None, 'Failed to load the model, see stderr'
|
||||
|
||||
self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx)
|
||||
self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx)
|
||||
|
||||
self.valid = True
|
||||
|
||||
def eval(self, token: int, state_in: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Evaluates the model for a single token.
|
||||
In case of any error, this method will throw an exception.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
token : int
|
||||
Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
|
||||
state_in : Optional[torch.Tensor]
|
||||
State from previous call of this method. If this is a first pass, set it to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
logits, state
|
||||
Logits vector of shape (n_vocab); state for the next step.
|
||||
"""
|
||||
|
||||
assert self.valid, 'Model was freed'
|
||||
|
||||
if state_in is None:
|
||||
state_in_ptr = 0
|
||||
else:
|
||||
expected_shape = (self.state_buffer_element_count,)
|
||||
|
||||
assert state_in.is_contiguous(), 'State tensor is not contiguous'
|
||||
assert state_in.shape == expected_shape, f'Invalid state shape {state_in.shape}, expected {expected_shape}'
|
||||
|
||||
state_in_ptr = state_in.storage().data_ptr()
|
||||
|
||||
# TODO Probably these allocations can be optimized away
|
||||
state_out: torch.Tensor = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||
logits_out: torch.Tensor = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||
|
||||
result = self.library.rwkv_eval(
|
||||
self.ctx,
|
||||
ctypes.c_long(token),
|
||||
ctypes.cast(state_in_ptr, P_FLOAT),
|
||||
ctypes.cast(state_out.storage().data_ptr(), P_FLOAT),
|
||||
ctypes.cast(logits_out.storage().data_ptr(), P_FLOAT)
|
||||
)
|
||||
|
||||
assert result, 'Inference failed, see stderr'
|
||||
|
||||
return logits_out, state_out
|
||||
|
||||
def free(self):
|
||||
"""
|
||||
Frees all allocated resources.
|
||||
In case of any error, this method will throw an exception.
|
||||
The object must not be used anymore after calling this method.
|
||||
"""
|
||||
|
||||
assert self.valid, 'Already freed'
|
||||
|
||||
self.valid = False
|
||||
|
||||
self.library.rwkv_free(self.ctx)
|
|
@ -0,0 +1,117 @@
|
|||
import os
|
||||
import torch
|
||||
import multiprocessing
|
||||
import rwkv_cpp_shared_library
|
||||
from typing import Tuple, Optional
|
||||
|
||||
class RWKVModel:
|
||||
"""
|
||||
PyTorch wrapper around rwkv.cpp model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary,
|
||||
model_path: str,
|
||||
thread_count: int = max(1, multiprocessing.cpu_count() // 2)
|
||||
):
|
||||
"""
|
||||
Loads the model and prepares it for inference.
|
||||
In case of any error, this method will throw an exception.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shared_library : RWKVSharedLibrary
|
||||
rwkv.cpp shared library.
|
||||
model_path : str
|
||||
Path to RWKV model file in ggml format.
|
||||
thread_count : int
|
||||
Thread count to use. If not set, defaults to CPU count / 2.
|
||||
"""
|
||||
|
||||
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
||||
assert thread_count > 0, 'Thread count must be positive'
|
||||
|
||||
self.library = shared_library
|
||||
|
||||
self.ctx = self.library.rwkv_init_from_file(model_path, thread_count)
|
||||
|
||||
self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx)
|
||||
self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx)
|
||||
|
||||
self.valid = True
|
||||
|
||||
def eval(
|
||||
self,
|
||||
token: int,
|
||||
state_in: Optional[torch.Tensor],
|
||||
state_out: Optional[torch.Tensor] = None,
|
||||
logits_out: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Evaluates the model for a single token.
|
||||
In case of any error, this method will throw an exception.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
token : int
|
||||
Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
|
||||
state_in : Optional[torch.Tensor]
|
||||
State from previous call of this method. If this is a first pass, set it to None.
|
||||
state_out : Optional[torch.Tensor]
|
||||
Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
|
||||
logits_out : Optional[torch.Tensor]
|
||||
Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
|
||||
|
||||
Returns
|
||||
-------
|
||||
logits, state
|
||||
Logits vector of shape (n_vocab); state for the next step.
|
||||
"""
|
||||
|
||||
assert self.valid, 'Model was freed'
|
||||
|
||||
def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None:
|
||||
assert buf.dtype == torch.float32, f'{name} is not of type float32'
|
||||
assert buf.is_contiguous(), f'{name} is not contiguous'
|
||||
assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})'
|
||||
|
||||
if state_in is not None:
|
||||
validate_buffer(state_in, 'state_in', self.state_buffer_element_count)
|
||||
|
||||
state_in_ptr = state_in.storage().data_ptr()
|
||||
else:
|
||||
state_in_ptr = 0
|
||||
|
||||
if state_out is not None:
|
||||
validate_buffer(state_out, 'state_out', self.state_buffer_element_count)
|
||||
else:
|
||||
state_out = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||
|
||||
if logits_out is not None:
|
||||
validate_buffer(logits_out, 'logits_out', self.logits_buffer_element_count)
|
||||
else:
|
||||
logits_out = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||
|
||||
self.library.rwkv_eval(
|
||||
self.ctx,
|
||||
token,
|
||||
state_in_ptr,
|
||||
state_out.storage().data_ptr(),
|
||||
logits_out.storage().data_ptr()
|
||||
)
|
||||
|
||||
return logits_out, state_out
|
||||
|
||||
def free(self):
|
||||
"""
|
||||
Frees all allocated resources.
|
||||
In case of any error, this method will throw an exception.
|
||||
The object must not be used anymore after calling this method.
|
||||
"""
|
||||
|
||||
assert self.valid, 'Already freed'
|
||||
|
||||
self.valid = False
|
||||
|
||||
self.library.rwkv_free(self.ctx)
|
|
@ -0,0 +1,204 @@
|
|||
import os
|
||||
import sys
|
||||
import ctypes
|
||||
from typing import Optional
|
||||
|
||||
P_FLOAT = ctypes.POINTER(ctypes.c_float)
|
||||
|
||||
class RWKVContext:
|
||||
|
||||
def __init__(self, ptr: ctypes.pointer):
|
||||
self.ptr = ptr
|
||||
|
||||
class RWKVSharedLibrary:
|
||||
"""
|
||||
Python wrapper around rwkv.cpp shared library.
|
||||
"""
|
||||
|
||||
def __init__(self, shared_library_path: str):
|
||||
"""
|
||||
Loads the shared library from specified file.
|
||||
In case of any error, this method will throw an exception.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shared_library_path : str
|
||||
Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
|
||||
"""
|
||||
|
||||
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
|
||||
|
||||
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
|
||||
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
|
||||
|
||||
self.library.rwkv_eval.argtypes = [
|
||||
ctypes.c_void_p, # ctx
|
||||
ctypes.c_int32, # token
|
||||
P_FLOAT, # state_in
|
||||
P_FLOAT, # state_out
|
||||
P_FLOAT # logits_out
|
||||
]
|
||||
self.library.rwkv_eval.restype = ctypes.c_bool
|
||||
|
||||
self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_uint32
|
||||
|
||||
self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_uint32
|
||||
|
||||
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_free.restype = None
|
||||
|
||||
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_free.restype = None
|
||||
|
||||
self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_uint32]
|
||||
self.library.rwkv_quantize_model_file.restype = ctypes.c_bool
|
||||
|
||||
self.library.rwkv_get_system_info_string.argtypes = []
|
||||
self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p
|
||||
|
||||
def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext:
|
||||
"""
|
||||
Loads the model from a file and prepares it for inference.
|
||||
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_file_path : str
|
||||
Path to model file in ggml format.
|
||||
thread_count : int
|
||||
Count of threads to use, must be positive.
|
||||
"""
|
||||
|
||||
ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count))
|
||||
assert ptr is not None, 'rwkv_init_from_file failed, check stderr'
|
||||
return RWKVContext(ptr)
|
||||
|
||||
def rwkv_eval(
|
||||
self,
|
||||
ctx: RWKVContext,
|
||||
token: int,
|
||||
state_in_address: Optional[int],
|
||||
state_out_address: int,
|
||||
logits_out_address: int
|
||||
) -> None:
|
||||
"""
|
||||
Evaluates the model for a single token.
|
||||
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
token : int
|
||||
Next token index, in range 0 <= token < n_vocab.
|
||||
state_in_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
|
||||
state_out_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
||||
logits_out_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
||||
"""
|
||||
|
||||
assert self.library.rwkv_eval(
|
||||
ctx.ptr,
|
||||
ctypes.c_int32(token),
|
||||
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
||||
ctypes.cast(state_out_address, P_FLOAT),
|
||||
ctypes.cast(logits_out_address, P_FLOAT)
|
||||
), 'rwkv_eval failed, check stderr'
|
||||
|
||||
def rwkv_get_state_buffer_element_count(self, ctx: RWKVContext) -> int:
|
||||
"""
|
||||
Returns count of FP32 elements in state buffer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
"""
|
||||
|
||||
return self.library.rwkv_get_state_buffer_element_count(ctx.ptr)
|
||||
|
||||
def rwkv_get_logits_buffer_element_count(self, ctx: RWKVContext) -> int:
|
||||
"""
|
||||
Returns count of FP32 elements in logits buffer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
"""
|
||||
|
||||
return self.library.rwkv_get_logits_buffer_element_count(ctx.ptr)
|
||||
|
||||
def rwkv_free(self, ctx: RWKVContext) -> None:
|
||||
"""
|
||||
Frees all allocated memory and the context.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
"""
|
||||
|
||||
self.library.rwkv_free(ctx.ptr)
|
||||
|
||||
ctx.ptr = ctypes.cast(0, ctypes.c_void_p)
|
||||
|
||||
def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, q_type: int) -> None:
|
||||
"""
|
||||
Quantizes FP32 or FP16 model to one of INT4 formats.
|
||||
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_file_path_in : str
|
||||
Path to model file in ggml format, must be either FP32 or FP16.
|
||||
model_file_path_out : str
|
||||
Quantized model will be written here.
|
||||
q_type : int
|
||||
Set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1.
|
||||
"""
|
||||
|
||||
assert self.library.rwkv_quantize_model_file(
|
||||
model_file_path_in.encode('utf-8'),
|
||||
model_file_path_out.encode('utf-8'),
|
||||
ctypes.c_uint32(q_type)
|
||||
), 'rwkv_quantize_model_file failed, check stderr'
|
||||
|
||||
def rwkv_get_system_info_string(self) -> str:
|
||||
"""
|
||||
Returns system information string.
|
||||
"""
|
||||
|
||||
return self.library.rwkv_get_system_info_string()
|
||||
|
||||
def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
||||
"""
|
||||
Attempts to find rwkv.cpp shared library and load it.
|
||||
To specify exact path to the library, create an instance of RWKVSharedLibrary explicitly.
|
||||
"""
|
||||
|
||||
file_name: str
|
||||
|
||||
if 'win32' in sys.platform or 'cygwin' in sys.platform:
|
||||
file_name = 'rwkv.dll'
|
||||
else:
|
||||
file_name = 'rwkv.so'
|
||||
|
||||
paths = [
|
||||
# If we are in "rwkv" directory
|
||||
f'../bin/Release/{file_name}',
|
||||
# If we are in repo root directory
|
||||
f'bin/Release/{file_name}',
|
||||
# Fallback
|
||||
file_name
|
||||
]
|
||||
|
||||
for path in paths:
|
||||
if os.path.isfile(path):
|
||||
return RWKVSharedLibrary(path)
|
||||
|
||||
return RWKVSharedLibrary(paths[-1])
|
Loading…
Reference in New Issue