diff --git a/rwkv.cpp b/rwkv.cpp index eff1807..f91fa4f 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -476,8 +476,33 @@ struct rwkv_graph { std::unique_ptr cgraph; }; -struct rwkv_context { +struct rwkv_ggml_guard { + struct ggml_context * ctx; + ~rwkv_ggml_guard() { if (ctx) { ggml_free(ctx); } } +}; + +// An instance of an RWKV model loaded into memory: +// Contains all the model weights. +// Shared by one or more contexts. +struct rwkv_instance { struct rwkv_model model; + struct rwkv_ggml_guard ctx; + std::unique_ptr scratch; + + // TODO come up with a better solution to estimate "work tensor" size. + // The ggml_cgraph allocates a "work tensor" the first time it is used. + // Currently, the height of blocks.0.ffn.key.weight is the bottleneck in our implementation of RWKV. + // Since it is the largest dimension used in any matrix multiply, it is the size used for the "work tensor". + // However, if ggml changes its implementation, or rwkv.cpp changes its own implementation, at any point, + // this may become outdated. We need to find a way not to hardcode a specific tensor, but to calculate accurately. + // This may come out of a ggml issue: https://github.com/ggerganov/ggml/issues/214 + size_t ffn_key_size; +}; + +// RWKV context for a specific instance. +// Contains the computation graph and is used for inference. +struct rwkv_context { + std::shared_ptr instance; struct ggml_context * ctx; std::unique_ptr scratch; struct rwkv_graph graph; @@ -860,11 +885,6 @@ struct rwkv_file_guard { ~rwkv_file_guard() { if (file) { fclose(file); } } }; -struct rwkv_ggml_guard { - struct ggml_context * ctx; - ~rwkv_ggml_guard() { if (ctx) { ggml_free(ctx); } } -}; - void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) { bool * ptr = ctx ? &ctx->print_errors : &global_print_errors; *ptr = print_errors; @@ -881,14 +901,12 @@ enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) { return value; } -struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) { - global_last_error = RWKV_ERROR_NONE; - +bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & instance) { FILE * file = fopen(file_path, "rb"); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file, "Failed to open file %s", file_path); rwkv_file_guard file_guard { file }; - // Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to the get file length. + // Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to get the file length. struct stat file_stat; RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(file), &file_stat) == 0, "Failed to stat file %s", file_path); @@ -897,9 +915,10 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t size_t tensors_start = ftell(file); struct rwkv_ctx_size ctx_size; - size_t ffn_key = 0; std::string name; + instance.ffn_key_size = 0; + while ((size_t) ftell(file) < (size_t) file_stat.st_size) { struct rwkv_tensor_header header; RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_tensor_header(file, header), "Invalid tensor header"); @@ -907,18 +926,12 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file, rwkv_tensor_size(header), SEEK_CUR) == 0, "Failed to read tensor data"); rwkv_ctx_size_add_tensor(ctx_size, 1, 0, header); - if (ffn_key == 0 && name == "blocks.0.ffn.key.weight") { - ffn_key = header.height; + if (instance.ffn_key_size == 0 && name == "blocks.0.ffn.key.weight") { + instance.ffn_key_size = header.height; } } - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, ffn_key, "Model is missing parameter blocks.0.ffn.key.weight"); - - rwkv_ctx_size_add(ctx_size, 1, rwkv_single_graph_size(header.n_vocab, header.n_embed, header.n_layer, ffn_key)); - // And finally to end it all off: the graph work tensor - enum ggml_type mul_mat_type = ggml_is_quantized(rwkv_type_to_ggml[header.data_type]) ? GGML_TYPE_Q8_1 : rwkv_type_to_ggml[header.data_type]; - rwkv_ctx_size_add_objects(ctx_size, 1, sizeof(struct ggml_tensor) + rwkv_tensor_size(GGML_TYPE_I8, rwkv_tensor_size(mul_mat_type, ffn_key) * n_threads + 64 * (n_threads - 1))); - + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, instance.ffn_key_size, "Model is missing parameter blocks.0.ffn.key.weight"); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file, tensors_start, SEEK_SET) == 0, "Failed to seek in file"); std::unique_ptr scratch(new(std::nothrow) uint8_t [ctx_size.scratch_size]); @@ -957,16 +970,46 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == header.n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == header.n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]); + // Don't free ggml context now + ggml_guard.ctx = NULL; + // Attach ggml context to instance + instance.ctx.ctx = ctx; + instance.model = std::move(model); + instance.scratch = std::move(scratch); + return true; +} + +struct rwkv_context * rwkv_new_context_impl(std::shared_ptr instance, const uint32_t n_threads) { + global_last_error = RWKV_ERROR_NONE; + + struct rwkv_file_header & header = instance->model.header; + + rwkv_ctx_size ctx_size; + rwkv_ctx_size_add(ctx_size, 1, rwkv_single_graph_size(header.n_vocab, header.n_embed, header.n_layer, instance->ffn_key_size)); + // And finally to end it all off: the graph work tensor + enum ggml_type mul_mat_type = ggml_is_quantized(rwkv_type_to_ggml[header.data_type]) ? GGML_TYPE_Q8_1 : rwkv_type_to_ggml[header.data_type]; + rwkv_ctx_size_add(ctx_size, 1, rwkv_tensor_size(GGML_TYPE_I8, rwkv_tensor_size(mul_mat_type, instance->ffn_key_size) * n_threads + 64 * (n_threads - 1))); + + std::unique_ptr scratch(new(std::nothrow) uint8_t [ctx_size.scratch_size]); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, scratch.get(), "Failed to allocate graph scratch space (%d)", ctx_size.scratch_size); + + struct ggml_context * ctx = ggml_init({ ctx_size.objects_size + ctx_size.objects_count * GGML_OBJECT_SIZE, NULL, false}); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, ctx, "Failed to create GGML context"); + rwkv_ggml_guard ggml_guard { ctx }; + + ggml_set_scratch(ctx, { 0, ctx_size.scratch_size, scratch.get() }); + // Build graph struct rwkv_graph graph; - RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_single_graph(ctx, model, n_threads, graph)); + RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_single_graph(ctx, instance->model, n_threads, graph)); std::unique_ptr rwkv_ctx(new(std::nothrow) struct rwkv_context()); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx.get(), "Failed to allocate context"); // Don't free ggml context ggml_guard.ctx = NULL; - rwkv_ctx->model = std::move(model); + + rwkv_ctx->instance = std::move(instance); rwkv_ctx->ctx = ctx; rwkv_ctx->scratch = std::move(scratch); rwkv_ctx->graph = std::move(graph); @@ -975,21 +1018,39 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t rwkv_ctx->gpu_layers = 0; rwkv_ctx->vram_total = 0; - ggml_set_scratch(ctx, { 0, 0, NULL }); - return rwkv_ctx.release(); } +struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) { + global_last_error = RWKV_ERROR_NONE; + + std::shared_ptr instance(new(std::nothrow) struct rwkv_instance); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, instance.get(), "Failed to allocate instance"); + RWKV_ENSURE_OR_NULL(rwkv_instance_from_file(file_path, *instance.get())); + + return rwkv_new_context_impl(instance, n_threads); +} + +struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads) { + struct rwkv_context * clone = rwkv_new_context_impl(ctx->instance, n_threads); + + if (clone) { + clone->print_errors = ctx->print_errors; + } + + return clone; +} + bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers) { #ifdef GGML_USE_CUBLAS { - size_t n_gpu = std::min(n_gpu_layers, ctx->model.header.n_layer); + size_t n_gpu = std::min(n_gpu_layers, ctx->instance->model.header.n_layer); size_t gpu_layers = ((struct rwkv_context *) ctx)->gpu_layers; size_t vram_total = ((struct rwkv_context *) ctx)->vram_total; for (size_t i = 0; i < n_gpu; i++) { - const struct rwkv_layer & layer = ctx->model.layers[i]; + const struct rwkv_layer & layer = ctx->instance->model.layers[i]; // Use cuBLAS only for heavy matrices; other operations are not supported for GPU at the moment ggml_cuda_transform_tensor(layer.att_key); vram_total += ggml_nbytes(layer.att_key); @@ -1012,7 +1073,7 @@ bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_g bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) { ((struct rwkv_context *) ctx)->last_error = RWKV_ERROR_NONE; - const struct rwkv_file_header & header = ctx->model.header; + const struct rwkv_file_header & header = ctx->instance->model.header; RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < header.n_vocab, "Token is out of range 0..%d", header.n_vocab - 1); const struct rwkv_graph & graph = ctx->graph; @@ -1055,11 +1116,11 @@ bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const floa } uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) { - return ctx->model.header.n_layer * 5 * ctx->model.header.n_embed; + return ctx->instance->model.header.n_layer * 5 * ctx->instance->model.header.n_embed; } uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) { - return ctx->model.header.n_vocab; + return ctx->instance->model.header.n_vocab; } void rwkv_free(struct rwkv_context * ctx) { diff --git a/rwkv.h b/rwkv.h index 5e8c756..b62b40c 100644 --- a/rwkv.h +++ b/rwkv.h @@ -61,6 +61,10 @@ extern "C" { RWKV_ERROR_PARAM_MISSING = 14 }; + // RWKV context that can be used for inference. + // All functions that operate on rwkv_context are thread-safe. + // rwkv_context can be sent to different threads between calls to rwkv_eval. + // There is no requirement for rwkv_context to be freed on the creating thread. struct rwkv_context; // Sets whether errors are automatically printed to stderr. @@ -85,11 +89,20 @@ extern "C" { // - n_threads: count of threads to use, must be positive. RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads); + // Creates a new context from an existing one. + // This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times. + // Each rwkv_context can have one eval running at a time. + // Every rwkv_context must be freed using rwkv_free. + // - ctx: context to be cloned. + // - n_threads: count of threads to use, must be positive. + RWKV_API struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads); + // Offloads specified layers of context onto GPU using cuBLAS, if it is enabled. // If rwkv.cpp was compiled without cuBLAS support, this function is a no-op. RWKV_API bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers); // Evaluates the model for a single token. + // Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. // Returns false on any error. Error messages would be printed to stderr. // - token: next token index, in range 0 <= token < n_vocab. // - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass. @@ -104,6 +117,7 @@ extern "C" { RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx); // Frees all allocated memory and the context. + // Does not need to be the same thread that created the rwkv_context. RWKV_API void rwkv_free(struct rwkv_context * ctx); // Quantizes FP32 or FP16 model to one of quantized formats. diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 2a4e4b3..d176f7b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -14,3 +14,4 @@ file(COPY expected_logits.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) rwkv_add_test(test_ggml_basics.c) rwkv_add_test(test_tiny_rwkv.c) +rwkv_add_test(test_context_cloning.c) diff --git a/tests/test_context_cloning.c b/tests/test_context_cloning.c new file mode 100644 index 0000000..9585f16 --- /dev/null +++ b/tests/test_context_cloning.c @@ -0,0 +1,64 @@ +#include + +#include +#include +#include + +int main() { + struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); + + if (!ctx) { + enum rwkv_error_flags error = rwkv_get_last_error(NULL); + fprintf(stderr, "Unexpected error 0x%.8X\n", error); + return EXIT_FAILURE; + } + + float * state = calloc(rwkv_get_state_buffer_element_count(ctx), sizeof(float)); + float * logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float)); + + if (!state || !logits) { + fprintf(stderr, "Failed to allocate state/logits\n"); + return EXIT_FAILURE; + } + + // 0xd1 or 209 is space (0x20 or \u0120 in tokenizer) + const unsigned char * prompt = "hello\xd1world"; + + rwkv_eval(ctx, prompt[0], NULL, state, logits); + + for (const unsigned char * token = prompt + 1; *token != 0; token++) { + rwkv_eval(ctx, *token, state, state, logits); + } + + float * expected_logits = logits; + logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float)); + + if (!logits) { + fprintf(stderr, "Failed to allocate state/logits\n"); + return EXIT_FAILURE; + } + + struct rwkv_context * ctx2 = rwkv_clone_context(ctx, 2); + + rwkv_eval(ctx, prompt[0], NULL, state, logits); + + for (const unsigned char * token = prompt + 1; *token != 0; token++) { + rwkv_eval(ctx, *token, state, state, logits); + } + + if (memcmp(expected_logits, logits, rwkv_get_logits_buffer_element_count(ctx) * sizeof(float))) { + fprintf(stderr, "results not identical :(\n"); + return EXIT_FAILURE; + } else { + fprintf(stdout, "Results identical, success!\n"); + } + + rwkv_free(ctx); + rwkv_free(ctx2); + + free(expected_logits); + free(logits); + free(state); + + return EXIT_SUCCESS; +} \ No newline at end of file