Allow creating multiple contexts per model (#83)

* Allow creating multiple contexts per model

This allows for parallel inference and I am preparing to support
sequence mode using a method similar to this

* Fix cuBLAS

* Update rwkv.h

Co-authored-by: Alex <saharNooby@users.noreply.github.com>

* Update rwkv.cpp

Co-authored-by: Alex <saharNooby@users.noreply.github.com>

* Inherit print_errors from parent ctx when cloning

* Add context cloning test

* Free

* Free ggml context when last rwkv_context is freed

* Free before exit

* int main

* add explanation of ffn_key_size

* Update rwkv_instance and rwkv_context comments

* Thread safety notes

---------

Co-authored-by: Alex <saharNooby@users.noreply.github.com>
This commit is contained in:
LoganDark 2023-06-03 03:06:24 -07:00 committed by GitHub
parent 363dfb1a06
commit 3f8bb2c080
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 169 additions and 29 deletions

119
rwkv.cpp
View File

@ -476,8 +476,33 @@ struct rwkv_graph {
std::unique_ptr<struct ggml_cgraph> cgraph;
};
struct rwkv_context {
struct rwkv_ggml_guard {
struct ggml_context * ctx;
~rwkv_ggml_guard() { if (ctx) { ggml_free(ctx); } }
};
// An instance of an RWKV model loaded into memory:
// Contains all the model weights.
// Shared by one or more contexts.
struct rwkv_instance {
struct rwkv_model model;
struct rwkv_ggml_guard ctx;
std::unique_ptr<uint8_t []> scratch;
// TODO come up with a better solution to estimate "work tensor" size.
// The ggml_cgraph allocates a "work tensor" the first time it is used.
// Currently, the height of blocks.0.ffn.key.weight is the bottleneck in our implementation of RWKV.
// Since it is the largest dimension used in any matrix multiply, it is the size used for the "work tensor".
// However, if ggml changes its implementation, or rwkv.cpp changes its own implementation, at any point,
// this may become outdated. We need to find a way not to hardcode a specific tensor, but to calculate accurately.
// This may come out of a ggml issue: https://github.com/ggerganov/ggml/issues/214
size_t ffn_key_size;
};
// RWKV context for a specific instance.
// Contains the computation graph and is used for inference.
struct rwkv_context {
std::shared_ptr<struct rwkv_instance> instance;
struct ggml_context * ctx;
std::unique_ptr<uint8_t []> scratch;
struct rwkv_graph graph;
@ -860,11 +885,6 @@ struct rwkv_file_guard {
~rwkv_file_guard() { if (file) { fclose(file); } }
};
struct rwkv_ggml_guard {
struct ggml_context * ctx;
~rwkv_ggml_guard() { if (ctx) { ggml_free(ctx); } }
};
void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) {
bool * ptr = ctx ? &ctx->print_errors : &global_print_errors;
*ptr = print_errors;
@ -881,14 +901,12 @@ enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) {
return value;
}
struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) {
global_last_error = RWKV_ERROR_NONE;
bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & instance) {
FILE * file = fopen(file_path, "rb");
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file, "Failed to open file %s", file_path);
rwkv_file_guard file_guard { file };
// Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to the get file length.
// Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to get the file length.
struct stat file_stat;
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(file), &file_stat) == 0, "Failed to stat file %s", file_path);
@ -897,9 +915,10 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
size_t tensors_start = ftell(file);
struct rwkv_ctx_size ctx_size;
size_t ffn_key = 0;
std::string name;
instance.ffn_key_size = 0;
while ((size_t) ftell(file) < (size_t) file_stat.st_size) {
struct rwkv_tensor_header header;
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_tensor_header(file, header), "Invalid tensor header");
@ -907,18 +926,12 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file, rwkv_tensor_size(header), SEEK_CUR) == 0, "Failed to read tensor data");
rwkv_ctx_size_add_tensor(ctx_size, 1, 0, header);
if (ffn_key == 0 && name == "blocks.0.ffn.key.weight") {
ffn_key = header.height;
if (instance.ffn_key_size == 0 && name == "blocks.0.ffn.key.weight") {
instance.ffn_key_size = header.height;
}
}
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, ffn_key, "Model is missing parameter blocks.0.ffn.key.weight");
rwkv_ctx_size_add(ctx_size, 1, rwkv_single_graph_size(header.n_vocab, header.n_embed, header.n_layer, ffn_key));
// And finally to end it all off: the graph work tensor
enum ggml_type mul_mat_type = ggml_is_quantized(rwkv_type_to_ggml[header.data_type]) ? GGML_TYPE_Q8_1 : rwkv_type_to_ggml[header.data_type];
rwkv_ctx_size_add_objects(ctx_size, 1, sizeof(struct ggml_tensor) + rwkv_tensor_size(GGML_TYPE_I8, rwkv_tensor_size(mul_mat_type, ffn_key) * n_threads + 64 * (n_threads - 1)));
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, instance.ffn_key_size, "Model is missing parameter blocks.0.ffn.key.weight");
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file, tensors_start, SEEK_SET) == 0, "Failed to seek in file");
std::unique_ptr<uint8_t []> scratch(new(std::nothrow) uint8_t [ctx_size.scratch_size]);
@ -957,16 +970,46 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == header.n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]);
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == header.n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]);
// Don't free ggml context now
ggml_guard.ctx = NULL;
// Attach ggml context to instance
instance.ctx.ctx = ctx;
instance.model = std::move(model);
instance.scratch = std::move(scratch);
return true;
}
struct rwkv_context * rwkv_new_context_impl(std::shared_ptr<struct rwkv_instance> instance, const uint32_t n_threads) {
global_last_error = RWKV_ERROR_NONE;
struct rwkv_file_header & header = instance->model.header;
rwkv_ctx_size ctx_size;
rwkv_ctx_size_add(ctx_size, 1, rwkv_single_graph_size(header.n_vocab, header.n_embed, header.n_layer, instance->ffn_key_size));
// And finally to end it all off: the graph work tensor
enum ggml_type mul_mat_type = ggml_is_quantized(rwkv_type_to_ggml[header.data_type]) ? GGML_TYPE_Q8_1 : rwkv_type_to_ggml[header.data_type];
rwkv_ctx_size_add(ctx_size, 1, rwkv_tensor_size(GGML_TYPE_I8, rwkv_tensor_size(mul_mat_type, instance->ffn_key_size) * n_threads + 64 * (n_threads - 1)));
std::unique_ptr<uint8_t []> scratch(new(std::nothrow) uint8_t [ctx_size.scratch_size]);
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, scratch.get(), "Failed to allocate graph scratch space (%d)", ctx_size.scratch_size);
struct ggml_context * ctx = ggml_init({ ctx_size.objects_size + ctx_size.objects_count * GGML_OBJECT_SIZE, NULL, false});
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, ctx, "Failed to create GGML context");
rwkv_ggml_guard ggml_guard { ctx };
ggml_set_scratch(ctx, { 0, ctx_size.scratch_size, scratch.get() });
// Build graph
struct rwkv_graph graph;
RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_single_graph(ctx, model, n_threads, graph));
RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_single_graph(ctx, instance->model, n_threads, graph));
std::unique_ptr<struct rwkv_context> rwkv_ctx(new(std::nothrow) struct rwkv_context());
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx.get(), "Failed to allocate context");
// Don't free ggml context
ggml_guard.ctx = NULL;
rwkv_ctx->model = std::move(model);
rwkv_ctx->instance = std::move(instance);
rwkv_ctx->ctx = ctx;
rwkv_ctx->scratch = std::move(scratch);
rwkv_ctx->graph = std::move(graph);
@ -975,21 +1018,39 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
rwkv_ctx->gpu_layers = 0;
rwkv_ctx->vram_total = 0;
ggml_set_scratch(ctx, { 0, 0, NULL });
return rwkv_ctx.release();
}
struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) {
global_last_error = RWKV_ERROR_NONE;
std::shared_ptr<struct rwkv_instance> instance(new(std::nothrow) struct rwkv_instance);
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, instance.get(), "Failed to allocate instance");
RWKV_ENSURE_OR_NULL(rwkv_instance_from_file(file_path, *instance.get()));
return rwkv_new_context_impl(instance, n_threads);
}
struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads) {
struct rwkv_context * clone = rwkv_new_context_impl(ctx->instance, n_threads);
if (clone) {
clone->print_errors = ctx->print_errors;
}
return clone;
}
bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers) {
#ifdef GGML_USE_CUBLAS
{
size_t n_gpu = std::min(n_gpu_layers, ctx->model.header.n_layer);
size_t n_gpu = std::min(n_gpu_layers, ctx->instance->model.header.n_layer);
size_t gpu_layers = ((struct rwkv_context *) ctx)->gpu_layers;
size_t vram_total = ((struct rwkv_context *) ctx)->vram_total;
for (size_t i = 0; i < n_gpu; i++) {
const struct rwkv_layer & layer = ctx->model.layers[i];
const struct rwkv_layer & layer = ctx->instance->model.layers[i];
// Use cuBLAS only for heavy matrices; other operations are not supported for GPU at the moment
ggml_cuda_transform_tensor(layer.att_key); vram_total += ggml_nbytes(layer.att_key);
@ -1012,7 +1073,7 @@ bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_g
bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) {
((struct rwkv_context *) ctx)->last_error = RWKV_ERROR_NONE;
const struct rwkv_file_header & header = ctx->model.header;
const struct rwkv_file_header & header = ctx->instance->model.header;
RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < header.n_vocab, "Token is out of range 0..%d", header.n_vocab - 1);
const struct rwkv_graph & graph = ctx->graph;
@ -1055,11 +1116,11 @@ bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const floa
}
uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) {
return ctx->model.header.n_layer * 5 * ctx->model.header.n_embed;
return ctx->instance->model.header.n_layer * 5 * ctx->instance->model.header.n_embed;
}
uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) {
return ctx->model.header.n_vocab;
return ctx->instance->model.header.n_vocab;
}
void rwkv_free(struct rwkv_context * ctx) {

14
rwkv.h
View File

@ -61,6 +61,10 @@ extern "C" {
RWKV_ERROR_PARAM_MISSING = 14
};
// RWKV context that can be used for inference.
// All functions that operate on rwkv_context are thread-safe.
// rwkv_context can be sent to different threads between calls to rwkv_eval.
// There is no requirement for rwkv_context to be freed on the creating thread.
struct rwkv_context;
// Sets whether errors are automatically printed to stderr.
@ -85,11 +89,20 @@ extern "C" {
// - n_threads: count of threads to use, must be positive.
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads);
// Creates a new context from an existing one.
// This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times.
// Each rwkv_context can have one eval running at a time.
// Every rwkv_context must be freed using rwkv_free.
// - ctx: context to be cloned.
// - n_threads: count of threads to use, must be positive.
RWKV_API struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads);
// Offloads specified layers of context onto GPU using cuBLAS, if it is enabled.
// If rwkv.cpp was compiled without cuBLAS support, this function is a no-op.
RWKV_API bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers);
// Evaluates the model for a single token.
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
// Returns false on any error. Error messages would be printed to stderr.
// - token: next token index, in range 0 <= token < n_vocab.
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
@ -104,6 +117,7 @@ extern "C" {
RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx);
// Frees all allocated memory and the context.
// Does not need to be the same thread that created the rwkv_context.
RWKV_API void rwkv_free(struct rwkv_context * ctx);
// Quantizes FP32 or FP16 model to one of quantized formats.

View File

@ -14,3 +14,4 @@ file(COPY expected_logits.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
rwkv_add_test(test_ggml_basics.c)
rwkv_add_test(test_tiny_rwkv.c)
rwkv_add_test(test_context_cloning.c)

View File

@ -0,0 +1,64 @@
#include <rwkv.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
int main() {
struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2);
if (!ctx) {
enum rwkv_error_flags error = rwkv_get_last_error(NULL);
fprintf(stderr, "Unexpected error 0x%.8X\n", error);
return EXIT_FAILURE;
}
float * state = calloc(rwkv_get_state_buffer_element_count(ctx), sizeof(float));
float * logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float));
if (!state || !logits) {
fprintf(stderr, "Failed to allocate state/logits\n");
return EXIT_FAILURE;
}
// 0xd1 or 209 is space (0x20 or \u0120 in tokenizer)
const unsigned char * prompt = "hello\xd1world";
rwkv_eval(ctx, prompt[0], NULL, state, logits);
for (const unsigned char * token = prompt + 1; *token != 0; token++) {
rwkv_eval(ctx, *token, state, state, logits);
}
float * expected_logits = logits;
logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float));
if (!logits) {
fprintf(stderr, "Failed to allocate state/logits\n");
return EXIT_FAILURE;
}
struct rwkv_context * ctx2 = rwkv_clone_context(ctx, 2);
rwkv_eval(ctx, prompt[0], NULL, state, logits);
for (const unsigned char * token = prompt + 1; *token != 0; token++) {
rwkv_eval(ctx, *token, state, state, logits);
}
if (memcmp(expected_logits, logits, rwkv_get_logits_buffer_element_count(ctx) * sizeof(float))) {
fprintf(stderr, "results not identical :(\n");
return EXIT_FAILURE;
} else {
fprintf(stdout, "Results identical, success!\n");
}
rwkv_free(ctx);
rwkv_free(ctx2);
free(expected_logits);
free(logits);
free(state);
return EXIT_SUCCESS;
}