Feature add cublas support (#65)
* chore: add ggml import in the head of rwkv.h * chore: add ggml import in the head of rwkv.h * feat: add cublas support * feat: update rwkv.cpp * feat: remove unused change * chore: fix linux build issue * chore: sync ggml and offload tensor to gpu * chore: comment out tensors which occurs error on GPU * chore: update comment and readme * chore: update ggml to recent * chore: add more performance test results * chore: add more performance test results * chore: fix problem of reading file more than 2 gb * chore: merge master * chore: remove unused comment * chore: fix for comments * Update README.md * Update rwkv.cpp --------- Co-authored-by: Alex <saharNooby@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									dea929f8ca
								
							
						
					
					
						commit
						241350fde6
					
				| 
						 | 
				
			
			@ -39,6 +39,7 @@ option(RWKV_FMA                    "rwkv: enable FMA"
 | 
			
		|||
# 3rd party libs
 | 
			
		||||
option(RWKV_ACCELERATE             "rwkv: enable Accelerate framework"                    ON)
 | 
			
		||||
option(RWKV_OPENBLAS               "rwkv: use OpenBLAS"                                   OFF)
 | 
			
		||||
option(RWKV_CUBLAS                 "rwkv: use cuBLAS"                                     OFF)
 | 
			
		||||
 | 
			
		||||
#
 | 
			
		||||
# Compile flags
 | 
			
		||||
| 
						 | 
				
			
			@ -97,6 +98,30 @@ if (RWKV_OPENBLAS)
 | 
			
		|||
    endif()
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
if (RWKV_CUBLAS)
 | 
			
		||||
    cmake_minimum_required(VERSION 3.17)
 | 
			
		||||
 | 
			
		||||
    find_package(CUDAToolkit)
 | 
			
		||||
    if (CUDAToolkit_FOUND)
 | 
			
		||||
        message(STATUS "cuBLAS found")
 | 
			
		||||
 | 
			
		||||
        enable_language(CUDA)
 | 
			
		||||
 | 
			
		||||
        set(GGML_CUDA_SOURCES ${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.cu ${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.h)
 | 
			
		||||
 | 
			
		||||
        add_compile_definitions(GGML_USE_CUBLAS)
 | 
			
		||||
 | 
			
		||||
        if (RWKV_STATIC)
 | 
			
		||||
            set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
 | 
			
		||||
        else()
 | 
			
		||||
            set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
 | 
			
		||||
        endif()
 | 
			
		||||
 | 
			
		||||
    else()
 | 
			
		||||
        message(WARNING "cuBLAS not found")
 | 
			
		||||
    endif()
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
if (RWKV_ALL_WARNINGS)
 | 
			
		||||
    if (NOT MSVC)
 | 
			
		||||
        set(c_flags
 | 
			
		||||
| 
						 | 
				
			
			@ -177,11 +202,18 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
 | 
			
		|||
    message(STATUS "x86 detected")
 | 
			
		||||
    if (MSVC)
 | 
			
		||||
        if (RWKV_AVX512)
 | 
			
		||||
            add_compile_options(/arch:AVX512)
 | 
			
		||||
            add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>)
 | 
			
		||||
            add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
 | 
			
		||||
            # MSVC has no compile-time flags enabling specific
 | 
			
		||||
            # AVX512 extensions, neither it defines the
 | 
			
		||||
            # macros corresponding to the extensions.
 | 
			
		||||
            # Do it manually.
 | 
			
		||||
        elseif (RWKV_AVX2)
 | 
			
		||||
            add_compile_options(/arch:AVX2)
 | 
			
		||||
            add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX2>)
 | 
			
		||||
            add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>)
 | 
			
		||||
        elseif (RWKV_AVX)
 | 
			
		||||
            add_compile_options(/arch:AVX)
 | 
			
		||||
            add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX>)
 | 
			
		||||
            add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX>)
 | 
			
		||||
        endif()
 | 
			
		||||
    else()
 | 
			
		||||
        add_compile_options(-mf16c)
 | 
			
		||||
| 
						 | 
				
			
			@ -212,7 +244,19 @@ if (MSVC)
 | 
			
		|||
    add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
add_subdirectory(ggml)
 | 
			
		||||
