Add Python wrapper for C library
This commit is contained in:
parent
7130a89d1f
commit
a1e1d34c93
|
@ -247,9 +247,17 @@ add_library(rwkv
|
||||||
target_include_directories(llama PUBLIC .)
|
target_include_directories(llama PUBLIC .)
|
||||||
target_compile_features(llama PUBLIC cxx_std_11) # don't bump
|
target_compile_features(llama PUBLIC cxx_std_11) # don't bump
|
||||||
target_link_libraries(llama PRIVATE ggml ${LLAMA_EXTRA_LIBS})
|
target_link_libraries(llama PRIVATE ggml ${LLAMA_EXTRA_LIBS})
|
||||||
|
|
||||||
|
target_include_directories(rwkv PUBLIC .)
|
||||||
|
target_compile_features(rwkv PUBLIC cxx_std_11) # don't bump
|
||||||
|
target_link_libraries(rwkv PRIVATE ggml ${LLAMA_EXTRA_LIBS})
|
||||||
|
|
||||||
if (BUILD_SHARED_LIBS)
|
if (BUILD_SHARED_LIBS)
|
||||||
set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD)
|
target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD)
|
||||||
|
|
||||||
|
set_target_properties(rwkv PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
|
target_compile_definitions(rwkv PRIVATE LLAMA_SHARED LLAMA_BUILD)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|
3
Makefile
3
Makefile
|
@ -243,6 +243,9 @@ main: examples/main/main.cpp ggml.o llama.o common.o
|
||||||
@echo '==== Run ./main -h for help. ===='
|
@echo '==== Run ./main -h for help. ===='
|
||||||
@echo
|
@echo
|
||||||
|
|
||||||
|
main_rwkv: examples/main_rwkv/main_rwkv.cpp ggml.o rwkv.o common.o
|
||||||
|
$(CXX) $(CXXFLAGS) examples/main_rwkv/main_rwkv.cpp ggml.o rwkv.o common.o -o main_rwkv $(LDFLAGS)
|
||||||
|
|
||||||
quantize: examples/quantize/quantize.cpp ggml.o llama.o
|
quantize: examples/quantize/quantize.cpp ggml.o llama.o
|
||||||
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp ggml.o llama.o -o quantize $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp ggml.o llama.o -o quantize $(LDFLAGS)
|
||||||
|
|
||||||
|
|
44
README.md
44
README.md
|
@ -2,21 +2,21 @@
|
||||||
|
|
||||||
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). The end goal is to allow 4-bit quanized inference on CPU.
|
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). The end goal is to allow 4-bit quanized inference on CPU.
|
||||||
|
|
||||||
**WORK IN PROGRESS!** **Status**: FP32 and FP16 inference work correctly. Currently, I'm working on creating usable C library interface and Python wrapper for it.
|
**WORK IN PROGRESS!** **Status**: There is a Python wrapper, FP32 and FP16 inference work correctly. Currently, I'm working on INT4 quantization support.
|
||||||
|
|
||||||
## Plan
|
## Plan
|
||||||
|
|
||||||
1. Create proper interface (probably, C library)
|
1. Make INT4 inference work
|
||||||
2. Create Python wrapper with sampling and simple chat interface
|
2. Create Python script with sampling and simple chat interface
|
||||||
3. Write a good `README.md` and publish links to this repo
|
3. Clean up the repo (remove llama related files and mentions)
|
||||||
4. Make INT4 inference work
|
4. Write a good `README.md` and publish links to this repo
|
||||||
5. Create pull request to main `ggml` repo with all improvements made here
|
5. Create pull request to main `ggml` repo with all improvements made here
|
||||||
|
|
||||||
## Structure
|
## Structure
|
||||||
|
|
||||||
This repo is based on the [llama.cpp repo](https://github.com/ggerganov/llama.cpp). RWKV-related code is in these directories:
|
This repo is based on the [llama.cpp repo](https://github.com/ggerganov/llama.cpp). RWKV-related code is in these directories:
|
||||||
|
|
||||||
- `./rwkv`: directory containing Python scripts for conversion and validation
|
- `./rwkv`: directory containing Python scripts for conversion, inference and validation
|
||||||
- `./examples/main_rwkw`: directory containing script that loads and infers RWKV model
|
- `./examples/main_rwkw`: directory containing script that loads and infers RWKV model
|
||||||
|
|
||||||
Please do not change files in other directories — this will make pulling recent changes easier.
|
Please do not change files in other directories — this will make pulling recent changes easier.
|
||||||
|
@ -27,25 +27,39 @@ Please do not change files in other directories — this will make pulling recen
|
||||||
|
|
||||||
Requirements: [git](https://gitforwindows.org/), [CMake](https://cmake.org/download/), MSVC compiler, Python 3.x with PyTorch.
|
Requirements: [git](https://gitforwindows.org/), [CMake](https://cmake.org/download/), MSVC compiler, Python 3.x with PyTorch.
|
||||||
|
|
||||||
Clone the repo and set it up for build:
|
#### 1. Clone the repo and build it:
|
||||||
|
|
||||||
```commandline
|
```commandline
|
||||||
git clone https://github.com/saharNooby/rwkv.cpp.git
|
git clone https://github.com/saharNooby/rwkv.cpp.git
|
||||||
cd rwkv.cpp
|
cd rwkv.cpp
|
||||||
cmake .
|
cmake -DBUILD_SHARED_LIBS=ON -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF .
|
||||||
|
cmake --build . --config Release
|
||||||
```
|
```
|
||||||
|
|
||||||
Download an RWKV model from [Huggingface](https://huggingface.co/BlinkDL) and convert it into `ggml` format:
|
If everything went OK, `bin\Release\rwkv.dll` file should appear.
|
||||||
|
|
||||||
|
#### 2. Download an RWKV model from [Huggingface](https://huggingface.co/BlinkDL) and convert it into `ggml` format:
|
||||||
|
|
||||||
```commandline
|
```commandline
|
||||||
python rwkv\convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32
|
python rwkv\convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32
|
||||||
```
|
```
|
||||||
|
|
||||||
Compile and run the script:
|
#### 3. Use the model in Python:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# This file is located at rwkv/rwkv_cpp.py
|
||||||
|
import rwkv_cpp
|
||||||
|
|
||||||
|
model = rwkv_cpp.RWKVModel(r'bin\Release\rwkv.dll', r'C:\rwkv.cpp-169M.bin')
|
||||||
|
|
||||||
|
logits, state = None, None
|
||||||
|
|
||||||
|
for token in [1, 2, 3]:
|
||||||
|
logits, state = model.eval(token, state)
|
||||||
|
|
||||||
|
print(f'Output logits: {logits}')
|
||||||
|
|
||||||
|
# Don't forget to free memory after you've done working with the model
|
||||||
|
model.free()
|
||||||
|
|
||||||
```commandline
|
|
||||||
cmake --build . --config Release
|
|
||||||
bin\Release\main_rwkv.exe "C:\rwkv.cpp-169M.bin" 123 "C:\state_in.bin" "C:\state_out.bin" "C:\logits_out.bin"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The script will read state from `state_in.bin`, do single inference using the state and token `123` as an input, save new state into `state_out.bin` and logits into `logits_out.bin`.
|
|
||||||
|
|
6
rwkv.h
6
rwkv.h
|
@ -29,13 +29,13 @@ extern "C" {
|
||||||
|
|
||||||
struct rwkv_context;
|
struct rwkv_context;
|
||||||
|
|
||||||
// Loads the model from a file and prepares it for inference by allocating memory and building computation graph.
|
// Loads the model from a file and prepares it for inference.
|
||||||
// Returns NULL on any error. Error messages would be printed to stderr.
|
// Returns NULL on any error. Error messages would be printed to stderr.
|
||||||
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads);
|
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads);
|
||||||
|
|
||||||
// Evaluates the model for a single pass.
|
// Evaluates the model for a single token.
|
||||||
// Returns false on any error. Error messages would be printed to stderr.
|
// Returns false on any error. Error messages would be printed to stderr.
|
||||||
// - token: next token index, in range 0..n_vocab - 1.
|
// - token: next token index, in range 0 <= token < n_vocab.
|
||||||
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
|
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
|
||||||
// - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
// - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
||||||
// - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
// - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
||||||
|
|
|
@ -9,11 +9,12 @@ import argparse
|
||||||
import subprocess
|
import subprocess
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import rwkv_cpp
|
||||||
from typing import List, Tuple, Any
|
from typing import List, Tuple, Any
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
|
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
|
||||||
parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file')
|
parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file or shared library')
|
||||||
parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file')
|
parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@ -48,28 +49,42 @@ def main() -> None:
|
||||||
# FP16, lower precision, so higher threshold
|
# FP16, lower precision, so higher threshold
|
||||||
threshold = 0.003
|
threshold = 0.003
|
||||||
|
|
||||||
|
model = None
|
||||||
|
|
||||||
|
if args.main_executable_path.lower().endswith('.dll') or args.main_executable_path.lower().endswith('.so'):
|
||||||
|
print('Testing shared library')
|
||||||
|
|
||||||
|
model = rwkv_cpp.RWKVModel(args.main_executable_path, args.ggml_model_path)
|
||||||
|
else:
|
||||||
|
print('Testing main_rwkv executable')
|
||||||
|
|
||||||
def compare_logits(tokens_subset: List[int]) -> None:
|
def compare_logits(tokens_subset: List[int]) -> None:
|
||||||
token_count: int = len(tokens_subset)
|
token_count: int = len(tokens_subset)
|
||||||
state_path: str = './state.bin'
|
state_path: str = './state.bin'
|
||||||
logits_path: str = './logits.bin'
|
logits_path: str = './logits.bin'
|
||||||
|
|
||||||
|
logits, state = None, None
|
||||||
|
|
||||||
for i in range(token_count):
|
for i in range(token_count):
|
||||||
token: int = tokens_subset[i]
|
token: int = tokens_subset[i]
|
||||||
|
|
||||||
print(f'{i + 1}/{token_count}')
|
print(f'{i + 1}/{token_count}')
|
||||||
|
|
||||||
subprocess.run(
|
if model is not None:
|
||||||
[
|
logits, state = model.eval(token, state)
|
||||||
args.main_executable_path,
|
else:
|
||||||
args.ggml_model_path,
|
subprocess.run(
|
||||||
str(token),
|
[
|
||||||
# If this is the first token, let the script create a new state.
|
args.main_executable_path,
|
||||||
'' if i == 0 else state_path,
|
args.ggml_model_path,
|
||||||
state_path,
|
str(token),
|
||||||
logits_path
|
# If this is the first token, let the script create a new state.
|
||||||
],
|
'' if i == 0 else state_path,
|
||||||
check=True
|
state_path,
|
||||||
)
|
logits_path
|
||||||
|
],
|
||||||
|
check=True
|
||||||
|
)
|
||||||
|
|
||||||
expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin'
|
expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin'
|
||||||
|
|
||||||
|
@ -79,8 +94,11 @@ def main() -> None:
|
||||||
with open(expected_logits_path, 'rb') as logits_file:
|
with open(expected_logits_path, 'rb') as logits_file:
|
||||||
expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
||||||
|
|
||||||
with open(logits_path, 'rb') as logits_file:
|
if model is not None:
|
||||||
actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
actual_logits = logits
|
||||||
|
else:
|
||||||
|
with open(logits_path, 'rb') as logits_file:
|
||||||
|
actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
||||||
|
|
||||||
difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item()
|
difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item()
|
||||||
|
|
||||||
|
@ -97,5 +115,8 @@ def main() -> None:
|
||||||
print()
|
print()
|
||||||
print('Test passes')
|
print('Test passes')
|
||||||
|
|
||||||
|
if model is not None:
|
||||||
|
model.free()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -0,0 +1,127 @@
|
||||||
|
import os
|
||||||
|
import ctypes
|
||||||
|
import torch
|
||||||
|
import multiprocessing
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
|
P_FLOAT = ctypes.POINTER(ctypes.c_float)
|
||||||
|
|
||||||
|
class RWKVModel:
|
||||||
|
"""
|
||||||
|
PyTorch wrapper around rwkv.cpp shared library.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
shared_library_path: str,
|
||||||
|
model_path: str,
|
||||||
|
thread_count: int = max(1, multiprocessing.cpu_count() // 2)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Loads the model and prepares it for inference.
|
||||||
|
In case of any error, this method will throw an exception.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
shared_library_path : str
|
||||||
|
Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
|
||||||
|
model_path : str
|
||||||
|
Path to RWKV model file in ggml format.
|
||||||
|
thread_count : int
|
||||||
|
Thread count to use. If not set, defaults to CPU count / 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert os.path.isfile(shared_library_path), f'{shared_library_path} is not a file'
|
||||||
|
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
||||||
|
assert thread_count > 0, 'Thread count must be positive'
|
||||||
|
|
||||||
|
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
|
||||||
|
|
||||||
|
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_int]
|
||||||
|
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
|
||||||
|
|
||||||
|
self.library.rwkv_eval.argtypes = [
|
||||||
|
ctypes.c_void_p, # ctx
|
||||||
|
ctypes.c_long, # token
|
||||||
|
P_FLOAT, # state_in
|
||||||
|
P_FLOAT, # state_out
|
||||||
|
P_FLOAT # logits_out
|
||||||
|
]
|
||||||
|
self.library.rwkv_eval.restype = ctypes.c_bool
|
||||||
|
|
||||||
|
self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
|
||||||
|
self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_size_t
|
||||||
|
|
||||||
|
self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
|
||||||
|
self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_size_t
|
||||||
|
|
||||||
|
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
||||||
|
self.library.rwkv_free.restype = None
|
||||||
|
|
||||||
|
self.ctx = self.library.rwkv_init_from_file(model_path.encode('utf-8'), ctypes.c_int(thread_count))
|
||||||
|
|
||||||
|
assert self.ctx is not None, 'Failed to load the model, see stderr'
|
||||||
|
|
||||||
|
self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx)
|
||||||
|
self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx)
|
||||||
|
|
||||||
|
self.valid = True
|
||||||
|
|
||||||
|
def eval(self, token: int, state_in: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Evaluates the model for a single token.
|
||||||
|
In case of any error, this method will throw an exception.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
token : int
|
||||||
|
Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
|
||||||
|
state_in : Optional[torch.Tensor]
|
||||||
|
State from previous call of this method. If this is a first pass, set it to None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
logits, state
|
||||||
|
Logits vector of shape (n_vocab); state for the next step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert self.valid, 'Model was freed'
|
||||||
|
|
||||||
|
if state_in is None:
|
||||||
|
state_in_ptr = 0
|
||||||
|
else:
|
||||||
|
expected_shape = (self.state_buffer_element_count,)
|
||||||
|
|
||||||
|
assert state_in.is_contiguous(), 'State tensor is not contiguous'
|
||||||
|
assert state_in.shape == expected_shape, f'Invalid state shape {state_in.shape}, expected {expected_shape}'
|
||||||
|
|
||||||
|
state_in_ptr = state_in.storage().data_ptr()
|
||||||
|
|
||||||
|
# TODO Probably these allocations can be optimized away
|
||||||
|
state_out: torch.Tensor = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||||
|
logits_out: torch.Tensor = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||||
|
|
||||||
|
result = self.library.rwkv_eval(
|
||||||
|
self.ctx,
|
||||||
|
ctypes.c_long(token),
|
||||||
|
ctypes.cast(state_in_ptr, P_FLOAT),
|
||||||
|
ctypes.cast(state_out.storage().data_ptr(), P_FLOAT),
|
||||||
|
ctypes.cast(logits_out.storage().data_ptr(), P_FLOAT)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result, 'Inference failed, see stderr'
|
||||||
|
|
||||||
|
return logits_out, state_out
|
||||||
|
|
||||||
|
def free(self):
|
||||||
|
"""
|
||||||
|
Frees all allocated resources.
|
||||||
|
In case of any error, this method will throw an exception.
|
||||||
|
The object must not be used anymore after calling this method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert self.valid, 'Already freed'
|
||||||
|
|
||||||
|
self.valid = False
|
||||||
|
|
||||||
|
self.library.rwkv_free(self.ctx)
|
Loading…
Reference in New Issue