Add Python wrapper for C library
This commit is contained in:
parent
7130a89d1f
commit
a1e1d34c93
|
@ -247,9 +247,17 @@ add_library(rwkv
|
|||
target_include_directories(llama PUBLIC .)
|
||||
target_compile_features(llama PUBLIC cxx_std_11) # don't bump
|
||||
target_link_libraries(llama PRIVATE ggml ${LLAMA_EXTRA_LIBS})
|
||||
|
||||
target_include_directories(rwkv PUBLIC .)
|
||||
target_compile_features(rwkv PUBLIC cxx_std_11) # don't bump
|
||||
target_link_libraries(rwkv PRIVATE ggml ${LLAMA_EXTRA_LIBS})
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD)
|
||||
|
||||
set_target_properties(rwkv PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(rwkv PRIVATE LLAMA_SHARED LLAMA_BUILD)
|
||||
endif()
|
||||
|
||||
#
|
||||
|
|
3
Makefile
3
Makefile
|
@ -243,6 +243,9 @@ main: examples/main/main.cpp ggml.o llama.o common.o
|
|||
@echo '==== Run ./main -h for help. ===='
|
||||
@echo
|
||||
|
||||
main_rwkv: examples/main_rwkv/main_rwkv.cpp ggml.o rwkv.o common.o
|
||||
$(CXX) $(CXXFLAGS) examples/main_rwkv/main_rwkv.cpp ggml.o rwkv.o common.o -o main_rwkv $(LDFLAGS)
|
||||
|
||||
quantize: examples/quantize/quantize.cpp ggml.o llama.o
|
||||
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp ggml.o llama.o -o quantize $(LDFLAGS)
|
||||
|
||||
|
|
44
README.md
44
README.md
|
@ -2,21 +2,21 @@
|
|||
|
||||
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). The end goal is to allow 4-bit quanized inference on CPU.
|
||||
|
||||
**WORK IN PROGRESS!** **Status**: FP32 and FP16 inference work correctly. Currently, I'm working on creating usable C library interface and Python wrapper for it.
|
||||
**WORK IN PROGRESS!** **Status**: There is a Python wrapper, FP32 and FP16 inference work correctly. Currently, I'm working on INT4 quantization support.
|
||||
|
||||
## Plan
|
||||
|
||||
1. Create proper interface (probably, C library)
|
||||
2. Create Python wrapper with sampling and simple chat interface
|
||||
3. Write a good `README.md` and publish links to this repo
|
||||
4. Make INT4 inference work
|
||||
1. Make INT4 inference work
|
||||
2. Create Python script with sampling and simple chat interface
|
||||
3. Clean up the repo (remove llama related files and mentions)
|
||||
4. Write a good `README.md` and publish links to this repo
|
||||
5. Create pull request to main `ggml` repo with all improvements made here
|
||||
|
||||
## Structure
|
||||
|
||||
This repo is based on the [llama.cpp repo](https://github.com/ggerganov/llama.cpp). RWKV-related code is in these directories:
|
||||
|
||||
- `./rwkv`: directory containing Python scripts for conversion and validation
|
||||
- `./rwkv`: directory containing Python scripts for conversion, inference and validation
|
||||
- `./examples/main_rwkw`: directory containing script that loads and infers RWKV model
|
||||
|
||||
Please do not change files in other directories — this will make pulling recent changes easier.
|
||||
|
@ -27,25 +27,39 @@ Please do not change files in other directories — this will make pulling recen
|
|||
|
||||
Requirements: [git](https://gitforwindows.org/), [CMake](https://cmake.org/download/), MSVC compiler, Python 3.x with PyTorch.
|
||||
|
||||
Clone the repo and set it up for build:
|
||||
#### 1. Clone the repo and build it:
|
||||
|
||||
```commandline
|
||||
git clone https://github.com/saharNooby/rwkv.cpp.git
|
||||
cd rwkv.cpp
|
||||
cmake .
|
||||
cmake -DBUILD_SHARED_LIBS=ON -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF .
|
||||
cmake --build . --config Release
|
||||
```
|
||||
|
||||
Download an RWKV model from [Huggingface](https://huggingface.co/BlinkDL) and convert it into `ggml` format:
|
||||
If everything went OK, `bin\Release\rwkv.dll` file should appear.
|
||||
|
||||
#### 2. Download an RWKV model from [Huggingface](https://huggingface.co/BlinkDL) and convert it into `ggml` format:
|
||||
|
||||
```commandline
|
||||
python rwkv\convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32
|
||||
```
|
||||
|
||||
Compile and run the script:
|
||||
#### 3. Use the model in Python:
|
||||
|
||||
```python
|
||||
# This file is located at rwkv/rwkv_cpp.py
|
||||
import rwkv_cpp
|
||||
|
||||
model = rwkv_cpp.RWKVModel(r'bin\Release\rwkv.dll', r'C:\rwkv.cpp-169M.bin')
|
||||
|
||||
logits, state = None, None
|
||||
|
||||
for token in [1, 2, 3]:
|
||||
logits, state = model.eval(token, state)
|
||||
|
||||
print(f'Output logits: {logits}')
|
||||
|
||||
# Don't forget to free memory after you've done working with the model
|
||||
model.free()
|
||||
|
||||
```commandline
|
||||
cmake --build . --config Release
|
||||
bin\Release\main_rwkv.exe "C:\rwkv.cpp-169M.bin" 123 "C:\state_in.bin" "C:\state_out.bin" "C:\logits_out.bin"
|
||||
```
|
||||
|
||||
The script will read state from `state_in.bin`, do single inference using the state and token `123` as an input, save new state into `state_out.bin` and logits into `logits_out.bin`.
|
||||
|
|
6
rwkv.h
6
rwkv.h
|
@ -29,13 +29,13 @@ extern "C" {
|
|||
|
||||
struct rwkv_context;
|
||||
|
||||
// Loads the model from a file and prepares it for inference by allocating memory and building computation graph.
|
||||
// Loads the model from a file and prepares it for inference.
|
||||
// Returns NULL on any error. Error messages would be printed to stderr.
|
||||
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads);
|
||||
|
||||
// Evaluates the model for a single pass.
|
||||
// Evaluates the model for a single token.
|
||||
// Returns false on any error. Error messages would be printed to stderr.
|
||||
// - token: next token index, in range 0..n_vocab - 1.
|
||||
// - token: next token index, in range 0 <= token < n_vocab.
|
||||
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
|
||||
// - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
||||
// - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
||||
|
|
|
@ -9,11 +9,12 @@ import argparse
|
|||
import subprocess
|
||||
import torch
|
||||
import numpy as np
|
||||
import rwkv_cpp
|
||||
from typing import List, Tuple, Any
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
|
||||
parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file')
|
||||
parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file or shared library')
|
||||
parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file')
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -48,28 +49,42 @@ def main() -> None:
|
|||
# FP16, lower precision, so higher threshold
|
||||
threshold = 0.003
|
||||
|
||||
model = None
|
||||
|
||||
if args.main_executable_path.lower().endswith('.dll') or args.main_executable_path.lower().endswith('.so'):
|
||||
print('Testing shared library')
|
||||
|
||||
model = rwkv_cpp.RWKVModel(args.main_executable_path, args.ggml_model_path)
|
||||
else:
|
||||
print('Testing main_rwkv executable')
|
||||
|
||||
def compare_logits(tokens_subset: List[int]) -> None:
|
||||
token_count: int = len(tokens_subset)
|
||||
state_path: str = './state.bin'
|
||||
logits_path: str = './logits.bin'
|
||||
|
||||
logits, state = None, None
|
||||
|
||||
for i in range(token_count):
|
||||
token: int = tokens_subset[i]
|
||||
|
||||
print(f'{i + 1}/{token_count}')
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
args.main_executable_path,
|
||||
args.ggml_model_path,
|
||||
str(token),
|
||||
# If this is the first token, let the script create a new state.
|
||||
'' if i == 0 else state_path,
|
||||
state_path,
|
||||
logits_path
|
||||
],
|
||||
check=True
|
||||
)
|
||||
if model is not None:
|
||||
logits, state = model.eval(token, state)
|
||||
else:
|
||||
subprocess.run(
|
||||
[
|
||||
args.main_executable_path,
|
||||
args.ggml_model_path,
|
||||
str(token),
|
||||
# If this is the first token, let the script create a new state.
|
||||
'' if i == 0 else state_path,
|
||||
state_path,
|
||||
logits_path
|
||||
],
|
||||
check=True
|
||||
)
|
||||
|
||||
expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin'
|
||||
|
||||
|
@ -79,8 +94,11 @@ def main() -> None:
|
|||
with open(expected_logits_path, 'rb') as logits_file:
|
||||
expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
||||
|
||||
with open(logits_path, 'rb') as logits_file:
|
||||
actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
||||
if model is not None:
|
||||
actual_logits = logits
|
||||
else:
|
||||
with open(logits_path, 'rb') as logits_file:
|
||||
actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
||||
|
||||
difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item()
|
||||
|
||||
|
@ -97,5 +115,8 @@ def main() -> None:
|
|||
print()
|
||||
print('Test passes')
|
||||
|
||||
if model is not None:
|
||||
model.free()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
import os
|
||||
import ctypes
|
||||
import torch
|
||||
import multiprocessing
|
||||
from typing import Tuple, Optional
|
||||
|
||||
P_FLOAT = ctypes.POINTER(ctypes.c_float)
|
||||
|
||||
class RWKVModel:
|
||||
"""
|
||||
PyTorch wrapper around rwkv.cpp shared library.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shared_library_path: str,
|
||||
model_path: str,
|
||||
thread_count: int = max(1, multiprocessing.cpu_count() // 2)
|
||||
):
|
||||
"""
|
||||
Loads the model and prepares it for inference.
|
||||
In case of any error, this method will throw an exception.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shared_library_path : str
|
||||
Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
|
||||
model_path : str
|
||||
Path to RWKV model file in ggml format.
|
||||
thread_count : int
|
||||
Thread count to use. If not set, defaults to CPU count / 2.
|
||||
"""
|
||||
|
||||
assert os.path.isfile(shared_library_path), f'{shared_library_path} is not a file'
|
||||
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
||||
assert thread_count > 0, 'Thread count must be positive'
|
||||
|
||||
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
|
||||
|
||||
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_int]
|
||||
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
|
||||
|
||||
self.library.rwkv_eval.argtypes = [
|
||||
ctypes.c_void_p, # ctx
|
||||
ctypes.c_long, # token
|
||||
P_FLOAT, # state_in
|
||||
P_FLOAT, # state_out
|
||||
P_FLOAT # logits_out
|
||||
]
|
||||
self.library.rwkv_eval.restype = ctypes.c_bool
|
||||
|
||||
self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_size_t
|
||||
|
||||
self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_size_t
|
||||
|
||||
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_free.restype = None
|
||||
|
||||
self.ctx = self.library.rwkv_init_from_file(model_path.encode('utf-8'), ctypes.c_int(thread_count))
|
||||
|
||||
assert self.ctx is not None, 'Failed to load the model, see stderr'
|
||||
|
||||
self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx)
|
||||
self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx)
|
||||
|
||||
self.valid = True
|
||||
|
||||
def eval(self, token: int, state_in: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Evaluates the model for a single token.
|
||||
In case of any error, this method will throw an exception.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
token : int
|
||||
Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
|
||||
state_in : Optional[torch.Tensor]
|
||||
State from previous call of this method. If this is a first pass, set it to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
logits, state
|
||||
Logits vector of shape (n_vocab); state for the next step.
|
||||
"""
|
||||
|
||||
assert self.valid, 'Model was freed'
|
||||
|
||||
if state_in is None:
|
||||
state_in_ptr = 0
|
||||
else:
|
||||
expected_shape = (self.state_buffer_element_count,)
|
||||
|
||||
assert state_in.is_contiguous(), 'State tensor is not contiguous'
|
||||
assert state_in.shape == expected_shape, f'Invalid state shape {state_in.shape}, expected {expected_shape}'
|
||||
|
||||
state_in_ptr = state_in.storage().data_ptr()
|
||||
|
||||
# TODO Probably these allocations can be optimized away
|
||||
state_out: torch.Tensor = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||
logits_out: torch.Tensor = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||
|
||||
result = self.library.rwkv_eval(
|
||||
self.ctx,
|
||||
ctypes.c_long(token),
|
||||
ctypes.cast(state_in_ptr, P_FLOAT),
|
||||
ctypes.cast(state_out.storage().data_ptr(), P_FLOAT),
|
||||
ctypes.cast(logits_out.storage().data_ptr(), P_FLOAT)
|
||||
)
|
||||
|
||||
assert result, 'Inference failed, see stderr'
|
||||
|
||||
return logits_out, state_out
|
||||
|
||||
def free(self):
|
||||
"""
|
||||
Frees all allocated resources.
|
||||
In case of any error, this method will throw an exception.
|
||||
The object must not be used anymore after calling this method.
|
||||
"""
|
||||
|
||||
assert self.valid, 'Already freed'
|
||||
|
||||
self.valid = False
|
||||
|
||||
self.library.rwkv_free(self.ctx)
|
Loading…
Reference in New Issue