add_library(ggml OBJECT
 | 
			
		||||
            ${CMAKE_SOURCE_DIR}/ggml/src/ggml.c
 | 
			
		||||
            ${CMAKE_SOURCE_DIR}/ggml/include/ggml/ggml.h
 | 
			
		||||
            ${GGML_CUDA_SOURCES})
 | 
			
		||||
 | 
			
		||||
target_include_directories(ggml PUBLIC ${CMAKE_SOURCE_DIR}/ggml/include/ggml)
 | 
			
		||||
target_compile_features(ggml PUBLIC c_std_11) # Don't bump
 | 
			
		||||
 | 
			
		||||
if (MSVC)
 | 
			
		||||
    target_link_libraries(ggml PUBLIC ${RWKV_EXTRA_LIBS} Threads::Threads)
 | 
			
		||||
else()
 | 
			
		||||
    target_link_libraries(ggml PUBLIC m ${RWKV_EXTRA_LIBS} Threads::Threads)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
if (RWKV_BUILD_SHARED_LIBRARY)
 | 
			
		||||
    set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON)
 | 
			
		||||
| 
						 | 
				
			
			@ -233,5 +277,12 @@ if (RWKV_BUILD_SHARED_LIBRARY)
 | 
			
		|||
    target_compile_definitions(rwkv PRIVATE RWKV_SHARED RWKV_BUILD)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
if (GGML_CUDA_SOURCES)
 | 
			
		||||
    message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
 | 
			
		||||
    set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF)
 | 
			
		||||
    set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
 | 
			
		||||
    set_property(TARGET rwkv PROPERTY CUDA_ARCHITECTURES OFF)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
enable_testing()
 | 
			
		||||
add_subdirectory(tests)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										42
									
								
								README.md
								
								
								
								
							
							
						
						
									
										42
									
								
								README.md
								
								
								
								
							| 
						 | 
				
			
			@ -26,6 +26,21 @@ Below table is for reference only. Measurements were made on 4C/8T x86 CPU with
 | 
			
		|||
| `FP16`    | **15.623**        | 117                | 2.82                 |
 | 
			
		||||
| `FP32`    | **15.623**        | 198                | 5.64                 |
 | 
			
		||||
 | 
			
		||||
#### With cuBLAS
 | 
			
		||||
 | 
			
		||||
Measurements were made on 3060Ti 8G + i7 13700K. Latency per token shown.
 | 
			
		||||
 | 
			
		||||
| Model                 | Layers on GPU | Format | 24 Threads  | 8 Threads  | 4 Threads  | 2 Threads  | 1 Threads  |
 | 
			
		||||
|-----------------------|---------------|--------|-------------|------------|------------|------------|------------|
 | 
			
		||||
| `RWKV-4-Pile-169M`    | 12            | `Q4_0` | 20.6 ms     | 8.6 ms     | 6.9 ms     | 6.2 ms     | 7.9 ms     |
 | 
			
		||||
| `RWKV-4-Pile-169M`    | 12            | `Q4_1` | 21.4 ms     | 8.6 ms     | 6.9 ms     | 6.7 ms     | 7.8 ms     |
 | 
			
		||||
| `RWKV-4-Pile-169M`    | 12            | `Q5_1` | 22.2 ms     | 9.0 ms     | 6.9 ms     | 6.7 ms     | 8.1 ms     |
 | 
			
		||||
| `RWKV-4-Raven-7B-v11` | 32            | `Q4_0` | 94.9 ms     | 54.3 ms    | 50.2 ms    | 51.6 ms    | 59.2 ms    |
 | 
			
		||||
| `RWKV-4-Raven-7B-v11` | 32            | `Q4_1` | 94.5 ms     | 54.3 ms    | 49.7 ms    | 51.8 ms    | 59.2 ms    |
 | 
			
		||||
| `RWKV-4-Raven-7B-v11` | 32            | `Q5_1` | 101.6 ms    | 72.3 ms    | 67.2 ms    | 69.3 ms    | 77.0 ms    |
 | 
			
		||||
 | 
			
		||||
Note: since there is only `ggml_mul_mat()` supported with cuBLAS, we still need to assign few CPU resources to execute remaining operations.
 | 
			
		||||
 | 
			
		||||
## How to use
 | 
			
		||||
 | 
			
		||||
### 1. Clone the repo
 | 
			
		||||
| 
						 | 
				
			
			@ -62,6 +77,17 @@ cmake --build . --config Release
 | 
			
		|||
 | 
			
		||||
