Move library wrapper to separate file, refactor code
This commit is contained in:
parent
38f9d02d52
commit
935d16f5db
12
README.md
12
README.md
|
@ -47,10 +47,14 @@ python rwkv\convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pt
|
||||||
#### 3. Use the model in Python:
|
#### 3. Use the model in Python:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# This file is located at rwkv/rwkv_cpp.py
|
# These files are located in rwkv directory
|
||||||
import rwkv_cpp
|
import rwkv_cpp_model
|
||||||
|
import rwkv_cpp_shared_library
|
||||||
|
|
||||||
model = rwkv_cpp.RWKVModel(r'bin\Release\rwkv.dll', r'C:\rwkv.cpp-169M.bin')
|
model = rwkv_cpp_model.RWKVModel(
|
||||||
|
rwkv_cpp_shared_library.load_rwkv_shared_library(),
|
||||||
|
r'C:\rwkv.cpp-169M.bin'
|
||||||
|
)
|
||||||
|
|
||||||
logits, state = None, None
|
logits, state = None, None
|
||||||
|
|
||||||
|
@ -59,7 +63,7 @@ for token in [1, 2, 3]:
|
||||||
|
|
||||||
print(f'Output logits: {logits}')
|
print(f'Output logits: {logits}')
|
||||||
|
|
||||||
# Don't forget to free memory after you've done working with the model
|
# Don't forget to free the memory after you've done working with the model
|
||||||
model.free()
|
model.free()
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
10
rwkv.cpp
10
rwkv.cpp
|
@ -163,7 +163,7 @@ struct rwkv_context {
|
||||||
bool freed;
|
bool freed;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads) {
|
struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_threads) {
|
||||||
FILE * file = fopen(file_path, "rb");
|
FILE * file = fopen(file_path, "rb");
|
||||||
RWKV_ASSERT_NULL(file != NULL, "Failed to open file %s", file_path);
|
RWKV_ASSERT_NULL(file != NULL, "Failed to open file %s", file_path);
|
||||||
|
|
||||||
|
@ -505,15 +505,15 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads)
|
||||||
return rwkv_ctx;
|
return rwkv_ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx) {
|
uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx) {
|
||||||
return ctx->model->n_layer * 5 * ctx->model->n_embed;
|
return ctx->model->n_layer * 5 * ctx->model->n_embed;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx) {
|
uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx) {
|
||||||
return ctx->model->n_vocab;
|
return ctx->model->n_vocab;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool rwkv_eval(struct rwkv_context * ctx, long int token, float * state_in, float * state_out, float * logits_out) {
|
bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out) {
|
||||||
RWKV_ASSERT_FALSE(state_out != NULL, "state_out is NULL");
|
RWKV_ASSERT_FALSE(state_out != NULL, "state_out is NULL");
|
||||||
RWKV_ASSERT_FALSE(logits_out != NULL, "logits_out is NULL");
|
RWKV_ASSERT_FALSE(logits_out != NULL, "logits_out is NULL");
|
||||||
|
|
||||||
|
@ -564,7 +564,7 @@ void rwkv_free(struct rwkv_context * ctx) {
|
||||||
delete ctx;
|
delete ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type) {
|
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) {
|
||||||
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type);
|
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type);
|
||||||
|
|
||||||
ggml_type type;
|
ggml_type type;
|
||||||
|
|
10
rwkv.h
10
rwkv.h
|
@ -33,7 +33,7 @@ extern "C" {
|
||||||
// Returns NULL on any error. Error messages would be printed to stderr.
|
// Returns NULL on any error. Error messages would be printed to stderr.
|
||||||
// - model_file_path: path to model file in ggml format.
|
// - model_file_path: path to model file in ggml format.
|
||||||
// - n_threads: count of threads to use, must be positive.
|
// - n_threads: count of threads to use, must be positive.
|
||||||
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads);
|
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, uint32_t n_threads);
|
||||||
|
|
||||||
// Evaluates the model for a single token.
|
// Evaluates the model for a single token.
|
||||||
// Returns false on any error. Error messages would be printed to stderr.
|
// Returns false on any error. Error messages would be printed to stderr.
|
||||||
|
@ -41,13 +41,13 @@ extern "C" {
|
||||||
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
|
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
|
||||||
// - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
// - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
||||||
// - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
// - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
||||||
RWKV_API bool rwkv_eval(struct rwkv_context * ctx, long int token, float * state_in, float * state_out, float * logits_out);
|
RWKV_API bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out);
|
||||||
|
|
||||||
// Returns count of FP32 elements in state buffer.
|
// Returns count of FP32 elements in state buffer.
|
||||||
RWKV_API size_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx);
|
RWKV_API uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx);
|
||||||
|
|
||||||
// Returns count of FP32 elements in logits buffer.
|
// Returns count of FP32 elements in logits buffer.
|
||||||
RWKV_API size_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx);
|
RWKV_API uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx);
|
||||||
|
|
||||||
// Frees all allocated memory and the context.
|
// Frees all allocated memory and the context.
|
||||||
RWKV_API void rwkv_free(struct rwkv_context * ctx);
|
RWKV_API void rwkv_free(struct rwkv_context * ctx);
|
||||||
|
@ -57,7 +57,7 @@ extern "C" {
|
||||||
// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
|
// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
|
||||||
// - model_file_path_out: quantized model will be written here.
|
// - model_file_path_out: quantized model will be written here.
|
||||||
// - q_type: set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1.
|
// - q_type: set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1.
|
||||||
RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type);
|
RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type);
|
||||||
|
|
||||||
// Returns system information string.
|
// Returns system information string.
|
||||||
RWKV_API const char * rwkv_get_system_info_string(void);
|
RWKV_API const char * rwkv_get_system_info_string(void);
|
||||||
|
|
|
@ -1,20 +1,19 @@
|
||||||
# Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV.
|
# Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV.
|
||||||
# Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch.
|
# Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch.
|
||||||
# Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py
|
# Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py
|
||||||
# Usage: python compare_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin
|
# Usage: python compare_with_reference_implementation.py C:\rwkv.cpp-169M.bin
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
import argparse
|
import argparse
|
||||||
import subprocess
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import rwkv_cpp
|
import rwkv_cpp_model
|
||||||
|
import rwkv_cpp_shared_library
|
||||||
from typing import List, Tuple, Any
|
from typing import List, Tuple, Any
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
|
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
|
||||||
parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file or shared library')
|
|
||||||
parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file')
|
parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@ -22,17 +21,12 @@ def main() -> None:
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# Don't want to depend on tokenizer here.
|
# Don't want to depend on tokenizer here.
|
||||||
# Exact string is:
|
tokens: List[int] = [4, 3631, 4420, 2412, 953, 432, 391, 30567, 87, 15, 14161, 7092, 273, 416, 27767, 55, 342,
|
||||||
# context = "1 In the beginning God created the heaven and the earth. " \
|
2412, 953, 432, 3806, 7092, 273, 416, 27767, 55, 15, 187, 4, 19039, 2412, 953, 497, 4561,
|
||||||
# "2 And the earth was without form, and void; and darkness was upon the face of the deep. And the Spirit of God moved upon the face of the waters. " \
|
342, 416, 27767, 55, 14, 21, 14, 49, 587, 14, 17809, 46, 14, 938, 14256, 28950, 14, 1438,
|
||||||
# "3 And God said, Let there be light: and there was light. " \
|
1508, 15, 81, 394, 1566, 275, 8462, 22097, 348, 15, 187, 4, 43825, 27, 15548, 7277, 64,
|
||||||
# "4 And God saw the light, that it was good: and God divided the light from the darkness."
|
3113, 64, 14005, 64, 39595, 15, 4789, 10269, 61, 18992, 61, 7265, 64, 30217, 39297, 15,
|
||||||
# The Bible was the first non-copyrighted public domain text that came to my mind.
|
20963, 330, 27, 190, 30567, 87, 15, 14161, 14, 17809, 46, 15, 4805]
|
||||||
tokens: List[int] = [18, 496, 253, 5068, 2656, 3562, 253, 13926, 285, 253, 6149, 15, 374, 1244, 253, 6149, 369, 1293, 830,
|
|
||||||
13, 285, 2991, 28, 285, 13862, 369, 2220, 253, 2454, 273, 253, 3676, 15, 1244, 253, 14559, 273, 2656,
|
|
||||||
4395, 2220, 253, 2454, 273, 253, 12685, 15, 495, 1244, 2656, 753, 13, 1281, 627, 320, 1708, 27, 285,
|
|
||||||
627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272,
|
|
||||||
253, 1708, 432, 253, 13862, 15]
|
|
||||||
|
|
||||||
threshold: float
|
threshold: float
|
||||||
|
|
||||||
|
@ -50,7 +44,7 @@ def main() -> None:
|
||||||
threshold = 0.000005
|
threshold = 0.000005
|
||||||
elif data_type == 1:
|
elif data_type == 1:
|
||||||
# FP16, lower precision, so higher threshold
|
# FP16, lower precision, so higher threshold
|
||||||
threshold = 0.003
|
threshold = 0.0032
|
||||||
elif data_type == 2:
|
elif data_type == 2:
|
||||||
# INT4 quantized, even lower precision, so even higher threshold
|
# INT4 quantized, even lower precision, so even higher threshold
|
||||||
# This threshold will let some bugs pass
|
# This threshold will let some bugs pass
|
||||||
|
@ -59,42 +53,24 @@ def main() -> None:
|
||||||
# This format stores more data, so error would be lower
|
# This format stores more data, so error would be lower
|
||||||
threshold = 1.2
|
threshold = 1.2
|
||||||
|
|
||||||
model = None
|
model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path)
|
||||||
|
|
||||||
if args.main_executable_path.lower().endswith('.dll') or args.main_executable_path.lower().endswith('.so'):
|
|
||||||
print('Testing shared library')
|
|
||||||
|
|
||||||
model = rwkv_cpp.RWKVModel(args.main_executable_path, args.ggml_model_path)
|
|
||||||
else:
|
|
||||||
print('Testing main_rwkv executable')
|
|
||||||
|
|
||||||
def compare_logits(tokens_subset: List[int]) -> None:
|
def compare_logits(tokens_subset: List[int]) -> None:
|
||||||
token_count: int = len(tokens_subset)
|
token_count: int = len(tokens_subset)
|
||||||
state_path: str = './state.bin'
|
|
||||||
logits_path: str = './logits.bin'
|
|
||||||
|
|
||||||
logits, state = None, None
|
logits, state = None, None
|
||||||
|
|
||||||
for i in range(token_count):
|
for i in range(token_count):
|
||||||
token: int = tokens_subset[i]
|
token: int = tokens_subset[i]
|
||||||
|
|
||||||
|
if token_count <= 10 or i % (token_count // 10) == 0:
|
||||||
print(f'{i + 1}/{token_count}')
|
print(f'{i + 1}/{token_count}')
|
||||||
|
|
||||||
if model is not None:
|
logits, state = model.eval(token, state, state, logits)
|
||||||
logits, state = model.eval(token, state)
|
|
||||||
else:
|
actual_logits = logits
|
||||||
subprocess.run(
|
|
||||||
[
|
# ---
|
||||||
args.main_executable_path,
|
|
||||||
args.ggml_model_path,
|
|
||||||
str(token),
|
|
||||||
# If this is the first token, let the script create a new state.
|
|
||||||
'' if i == 0 else state_path,
|
|
||||||
state_path,
|
|
||||||
logits_path
|
|
||||||
],
|
|
||||||
check=True
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin'
|
expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin'
|
||||||
|
|
||||||
|
@ -104,11 +80,7 @@ def main() -> None:
|
||||||
with open(expected_logits_path, 'rb') as logits_file:
|
with open(expected_logits_path, 'rb') as logits_file:
|
||||||
expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
||||||
|
|
||||||
if model is not None:
|
# ---
|
||||||
actual_logits = logits
|
|
||||||
else:
|
|
||||||
with open(logits_path, 'rb') as logits_file:
|
|
||||||
actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
|
||||||
|
|
||||||
difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item()
|
difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item()
|
||||||
|
|
||||||
|
@ -118,8 +90,6 @@ def main() -> None:
|
||||||
|
|
||||||
assert abs(difference) <= threshold, 'Difference is too big'
|
assert abs(difference) <= threshold, 'Difference is too big'
|
||||||
|
|
||||||
# Check small token amount first to avoid waiting too long before seeing that model is broken
|
|
||||||
compare_logits(tokens[:4])
|
|
||||||
compare_logits(tokens)
|
compare_logits(tokens)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,12 +1,11 @@
|
||||||
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1.
|
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1.
|
||||||
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3
|
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3
|
||||||
|
|
||||||
import ctypes
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import rwkv_cpp_shared_library
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1')
|
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1')
|
||||||
parser.add_argument('shared_library_path', help='Path to rwkv.cpp shared library')
|
|
||||||
parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
|
parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
|
||||||
parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
|
parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
|
||||||
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3)
|
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3)
|
||||||
|
@ -15,19 +14,14 @@ def parse_args():
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
library = ctypes.cdll.LoadLibrary(args.shared_library_path)
|
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||||
|
|
||||||
library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_int]
|
library.rwkv_quantize_model_file(
|
||||||
library.rwkv_quantize_model_file.restype = ctypes.c_bool
|
args.src_path,
|
||||||
|
args.dest_path,
|
||||||
result: bool = library.rwkv_quantize_model_file(
|
args.data_type
|
||||||
args.src_path.encode('utf-8'),
|
|
||||||
args.dest_path.encode('utf-8'),
|
|
||||||
ctypes.c_int(args.data_type)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result, 'Failed to quantize, check stderr'
|
|
||||||
|
|
||||||
print('Done')
|
print('Done')
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
127
rwkv/rwkv_cpp.py
127
rwkv/rwkv_cpp.py
|
@ -1,127 +0,0 @@
|
||||||
import os
|
|
||||||
import ctypes
|
|
||||||
import torch
|
|
||||||
import multiprocessing
|
|
||||||
from typing import Tuple, Optional
|
|
||||||
|
|
||||||
P_FLOAT = ctypes.POINTER(ctypes.c_float)
|
|
||||||
|
|
||||||
class RWKVModel:
|
|
||||||
"""
|
|
||||||
PyTorch wrapper around rwkv.cpp shared library.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
shared_library_path: str,
|
|
||||||
model_path: str,
|
|
||||||
thread_count: int = max(1, multiprocessing.cpu_count() // 2)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Loads the model and prepares it for inference.
|
|
||||||
In case of any error, this method will throw an exception.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
shared_library_path : str
|
|
||||||
Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
|
|
||||||
model_path : str
|
|
||||||
Path to RWKV model file in ggml format.
|
|
||||||
thread_count : int
|
|
||||||
Thread count to use. If not set, defaults to CPU count / 2.
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert os.path.isfile(shared_library_path), f'{shared_library_path} is not a file'
|
|
||||||
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
|
||||||
assert thread_count > 0, 'Thread count must be positive'
|
|
||||||
|
|
||||||
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
|
|
||||||
|
|
||||||
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_int]
|
|
||||||
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
|
|
||||||
|
|
||||||
self.library.rwkv_eval.argtypes = [
|
|
||||||
ctypes.c_void_p, # ctx
|
|
||||||
ctypes.c_long, # token
|
|
||||||
P_FLOAT, # state_in
|
|
||||||
P_FLOAT, # state_out
|
|
||||||
P_FLOAT # logits_out
|
|
||||||
]
|
|
||||||
self.library.rwkv_eval.restype = ctypes.c_bool
|
|
||||||
|
|
||||||
self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
|
|
||||||
self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_size_t
|
|
||||||
|
|
||||||
self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
|
|
||||||
self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_size_t
|
|
||||||
|
|
||||||
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
|
||||||
self.library.rwkv_free.restype = None
|
|
||||||
|
|
||||||
self.ctx = self.library.rwkv_init_from_file(model_path.encode('utf-8'), ctypes.c_int(thread_count))
|
|
||||||
|
|
||||||
assert self.ctx is not None, 'Failed to load the model, see stderr'
|
|
||||||
|
|
||||||
self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx)
|
|
||||||
self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx)
|
|
||||||
|
|
||||||
self.valid = True
|
|
||||||
|
|
||||||
def eval(self, token: int, state_in: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Evaluates the model for a single token.
|
|
||||||
In case of any error, this method will throw an exception.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
token : int
|
|
||||||
Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
|
|
||||||
state_in : Optional[torch.Tensor]
|
|
||||||
State from previous call of this method. If this is a first pass, set it to None.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
logits, state
|
|
||||||
Logits vector of shape (n_vocab); state for the next step.
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert self.valid, 'Model was freed'
|
|
||||||
|
|
||||||
if state_in is None:
|
|
||||||
state_in_ptr = 0
|
|
||||||
else:
|
|
||||||
expected_shape = (self.state_buffer_element_count,)
|
|
||||||
|
|
||||||
assert state_in.is_contiguous(), 'State tensor is not contiguous'
|
|
||||||
assert state_in.shape == expected_shape, f'Invalid state shape {state_in.shape}, expected {expected_shape}'
|
|
||||||
|
|
||||||
state_in_ptr = state_in.storage().data_ptr()
|
|
||||||
|
|
||||||
# TODO Probably these allocations can be optimized away
|
|
||||||
state_out: torch.Tensor = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu')
|
|
||||||
logits_out: torch.Tensor = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu')
|
|
||||||
|
|
||||||
result = self.library.rwkv_eval(
|
|
||||||
self.ctx,
|
|
||||||
ctypes.c_long(token),
|
|
||||||
ctypes.cast(state_in_ptr, P_FLOAT),
|
|
||||||
ctypes.cast(state_out.storage().data_ptr(), P_FLOAT),
|
|
||||||
ctypes.cast(logits_out.storage().data_ptr(), P_FLOAT)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result, 'Inference failed, see stderr'
|
|
||||||
|
|
||||||
return logits_out, state_out
|
|
||||||
|
|
||||||
def free(self):
|
|
||||||
"""
|
|
||||||
Frees all allocated resources.
|
|
||||||
In case of any error, this method will throw an exception.
|
|
||||||
The object must not be used anymore after calling this method.
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert self.valid, 'Already freed'
|
|
||||||
|
|
||||||
self.valid = False
|
|
||||||
|
|
||||||
self.library.rwkv_free(self.ctx)
|
|
|
@ -0,0 +1,117 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import multiprocessing
|
||||||
|
import rwkv_cpp_shared_library
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
|
class RWKVModel:
|
||||||
|
"""
|
||||||
|
PyTorch wrapper around rwkv.cpp model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary,
|
||||||
|
model_path: str,
|
||||||
|
thread_count: int = max(1, multiprocessing.cpu_count() // 2)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Loads the model and prepares it for inference.
|
||||||
|
In case of any error, this method will throw an exception.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
shared_library : RWKVSharedLibrary
|
||||||
|
rwkv.cpp shared library.
|
||||||
|
model_path : str
|
||||||
|
Path to RWKV model file in ggml format.
|
||||||
|
thread_count : int
|
||||||
|
Thread count to use. If not set, defaults to CPU count / 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
||||||
|
assert thread_count > 0, 'Thread count must be positive'
|
||||||
|
|
||||||
|
self.library = shared_library
|
||||||
|
|
||||||
|
self.ctx = self.library.rwkv_init_from_file(model_path, thread_count)
|
||||||
|
|
||||||
|
self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx)
|
||||||
|
self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx)
|
||||||
|
|
||||||
|
self.valid = True
|
||||||
|
|
||||||
|
def eval(
|
||||||
|
self,
|
||||||
|
token: int,
|
||||||
|
state_in: Optional[torch.Tensor],
|
||||||
|
state_out: Optional[torch.Tensor] = None,
|
||||||
|
logits_out: Optional[torch.Tensor] = None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Evaluates the model for a single token.
|
||||||
|
In case of any error, this method will throw an exception.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
token : int
|
||||||
|
Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
|
||||||
|
state_in : Optional[torch.Tensor]
|
||||||
|
State from previous call of this method. If this is a first pass, set it to None.
|
||||||
|
state_out : Optional[torch.Tensor]
|
||||||
|
Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
|
||||||
|
logits_out : Optional[torch.Tensor]
|
||||||
|
Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
logits, state
|
||||||
|
Logits vector of shape (n_vocab); state for the next step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert self.valid, 'Model was freed'
|
||||||
|
|
||||||
|
def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None:
|
||||||
|
assert buf.dtype == torch.float32, f'{name} is not of type float32'
|
||||||
|
assert buf.is_contiguous(), f'{name} is not contiguous'
|
||||||
|
assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})'
|
||||||
|
|
||||||
|
if state_in is not None:
|
||||||
|
validate_buffer(state_in, 'state_in', self.state_buffer_element_count)
|
||||||
|
|
||||||
|
state_in_ptr = state_in.storage().data_ptr()
|
||||||
|
else:
|
||||||
|
state_in_ptr = 0
|
||||||
|
|
||||||
|
if state_out is not None:
|
||||||
|
validate_buffer(state_out, 'state_out', self.state_buffer_element_count)
|
||||||
|
else:
|
||||||
|
state_out = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||||
|
|
||||||
|
if logits_out is not None:
|
||||||
|
validate_buffer(logits_out, 'logits_out', self.logits_buffer_element_count)
|
||||||
|
else:
|
||||||
|
logits_out = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||||
|
|
||||||
|
self.library.rwkv_eval(
|
||||||
|
self.ctx,
|
||||||
|
token,
|
||||||
|
state_in_ptr,
|
||||||
|
state_out.storage().data_ptr(),
|
||||||
|
logits_out.storage().data_ptr()
|
||||||
|
)
|
||||||
|
|
||||||
|
return logits_out, state_out
|
||||||
|
|
||||||
|
def free(self):
|
||||||
|
"""
|
||||||
|
Frees all allocated resources.
|
||||||
|
In case of any error, this method will throw an exception.
|
||||||
|
The object must not be used anymore after calling this method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert self.valid, 'Already freed'
|
||||||
|
|
||||||
|
self.valid = False
|
||||||
|
|
||||||
|
self.library.rwkv_free(self.ctx)
|
|
@ -0,0 +1,204 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import ctypes
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
P_FLOAT = ctypes.POINTER(ctypes.c_float)
|
||||||
|
|
||||||
|
class RWKVContext:
|
||||||
|
|
||||||
|
def __init__(self, ptr: ctypes.pointer):
|
||||||
|
self.ptr = ptr
|
||||||
|
|
||||||
|
class RWKVSharedLibrary:
|
||||||
|
"""
|
||||||
|
Python wrapper around rwkv.cpp shared library.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, shared_library_path: str):
|
||||||
|
"""
|
||||||
|
Loads the shared library from specified file.
|
||||||
|
In case of any error, this method will throw an exception.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
shared_library_path : str
|
||||||
|
Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
|
||||||
|
|
||||||
|
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
|
||||||
|
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
|
||||||
|
|
||||||
|
self.library.rwkv_eval.argtypes = [
|
||||||
|
ctypes.c_void_p, # ctx
|
||||||
|
ctypes.c_int32, # token
|
||||||
|
P_FLOAT, # state_in
|
||||||
|
P_FLOAT, # state_out
|
||||||
|
P_FLOAT # logits_out
|
||||||
|
]
|
||||||
|
self.library.rwkv_eval.restype = ctypes.c_bool
|
||||||
|
|
||||||
|
self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
|
||||||
|
self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_uint32
|
||||||
|
|
||||||
|
self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
|
||||||
|
self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_uint32
|
||||||
|
|
||||||
|
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
||||||
|
self.library.rwkv_free.restype = None
|
||||||
|
|
||||||
|
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
||||||
|
self.library.rwkv_free.restype = None
|
||||||
|
|
||||||
|
self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_uint32]
|
||||||
|
self.library.rwkv_quantize_model_file.restype = ctypes.c_bool
|
||||||
|
|
||||||
|
self.library.rwkv_get_system_info_string.argtypes = []
|
||||||
|
self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p
|
||||||
|
|
||||||
|
def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext:
|
||||||
|
"""
|
||||||
|
Loads the model from a file and prepares it for inference.
|
||||||
|
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_file_path : str
|
||||||
|
Path to model file in ggml format.
|
||||||
|
thread_count : int
|
||||||
|
Count of threads to use, must be positive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count))
|
||||||
|
assert ptr is not None, 'rwkv_init_from_file failed, check stderr'
|
||||||
|
return RWKVContext(ptr)
|
||||||
|
|
||||||
|
def rwkv_eval(
|
||||||
|
self,
|
||||||
|
ctx: RWKVContext,
|
||||||
|
token: int,
|
||||||
|
state_in_address: Optional[int],
|
||||||
|
state_out_address: int,
|
||||||
|
logits_out_address: int
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Evaluates the model for a single token.
|
||||||
|
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ctx : RWKVContext
|
||||||
|
RWKV context obtained from rwkv_init_from_file.
|
||||||
|
token : int
|
||||||
|
Next token index, in range 0 <= token < n_vocab.
|
||||||
|
state_in_address : int
|
||||||
|
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
|
||||||
|
state_out_address : int
|
||||||
|
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
||||||
|
logits_out_address : int
|
||||||
|
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert self.library.rwkv_eval(
|
||||||
|
ctx.ptr,
|
||||||
|
ctypes.c_int32(token),
|
||||||
|
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
||||||
|
ctypes.cast(state_out_address, P_FLOAT),
|
||||||
|
ctypes.cast(logits_out_address, P_FLOAT)
|
||||||
|
), 'rwkv_eval failed, check stderr'
|
||||||
|
|
||||||
|
def rwkv_get_state_buffer_element_count(self, ctx: RWKVContext) -> int:
|
||||||
|
"""
|
||||||
|
Returns count of FP32 elements in state buffer.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ctx : RWKVContext
|
||||||
|
RWKV context obtained from rwkv_init_from_file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.library.rwkv_get_state_buffer_element_count(ctx.ptr)
|
||||||
|
|
||||||
|
def rwkv_get_logits_buffer_element_count(self, ctx: RWKVContext) -> int:
|
||||||
|
"""
|
||||||
|
Returns count of FP32 elements in logits buffer.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ctx : RWKVContext
|
||||||
|
RWKV context obtained from rwkv_init_from_file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.library.rwkv_get_logits_buffer_element_count(ctx.ptr)
|
||||||
|
|
||||||
|
def rwkv_free(self, ctx: RWKVContext) -> None:
|
||||||
|
"""
|
||||||
|
Frees all allocated memory and the context.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ctx : RWKVContext
|
||||||
|
RWKV context obtained from rwkv_init_from_file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.library.rwkv_free(ctx.ptr)
|
||||||
|
|
||||||
|
ctx.ptr = ctypes.cast(0, ctypes.c_void_p)
|
||||||
|
|
||||||
|
def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, q_type: int) -> None:
|
||||||
|
"""
|
||||||
|
Quantizes FP32 or FP16 model to one of INT4 formats.
|
||||||
|
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_file_path_in : str
|
||||||
|
Path to model file in ggml format, must be either FP32 or FP16.
|
||||||
|
model_file_path_out : str
|
||||||
|
Quantized model will be written here.
|
||||||
|
q_type : int
|
||||||
|
Set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert self.library.rwkv_quantize_model_file(
|
||||||
|
model_file_path_in.encode('utf-8'),
|
||||||
|
model_file_path_out.encode('utf-8'),
|
||||||
|
ctypes.c_uint32(q_type)
|
||||||
|
), 'rwkv_quantize_model_file failed, check stderr'
|
||||||
|
|
||||||
|
def rwkv_get_system_info_string(self) -> str:
|
||||||
|
"""
|
||||||
|
Returns system information string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.library.rwkv_get_system_info_string()
|
||||||
|
|
||||||
|
def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
||||||
|
"""
|
||||||
|
Attempts to find rwkv.cpp shared library and load it.
|
||||||
|
To specify exact path to the library, create an instance of RWKVSharedLibrary explicitly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
file_name: str
|
||||||
|
|
||||||
|
if 'win32' in sys.platform or 'cygwin' in sys.platform:
|
||||||
|
file_name = 'rwkv.dll'
|
||||||
|
else:
|
||||||
|
file_name = 'rwkv.so'
|
||||||
|
|
||||||
|
paths = [
|
||||||
|
# If we are in "rwkv" directory
|
||||||
|
f'../bin/Release/{file_name}',
|
||||||
|
# If we are in repo root directory
|
||||||
|
f'bin/Release/{file_name}',
|
||||||
|
# Fallback
|
||||||
|
file_name
|
||||||
|
]
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
if os.path.isfile(path):
|
||||||
|
return RWKVSharedLibrary(path)
|
||||||
|
|
||||||
|
return RWKVSharedLibrary(paths[-1])
|
Loading…
Reference in New Issue