Add Python wrapper for C library

This commit is contained in:
saharNooby 2023-04-01 16:02:22 +04:00
parent 7130a89d1f
commit a1e1d34c93
6 changed files with 206 additions and 33 deletions

View File

@ -247,9 +247,17 @@ add_library(rwkv
target_include_directories(llama PUBLIC .) target_include_directories(llama PUBLIC .)
target_compile_features(llama PUBLIC cxx_std_11) # don't bump target_compile_features(llama PUBLIC cxx_std_11) # don't bump
target_link_libraries(llama PRIVATE ggml ${LLAMA_EXTRA_LIBS}) target_link_libraries(llama PRIVATE ggml ${LLAMA_EXTRA_LIBS})
target_include_directories(rwkv PUBLIC .)
target_compile_features(rwkv PUBLIC cxx_std_11) # don't bump
target_link_libraries(rwkv PRIVATE ggml ${LLAMA_EXTRA_LIBS})
if (BUILD_SHARED_LIBS) if (BUILD_SHARED_LIBS)
set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD) target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD)
set_target_properties(rwkv PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(rwkv PRIVATE LLAMA_SHARED LLAMA_BUILD)
endif() endif()
# #

View File

@ -243,6 +243,9 @@ main: examples/main/main.cpp ggml.o llama.o common.o
@echo '==== Run ./main -h for help. ====' @echo '==== Run ./main -h for help. ===='
@echo @echo
main_rwkv: examples/main_rwkv/main_rwkv.cpp ggml.o rwkv.o common.o
$(CXX) $(CXXFLAGS) examples/main_rwkv/main_rwkv.cpp ggml.o rwkv.o common.o -o main_rwkv $(LDFLAGS)
quantize: examples/quantize/quantize.cpp ggml.o llama.o quantize: examples/quantize/quantize.cpp ggml.o llama.o
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp ggml.o llama.o -o quantize $(LDFLAGS) $(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp ggml.o llama.o -o quantize $(LDFLAGS)

View File

@ -2,21 +2,21 @@
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). The end goal is to allow 4-bit quanized inference on CPU. This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). The end goal is to allow 4-bit quanized inference on CPU.
**WORK IN PROGRESS!** **Status**: FP32 and FP16 inference work correctly. Currently, I'm working on creating usable C library interface and Python wrapper for it. **WORK IN PROGRESS!** **Status**: There is a Python wrapper, FP32 and FP16 inference work correctly. Currently, I'm working on INT4 quantization support.
## Plan ## Plan
1. Create proper interface (probably, C library) 1. Make INT4 inference work
2. Create Python wrapper with sampling and simple chat interface 2. Create Python script with sampling and simple chat interface
3. Write a good `README.md` and publish links to this repo 3. Clean up the repo (remove llama related files and mentions)
4. Make INT4 inference work 4. Write a good `README.md` and publish links to this repo
5. Create pull request to main `ggml` repo with all improvements made here 5. Create pull request to main `ggml` repo with all improvements made here
## Structure ## Structure
This repo is based on the [llama.cpp repo](https://github.com/ggerganov/llama.cpp). RWKV-related code is in these directories: This repo is based on the [llama.cpp repo](https://github.com/ggerganov/llama.cpp). RWKV-related code is in these directories:
- `./rwkv`: directory containing Python scripts for conversion and validation - `./rwkv`: directory containing Python scripts for conversion, inference and validation
- `./examples/main_rwkw`: directory containing script that loads and infers RWKV model - `./examples/main_rwkw`: directory containing script that loads and infers RWKV model
Please do not change files in other directories — this will make pulling recent changes easier. Please do not change files in other directories — this will make pulling recent changes easier.
@ -27,25 +27,39 @@ Please do not change files in other directories — this will make pulling recen
Requirements: [git](https://gitforwindows.org/), [CMake](https://cmake.org/download/), MSVC compiler, Python 3.x with PyTorch. Requirements: [git](https://gitforwindows.org/), [CMake](https://cmake.org/download/), MSVC compiler, Python 3.x with PyTorch.
Clone the repo and set it up for build: #### 1. Clone the repo and build it:
```commandline ```commandline
git clone https://github.com/saharNooby/rwkv.cpp.git git clone https://github.com/saharNooby/rwkv.cpp.git
cd rwkv.cpp cd rwkv.cpp
cmake . cmake -DBUILD_SHARED_LIBS=ON -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF .
cmake --build . --config Release
``` ```
Download an RWKV model from [Huggingface](https://huggingface.co/BlinkDL) and convert it into `ggml` format: If everything went OK, `bin\Release\rwkv.dll` file should appear.
#### 2. Download an RWKV model from [Huggingface](https://huggingface.co/BlinkDL) and convert it into `ggml` format:
```commandline ```commandline
python rwkv\convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32 python rwkv\convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32
``` ```
Compile and run the script: #### 3. Use the model in Python:
```python
# This file is located at rwkv/rwkv_cpp.py
import rwkv_cpp
model = rwkv_cpp.RWKVModel(r'bin\Release\rwkv.dll', r'C:\rwkv.cpp-169M.bin')
logits, state = None, None
for token in [1, 2, 3]:
logits, state = model.eval(token, state)
print(f'Output logits: {logits}')
# Don't forget to free memory after you've done working with the model
model.free()
```commandline
cmake --build . --config Release
bin\Release\main_rwkv.exe "C:\rwkv.cpp-169M.bin" 123 "C:\state_in.bin" "C:\state_out.bin" "C:\logits_out.bin"
``` ```
The script will read state from `state_in.bin`, do single inference using the state and token `123` as an input, save new state into `state_out.bin` and logits into `logits_out.bin`.

6
rwkv.h
View File

@ -29,13 +29,13 @@ extern "C" {
struct rwkv_context; struct rwkv_context;
// Loads the model from a file and prepares it for inference by allocating memory and building computation graph. // Loads the model from a file and prepares it for inference.
// Returns NULL on any error. Error messages would be printed to stderr. // Returns NULL on any error. Error messages would be printed to stderr.
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads); RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads);
// Evaluates the model for a single pass. // Evaluates the model for a single token.
// Returns false on any error. Error messages would be printed to stderr. // Returns false on any error. Error messages would be printed to stderr.
// - token: next token index, in range 0..n_vocab - 1. // - token: next token index, in range 0 <= token < n_vocab.
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass. // - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to. // - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
// - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. // - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.

View File

@ -9,11 +9,12 @@ import argparse
import subprocess import subprocess
import torch import torch
import numpy as np import numpy as np
import rwkv_cpp
from typing import List, Tuple, Any from typing import List, Tuple, Any
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV') parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file') parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file or shared library')
parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file') parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file')
return parser.parse_args() return parser.parse_args()
@ -48,28 +49,42 @@ def main() -> None:
# FP16, lower precision, so higher threshold # FP16, lower precision, so higher threshold
threshold = 0.003 threshold = 0.003
model = None
if args.main_executable_path.lower().endswith('.dll') or args.main_executable_path.lower().endswith('.so'):
print('Testing shared library')
model = rwkv_cpp.RWKVModel(args.main_executable_path, args.ggml_model_path)
else:
print('Testing main_rwkv executable')
def compare_logits(tokens_subset: List[int]) -> None: def compare_logits(tokens_subset: List[int]) -> None:
token_count: int = len(tokens_subset) token_count: int = len(tokens_subset)
state_path: str = './state.bin' state_path: str = './state.bin'
logits_path: str = './logits.bin' logits_path: str = './logits.bin'
logits, state = None, None
for i in range(token_count): for i in range(token_count):
token: int = tokens_subset[i] token: int = tokens_subset[i]
print(f'{i + 1}/{token_count}') print(f'{i + 1}/{token_count}')
subprocess.run( if model is not None:
[ logits, state = model.eval(token, state)
args.main_executable_path, else:
args.ggml_model_path, subprocess.run(
str(token), [
# If this is the first token, let the script create a new state. args.main_executable_path,
'' if i == 0 else state_path, args.ggml_model_path,
state_path, str(token),
logits_path # If this is the first token, let the script create a new state.
], '' if i == 0 else state_path,
check=True state_path,
) logits_path
],
check=True
)
expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin' expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin'
@ -79,8 +94,11 @@ def main() -> None:
with open(expected_logits_path, 'rb') as logits_file: with open(expected_logits_path, 'rb') as logits_file:
expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
with open(logits_path, 'rb') as logits_file: if model is not None:
actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) actual_logits = logits
else:
with open(logits_path, 'rb') as logits_file:
actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item() difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item()
@ -97,5 +115,8 @@ def main() -> None:
print() print()
print('Test passes') print('Test passes')
if model is not None:
model.free()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

