128 lines
4.5 KiB
Python
128 lines
4.5 KiB
Python
import os
|
|
import ctypes
|
|
import torch
|
|
import multiprocessing
|
|
from typing import Tuple, Optional
|
|
|
|
P_FLOAT = ctypes.POINTER(ctypes.c_float)
|
|
|
|
class RWKVModel:
|
|
"""
|
|
PyTorch wrapper around rwkv.cpp shared library.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
shared_library_path: str,
|
|
model_path: str,
|
|
thread_count: int = max(1, multiprocessing.cpu_count() // 2)
|
|
):
|
|
"""
|
|
Loads the model and prepares it for inference.
|
|
In case of any error, this method will throw an exception.
|
|
|
|
Parameters
|
|
----------
|
|
shared_library_path : str
|
|
Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
|
|
model_path : str
|
|
Path to RWKV model file in ggml format.
|
|
thread_count : int
|
|
Thread count to use. If not set, defaults to CPU count / 2.
|
|
"""
|
|
|
|
assert os.path.isfile(shared_library_path), f'{shared_library_path} is not a file'
|
|
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
|
assert thread_count > 0, 'Thread count must be positive'
|
|
|
|
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
|
|
|
|
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_int]
|
|
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
|
|
|
|
self.library.rwkv_eval.argtypes = [
|
|
ctypes.c_void_p, # ctx
|
|
ctypes.c_long, # token
|
|
P_FLOAT, # state_in
|
|
P_FLOAT, # state_out
|
|
P_FLOAT # logits_out
|
|
]
|
|
self.library.rwkv_eval.restype = ctypes.c_bool
|
|
|
|
self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
|
|
self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_size_t
|
|
|
|
self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
|
|
self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_size_t
|
|
|
|
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
|
self.library.rwkv_free.restype = None
|
|
|
|
self.ctx = self.library.rwkv_init_from_file(model_path.encode('utf-8'), ctypes.c_int(thread_count))
|
|
|
|
assert self.ctx is not None, 'Failed to load the model, see stderr'
|
|
|
|
self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx)
|
|
self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx)
|
|
|
|
self.valid = True
|
|
|
|
def eval(self, token: int, state_in: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Evaluates the model for a single token.
|
|
In case of any error, this method will throw an exception.
|
|
|
|
Parameters
|
|
----------
|
|
token : int
|
|
Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
|
|
state_in : Optional[torch.Tensor]
|
|
State from previous call of this method. If this is a first pass, set it to None.
|
|
|
|
Returns
|
|
-------
|
|
logits, state
|
|
Logits vector of shape (n_vocab); state for the next step.
|
|
"""
|
|
|
|
assert self.valid, 'Model was freed'
|
|
|
|
if state_in is None:
|
|
state_in_ptr = 0
|
|
else:
|
|
expected_shape = (self.state_buffer_element_count,)
|
|
|
|
assert state_in.is_contiguous(), 'State tensor is not contiguous'
|
|
assert state_in.shape == expected_shape, f'Invalid state shape {state_in.shape}, expected {expected_shape}'
|
|
|
|
state_in_ptr = state_in.storage().data_ptr()
|
|
|
|
# TODO Probably these allocations can be optimized away
|
|
state_out: torch.Tensor = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu')
|
|
logits_out: torch.Tensor = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu')
|
|
|
|
result = self.library.rwkv_eval(
|
|
self.ctx,
|
|
ctypes.c_long(token),
|
|
ctypes.cast(state_in_ptr, P_FLOAT),
|
|
ctypes.cast(state_out.storage().data_ptr(), P_FLOAT),
|
|
ctypes.cast(logits_out.storage().data_ptr(), P_FLOAT)
|
|
)
|
|
|
|
assert result, 'Inference failed, see stderr'
|
|
|
|
return logits_out, state_out
|
|
|
|
def free(self):
|
|
"""
|
|
Frees all allocated resources.
|
|
In case of any error, this method will throw an exception.
|
|
The object must not be used anymore after calling this method.
|
|
"""
|
|
|
|
assert self.valid, 'Already freed'
|
|
|
|
self.valid = False
|
|
|
|
self.library.rwkv_free(self.ctx)
|