Feature add cublas support (#65)

* chore: add ggml import in the head of rwkv.h

* chore: add ggml import in the head of rwkv.h

* feat: add cublas support

* feat: update rwkv.cpp

* feat: remove unused change

* chore: fix linux build issue

* chore: sync ggml and offload tensor to gpu

* chore: comment out tensors which occurs error on GPU

* chore: update comment and readme

* chore: update ggml to recent

* chore: add more performance test results

* chore: add more performance test results

* chore: fix problem of reading file more than 2 gb

* chore: merge master

* chore: remove unused comment

* chore: fix for comments

* Update README.md

* Update rwkv.cpp

---------

Co-authored-by: Alex <saharNooby@users.noreply.github.com>
This commit is contained in:
YorkZero 2023-05-29 21:10:19 +09:00 committed by GitHub
parent dea929f8ca
commit 241350fde6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 169 additions and 17 deletions

View File

@ -39,6 +39,7 @@ option(RWKV_FMA "rwkv: enable FMA"
# 3rd party libs
option(RWKV_ACCELERATE "rwkv: enable Accelerate framework" ON)
option(RWKV_OPENBLAS "rwkv: use OpenBLAS" OFF)
option(RWKV_CUBLAS "rwkv: use cuBLAS" OFF)
#
# Compile flags
@ -97,6 +98,30 @@ if (RWKV_OPENBLAS)
endif()
endif()
if (RWKV_CUBLAS)
cmake_minimum_required(VERSION 3.17)
find_package(CUDAToolkit)
if (CUDAToolkit_FOUND)
message(STATUS "cuBLAS found")
enable_language(CUDA)
set(GGML_CUDA_SOURCES ${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.cu ${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.h)
add_compile_definitions(GGML_USE_CUBLAS)
if (RWKV_STATIC)
set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
else()
set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
endif()
else()
message(WARNING "cuBLAS not found")
endif()
endif()
if (RWKV_ALL_WARNINGS)
if (NOT MSVC)
set(c_flags
@ -177,11 +202,18 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
message(STATUS "x86 detected")
if (MSVC)
if (RWKV_AVX512)
add_compile_options(/arch:AVX512)
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
# MSVC has no compile-time flags enabling specific
# AVX512 extensions, neither it defines the
# macros corresponding to the extensions.
# Do it manually.
elseif (RWKV_AVX2)
add_compile_options(/arch:AVX2)
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX2>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>)
elseif (RWKV_AVX)
add_compile_options(/arch:AVX)
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX>)
endif()
else()
add_compile_options(-mf16c)
@ -212,7 +244,19 @@ if (MSVC)
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
endif()
add_subdirectory(ggml)
add_library(ggml OBJECT
${CMAKE_SOURCE_DIR}/ggml/src/ggml.c
${CMAKE_SOURCE_DIR}/ggml/include/ggml/ggml.h
${GGML_CUDA_SOURCES})
target_include_directories(ggml PUBLIC ${CMAKE_SOURCE_DIR}/ggml/include/ggml)
target_compile_features(ggml PUBLIC c_std_11) # Don't bump
if (MSVC)
target_link_libraries(ggml PUBLIC ${RWKV_EXTRA_LIBS} Threads::Threads)
else()
target_link_libraries(ggml PUBLIC m ${RWKV_EXTRA_LIBS} Threads::Threads)
endif()
if (RWKV_BUILD_SHARED_LIBRARY)
set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON)
@ -233,5 +277,12 @@ if (RWKV_BUILD_SHARED_LIBRARY)
target_compile_definitions(rwkv PRIVATE RWKV_SHARED RWKV_BUILD)
endif()
if (GGML_CUDA_SOURCES)
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF)
set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
set_property(TARGET rwkv PROPERTY CUDA_ARCHITECTURES OFF)
endif()
enable_testing()
add_subdirectory(tests)

View File

@ -26,6 +26,21 @@ Below table is for reference only. Measurements were made on 4C/8T x86 CPU with
| `FP16` | **15.623** | 117 | 2.82 |
| `FP32` | **15.623** | 198 | 5.64 |
#### With cuBLAS
Measurements were made on 3060Ti 8G + i7 13700K. Latency per token shown.
| Model | Layers on GPU | Format | 24 Threads | 8 Threads | 4 Threads | 2 Threads | 1 Threads |
|-----------------------|---------------|--------|-------------|------------|------------|------------|------------|
| `RWKV-4-Pile-169M` | 12 | `Q4_0` | 20.6 ms | 8.6 ms | 6.9 ms | 6.2 ms | 7.9 ms |
| `RWKV-4-Pile-169M` | 12 | `Q4_1` | 21.4 ms | 8.6 ms | 6.9 ms | 6.7 ms | 7.8 ms |
| `RWKV-4-Pile-169M` | 12 | `Q5_1` | 22.2 ms | 9.0 ms | 6.9 ms | 6.7 ms | 8.1 ms |
| `RWKV-4-Raven-7B-v11` | 32 | `Q4_0` | 94.9 ms | 54.3 ms | 50.2 ms | 51.6 ms | 59.2 ms |
| `RWKV-4-Raven-7B-v11` | 32 | `Q4_1` | 94.5 ms | 54.3 ms | 49.7 ms | 51.8 ms | 59.2 ms |
| `RWKV-4-Raven-7B-v11` | 32 | `Q5_1` | 101.6 ms | 72.3 ms | 67.2 ms | 69.3 ms | 77.0 ms |
Note: since there is only `ggml_mul_mat()` supported with cuBLAS, we still need to assign few CPU resources to execute remaining operations.
## How to use
### 1. Clone the repo
@ -62,6 +77,17 @@ cmake --build . --config Release
If everything went OK, `bin\Release\rwkv.dll` file should appear.
##### Windows + cuBLAS
**Important**: Since there is no cuBLAS static libraries for Windows, after compiling with dynamic libraries following DLLs should be copied from `{CUDA}/bin` into `build/bin/Release`: `cudart64_12.dll`, `cublas64_12.dll`, `cublasLt64_12.dll`.
```commandline
mkdir build
cd build
cmake .. -DRWKV_CUBLAS=ON
cmake --build . --config Release
```
##### Linux / MacOS
**Requirements**: CMake (Linux: `sudo apt install cmake`, MacOS: `brew install cmake`, anaconoda: [cmake package](https://anaconda.org/conda-forge/cmake)).
@ -75,6 +101,16 @@ cmake --build . --config Release
If everything went OK, `librwkv.so` (Linux) or `librwkv.dylib` (MacOS) file should appear in the base repo folder.
##### Linux / MacOS + cuBLAS
```commandline
mkdir build
cd build
cmake .. -DRWKV_CUBLAS=ON
cmake --build . --config Release
```
If everything went OK, `librwkv.so` (Linux) or `librwkv.dylib` (MacOS) file should appear in the base repo folder.
### 3. Get an RWKV model
@ -152,14 +188,16 @@ model_path = r'C:\rwkv.cpp-169M.bin'
model = rwkv_cpp_model.RWKVModel(
rwkv_cpp_shared_library.load_rwkv_shared_library(),
model_path
model_path,
thread_count=4, #need to adjust when use cuBLAS
gpu_layers_count=5 #only enabled when use cuBLAS
)
logits, state = None, None
for token in [1, 2, 3]:
logits, state = model.eval(token, state)
print(f'Output logits: {logits}')
# Don't forget to free the memory after you've done working with the model

View File

@ -1,6 +1,10 @@
#include "rwkv.h"
#include "ggml.h"
#ifdef GGML_USE_CUBLAS
#include "ggml/src/ggml-cuda.h"
#endif
#include <string>
#include <vector>
#include <thread>
@ -274,6 +278,8 @@ struct rwkv_context {
struct rwkv_graph graph;
enum rwkv_error_flags last_error;
bool print_errors;
size_t vram_total;
int gpu_layers;
};
void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) {
@ -461,6 +467,10 @@ struct rwkv_ggml_guard {
};
struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) {
return rwkv_init_from_file(file_path, n_threads, 0);
}
struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads, const uint32_t n_gpu_layers) {
global_last_error = RWKV_ERROR_NONE;
FILE * file = fopen(file_path, "rb");
@ -481,7 +491,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
std::unique_ptr<rwkv_model> model(new(std::nothrow) struct rwkv_model());
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_ALLOC, model.get(), "Failed to allocate model");
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_vocab, "n_vocab"));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_embed, "n_embed"));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_layer, "n_layer"));
@ -600,6 +610,29 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "ffn.receptance.weight", &layer->ffn_receptance));
}
int n_gpu = 0;
size_t vram_total = 0;
#ifdef GGML_USE_CUBLAS
{
n_gpu = std::min(n_gpu_layers, model->n_layer);
for (int i = 0; i < n_gpu; ++i) {
const auto & layer = model->layers[i];
// Use cuBLAS only for heavy matrices; other operations are not supported for GPU at the moment
ggml_cuda_transform_tensor(layer.att_key); vram_total += ggml_nbytes(layer.att_key);
ggml_cuda_transform_tensor(layer.att_value); vram_total += ggml_nbytes(layer.att_value);
ggml_cuda_transform_tensor(layer.att_receptance); vram_total += ggml_nbytes(layer.att_receptance);
ggml_cuda_transform_tensor(layer.att_output); vram_total += ggml_nbytes(layer.att_output);
ggml_cuda_transform_tensor(layer.ffn_key); vram_total += ggml_nbytes(layer.ffn_key);
ggml_cuda_transform_tensor(layer.ffn_value); vram_total += ggml_nbytes(layer.ffn_value);
ggml_cuda_transform_tensor(layer.ffn_receptance); vram_total += ggml_nbytes(layer.ffn_receptance);
}
}
#endif
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(&parameters, "ln_out.weight", &model->ln_out_weight));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(&parameters, "ln_out.bias", &model->ln_out_bias));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(&parameters, "head.weight", &model->head));
@ -621,6 +654,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
rwkv_ctx->graph = std::move(graph);
rwkv_ctx->last_error = RWKV_ERROR_NONE;
rwkv_ctx->print_errors = global_print_errors;
rwkv_ctx->gpu_layers = n_gpu;
rwkv_ctx->vram_total = vram_total;
// Don't free ggml context
ggml_guard.ctx = NULL;
return rwkv_ctx.release();
@ -936,4 +971,4 @@ const char * rwkv_get_system_info_string(void) {
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
return s.c_str();
}
}

