Feature add cublas support (#65)
* chore: add ggml import in the head of rwkv.h * chore: add ggml import in the head of rwkv.h * feat: add cublas support * feat: update rwkv.cpp * feat: remove unused change * chore: fix linux build issue * chore: sync ggml and offload tensor to gpu * chore: comment out tensors which occurs error on GPU * chore: update comment and readme * chore: update ggml to recent * chore: add more performance test results * chore: add more performance test results * chore: fix problem of reading file more than 2 gb * chore: merge master * chore: remove unused comment * chore: fix for comments * Update README.md * Update rwkv.cpp --------- Co-authored-by: Alex <saharNooby@users.noreply.github.com>
This commit is contained in:
parent
dea929f8ca
commit
241350fde6
|
@ -39,6 +39,7 @@ option(RWKV_FMA "rwkv: enable FMA"
|
||||||
# 3rd party libs
|
# 3rd party libs
|
||||||
option(RWKV_ACCELERATE "rwkv: enable Accelerate framework" ON)
|
option(RWKV_ACCELERATE "rwkv: enable Accelerate framework" ON)
|
||||||
option(RWKV_OPENBLAS "rwkv: use OpenBLAS" OFF)
|
option(RWKV_OPENBLAS "rwkv: use OpenBLAS" OFF)
|
||||||
|
option(RWKV_CUBLAS "rwkv: use cuBLAS" OFF)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Compile flags
|
# Compile flags
|
||||||
|
@ -97,6 +98,30 @@ if (RWKV_OPENBLAS)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (RWKV_CUBLAS)
|
||||||
|
cmake_minimum_required(VERSION 3.17)
|
||||||
|
|
||||||
|
find_package(CUDAToolkit)
|
||||||
|
if (CUDAToolkit_FOUND)
|
||||||
|
message(STATUS "cuBLAS found")
|
||||||
|
|
||||||
|
enable_language(CUDA)
|
||||||
|
|
||||||
|
set(GGML_CUDA_SOURCES ${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.cu ${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.h)
|
||||||
|
|
||||||
|
add_compile_definitions(GGML_USE_CUBLAS)
|
||||||
|
|
||||||
|
if (RWKV_STATIC)
|
||||||
|
set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||||
|
else()
|
||||||
|
set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
else()
|
||||||
|
message(WARNING "cuBLAS not found")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
if (RWKV_ALL_WARNINGS)
|
if (RWKV_ALL_WARNINGS)
|
||||||
if (NOT MSVC)
|
if (NOT MSVC)
|
||||||
set(c_flags
|
set(c_flags
|
||||||
|
@ -177,11 +202,18 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
|
||||||
message(STATUS "x86 detected")
|
message(STATUS "x86 detected")
|
||||||
if (MSVC)
|
if (MSVC)
|
||||||
if (RWKV_AVX512)
|
if (RWKV_AVX512)
|
||||||
add_compile_options(/arch:AVX512)
|
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>)
|
||||||
|
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
|
||||||
|
# MSVC has no compile-time flags enabling specific
|
||||||
|
# AVX512 extensions, neither it defines the
|
||||||
|
# macros corresponding to the extensions.
|
||||||
|
# Do it manually.
|
||||||
elseif (RWKV_AVX2)
|
elseif (RWKV_AVX2)
|
||||||
add_compile_options(/arch:AVX2)
|
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX2>)
|
||||||
|
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>)
|
||||||
elseif (RWKV_AVX)
|
elseif (RWKV_AVX)
|
||||||
add_compile_options(/arch:AVX)
|
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX>)
|
||||||
|
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX>)
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
add_compile_options(-mf16c)
|
add_compile_options(-mf16c)
|
||||||
|
@ -212,7 +244,19 @@ if (MSVC)
|
||||||
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
|
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_subdirectory(ggml)
|
add_library(ggml OBJECT
|
||||||
|
${CMAKE_SOURCE_DIR}/ggml/src/ggml.c
|
||||||
|
${CMAKE_SOURCE_DIR}/ggml/include/ggml/ggml.h
|
||||||
|
${GGML_CUDA_SOURCES})
|
||||||
|
|
||||||
|
target_include_directories(ggml PUBLIC ${CMAKE_SOURCE_DIR}/ggml/include/ggml)
|
||||||
|
target_compile_features(ggml PUBLIC c_std_11) # Don't bump
|
||||||
|
|
||||||
|
if (MSVC)
|
||||||
|
target_link_libraries(ggml PUBLIC ${RWKV_EXTRA_LIBS} Threads::Threads)
|
||||||
|
else()
|
||||||
|
target_link_libraries(ggml PUBLIC m ${RWKV_EXTRA_LIBS} Threads::Threads)
|
||||||
|
endif()
|
||||||
|
|
||||||
if (RWKV_BUILD_SHARED_LIBRARY)
|
if (RWKV_BUILD_SHARED_LIBRARY)
|
||||||
set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
|
@ -233,5 +277,12 @@ if (RWKV_BUILD_SHARED_LIBRARY)
|
||||||
target_compile_definitions(rwkv PRIVATE RWKV_SHARED RWKV_BUILD)
|
target_compile_definitions(rwkv PRIVATE RWKV_SHARED RWKV_BUILD)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (GGML_CUDA_SOURCES)
|
||||||
|
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
|
||||||
|
set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF)
|
||||||
|
set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
|
||||||
|
set_property(TARGET rwkv PROPERTY CUDA_ARCHITECTURES OFF)
|
||||||
|
endif()
|
||||||
|
|
||||||
enable_testing()
|
enable_testing()
|
||||||
add_subdirectory(tests)
|
add_subdirectory(tests)
|
||||||
|
|
40
README.md
40
README.md
|
@ -26,6 +26,21 @@ Below table is for reference only. Measurements were made on 4C/8T x86 CPU with
|
||||||
| `FP16` | **15.623** | 117 | 2.82 |
|
| `FP16` | **15.623** | 117 | 2.82 |
|
||||||
| `FP32` | **15.623** | 198 | 5.64 |
|
| `FP32` | **15.623** | 198 | 5.64 |
|
||||||
|
|
||||||
|
#### With cuBLAS
|
||||||
|
|
||||||
|
Measurements were made on 3060Ti 8G + i7 13700K. Latency per token shown.
|
||||||
|
|
||||||
|
| Model | Layers on GPU | Format | 24 Threads | 8 Threads | 4 Threads | 2 Threads | 1 Threads |
|
||||||
|
|-----------------------|---------------|--------|-------------|------------|------------|------------|------------|
|
||||||
|
| `RWKV-4-Pile-169M` | 12 | `Q4_0` | 20.6 ms | 8.6 ms | 6.9 ms | 6.2 ms | 7.9 ms |
|
||||||
|
| `RWKV-4-Pile-169M` | 12 | `Q4_1` | 21.4 ms | 8.6 ms | 6.9 ms | 6.7 ms | 7.8 ms |
|
||||||
|
| `RWKV-4-Pile-169M` | 12 | `Q5_1` | 22.2 ms | 9.0 ms | 6.9 ms | 6.7 ms | 8.1 ms |
|
||||||
|
| `RWKV-4-Raven-7B-v11` | 32 | `Q4_0` | 94.9 ms | 54.3 ms | 50.2 ms | 51.6 ms | 59.2 ms |
|
||||||
|
| `RWKV-4-Raven-7B-v11` | 32 | `Q4_1` | 94.5 ms | 54.3 ms | 49.7 ms | 51.8 ms | 59.2 ms |
|
||||||
|
| `RWKV-4-Raven-7B-v11` | 32 | `Q5_1` | 101.6 ms | 72.3 ms | 67.2 ms | 69.3 ms | 77.0 ms |
|
||||||
|
|
||||||
|
Note: since there is only `ggml_mul_mat()` supported with cuBLAS, we still need to assign few CPU resources to execute remaining operations.
|
||||||
|
|
||||||
## How to use
|
## How to use
|
||||||
|
|
||||||
### 1. Clone the repo
|
### 1. Clone the repo
|
||||||
|
@ -62,6 +77,17 @@ cmake --build . --config Release
|
||||||
|
|
||||||
If everything went OK, `bin\Release\rwkv.dll` file should appear.
|
If everything went OK, `bin\Release\rwkv.dll` file should appear.
|
||||||
|
|
||||||
|
##### Windows + cuBLAS
|
||||||
|
|
||||||
|
**Important**: Since there is no cuBLAS static libraries for Windows, after compiling with dynamic libraries following DLLs should be copied from `{CUDA}/bin` into `build/bin/Release`: `cudart64_12.dll`, `cublas64_12.dll`, `cublasLt64_12.dll`.
|
||||||
|
|
||||||
|
```commandline
|
||||||
|
mkdir build
|
||||||
|
cd build
|
||||||
|
cmake .. -DRWKV_CUBLAS=ON
|
||||||
|
cmake --build . --config Release
|
||||||
|
```
|
||||||
|
|
||||||
##### Linux / MacOS
|
##### Linux / MacOS
|
||||||
|
|
||||||
**Requirements**: CMake (Linux: `sudo apt install cmake`, MacOS: `brew install cmake`, anaconoda: [cmake package](https://anaconda.org/conda-forge/cmake)).
|
**Requirements**: CMake (Linux: `sudo apt install cmake`, MacOS: `brew install cmake`, anaconoda: [cmake package](https://anaconda.org/conda-forge/cmake)).
|
||||||
|
@ -75,6 +101,16 @@ cmake --build . --config Release
|
||||||
|
|
||||||
If everything went OK, `librwkv.so` (Linux) or `librwkv.dylib` (MacOS) file should appear in the base repo folder.
|
If everything went OK, `librwkv.so` (Linux) or `librwkv.dylib` (MacOS) file should appear in the base repo folder.
|
||||||
|
|
||||||
|
##### Linux / MacOS + cuBLAS
|
||||||
|
|
||||||
|
```commandline
|
||||||
|
mkdir build
|
||||||
|
cd build
|
||||||
|
cmake .. -DRWKV_CUBLAS=ON
|
||||||
|
cmake --build . --config Release
|
||||||
|
```
|
||||||
|
|
||||||
|
If everything went OK, `librwkv.so` (Linux) or `librwkv.dylib` (MacOS) file should appear in the base repo folder.
|
||||||
|
|
||||||
### 3. Get an RWKV model
|
### 3. Get an RWKV model
|
||||||
|
|
||||||
|
@ -152,7 +188,9 @@ model_path = r'C:\rwkv.cpp-169M.bin'
|
||||||
|
|
||||||
model = rwkv_cpp_model.RWKVModel(
|
model = rwkv_cpp_model.RWKVModel(
|
||||||
rwkv_cpp_shared_library.load_rwkv_shared_library(),
|
rwkv_cpp_shared_library.load_rwkv_shared_library(),
|
||||||
model_path
|
model_path,
|
||||||
|
thread_count=4, #need to adjust when use cuBLAS
|
||||||
|
gpu_layers_count=5 #only enabled when use cuBLAS
|
||||||
)
|
)
|
||||||
|
|
||||||
logits, state = None, None
|
logits, state = None, None
|
||||||
|
|
35
rwkv.cpp
35
rwkv.cpp
|
@ -1,6 +1,10 @@
|
||||||
#include "rwkv.h"
|
#include "rwkv.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
#include "ggml/src/ggml-cuda.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
@ -274,6 +278,8 @@ struct rwkv_context {
|
||||||
struct rwkv_graph graph;
|
struct rwkv_graph graph;
|
||||||
enum rwkv_error_flags last_error;
|
enum rwkv_error_flags last_error;
|
||||||
bool print_errors;
|
bool print_errors;
|
||||||
|
size_t vram_total;
|
||||||
|
int gpu_layers;
|
||||||
};
|
};
|
||||||
|
|
||||||
void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) {
|
void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) {
|
||||||
|
@ -461,6 +467,10 @@ struct rwkv_ggml_guard {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) {
|
struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) {
|
||||||
|
return rwkv_init_from_file(file_path, n_threads, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads, const uint32_t n_gpu_layers) {
|
||||||
global_last_error = RWKV_ERROR_NONE;
|
global_last_error = RWKV_ERROR_NONE;
|
||||||
|
|
||||||
FILE * file = fopen(file_path, "rb");
|
FILE * file = fopen(file_path, "rb");
|
||||||
|
@ -600,6 +610,29 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
|
||||||
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.receptance.weight", &layer->ffn_receptance));
|
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.receptance.weight", &layer->ffn_receptance));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int n_gpu = 0;
|
||||||
|
size_t vram_total = 0;
|
||||||
|
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
{
|
||||||
|
n_gpu = std::min(n_gpu_layers, model->n_layer);
|
||||||
|
|
||||||
|
for (int i = 0; i < n_gpu; ++i) {
|
||||||
|
const auto & layer = model->layers[i];
|
||||||
|
|
||||||
|
// Use cuBLAS only for heavy matrices; other operations are not supported for GPU at the moment
|
||||||
|
ggml_cuda_transform_tensor(layer.att_key); vram_total += ggml_nbytes(layer.att_key);
|
||||||
|
ggml_cuda_transform_tensor(layer.att_value); vram_total += ggml_nbytes(layer.att_value);
|
||||||
|
ggml_cuda_transform_tensor(layer.att_receptance); vram_total += ggml_nbytes(layer.att_receptance);
|
||||||
|
ggml_cuda_transform_tensor(layer.att_output); vram_total += ggml_nbytes(layer.att_output);
|
||||||
|
|
||||||
|
ggml_cuda_transform_tensor(layer.ffn_key); vram_total += ggml_nbytes(layer.ffn_key);
|
||||||
|
ggml_cuda_transform_tensor(layer.ffn_value); vram_total += ggml_nbytes(layer.ffn_value);
|
||||||
|
ggml_cuda_transform_tensor(layer.ffn_receptance); vram_total += ggml_nbytes(layer.ffn_receptance);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "ln_out.weight", &model->ln_out_weight));
|
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "ln_out.weight", &model->ln_out_weight));
|
||||||
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "ln_out.bias", &model->ln_out_bias));
|
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "ln_out.bias", &model->ln_out_bias));
|
||||||
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "head.weight", &model->head));
|
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "head.weight", &model->head));
|
||||||
|
@ -621,6 +654,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
|
||||||
rwkv_ctx->graph = std::move(graph);
|
rwkv_ctx->graph = std::move(graph);
|
||||||
rwkv_ctx->last_error = RWKV_ERROR_NONE;
|
rwkv_ctx->last_error = RWKV_ERROR_NONE;
|
||||||
rwkv_ctx->print_errors = global_print_errors;
|
rwkv_ctx->print_errors = global_print_errors;
|
||||||
|
rwkv_ctx->gpu_layers = n_gpu;
|
||||||
|
rwkv_ctx->vram_total = vram_total;
|
||||||
// Don't free ggml context
|
// Don't free ggml context
|
||||||
ggml_guard.ctx = NULL;
|
ggml_guard.ctx = NULL;
|
||||||
return rwkv_ctx.release();
|
return rwkv_ctx.release();
|
||||||
|
|
3
rwkv.h
3
rwkv.h
|
@ -83,7 +83,8 @@ extern "C" {
|
||||||
// Returns NULL on any error. Error messages would be printed to stderr.
|
// Returns NULL on any error. Error messages would be printed to stderr.
|
||||||
// - model_file_path: path to model file in ggml format.
|
// - model_file_path: path to model file in ggml format.
|
||||||
// - n_threads: count of threads to use, must be positive.
|
// - n_threads: count of threads to use, must be positive.
|
||||||
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads);
|
// - n_gpu_layer: count of layers need to load to gpu (only works when cuBLAS is on)
|
||||||
|
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads, const uint32_t n_gpu_layers);
|
||||||
|
|
||||||
// Evaluates the model for a single token.
|
// Evaluates the model for a single token.
|
||||||
// Returns false on any error. Error messages would be printed to stderr.
|
// Returns false on any error. Error messages would be printed to stderr.
|
||||||
|
|
|
@ -13,6 +13,7 @@ import rwkv_cpp_model
|
||||||
import rwkv_cpp_shared_library
|
import rwkv_cpp_shared_library
|
||||||
import json
|
import json
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
|
import time
|
||||||
|
|
||||||
# ======================================== Script settings ========================================
|
# ======================================== Script settings ========================================
|
||||||
|
|
||||||
|
@ -108,10 +109,13 @@ def split_last_end_of_line(tokens):
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
# =================================================================================================
|
# =================================================================================================
|
||||||
|
T1 = time.time()
|
||||||
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
|
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
|
||||||
|
|
||||||
process_tokens(split_last_end_of_line(tokenizer.encode(init_prompt).ids))
|
process_tokens(split_last_end_of_line(tokenizer.encode(init_prompt).ids))
|
||||||
|
T2 = time.time()
|
||||||
|
print(f'Process time :{((T2 - T1)*1000)} ms')
|
||||||
|
print(f'Process time per token :{(((T2 - T1)*1000)) / prompt_token_count} ms')
|
||||||
|
|
||||||
save_thread_state('chat_init')
|
save_thread_state('chat_init')
|
||||||
save_thread_state('chat')
|
save_thread_state('chat')
|
||||||
|
|
|
@ -13,7 +13,8 @@ class RWKVModel:
|
||||||
self,
|
self,
|
||||||
shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary,
|
shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
thread_count: int = max(1, multiprocessing.cpu_count() // 2)
|
thread_count: int = max(1, multiprocessing.cpu_count() // 2),
|
||||||
|
gpu_layers_count: int = 4,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Loads the model and prepares it for inference.
|
Loads the model and prepares it for inference.
|
||||||
|
@ -31,10 +32,11 @@ class RWKVModel:
|
||||||
|
|
||||||
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
||||||
assert thread_count > 0, 'Thread count must be positive'
|
assert thread_count > 0, 'Thread count must be positive'
|
||||||
|
assert gpu_layers_count > 0, 'GPU layers count must be positive'
|
||||||
|
|
||||||
self._library = shared_library
|
self._library = shared_library
|
||||||
|
|
||||||
self._ctx = self._library.rwkv_init_from_file(model_path, thread_count)
|
self._ctx = self._library.rwkv_init_from_file(model_path, thread_count, gpu_layers_count)
|
||||||
|
|
||||||
self._state_buffer_element_count = self._library.rwkv_get_state_buffer_element_count(self._ctx)
|
self._state_buffer_element_count = self._library.rwkv_get_state_buffer_element_count(self._ctx)
|
||||||
self._logits_buffer_element_count = self._library.rwkv_get_logits_buffer_element_count(self._ctx)
|
self._logits_buffer_element_count = self._library.rwkv_get_logits_buffer_element_count(self._ctx)
|
||||||
|
|
|
@ -37,7 +37,7 @@ class RWKVSharedLibrary:
|
||||||
|
|
||||||
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
|
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
|
||||||
|
|
||||||
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
|
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32, ctypes.c_uint32]
|
||||||
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
|
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
|
||||||
|
|
||||||
self.library.rwkv_eval.argtypes = [
|
self.library.rwkv_eval.argtypes = [
|
||||||
|
@ -67,7 +67,7 @@ class RWKVSharedLibrary:
|
||||||
self.library.rwkv_get_system_info_string.argtypes = []
|
self.library.rwkv_get_system_info_string.argtypes = []
|
||||||
self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p
|
self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p
|
||||||
|
|
||||||
def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext:
|
def rwkv_init_from_file(self, model_file_path: str, thread_count: int, gpu_layers_count: int) -> RWKVContext:
|
||||||
"""
|
"""
|
||||||
Loads the model from a file and prepares it for inference.
|
Loads the model from a file and prepares it for inference.
|
||||||
Throws an exception in case of any error. Error messages would be printed to stderr.
|
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||||
|
@ -78,9 +78,13 @@ class RWKVSharedLibrary:
|
||||||
Path to model file in ggml format.
|
Path to model file in ggml format.
|
||||||
thread_count : int
|
thread_count : int
|
||||||
Count of threads to use, must be positive.
|
Count of threads to use, must be positive.
|
||||||
|
gpu_layers_count : int
|
||||||
|
Count of layers to load on gpu, must be positive only enabled with cuBLAS.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count))
|
ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'),
|
||||||
|
ctypes.c_uint32(thread_count),
|
||||||
|
ctypes.c_uint32(gpu_layers_count))
|
||||||
assert ptr is not None, 'rwkv_init_from_file failed, check stderr'
|
assert ptr is not None, 'rwkv_init_from_file failed, check stderr'
|
||||||
return RWKVContext(ptr)
|
return RWKVContext(ptr)
|
||||||
|
|
||||||
|
@ -186,6 +190,7 @@ class RWKVSharedLibrary:
|
||||||
|
|
||||||
return self.library.rwkv_get_system_info_string().decode('utf-8')
|
return self.library.rwkv_get_system_info_string().decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
||||||
"""
|
"""
|
||||||
Attempts to find rwkv.cpp shared library and load it.
|
Attempts to find rwkv.cpp shared library and load it.
|
||||||
|
@ -208,6 +213,10 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
||||||
f'../bin/Release/{file_name}',
|
f'../bin/Release/{file_name}',
|
||||||
# If we are in repo root directory
|
# If we are in repo root directory
|
||||||
f'bin/Release/{file_name}',
|
f'bin/Release/{file_name}',
|
||||||
|
# If we compiled in build directory
|
||||||
|
f'build/bin/Release/{file_name}',
|
||||||
|
# If we compiled in build directory
|
||||||
|
f'build/{file_name}',
|
||||||
# Search relative to this file
|
# Search relative to this file
|
||||||
str(repo_root_dir / 'bin' / 'Release' / file_name),
|
str(repo_root_dir / 'bin' / 'Release' / file_name),
|
||||||
# Fallback
|
# Fallback
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
function(rwkv_add_test source)
|
function(rwkv_add_test source)
|
||||||
get_filename_component(TEST_TARGET ${source} NAME_WE)
|
get_filename_component(TEST_TARGET ${source} NAME_WE)
|
||||||
add_executable(${TEST_TARGET} ${source})
|
add_executable(${TEST_TARGET} ${source})
|
||||||
|
if (GGML_CUDA_SOURCES)
|
||||||
|
set_property(TARGET ${TEST_TARGET} PROPERTY CUDA_ARCHITECTURES OFF)
|
||||||
|
endif()
|
||||||
target_link_libraries(${TEST_TARGET} PRIVATE ggml rwkv)
|
target_link_libraries(${TEST_TARGET} PRIVATE ggml rwkv)
|
||||||
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
|
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
|
@ -21,12 +21,12 @@
|
||||||
|
|
||||||
#define N_VOCAB 256
|
#define N_VOCAB 256
|
||||||
#define N_THREADS 2
|
#define N_THREADS 2
|
||||||
|
#define N_GPU_LAYERS 1
|
||||||
|
|
||||||
void test_model(const char * model_path, const float * expected_logits, const float max_diff) {
|
void test_model(const char * model_path, const float * expected_logits, const float max_diff) {
|
||||||
fprintf(stderr, "Testing %s\n", model_path);
|
fprintf(stderr, "Testing %s\n", model_path);
|
||||||
|
|
||||||
struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS);
|
struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS, N_GPU_LAYERS);
|
||||||
|
|
||||||
enum rwkv_error_flags error = rwkv_get_last_error(NULL);
|
enum rwkv_error_flags error = rwkv_get_last_error(NULL);
|
||||||
ASSERT(error == 0, "Unexpected error %d", error);
|
ASSERT(error == 0, "Unexpected error %d", error);
|
||||||
|
|
||||||
|
@ -72,18 +72,27 @@ int main(void) {
|
||||||
ASSERT(elements_read == N_VOCAB, "Failed to read expected_logits.bin, read %zd elements", elements_read);
|
ASSERT(elements_read == N_VOCAB, "Failed to read expected_logits.bin, read %zd elements", elements_read);
|
||||||
fclose(file);
|
fclose(file);
|
||||||
|
|
||||||
|
// Somehow when using cuBLAS the calculation of Q4_1 may different from cpu only
|
||||||
float expected_difference_sum[14] = {
|
float expected_difference_sum[14] = {
|
||||||
0.000000F,
|
0.000000F,
|
||||||
-0.005320F,
|
-0.005320F,
|
||||||
|
|
||||||
-0.160030F,
|
-0.160030F,
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
-0.412408F,
|
||||||
|
#else
|
||||||
-0.370606F,
|
-0.370606F,
|
||||||
|
#endif
|
||||||
-0.170404F,
|
-0.170404F,
|
||||||
0.278034F,
|
0.278034F,
|
||||||
0.071216F,
|
0.071216F,
|
||||||
|
|
||||||
0.154614F,
|
0.154614F,
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
-0.405527F,
|
||||||
|
#else
|
||||||
-0.372169F,
|
-0.372169F,
|
||||||
|
#endif
|
||||||
-0.170043F,
|
-0.170043F,
|
||||||
0.294953F,
|
0.294953F,
|
||||||
0.065571F,
|
0.065571F,
|
||||||
|
|
Loading…
Reference in New Issue