rwkv.cpp/rwkv/rwkv_cpp.py

128 lines
4.5 KiB
Python

import os
import ctypes
import torch
import multiprocessing
from typing import Tuple, Optional
P_FLOAT = ctypes.POINTER(ctypes.c_float)
class RWKVModel:
"""
PyTorch wrapper around rwkv.cpp shared library.
"""
def __init__(
self,
shared_library_path: str,
model_path: str,
thread_count: int = max(1, multiprocessing.cpu_count() // 2)
):
"""
Loads the model and prepares it for inference.
In case of any error, this method will throw an exception.
Parameters
----------
shared_library_path : str
Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
model_path : str
Path to RWKV model file in ggml format.
thread_count : int
Thread count to use. If not set, defaults to CPU count / 2.
"""
assert os.path.isfile(shared_library_path), f'{shared_library_path} is not a file'
assert os.path.isfile(model_path), f'{model_path} is not a file'
assert thread_count > 0, 'Thread count must be positive'
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_int]
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
self.library.rwkv_eval.argtypes = [
ctypes.c_void_p, # ctx
ctypes.c_long, # token
P_FLOAT, # state_in
P_FLOAT, # state_out
P_FLOAT # logits_out
]
self.library.rwkv_eval.restype = ctypes.c_bool
self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_size_t
self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_size_t
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
self.library.rwkv_free.restype = None
self.ctx = self.library.rwkv_init_from_file(model_path.encode('utf-8'), ctypes.c_int(thread_count))
assert self.ctx is not None, 'Failed to load the model, see stderr'
self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx)
self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx)
self.valid = True
def eval(self, token: int, state_in: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Evaluates the model for a single token.
In case of any error, this method will throw an exception.
Parameters
----------
token : int
Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
state_in : Optional[torch.Tensor]
State from previous call of this method. If this is a first pass, set it to None.
Returns
-------
logits, state
Logits vector of shape (n_vocab); state for the next step.
"""
assert self.valid, 'Model was freed'
if state_in is None:
state_in_ptr = 0
else:
expected_shape = (self.state_buffer_element_count,)
assert state_in.is_contiguous(), 'State tensor is not contiguous'
assert state_in.shape == expected_shape, f'Invalid state shape {state_in.shape}, expected {expected_shape}'
state_in_ptr = state_in.storage().data_ptr()
# TODO Probably these allocations can be optimized away
state_out: torch.Tensor = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu')
logits_out: torch.Tensor = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu')
result = self.library.rwkv_eval(
self.ctx,
ctypes.c_long(token),
ctypes.cast(state_in_ptr, P_FLOAT),
ctypes.cast(state_out.storage().data_ptr(), P_FLOAT),
ctypes.cast(logits_out.storage().data_ptr(), P_FLOAT)
)
assert result, 'Inference failed, see stderr'
return logits_out, state_out
def free(self):
"""
Frees all allocated resources.
In case of any error, this method will throw an exception.
The object must not be used anymore after calling this method.
"""
assert self.valid, 'Already freed'
self.valid = False
self.library.rwkv_free(self.ctx)