246 lines
8.4 KiB
Python
246 lines
8.4 KiB
Python
import os
|
|
import sys
|
|
import ctypes
|
|
import pathlib
|
|
from typing import Optional
|
|
|
|
QUANTIZED_FORMAT_NAMES = (
|
|
'Q4_0',
|
|
'Q4_1',
|
|
'Q5_0',
|
|
'Q5_1',
|
|
'Q8_0'
|
|
)
|
|
|
|
P_FLOAT = ctypes.POINTER(ctypes.c_float)
|
|
|
|
class RWKVContext:
|
|
|
|
def __init__(self, ptr: ctypes.pointer):
|
|
self.ptr = ptr
|
|
|
|
class RWKVSharedLibrary:
|
|
"""
|
|
Python wrapper around rwkv.cpp shared library.
|
|
"""
|
|
|
|
def __init__(self, shared_library_path: str):
|
|
"""
|
|
Loads the shared library from specified file.
|
|
In case of any error, this method will throw an exception.
|
|
|
|
Parameters
|
|
----------
|
|
shared_library_path : str
|
|
Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
|
|
"""
|
|
|
|
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
|
|
|
|
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
|
|
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
|
|
|
|
self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32]
|
|
self.library.rwkv_gpu_offload_layers.restype = ctypes.c_bool
|
|
|
|
self.library.rwkv_eval.argtypes = [
|
|
ctypes.c_void_p, # ctx
|
|
ctypes.c_int32, # token
|
|
P_FLOAT, # state_in
|
|
P_FLOAT, # state_out
|
|
P_FLOAT # logits_out
|
|
]
|
|
self.library.rwkv_eval.restype = ctypes.c_bool
|
|
|
|
self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
|
|
self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_uint32
|
|
|
|
self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
|
|
self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_uint32
|
|
|
|
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
|
self.library.rwkv_free.restype = None
|
|
|
|
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
|
self.library.rwkv_free.restype = None
|
|
|
|
self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p]
|
|
self.library.rwkv_quantize_model_file.restype = ctypes.c_bool
|
|
|
|
self.library.rwkv_get_system_info_string.argtypes = []
|
|
self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p
|
|
|
|
def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext:
|
|
"""
|
|
Loads the model from a file and prepares it for inference.
|
|
Throws an exception in case of any error. Error messages would be printed to stderr.
|
|
|
|
Parameters
|
|
----------
|
|
model_file_path : str
|
|
Path to model file in ggml format.
|
|
thread_count : int
|
|
Count of threads to use, must be positive.
|
|
gpu_layers_count : int
|
|
Count of layers to load on gpu, must be positive only enabled with cuBLAS.
|
|
"""
|
|
|
|
ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'),
|
|
ctypes.c_uint32(thread_count))
|
|
assert ptr is not None, 'rwkv_init_from_file failed, check stderr'
|
|
return RWKVContext(ptr)
|
|
|
|
def rwkv_gpu_offload_layers(self, ctx: RWKVContext, gpu_layers_count: int) -> None:
|
|
"""
|
|
Offloads specified layers of context onto GPU using cuBLAS, if it is enabled.
|
|
If rwkv.cpp was compiled without cuBLAS support, this function is a no-op.
|
|
|
|
Parameters
|
|
----------
|
|
gpu_layers_count : int
|
|
Count of layers to load onto gpu, must be >= 0, only enabled with cuBLAS.
|
|
"""
|
|
|
|
assert self.library.rwkv_gpu_offload_layers(ctx.ptr, ctypes.c_uint32(gpu_layers_count)), 'rwkv_gpu_offload_layers failed, check stderr'
|
|
|
|
def rwkv_eval(
|
|
self,
|
|
ctx: RWKVContext,
|
|
token: int,
|
|
state_in_address: Optional[int],
|
|
state_out_address: int,
|
|
logits_out_address: int
|
|
) -> None:
|
|
"""
|
|
Evaluates the model for a single token.
|
|
Throws an exception in case of any error. Error messages would be printed to stderr.
|
|
|
|
Parameters
|
|
----------
|
|
ctx : RWKVContext
|
|
RWKV context obtained from rwkv_init_from_file.
|
|
token : int
|
|
Next token index, in range 0 <= token < n_vocab.
|
|
state_in_address : int
|
|
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
|
|
state_out_address : int
|
|
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
|
logits_out_address : int
|
|
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
|
"""
|
|
|
|
assert self.library.rwkv_eval(
|
|
ctx.ptr,
|
|
ctypes.c_int32(token),
|
|
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
|
ctypes.cast(state_out_address, P_FLOAT),
|
|
ctypes.cast(logits_out_address, P_FLOAT)
|
|
), 'rwkv_eval failed, check stderr'
|
|
|
|
def rwkv_get_state_buffer_element_count(self, ctx: RWKVContext) -> int:
|
|
"""
|
|
Returns count of FP32 elements in state buffer.
|
|
|
|
Parameters
|
|
----------
|
|
ctx : RWKVContext
|
|
RWKV context obtained from rwkv_init_from_file.
|
|
"""
|
|
|
|
return self.library.rwkv_get_state_buffer_element_count(ctx.ptr)
|
|
|
|
def rwkv_get_logits_buffer_element_count(self, ctx: RWKVContext) -> int:
|
|
"""
|
|
Returns count of FP32 elements in logits buffer.
|
|
|
|
Parameters
|
|
----------
|
|
ctx : RWKVContext
|
|
RWKV context obtained from rwkv_init_from_file.
|
|
"""
|
|
|
|
return self.library.rwkv_get_logits_buffer_element_count(ctx.ptr)
|
|
|
|
def rwkv_free(self, ctx: RWKVContext) -> None:
|
|
"""
|
|
Frees all allocated memory and the context.
|
|
|
|
Parameters
|
|
----------
|
|
ctx : RWKVContext
|
|
RWKV context obtained from rwkv_init_from_file.
|
|
"""
|
|
|
|
self.library.rwkv_free(ctx.ptr)
|
|
|
|
ctx.ptr = ctypes.cast(0, ctypes.c_void_p)
|
|
|
|
def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, format_name: str) -> None:
|
|
"""
|
|
Quantizes FP32 or FP16 model to one of INT4 formats.
|
|
Throws an exception in case of any error. Error messages would be printed to stderr.
|
|
|
|
Parameters
|
|
----------
|
|
model_file_path_in : str
|
|
Path to model file in ggml format, must be either FP32 or FP16.
|
|
model_file_path_out : str
|
|
Quantized model will be written here.
|
|
format_name : str
|
|
One of QUANTIZED_FORMAT_NAMES.
|
|
"""
|
|
|
|
assert format_name in QUANTIZED_FORMAT_NAMES, f'Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}'
|
|
|
|
assert self.library.rwkv_quantize_model_file(
|
|
model_file_path_in.encode('utf-8'),
|
|
model_file_path_out.encode('utf-8'),
|
|
format_name.encode('utf-8')
|
|
), 'rwkv_quantize_model_file failed, check stderr'
|
|
|
|
def rwkv_get_system_info_string(self) -> str:
|
|
"""
|
|
Returns system information string.
|
|
"""
|
|
|
|
return self.library.rwkv_get_system_info_string().decode('utf-8')
|
|
|
|
|
|
def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
|
"""
|
|
Attempts to find rwkv.cpp shared library and load it.
|
|
To specify exact path to the library, create an instance of RWKVSharedLibrary explicitly.
|
|
"""
|
|
|
|
file_name: str
|
|
|
|
if 'win32' in sys.platform or 'cygwin' in sys.platform:
|
|
file_name = 'rwkv.dll'
|
|
elif 'darwin' in sys.platform:
|
|
file_name = 'librwkv.dylib'
|
|
else:
|
|
file_name = 'librwkv.so'
|
|
|
|
repo_root_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent.parent
|
|
|
|
paths = [
|
|
# If we are in "rwkv" directory
|
|
f'../bin/Release/{file_name}',
|
|
# If we are in repo root directory
|
|
f'bin/Release/{file_name}',
|
|
# If we compiled in build directory
|
|
f'build/bin/Release/{file_name}',
|
|
# If we compiled in build directory
|
|
f'build/{file_name}',
|
|
# Search relative to this file
|
|
str(repo_root_dir / 'bin' / 'Release' / file_name),
|
|
# Fallback
|
|
str(repo_root_dir / file_name)
|
|
]
|
|
|
|
for path in paths:
|
|
if os.path.isfile(path):
|
|
return RWKVSharedLibrary(path)
|
|
|
|
return RWKVSharedLibrary(paths[-1])
|