127
rwkv/rwkv_cpp.py Normal file
View File

@ -0,0 +1,127 @@
import os
import ctypes
import torch
import multiprocessing
from typing import Tuple, Optional
P_FLOAT = ctypes.POINTER(ctypes.c_float)
class RWKVModel:
"""
PyTorch wrapper around rwkv.cpp shared library.
"""
def __init__(
self,
shared_library_path: str,
model_path: str,
thread_count: int = max(1, multiprocessing.cpu_count() // 2)
):
"""
Loads the model and prepares it for inference.
In case of any error, this method will throw an exception.
Parameters
----------
shared_library_path : str
Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
model_path : str
Path to RWKV model file in ggml format.
thread_count : int
Thread count to use. If not set, defaults to CPU count / 2.
"""
assert os.path.isfile(shared_library_path), f'{shared_library_path} is not a file'
assert os.path.isfile(model_path), f'{model_path} is not a file'
assert thread_count > 0, 'Thread count must be positive'
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_int]
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
self.library.rwkv_eval.argtypes = [
ctypes.c_void_p, # ctx
ctypes.c_long, # token
P_FLOAT, # state_in
P_FLOAT, # state_out
P_FLOAT # logits_out
]
self.library.rwkv_eval.restype = ctypes.c_bool
self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_size_t
self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_size_t
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
self.library.rwkv_free.restype = None
self.ctx = self.library.rwkv_init_from_file(model_path.encode('utf-8'), ctypes.c_int(thread_count))
assert self.ctx is not None, 'Failed to load the model, see stderr'
self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx)
self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx)
self.valid = True
def eval(self, token: int, state_in: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Evaluates the model for a single token.
In case of any error, this method will throw an exception.
Parameters
----------
token : int
Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
state_in : Optional[torch.Tensor]
State from previous call of this method. If this is a first pass, set it to None.
Returns
-------
logits, state
Logits vector of shape (n_vocab); state for the next step.
"""
assert self.valid, 'Model was freed'
if state_in is None:
state_in_ptr = 0
else:
expected_shape = (self.state_buffer_element_count,)
assert state_in.is_contiguous(), 'State tensor is not contiguous'
assert state_in.shape == expected_shape, f'Invalid state shape {state_in.shape}, expected {expected_shape}'
state_in_ptr = state_in.storage().data_ptr()
# TODO Probably these allocations can be optimized away
state_out: torch.Tensor = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu')
logits_out: torch.Tensor = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu')
result = self.library.rwkv_eval(
self.ctx,
ctypes.c_long(token),
ctypes.cast(state_in_ptr, P_FLOAT),
ctypes.cast(state_out.storage().data_ptr(), P_FLOAT),
ctypes.cast(logits_out.storage().data_ptr(), P_FLOAT)
)
assert result, 'Inference failed, see stderr'
return logits_out, state_out
def free(self):
"""
Frees all allocated resources.
In case of any error, this method will throw an exception.
The object must not be used anymore after calling this method.
"""
assert self.valid, 'Already freed'
self.valid = False
self.library.rwkv_free(self.ctx)