If everything went OK, `bin\Release\rwkv.dll` file should appear.
 | 
			
		||||
 | 
			
		||||
##### Windows + cuBLAS
 | 
			
		||||
 | 
			
		||||
**Important**: Since there is no cuBLAS static libraries for Windows, after compiling with dynamic libraries following DLLs should be copied from `{CUDA}/bin` into `build/bin/Release`: `cudart64_12.dll`, `cublas64_12.dll`, `cublasLt64_12.dll`.
 | 
			
		||||
 | 
			
		||||
```commandline
 | 
			
		||||
mkdir build
 | 
			
		||||
cd build
 | 
			
		||||
cmake .. -DRWKV_CUBLAS=ON
 | 
			
		||||
cmake --build . --config Release
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
##### Linux / MacOS
 | 
			
		||||
 | 
			
		||||
**Requirements**: CMake (Linux: `sudo apt install cmake`, MacOS: `brew install cmake`, anaconoda: [cmake package](https://anaconda.org/conda-forge/cmake)).
 | 
			
		||||
| 
						 | 
				
			
			@ -75,6 +101,16 @@ cmake --build . --config Release
 | 
			
		|||
 | 
			
		||||
If everything went OK, `librwkv.so` (Linux) or `librwkv.dylib` (MacOS) file should appear in the base repo folder.
 | 
			
		||||
 | 
			
		||||
##### Linux / MacOS + cuBLAS
 | 
			
		||||
 | 
			
		||||
```commandline
 | 
			
		||||
mkdir build
 | 
			
		||||
cd build
 | 
			
		||||
cmake .. -DRWKV_CUBLAS=ON
 | 
			
		||||
cmake --build . --config Release
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
If everything went OK, `librwkv.so` (Linux) or `librwkv.dylib` (MacOS) file should appear in the base repo folder.
 | 
			
		||||
 | 
			
		||||
### 3. Get an RWKV model
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -152,14 +188,16 @@ model_path = r'C:\rwkv.cpp-169M.bin'
 | 
			
		|||
 | 
			
		||||
model = rwkv_cpp_model.RWKVModel(
 | 
			
		||||
    rwkv_cpp_shared_library.load_rwkv_shared_library(),
 | 
			
		||||
    model_path
 | 
			
		||||
    model_path,
 | 
			
		||||
    thread_count=4,    #need to adjust when use cuBLAS
 | 
			
		||||
    gpu_layers_count=5 #only enabled when use cuBLAS
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
logits, state = None, None
 | 
			
		||||
 | 
			
		||||
for token in [1, 2, 3]:
 | 
			
		||||
    logits, state = model.eval(token, state)
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    print(f'Output logits: {logits}')
 | 
			
		||||
 | 
			
		||||
# Don't forget to free the memory after you've done working with the model
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										39
									
								
								rwkv.cpp
								
								
								
								
							
							
						
						
									
										39
									
								
								rwkv.cpp
								
								
								
								
							| 
						 | 
				
			
			@ -1,6 +1,10 @@
 | 
			
		|||
#include "rwkv.h"
 | 
			
		||||
#include "ggml.h"
 | 
			
		||||
 | 
			
		||||
#ifdef GGML_USE_CUBLAS
 | 
			
		||||
#include "ggml/src/ggml-cuda.h"
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
#include <thread>
 | 
			
		||||
| 
						 | 
				
			
			@ -274,6 +278,8 @@ struct rwkv_context {
 | 
			
		|||
    struct rwkv_graph graph;
 | 
			
		||||
    enum rwkv_error_flags last_error;
 | 
			
		||||
    bool print_errors;
 | 
			
		||||
    size_t vram_total;
 | 
			
		||||
    int gpu_layers;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) {
 | 
			
		||||
| 
						 | 
				
			
			@ -461,6 +467,10 @@ struct rwkv_ggml_guard {
 | 
			
		|||
};
 | 
			
		||||
 | 
			
		||||
struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) {
 | 
			
		||||
    return rwkv_init_from_file(file_path, n_threads, 0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads, const uint32_t n_gpu_layers) {
 | 
			
		||||
    global_last_error = RWKV_ERROR_NONE;
 | 
			
		||||
 | 
			
		||||
    FILE * file = fopen(file_path, "rb");
 | 
			
		||||
| 
						 | 
				
			
			@ -481,7 +491,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
 | 
			
		|||
 | 
			
		||||
    std::unique_ptr<rwkv_model> model(new(std::nothrow) struct rwkv_model());
 | 
			
		||||
    RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_ALLOC, model.get(), "Failed to allocate model");
 | 
			
		||||
 
 | 
			
		||||
 | 
			
		||||
    RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_vocab, "n_vocab"));
 | 
			
		||||
    RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_embed, "n_embed"));
 | 
			
		||||
    RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_layer, "n_layer"));
 | 
			
		||||
