From a1e1d34c936a62b32900baf5656a36dddb2aa478 Mon Sep 17 00:00:00 2001 From: saharNooby Date: Sat, 1 Apr 2023 16:02:22 +0400 Subject: [PATCH] Add Python wrapper for C library --- CMakeLists.txt | 8 ++ Makefile | 3 + README.md | 44 +++--- rwkv.h | 6 +- ...mpare_cpp_with_reference_implementation.py | 51 ++++--- rwkv/rwkv_cpp.py | 127 ++++++++++++++++++ 6 files changed, 206 insertions(+), 33 deletions(-) create mode 100644 rwkv/rwkv_cpp.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 79448d4..793c017 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -247,9 +247,17 @@ add_library(rwkv target_include_directories(llama PUBLIC .) target_compile_features(llama PUBLIC cxx_std_11) # don't bump target_link_libraries(llama PRIVATE ggml ${LLAMA_EXTRA_LIBS}) + +target_include_directories(rwkv PUBLIC .) +target_compile_features(rwkv PUBLIC cxx_std_11) # don't bump +target_link_libraries(rwkv PRIVATE ggml ${LLAMA_EXTRA_LIBS}) + if (BUILD_SHARED_LIBS) set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD) + + set_target_properties(rwkv PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(rwkv PRIVATE LLAMA_SHARED LLAMA_BUILD) endif() # diff --git a/Makefile b/Makefile index 035875c..e4e63b0 100644 --- a/Makefile +++ b/Makefile @@ -243,6 +243,9 @@ main: examples/main/main.cpp ggml.o llama.o common.o @echo '==== Run ./main -h for help. ====' @echo +main_rwkv: examples/main_rwkv/main_rwkv.cpp ggml.o rwkv.o common.o + $(CXX) $(CXXFLAGS) examples/main_rwkv/main_rwkv.cpp ggml.o rwkv.o common.o -o main_rwkv $(LDFLAGS) + quantize: examples/quantize/quantize.cpp ggml.o llama.o $(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp ggml.o llama.o -o quantize $(LDFLAGS) diff --git a/README.md b/README.md index 8f14573..b0ea4ea 100644 --- a/README.md +++ b/README.md @@ -2,21 +2,21 @@ This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). The end goal is to allow 4-bit quanized inference on CPU. -**WORK IN PROGRESS!** **Status**: FP32 and FP16 inference work correctly. Currently, I'm working on creating usable C library interface and Python wrapper for it. +**WORK IN PROGRESS!** **Status**: There is a Python wrapper, FP32 and FP16 inference work correctly. Currently, I'm working on INT4 quantization support. ## Plan -1. Create proper interface (probably, C library) -2. Create Python wrapper with sampling and simple chat interface -3. Write a good `README.md` and publish links to this repo -4. Make INT4 inference work +1. Make INT4 inference work +2. Create Python script with sampling and simple chat interface +3. Clean up the repo (remove llama related files and mentions) +4. Write a good `README.md` and publish links to this repo 5. Create pull request to main `ggml` repo with all improvements made here ## Structure This repo is based on the [llama.cpp repo](https://github.com/ggerganov/llama.cpp). RWKV-related code is in these directories: -- `./rwkv`: directory containing Python scripts for conversion and validation +- `./rwkv`: directory containing Python scripts for conversion, inference and validation - `./examples/main_rwkw`: directory containing script that loads and infers RWKV model Please do not change files in other directories — this will make pulling recent changes easier. @@ -27,25 +27,39 @@ Please do not change files in other directories — this will make pulling recen Requirements: [git](https://gitforwindows.org/), [CMake](https://cmake.org/download/), MSVC compiler, Python 3.x with PyTorch. -Clone the repo and set it up for build: +#### 1. Clone the repo and build it: ```commandline git clone https://github.com/saharNooby/rwkv.cpp.git cd rwkv.cpp -cmake . +cmake -DBUILD_SHARED_LIBS=ON -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF . +cmake --build . --config Release ``` -Download an RWKV model from [Huggingface](https://huggingface.co/BlinkDL) and convert it into `ggml` format: +If everything went OK, `bin\Release\rwkv.dll` file should appear. + +#### 2. Download an RWKV model from [Huggingface](https://huggingface.co/BlinkDL) and convert it into `ggml` format: ```commandline python rwkv\convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32 ``` -Compile and run the script: +#### 3. Use the model in Python: + +```python +# This file is located at rwkv/rwkv_cpp.py +import rwkv_cpp + +model = rwkv_cpp.RWKVModel(r'bin\Release\rwkv.dll', r'C:\rwkv.cpp-169M.bin') + +logits, state = None, None + +for token in [1, 2, 3]: + logits, state = model.eval(token, state) + + print(f'Output logits: {logits}') + +# Don't forget to free memory after you've done working with the model +model.free() -```commandline -cmake --build . --config Release -bin\Release\main_rwkv.exe "C:\rwkv.cpp-169M.bin" 123 "C:\state_in.bin" "C:\state_out.bin" "C:\logits_out.bin" ``` - -The script will read state from `state_in.bin`, do single inference using the state and token `123` as an input, save new state into `state_out.bin` and logits into `logits_out.bin`. diff --git a/rwkv.h b/rwkv.h index ff916e7..b44fbb0 100644 --- a/rwkv.h +++ b/rwkv.h @@ -29,13 +29,13 @@ extern "C" { struct rwkv_context; - // Loads the model from a file and prepares it for inference by allocating memory and building computation graph. + // Loads the model from a file and prepares it for inference. // Returns NULL on any error. Error messages would be printed to stderr. RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads); - // Evaluates the model for a single pass. + // Evaluates the model for a single token. // Returns false on any error. Error messages would be printed to stderr. - // - token: next token index, in range 0..n_vocab - 1. + // - token: next token index, in range 0 <= token < n_vocab. // - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass. // - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to. // - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. diff --git a/rwkv/compare_cpp_with_reference_implementation.py b/rwkv/compare_cpp_with_reference_implementation.py index d6b6e6d..f754fb1 100644 --- a/rwkv/compare_cpp_with_reference_implementation.py +++ b/rwkv/compare_cpp_with_reference_implementation.py @@ -9,11 +9,12 @@ import argparse import subprocess import torch import numpy as np +import rwkv_cpp from typing import List, Tuple, Any def parse_args(): parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV') - parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file') + parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file or shared library') parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file') return parser.parse_args() @@ -48,28 +49,42 @@ def main() -> None: # FP16, lower precision, so higher threshold threshold = 0.003 + model = None + + if args.main_executable_path.lower().endswith('.dll') or args.main_executable_path.lower().endswith('.so'): + print('Testing shared library') + + model = rwkv_cpp.RWKVModel(args.main_executable_path, args.ggml_model_path) + else: + print('Testing main_rwkv executable') + def compare_logits(tokens_subset: List[int]) -> None: token_count: int = len(tokens_subset) state_path: str = './state.bin' logits_path: str = './logits.bin' + logits, state = None, None + for i in range(token_count): token: int = tokens_subset[i] print(f'{i + 1}/{token_count}') - subprocess.run( - [ - args.main_executable_path, - args.ggml_model_path, - str(token), - # If this is the first token, let the script create a new state. - '' if i == 0 else state_path, - state_path, - logits_path - ], - check=True - ) + if model is not None: + logits, state = model.eval(token, state) + else: + subprocess.run( + [ + args.main_executable_path, + args.ggml_model_path, + str(token), + # If this is the first token, let the script create a new state. + '' if i == 0 else state_path, + state_path, + logits_path + ], + check=True + ) expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin' @@ -79,8 +94,11 @@ def main() -> None: with open(expected_logits_path, 'rb') as logits_file: expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) - with open(logits_path, 'rb') as logits_file: - actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) + if model is not None: + actual_logits = logits + else: + with open(logits_path, 'rb') as logits_file: + actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item() @@ -97,5 +115,8 @@ def main() -> None: print() print('Test passes') + if model is not None: + model.free() + if __name__ == "__main__": main() diff --git a/rwkv/rwkv_cpp.py b/rwkv/rwkv_cpp.py new file mode 100644 index 0000000..88b422a --- /dev/null +++ b/rwkv/rwkv_cpp.py @@ -0,0 +1,127 @@ +import os +import ctypes +import torch +import multiprocessing +from typing import Tuple, Optional + +P_FLOAT = ctypes.POINTER(ctypes.c_float) + +class RWKVModel: + """ + PyTorch wrapper around rwkv.cpp shared library. + """ + + def __init__( + self, + shared_library_path: str, + model_path: str, + thread_count: int = max(1, multiprocessing.cpu_count() // 2) + ): + """ + Loads the model and prepares it for inference. + In case of any error, this method will throw an exception. + + Parameters + ---------- + shared_library_path : str + Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'. + model_path : str + Path to RWKV model file in ggml format. + thread_count : int + Thread count to use. If not set, defaults to CPU count / 2. + """ + + assert os.path.isfile(shared_library_path), f'{shared_library_path} is not a file' + assert os.path.isfile(model_path), f'{model_path} is not a file' + assert thread_count > 0, 'Thread count must be positive' + + self.library = ctypes.cdll.LoadLibrary(shared_library_path) + + self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_int] + self.library.rwkv_init_from_file.restype = ctypes.c_void_p + + self.library.rwkv_eval.argtypes = [ + ctypes.c_void_p, # ctx + ctypes.c_long, # token + P_FLOAT, # state_in + P_FLOAT, # state_out + P_FLOAT # logits_out + ] + self.library.rwkv_eval.restype = ctypes.c_bool + + self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p] + self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_size_t + + self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p] + self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_size_t + + self.library.rwkv_free.argtypes = [ctypes.c_void_p] + self.library.rwkv_free.restype = None + + self.ctx = self.library.rwkv_init_from_file(model_path.encode('utf-8'), ctypes.c_int(thread_count)) + + assert self.ctx is not None, 'Failed to load the model, see stderr' + + self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx) + self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx) + + self.valid = True + + def eval(self, token: int, state_in: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Evaluates the model for a single token. + In case of any error, this method will throw an exception. + + Parameters + ---------- + token : int + Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab. + state_in : Optional[torch.Tensor] + State from previous call of this method. If this is a first pass, set it to None. + + Returns + ------- + logits, state + Logits vector of shape (n_vocab); state for the next step. + """ + + assert self.valid, 'Model was freed' + + if state_in is None: + state_in_ptr = 0 + else: + expected_shape = (self.state_buffer_element_count,) + + assert state_in.is_contiguous(), 'State tensor is not contiguous' + assert state_in.shape == expected_shape, f'Invalid state shape {state_in.shape}, expected {expected_shape}' + + state_in_ptr = state_in.storage().data_ptr() + + # TODO Probably these allocations can be optimized away + state_out: torch.Tensor = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu') + logits_out: torch.Tensor = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu') + + result = self.library.rwkv_eval( + self.ctx, + ctypes.c_long(token), + ctypes.cast(state_in_ptr, P_FLOAT), + ctypes.cast(state_out.storage().data_ptr(), P_FLOAT), + ctypes.cast(logits_out.storage().data_ptr(), P_FLOAT) + ) + + assert result, 'Inference failed, see stderr' + + return logits_out, state_out + + def free(self): + """ + Frees all allocated resources. + In case of any error, this method will throw an exception. + The object must not be used anymore after calling this method. + """ + + assert self.valid, 'Already freed' + + self.valid = False + + self.library.rwkv_free(self.ctx)