import os import torch import multiprocessing import rwkv_cpp_shared_library from typing import Tuple, Optional class RWKVModel: """ PyTorch wrapper around rwkv.cpp model. """ def __init__( self, shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary, model_path: str, thread_count: int = max(1, multiprocessing.cpu_count() // 2) ): """ Loads the model and prepares it for inference. In case of any error, this method will throw an exception. Parameters ---------- shared_library : RWKVSharedLibrary rwkv.cpp shared library. model_path : str Path to RWKV model file in ggml format. thread_count : int Thread count to use. If not set, defaults to CPU count / 2. """ assert os.path.isfile(model_path), f'{model_path} is not a file' assert thread_count > 0, 'Thread count must be positive' self._library = shared_library self._ctx = self._library.rwkv_init_from_file(model_path, thread_count) self._state_buffer_element_count = self._library.rwkv_get_state_buffer_element_count(self._ctx) self._logits_buffer_element_count = self._library.rwkv_get_logits_buffer_element_count(self._ctx) self._valid = True def eval( self, token: int, state_in: Optional[torch.Tensor], state_out: Optional[torch.Tensor] = None, logits_out: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Evaluates the model for a single token. In case of any error, this method will throw an exception. Parameters ---------- token : int Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab. state_in : Optional[torch.Tensor] State from previous call of this method. If this is a first pass, set it to None. state_out : Optional[torch.Tensor] Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count). logits_out : Optional[torch.Tensor] Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count). Returns ------- logits, state Logits vector of shape (n_vocab); state for the next step. """ assert self._valid, 'Model was freed' def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None: assert buf.dtype == torch.float32, f'{name} is not of type float32' assert buf.is_contiguous(), f'{name} is not contiguous' assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})' if state_in is not None: validate_buffer(state_in, 'state_in', self._state_buffer_element_count) state_in_ptr = state_in.storage().data_ptr() else: state_in_ptr = 0 if state_out is not None: validate_buffer(state_out, 'state_out', self._state_buffer_element_count) else: state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu') if logits_out is not None: validate_buffer(logits_out, 'logits_out', self._logits_buffer_element_count) else: logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu') self._library.rwkv_eval( self._ctx, token, state_in_ptr, state_out.storage().data_ptr(), logits_out.storage().data_ptr() ) return logits_out, state_out def free(self): """ Frees all allocated resources. In case of any error, this method will throw an exception. The object must not be used anymore after calling this method. """ assert self._valid, 'Already freed' self._valid = False self._library.rwkv_free(self._ctx) def __del__(self): # Free the context on GC in case user forgot to call free() explicitly. if hasattr(self, '_valid') and self._valid: self.free()