| 
						 | 
				
			
			@ -600,6 +610,29 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
 | 
			
		|||
        RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.receptance.weight", &layer->ffn_receptance));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    int n_gpu = 0;
 | 
			
		||||
    size_t vram_total = 0;
 | 
			
		||||
 | 
			
		||||
#ifdef GGML_USE_CUBLAS
 | 
			
		||||
    {
 | 
			
		||||
        n_gpu = std::min(n_gpu_layers, model->n_layer);
 | 
			
		||||
 | 
			
		||||
        for (int i = 0; i < n_gpu; ++i) {
 | 
			
		||||
            const auto & layer = model->layers[i];
 | 
			
		||||
 | 
			
		||||
            // Use cuBLAS only for heavy matrices; other operations are not supported for GPU at the moment
 | 
			
		||||
            ggml_cuda_transform_tensor(layer.att_key); vram_total += ggml_nbytes(layer.att_key);
 | 
			
		||||
            ggml_cuda_transform_tensor(layer.att_value); vram_total += ggml_nbytes(layer.att_value);
 | 
			
		||||
            ggml_cuda_transform_tensor(layer.att_receptance); vram_total += ggml_nbytes(layer.att_receptance);
 | 
			
		||||
            ggml_cuda_transform_tensor(layer.att_output); vram_total += ggml_nbytes(layer.att_output);
 | 
			
		||||
 | 
			
		||||
            ggml_cuda_transform_tensor(layer.ffn_key); vram_total += ggml_nbytes(layer.ffn_key);
 | 
			
		||||
            ggml_cuda_transform_tensor(layer.ffn_value); vram_total += ggml_nbytes(layer.ffn_value);
 | 
			
		||||
            ggml_cuda_transform_tensor(layer.ffn_receptance); vram_total += ggml_nbytes(layer.ffn_receptance);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
    RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "ln_out.weight", &model->ln_out_weight));
 | 
			
		||||
    RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "ln_out.bias", &model->ln_out_bias));
 | 
			
		||||
    RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "head.weight", &model->head));
 | 
			
		||||
| 
						 | 
				
			
			@ -621,6 +654,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
 | 
			
		|||
    rwkv_ctx->graph = std::move(graph);
 | 
			
		||||
    rwkv_ctx->last_error = RWKV_ERROR_NONE;
 | 
			
		||||
    rwkv_ctx->print_errors = global_print_errors;
 | 
			
		||||
    rwkv_ctx->gpu_layers = n_gpu;
 | 
			
		||||
    rwkv_ctx->vram_total = vram_total;
 | 
			
		||||
    // Don't free ggml context
 | 
			
		||||
    ggml_guard.ctx = NULL;
 | 
			
		||||
    return rwkv_ctx.release();
 | 
			
		||||
| 
						 | 
				
			
			@ -936,4 +971,4 @@ const char * rwkv_get_system_info_string(void) {
 | 
			
		|||
    s += "VSX = "       + std::to_string(ggml_cpu_has_vsx())       + " | ";
 | 
			
		||||
 | 
			
		||||
    return s.c_str();
 | 
			
		||||
}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										3
									
								
								rwkv.h
								
								
								
								
							
							
						
						
									
										3
									
								
								rwkv.h
								
								
								
								
							| 
						 | 
				
			
			@ -83,7 +83,8 @@ extern "C" {
 | 
			
		|||
    // Returns NULL on any error. Error messages would be printed to stderr.
 | 
			
		||||
    // - model_file_path: path to model file in ggml format.
 | 
			
		||||
    // - n_threads: count of threads to use, must be positive.
 | 
			
		||||
    RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads);
 | 
			
		||||
    // - n_gpu_layer: count of layers need to load to gpu (only works when cuBLAS is on)
 | 
			
		||||
    RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads, const uint32_t n_gpu_layers);
 | 
			
		||||
 | 
			
		||||
    // Evaluates the model for a single token.
 | 
			
		||||
    // Returns false on any error. Error messages would be printed to stderr.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,6 +13,7 @@ import rwkv_cpp_model
 | 
			
		|||