3
rwkv.h
View File

@ -83,7 +83,8 @@ extern "C" {
// Returns NULL on any error. Error messages would be printed to stderr.
// - model_file_path: path to model file in ggml format.
// - n_threads: count of threads to use, must be positive.
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads);
// - n_gpu_layer: count of layers need to load to gpu (only works when cuBLAS is on)
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads, const uint32_t n_gpu_layers);
// Evaluates the model for a single token.
// Returns false on any error. Error messages would be printed to stderr.

View File

@ -13,6 +13,7 @@ import rwkv_cpp_model
import rwkv_cpp_shared_library
import json
from typing import List, Dict, Optional
import time
# ======================================== Script settings ========================================
@ -108,10 +109,13 @@ def split_last_end_of_line(tokens):
return tokens
# =================================================================================================
T1 = time.time()
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
process_tokens(split_last_end_of_line(tokenizer.encode(init_prompt).ids))
T2 = time.time()
print(f'Process time :{((T2 - T1)*1000)} ms')
print(f'Process time per token :{(((T2 - T1)*1000)) / prompt_token_count} ms')
save_thread_state('chat_init')
save_thread_state('chat')

View File

@ -13,7 +13,8 @@ class RWKVModel:
self,
shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary,
model_path: str,
thread_count: int = max(1, multiprocessing.cpu_count() // 2)
thread_count: int = max(1, multiprocessing.cpu_count() // 2),
gpu_layers_count: int = 4,
):
"""
Loads the model and prepares it for inference.
@ -31,10 +32,11 @@ class RWKVModel:
assert os.path.isfile(model_path), f'{model_path} is not a file'
assert thread_count > 0, 'Thread count must be positive'
assert gpu_layers_count > 0, 'GPU layers count must be positive'
self._library = shared_library
self._ctx = self._library.rwkv_init_from_file(model_path, thread_count)
self._ctx = self._library.rwkv_init_from_file(model_path, thread_count, gpu_layers_count)
self._state_buffer_element_count = self._library.rwkv_get_state_buffer_element_count(self._ctx)
self._logits_buffer_element_count = self._library.rwkv_get_logits_buffer_element_count(self._ctx)

View File

@ -37,7 +37,7 @@ class RWKVSharedLibrary:
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32, ctypes.c_uint32]
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
self.library.rwkv_eval.argtypes = [
@ -67,7 +67,7 @@ class RWKVSharedLibrary:
self.library.rwkv_get_system_info_string.argtypes = []
self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p
def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext:
def rwkv_init_from_file(self, model_file_path: str, thread_count: int, gpu_layers_count: int) -> RWKVContext:
"""
Loads the model from a file and prepares it for inference.
Throws an exception in case of any error. Error messages would be printed to stderr.
@ -78,9 +78,13 @@ class RWKVSharedLibrary:
Path to model file in ggml format.
thread_count : int
Count of threads to use, must be positive.
gpu_layers_count : int
Count of layers to load on gpu, must be positive only enabled with cuBLAS.
"""
ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count))
ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'),
ctypes.c_uint32(thread_count),
ctypes.c_uint32(gpu_layers_count))
assert ptr is not None, 'rwkv_init_from_file failed, check stderr'
return RWKVContext(ptr)
@ -186,6 +190,7 @@ class RWKVSharedLibrary:
return self.library.rwkv_get_system_info_string().decode('utf-8')
def load_rwkv_shared_library() -> RWKVSharedLibrary:
"""
Attempts to find rwkv.cpp shared library and load it.
@ -208,6 +213,10 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary:
f'../bin/Release/{file_name}',
# If we are in repo root directory
f'bin/Release/{file_name}',
# If we compiled in build directory
f'build/bin/Release/{file_name}',
# If we compiled in build directory
f'build/{file_name}',
# Search relative to this file
str(repo_root_dir / 'bin' / 'Release' / file_name),
# Fallback

View File

@ -1,6 +1,9 @@
function(rwkv_add_test source)
get_filename_component(TEST_TARGET ${source} NAME_WE)
add_executable(${TEST_TARGET} ${source})
if (GGML_CUDA_SOURCES)
set_property(TARGET ${TEST_TARGET} PROPERTY CUDA_ARCHITECTURES OFF)
endif()
target_link_libraries(${TEST_TARGET} PRIVATE ggml rwkv)
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
endfunction()

View File

@ -21,12 +21,12 @@
#define N_VOCAB 256
#define N_THREADS 2
#define N_GPU_LAYERS 1
void test_model(const char * model_path, const float * expected_logits, const float max_diff) {
fprintf(stderr, "Testing %s\n", model_path);
struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS);
struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS, N_GPU_LAYERS);
enum rwkv_error_flags error = rwkv_get_last_error(NULL);
ASSERT(error == 0, "Unexpected error %d", error);
@ -72,18 +72,27 @@ int main(void) {
ASSERT(elements_read == N_VOCAB, "Failed to read expected_logits.bin, read %zd elements", elements_read);
fclose(file);
// Somehow when using cuBLAS the calculation of Q4_1 may different from cpu only
float expected_difference_sum[14] = {
0.000000F,
-0.005320F,
-0.160030F,
#ifdef GGML_USE_CUBLAS
-0.412408F,
#else
-0.370606F,
#endif
-0.170404F,
0.278034F,
0.071216F,
0.154614F,
#ifdef GGML_USE_CUBLAS
-0.405527F,
#else
-0.372169F,
#endif
-0.170043F,
0.294953F,
0.065571F,