diff --git a/CMakeLists.txt b/CMakeLists.txt index 7ff772f..38931e5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,7 @@ option(RWKV_FMA "rwkv: enable FMA" # 3rd party libs option(RWKV_ACCELERATE "rwkv: enable Accelerate framework" ON) option(RWKV_OPENBLAS "rwkv: use OpenBLAS" OFF) +option(RWKV_CUBLAS "rwkv: use cuBLAS" OFF) # # Compile flags @@ -97,6 +98,30 @@ if (RWKV_OPENBLAS) endif() endif() +if (RWKV_CUBLAS) + cmake_minimum_required(VERSION 3.17) + + find_package(CUDAToolkit) + if (CUDAToolkit_FOUND) + message(STATUS "cuBLAS found") + + enable_language(CUDA) + + set(GGML_CUDA_SOURCES ${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.cu ${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.h) + + add_compile_definitions(GGML_USE_CUBLAS) + + if (RWKV_STATIC) + set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + else() + set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) + endif() + + else() + message(WARNING "cuBLAS not found") + endif() +endif() + if (RWKV_ALL_WARNINGS) if (NOT MSVC) set(c_flags @@ -177,11 +202,18 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$") message(STATUS "x86 detected") if (MSVC) if (RWKV_AVX512) - add_compile_options(/arch:AVX512) + add_compile_options($<$:/arch:AVX512>) + add_compile_options($<$:/arch:AVX512>) + # MSVC has no compile-time flags enabling specific + # AVX512 extensions, neither it defines the + # macros corresponding to the extensions. + # Do it manually. elseif (RWKV_AVX2) - add_compile_options(/arch:AVX2) + add_compile_options($<$:/arch:AVX2>) + add_compile_options($<$:/arch:AVX2>) elseif (RWKV_AVX) - add_compile_options(/arch:AVX) + add_compile_options($<$:/arch:AVX>) + add_compile_options($<$:/arch:AVX>) endif() else() add_compile_options(-mf16c) @@ -212,7 +244,19 @@ if (MSVC) add_compile_definitions(_CRT_SECURE_NO_WARNINGS) endif() -add_subdirectory(ggml) +add_library(ggml OBJECT + ${CMAKE_SOURCE_DIR}/ggml/src/ggml.c + ${CMAKE_SOURCE_DIR}/ggml/include/ggml/ggml.h + ${GGML_CUDA_SOURCES}) + +target_include_directories(ggml PUBLIC ${CMAKE_SOURCE_DIR}/ggml/include/ggml) +target_compile_features(ggml PUBLIC c_std_11) # Don't bump + +if (MSVC) + target_link_libraries(ggml PUBLIC ${RWKV_EXTRA_LIBS} Threads::Threads) +else() + target_link_libraries(ggml PUBLIC m ${RWKV_EXTRA_LIBS} Threads::Threads) +endif() if (RWKV_BUILD_SHARED_LIBRARY) set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON) @@ -233,5 +277,12 @@ if (RWKV_BUILD_SHARED_LIBRARY) target_compile_definitions(rwkv PRIVATE RWKV_SHARED RWKV_BUILD) endif() +if (GGML_CUDA_SOURCES) + message(STATUS "GGML CUDA sources found, configuring CUDA architecture") + set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF) + set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto") + set_property(TARGET rwkv PROPERTY CUDA_ARCHITECTURES OFF) +endif() + enable_testing() add_subdirectory(tests) diff --git a/README.md b/README.md index ad60d29..a03c7ae 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,21 @@ Below table is for reference only. Measurements were made on 4C/8T x86 CPU with | `FP16` | **15.623** | 117 | 2.82 | | `FP32` | **15.623** | 198 | 5.64 | +#### With cuBLAS + +Measurements were made on 3060Ti 8G + i7 13700K. Latency per token shown. + +| Model | Layers on GPU | Format | 24 Threads | 8 Threads | 4 Threads | 2 Threads | 1 Threads | +|-----------------------|---------------|--------|-------------|------------|------------|------------|------------| +| `RWKV-4-Pile-169M` | 12 | `Q4_0` | 20.6 ms | 8.6 ms | 6.9 ms | 6.2 ms | 7.9 ms | +| `RWKV-4-Pile-169M` | 12 | `Q4_1` | 21.4 ms | 8.6 ms | 6.9 ms | 6.7 ms | 7.8 ms | +| `RWKV-4-Pile-169M` | 12 | `Q5_1` | 22.2 ms | 9.0 ms | 6.9 ms | 6.7 ms | 8.1 ms | +| `RWKV-4-Raven-7B-v11` | 32 | `Q4_0` | 94.9 ms | 54.3 ms | 50.2 ms | 51.6 ms | 59.2 ms | +| `RWKV-4-Raven-7B-v11` | 32 | `Q4_1` | 94.5 ms | 54.3 ms | 49.7 ms | 51.8 ms | 59.2 ms | +| `RWKV-4-Raven-7B-v11` | 32 | `Q5_1` | 101.6 ms | 72.3 ms | 67.2 ms | 69.3 ms | 77.0 ms | + +Note: since there is only `ggml_mul_mat()` supported with cuBLAS, we still need to assign few CPU resources to execute remaining operations. + ## How to use ### 1. Clone the repo @@ -62,6 +77,17 @@ cmake --build . --config Release If everything went OK, `bin\Release\rwkv.dll` file should appear. +##### Windows + cuBLAS + +**Important**: Since there is no cuBLAS static libraries for Windows, after compiling with dynamic libraries following DLLs should be copied from `{CUDA}/bin` into `build/bin/Release`: `cudart64_12.dll`, `cublas64_12.dll`, `cublasLt64_12.dll`. + +```commandline +mkdir build +cd build +cmake .. -DRWKV_CUBLAS=ON +cmake --build . --config Release +``` + ##### Linux / MacOS **Requirements**: CMake (Linux: `sudo apt install cmake`, MacOS: `brew install cmake`, anaconoda: [cmake package](https://anaconda.org/conda-forge/cmake)). @@ -75,6 +101,16 @@ cmake --build . --config Release If everything went OK, `librwkv.so` (Linux) or `librwkv.dylib` (MacOS) file should appear in the base repo folder. +##### Linux / MacOS + cuBLAS + +```commandline +mkdir build +cd build +cmake .. -DRWKV_CUBLAS=ON +cmake --build . --config Release +``` + +If everything went OK, `librwkv.so` (Linux) or `librwkv.dylib` (MacOS) file should appear in the base repo folder. ### 3. Get an RWKV model @@ -152,14 +188,16 @@ model_path = r'C:\rwkv.cpp-169M.bin' model = rwkv_cpp_model.RWKVModel( rwkv_cpp_shared_library.load_rwkv_shared_library(), - model_path + model_path, + thread_count=4, #need to adjust when use cuBLAS + gpu_layers_count=5 #only enabled when use cuBLAS ) logits, state = None, None for token in [1, 2, 3]: logits, state = model.eval(token, state) - + print(f'Output logits: {logits}') # Don't forget to free the memory after you've done working with the model diff --git a/rwkv.cpp b/rwkv.cpp index 4268992..7e76abf 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -1,6 +1,10 @@ #include "rwkv.h" #include "ggml.h" +#ifdef GGML_USE_CUBLAS +#include "ggml/src/ggml-cuda.h" +#endif + #include #include #include @@ -274,6 +278,8 @@ struct rwkv_context { struct rwkv_graph graph; enum rwkv_error_flags last_error; bool print_errors; + size_t vram_total; + int gpu_layers; }; void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) { @@ -461,6 +467,10 @@ struct rwkv_ggml_guard { }; struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) { + return rwkv_init_from_file(file_path, n_threads, 0); +} + +struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads, const uint32_t n_gpu_layers) { global_last_error = RWKV_ERROR_NONE; FILE * file = fopen(file_path, "rb"); @@ -481,7 +491,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t std::unique_ptr model(new(std::nothrow) struct rwkv_model()); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_ALLOC, model.get(), "Failed to allocate model"); - + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_vocab, "n_vocab")); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_embed, "n_embed")); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_layer, "n_layer")); @@ -600,6 +610,29 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.receptance.weight", &layer->ffn_receptance)); } + int n_gpu = 0; + size_t vram_total = 0; + +#ifdef GGML_USE_CUBLAS + { + n_gpu = std::min(n_gpu_layers, model->n_layer); + + for (int i = 0; i < n_gpu; ++i) { + const auto & layer = model->layers[i]; + + // Use cuBLAS only for heavy matrices; other operations are not supported for GPU at the moment + ggml_cuda_transform_tensor(layer.att_key); vram_total += ggml_nbytes(layer.att_key); + ggml_cuda_transform_tensor(layer.att_value); vram_total += ggml_nbytes(layer.att_value); + ggml_cuda_transform_tensor(layer.att_receptance); vram_total += ggml_nbytes(layer.att_receptance); + ggml_cuda_transform_tensor(layer.att_output); vram_total += ggml_nbytes(layer.att_output); + + ggml_cuda_transform_tensor(layer.ffn_key); vram_total += ggml_nbytes(layer.ffn_key); + ggml_cuda_transform_tensor(layer.ffn_value); vram_total += ggml_nbytes(layer.ffn_value); + ggml_cuda_transform_tensor(layer.ffn_receptance); vram_total += ggml_nbytes(layer.ffn_receptance); + } + } +#endif + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "ln_out.weight", &model->ln_out_weight)); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "ln_out.bias", &model->ln_out_bias)); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "head.weight", &model->head)); @@ -621,6 +654,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t rwkv_ctx->graph = std::move(graph); rwkv_ctx->last_error = RWKV_ERROR_NONE; rwkv_ctx->print_errors = global_print_errors; + rwkv_ctx->gpu_layers = n_gpu; + rwkv_ctx->vram_total = vram_total; // Don't free ggml context ggml_guard.ctx = NULL; return rwkv_ctx.release(); @@ -936,4 +971,4 @@ const char * rwkv_get_system_info_string(void) { s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; return s.c_str(); -} \ No newline at end of file +} diff --git a/rwkv.h b/rwkv.h index 539e655..ff50447 100644 --- a/rwkv.h +++ b/rwkv.h @@ -83,7 +83,8 @@ extern "C" { // Returns NULL on any error. Error messages would be printed to stderr. // - model_file_path: path to model file in ggml format. // - n_threads: count of threads to use, must be positive. - RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads); + // - n_gpu_layer: count of layers need to load to gpu (only works when cuBLAS is on) + RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads, const uint32_t n_gpu_layers); // Evaluates the model for a single token. // Returns false on any error. Error messages would be printed to stderr. diff --git a/rwkv/chat_with_bot.py b/rwkv/chat_with_bot.py index 3ebef41..f1c65dc 100644 --- a/rwkv/chat_with_bot.py +++ b/rwkv/chat_with_bot.py @@ -13,6 +13,7 @@ import rwkv_cpp_model import rwkv_cpp_shared_library import json from typing import List, Dict, Optional +import time # ======================================== Script settings ======================================== @@ -108,10 +109,13 @@ def split_last_end_of_line(tokens): return tokens # ================================================================================================= - +T1 = time.time() print(f'Processing {prompt_token_count} prompt tokens, may take a while') process_tokens(split_last_end_of_line(tokenizer.encode(init_prompt).ids)) +T2 = time.time() +print(f'Process time :{((T2 - T1)*1000)} ms') +print(f'Process time per token :{(((T2 - T1)*1000)) / prompt_token_count} ms') save_thread_state('chat_init') save_thread_state('chat') diff --git a/rwkv/rwkv_cpp_model.py b/rwkv/rwkv_cpp_model.py index 247002d..c108c8a 100644 --- a/rwkv/rwkv_cpp_model.py +++ b/rwkv/rwkv_cpp_model.py @@ -13,7 +13,8 @@ class RWKVModel: self, shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary, model_path: str, - thread_count: int = max(1, multiprocessing.cpu_count() // 2) + thread_count: int = max(1, multiprocessing.cpu_count() // 2), + gpu_layers_count: int = 4, ): """ Loads the model and prepares it for inference. @@ -31,10 +32,11 @@ class RWKVModel: assert os.path.isfile(model_path), f'{model_path} is not a file' assert thread_count > 0, 'Thread count must be positive' + assert gpu_layers_count > 0, 'GPU layers count must be positive' self._library = shared_library - self._ctx = self._library.rwkv_init_from_file(model_path, thread_count) + self._ctx = self._library.rwkv_init_from_file(model_path, thread_count, gpu_layers_count) self._state_buffer_element_count = self._library.rwkv_get_state_buffer_element_count(self._ctx) self._logits_buffer_element_count = self._library.rwkv_get_logits_buffer_element_count(self._ctx) diff --git a/rwkv/rwkv_cpp_shared_library.py b/rwkv/rwkv_cpp_shared_library.py index 2004361..56e4afb 100644 --- a/rwkv/rwkv_cpp_shared_library.py +++ b/rwkv/rwkv_cpp_shared_library.py @@ -37,7 +37,7 @@ class RWKVSharedLibrary: self.library = ctypes.cdll.LoadLibrary(shared_library_path) - self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32] + self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32, ctypes.c_uint32] self.library.rwkv_init_from_file.restype = ctypes.c_void_p self.library.rwkv_eval.argtypes = [ @@ -67,7 +67,7 @@ class RWKVSharedLibrary: self.library.rwkv_get_system_info_string.argtypes = [] self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p - def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext: + def rwkv_init_from_file(self, model_file_path: str, thread_count: int, gpu_layers_count: int) -> RWKVContext: """ Loads the model from a file and prepares it for inference. Throws an exception in case of any error. Error messages would be printed to stderr. @@ -78,9 +78,13 @@ class RWKVSharedLibrary: Path to model file in ggml format. thread_count : int Count of threads to use, must be positive. + gpu_layers_count : int + Count of layers to load on gpu, must be positive only enabled with cuBLAS. """ - ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count)) + ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), + ctypes.c_uint32(thread_count), + ctypes.c_uint32(gpu_layers_count)) assert ptr is not None, 'rwkv_init_from_file failed, check stderr' return RWKVContext(ptr) @@ -186,6 +190,7 @@ class RWKVSharedLibrary: return self.library.rwkv_get_system_info_string().decode('utf-8') + def load_rwkv_shared_library() -> RWKVSharedLibrary: """ Attempts to find rwkv.cpp shared library and load it. @@ -208,6 +213,10 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary: f'../bin/Release/{file_name}', # If we are in repo root directory f'bin/Release/{file_name}', + # If we compiled in build directory + f'build/bin/Release/{file_name}', + # If we compiled in build directory + f'build/{file_name}', # Search relative to this file str(repo_root_dir / 'bin' / 'Release' / file_name), # Fallback diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b9fe3df..2a4e4b3 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,6 +1,9 @@ function(rwkv_add_test source) get_filename_component(TEST_TARGET ${source} NAME_WE) add_executable(${TEST_TARGET} ${source}) + if (GGML_CUDA_SOURCES) + set_property(TARGET ${TEST_TARGET} PROPERTY CUDA_ARCHITECTURES OFF) + endif() target_link_libraries(${TEST_TARGET} PRIVATE ggml rwkv) add_test(NAME ${TEST_TARGET} COMMAND $ ${ARGN}) endfunction() diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index 1244877..e8085df 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -21,12 +21,12 @@ #define N_VOCAB 256 #define N_THREADS 2 +#define N_GPU_LAYERS 1 void test_model(const char * model_path, const float * expected_logits, const float max_diff) { fprintf(stderr, "Testing %s\n", model_path); - struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS); - + struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS, N_GPU_LAYERS); enum rwkv_error_flags error = rwkv_get_last_error(NULL); ASSERT(error == 0, "Unexpected error %d", error); @@ -72,18 +72,27 @@ int main(void) { ASSERT(elements_read == N_VOCAB, "Failed to read expected_logits.bin, read %zd elements", elements_read); fclose(file); + // Somehow when using cuBLAS the calculation of Q4_1 may different from cpu only float expected_difference_sum[14] = { 0.000000F, -0.005320F, -0.160030F, +#ifdef GGML_USE_CUBLAS + -0.412408F, +#else -0.370606F, +#endif -0.170404F, 0.278034F, 0.071216F, 0.154614F, +#ifdef GGML_USE_CUBLAS + -0.405527F, +#else -0.372169F, +#endif -0.170043F, 0.294953F, 0.065571F,