import rwkv_cpp_shared_library
 | 
			
		||||
import json
 | 
			
		||||
from typing import List, Dict, Optional
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
# ======================================== Script settings ========================================
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -108,10 +109,13 @@ def split_last_end_of_line(tokens):
 | 
			
		|||
    return tokens
 | 
			
		||||
 | 
			
		||||
# =================================================================================================
 | 
			
		||||
 | 
			
		||||
T1 = time.time()
 | 
			
		||||
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
 | 
			
		||||
 | 
			
		||||
process_tokens(split_last_end_of_line(tokenizer.encode(init_prompt).ids))
 | 
			
		||||
T2 = time.time()
 | 
			
		||||
print(f'Process time :{((T2 - T1)*1000)} ms')
 | 
			
		||||
print(f'Process time per token :{(((T2 - T1)*1000)) / prompt_token_count} ms')
 | 
			
		||||
 | 
			
		||||
save_thread_state('chat_init')
 | 
			
		||||
save_thread_state('chat')
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,7 +13,8 @@ class RWKVModel:
 | 
			
		|||
            self,
 | 
			
		||||
            shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary,
 | 
			
		||||
            model_path: str,
 | 
			
		||||
            thread_count: int = max(1, multiprocessing.cpu_count() // 2)
 | 
			
		||||
            thread_count: int = max(1, multiprocessing.cpu_count() // 2),
 | 
			
		||||
            gpu_layers_count: int = 4,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Loads the model and prepares it for inference.
 | 
			
		||||
| 
						 | 
				
			
			@ -31,10 +32,11 @@ class RWKVModel:
 | 
			
		|||
 | 
			
		||||
        assert os.path.isfile(model_path), f'{model_path} is not a file'
 | 
			
		||||
        assert thread_count > 0, 'Thread count must be positive'
 | 
			
		||||
        assert gpu_layers_count > 0, 'GPU layers count must be positive'
 | 
			
		||||
 | 
			
		||||
        self._library = shared_library
 | 
			
		||||
 | 
			
		||||
        self._ctx = self._library.rwkv_init_from_file(model_path, thread_count)
 | 
			
		||||
        self._ctx = self._library.rwkv_init_from_file(model_path, thread_count, gpu_layers_count)
 | 
			
		||||
 | 
			
		||||
        self._state_buffer_element_count = self._library.rwkv_get_state_buffer_element_count(self._ctx)
 | 
			
		||||
        self._logits_buffer_element_count = self._library.rwkv_get_logits_buffer_element_count(self._ctx)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -37,7 +37,7 @@ class RWKVSharedLibrary:
 | 
			
		|||
 | 
			
		||||
        self.library = ctypes.cdll.LoadLibrary(shared_library_path)
 | 
			
		||||
 | 
			
		||||
        self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
 | 
			
		||||
        self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32, ctypes.c_uint32]
 | 
			
		||||
        self.library.rwkv_init_from_file.restype = ctypes.c_void_p
 | 
			
		||||
 | 
			
		||||
        self.library.rwkv_eval.argtypes = [
 | 
			
		||||
| 
						 | 
				
			
			@ -67,7 +67,7 @@ class RWKVSharedLibrary:
 | 
			
		|||
        self.library.rwkv_get_system_info_string.argtypes = []
 | 
			
		||||
        self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p
 | 
			
		||||
 | 
			
		||||
    def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext:
 | 
			
		||||
    def rwkv_init_from_file(self, model_file_path: str, thread_count: int, gpu_layers_count: int) -> RWKVContext:
 | 
			
		||||
        """
 | 
			
		||||
        Loads the model from a file and prepares it for inference.
 | 
			
		||||
        Throws an exception in case of any error. Error messages would be printed to stderr.
 | 
			
		||||
| 
						 | 
				
			
			@ -78,9 +78,13 @@ class RWKVSharedLibrary:
 | 
			
		|||
            Path to model file in ggml format.
 | 
			
		||||
        thread_count : int
 | 
			
		||||
            Count of threads to use, must be positive.
 | 
			
		||||
        gpu_layers_count : int
 | 
			
		||||
            Count of layers to load on gpu, must be positive only enabled with cuBLAS.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count))
 | 
			
		||||
        ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'),
 | 
			
		||||
                                               ctypes.c_uint32(thread_count),
 | 
			
		||||
                                               ctypes.c_uint32(gpu_layers_count))
 | 
			
		||||
        assert ptr is not None, 'rwkv_init_from_file failed, check stderr'
 | 
			
		||||
        return RWKVContext(ptr)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -186,6 +190,7 @@ class RWKVSharedLibrary:
 | 
			
		|||
 | 
			
		||||
        return self.library.rwkv_get_system_info_string().decode('utf-8')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_rwkv_shared_library() -> RWKVSharedLibrary:
 | 
			
		||||
    """
 | 
			
		||||
    Attempts to find rwkv.cpp shared library and load it.
 | 
			
		||||
| 
						 | 
				
			
			@ -208,6 +213,10 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary:
 | 
			
		|||
        f'../bin/Release/{file_name}',
 | 
			
		||||
        # If we are in repo root directory
 | 
			
		||||
        f'bin/Release/{file_name}',
 | 
			
		||||
        # If we compiled in build directory
 | 
			
		||||
        f'build/bin/Release/{file_name}',
 | 
			
		||||
        # If we compiled in build directory
 | 
			
		||||
        f'build/{file_name}',
 | 
			
		||||
        # Search relative to this file
 | 
			
		||||
        str(repo_root_dir / 'bin' / 'Release' / file_name),
 | 
			
		||||
        # Fallback
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,9 @@
 | 
			
		|||
function(rwkv_add_test source)
 | 
			
		||||
    get_filename_component(TEST_TARGET ${source} NAME_WE)
 | 
			
		||||
    add_executable(${TEST_TARGET} ${source})
 | 
			
		||||
    if (GGML_CUDA_SOURCES)
 | 
			
		||||
        set_property(TARGET ${TEST_TARGET} PROPERTY CUDA_ARCHITECTURES OFF)
 | 
			
		||||
    endif()
 | 
			
		||||
    target_link_libraries(${TEST_TARGET} PRIVATE ggml rwkv)
 | 
			
		||||
    add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
 | 
			
		||||
endfunction()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,12 +21,12 @@
 | 
			
		|||
 | 
			
		||||
#define N_VOCAB 256
 | 
			
		||||
#define N_THREADS 2
 | 
			
		||||
#define N_GPU_LAYERS 1
 | 
			
		||||
 | 
			
		||||
void test_model(const char * model_path, const float * expected_logits, const float max_diff) {
 | 
			
		||||
    fprintf(stderr, "Testing %s\n", model_path);
 | 
			
		||||
 | 
			
		||||
    struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS);
 | 
			
		||||
 | 
			
		||||
    struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS, N_GPU_LAYERS);
 | 
			
		||||
    enum rwkv_error_flags error = rwkv_get_last_error(NULL);
 | 
			
		||||
    ASSERT(error == 0, "Unexpected error %d", error);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -72,18 +72,27 @@ int main(void) {
 | 
			
		|||
    ASSERT(elements_read == N_VOCAB, "Failed to read expected_logits.bin, read %zd elements", elements_read);
 | 
			
		||||
    fclose(file);
 | 
			
		||||
 | 
			
		||||
    // Somehow when using cuBLAS the calculation of Q4_1 may different from cpu only
 | 
			
		||||
    float expected_difference_sum[14] = {
 | 
			
		||||
        0.000000F,
 | 
			
		||||
        -0.005320F,
 | 
			
		||||
 | 
			
		||||
        -0.160030F,
 | 
			
		||||
#ifdef GGML_USE_CUBLAS
 | 
			
		||||
        -0.412408F,
 | 
			
		||||
#else
 | 
			
		||||
        -0.370606F,
 | 
			
		||||
#endif
 | 
			
		||||
        -0.170404F,
 | 
			
		||||
        0.278034F,
 | 
			
		||||
        0.071216F,
 | 
			
		||||
 | 
			
		||||
        0.154614F,
 | 
			
		||||
#ifdef GGML_USE_CUBLAS
 | 
			
		||||
        -0.405527F,
 | 
			
		||||
#else
 | 
			
		||||
        -0.372169F,
 | 
			
		||||
#endif
 | 
			
		||||
        -0.170043F,
 | 
			
		||||
        0.294953F,
 | 
			
		||||
        0.065571F,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue