Allow creating multiple contexts per model (#83)
* Allow creating multiple contexts per model This allows for parallel inference and I am preparing to support sequence mode using a method similar to this * Fix cuBLAS * Update rwkv.h Co-authored-by: Alex <saharNooby@users.noreply.github.com> * Update rwkv.cpp Co-authored-by: Alex <saharNooby@users.noreply.github.com> * Inherit print_errors from parent ctx when cloning * Add context cloning test * Free * Free ggml context when last rwkv_context is freed * Free before exit * int main * add explanation of ffn_key_size * Update rwkv_instance and rwkv_context comments * Thread safety notes --------- Co-authored-by: Alex <saharNooby@users.noreply.github.com>
This commit is contained in:
parent
363dfb1a06
commit
3f8bb2c080
119
rwkv.cpp
119
rwkv.cpp
|
@ -476,8 +476,33 @@ struct rwkv_graph {
|
|||
std::unique_ptr<struct ggml_cgraph> cgraph;
|
||||
};
|
||||
|
||||
struct rwkv_context {
|
||||
struct rwkv_ggml_guard {
|
||||
struct ggml_context * ctx;
|
||||
~rwkv_ggml_guard() { if (ctx) { ggml_free(ctx); } }
|
||||
};
|
||||
|
||||
// An instance of an RWKV model loaded into memory:
|
||||
// Contains all the model weights.
|
||||
// Shared by one or more contexts.
|
||||
struct rwkv_instance {
|
||||
struct rwkv_model model;
|
||||
struct rwkv_ggml_guard ctx;
|
||||
std::unique_ptr<uint8_t []> scratch;
|
||||
|
||||
// TODO come up with a better solution to estimate "work tensor" size.
|
||||
// The ggml_cgraph allocates a "work tensor" the first time it is used.
|
||||
// Currently, the height of blocks.0.ffn.key.weight is the bottleneck in our implementation of RWKV.
|
||||
// Since it is the largest dimension used in any matrix multiply, it is the size used for the "work tensor".
|
||||
// However, if ggml changes its implementation, or rwkv.cpp changes its own implementation, at any point,
|
||||
// this may become outdated. We need to find a way not to hardcode a specific tensor, but to calculate accurately.
|
||||
// This may come out of a ggml issue: https://github.com/ggerganov/ggml/issues/214
|
||||
size_t ffn_key_size;
|
||||
};
|
||||
|
||||
// RWKV context for a specific instance.
|
||||
// Contains the computation graph and is used for inference.
|
||||
struct rwkv_context {
|
||||
std::shared_ptr<struct rwkv_instance> instance;
|
||||
struct ggml_context * ctx;
|
||||
std::unique_ptr<uint8_t []> scratch;
|
||||
struct rwkv_graph graph;
|
||||
|
@ -860,11 +885,6 @@ struct rwkv_file_guard {
|
|||
~rwkv_file_guard() { if (file) { fclose(file); } }
|
||||
};
|
||||
|
||||
struct rwkv_ggml_guard {
|
||||
struct ggml_context * ctx;
|
||||
~rwkv_ggml_guard() { if (ctx) { ggml_free(ctx); } }
|
||||
};
|
||||
|
||||
void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) {
|
||||
bool * ptr = ctx ? &ctx->print_errors : &global_print_errors;
|
||||
*ptr = print_errors;
|
||||
|
@ -881,14 +901,12 @@ enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) {
|
|||
return value;
|
||||
}
|
||||
|
||||
struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) {
|
||||
global_last_error = RWKV_ERROR_NONE;
|
||||
|
||||
bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & instance) {
|
||||
FILE * file = fopen(file_path, "rb");
|
||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file, "Failed to open file %s", file_path);
|
||||
rwkv_file_guard file_guard { file };
|
||||
|
||||
// Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to the get file length.
|
||||
// Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to get the file length.
|
||||
struct stat file_stat;
|
||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(file), &file_stat) == 0, "Failed to stat file %s", file_path);
|
||||
|
||||
|
@ -897,9 +915,10 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
|
|||
|
||||
size_t tensors_start = ftell(file);
|
||||
struct rwkv_ctx_size ctx_size;
|
||||
size_t ffn_key = 0;
|
||||
|
||||
std::string name;
|
||||
instance.ffn_key_size = 0;
|
||||
|
||||
while ((size_t) ftell(file) < (size_t) file_stat.st_size) {
|
||||
struct rwkv_tensor_header header;
|
||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_tensor_header(file, header), "Invalid tensor header");
|
||||
|
@ -907,18 +926,12 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
|
|||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file, rwkv_tensor_size(header), SEEK_CUR) == 0, "Failed to read tensor data");
|
||||
rwkv_ctx_size_add_tensor(ctx_size, 1, 0, header);
|
||||
|
||||
if (ffn_key == 0 && name == "blocks.0.ffn.key.weight") {
|
||||
ffn_key = header.height;
|
||||
if (instance.ffn_key_size == 0 && name == "blocks.0.ffn.key.weight") {
|
||||
instance.ffn_key_size = header.height;
|
||||
}
|
||||
}
|
||||
|
||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, ffn_key, "Model is missing parameter blocks.0.ffn.key.weight");
|
||||
|
||||
rwkv_ctx_size_add(ctx_size, 1, rwkv_single_graph_size(header.n_vocab, header.n_embed, header.n_layer, ffn_key));
|
||||
// And finally to end it all off: the graph work tensor
|
||||
enum ggml_type mul_mat_type = ggml_is_quantized(rwkv_type_to_ggml[header.data_type]) ? GGML_TYPE_Q8_1 : rwkv_type_to_ggml[header.data_type];
|
||||
rwkv_ctx_size_add_objects(ctx_size, 1, sizeof(struct ggml_tensor) + rwkv_tensor_size(GGML_TYPE_I8, rwkv_tensor_size(mul_mat_type, ffn_key) * n_threads + 64 * (n_threads - 1)));
|
||||
|
||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, instance.ffn_key_size, "Model is missing parameter blocks.0.ffn.key.weight");
|
||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file, tensors_start, SEEK_SET) == 0, "Failed to seek in file");
|
||||
|
||||
std::unique_ptr<uint8_t []> scratch(new(std::nothrow) uint8_t [ctx_size.scratch_size]);
|
||||
|
@ -957,16 +970,46 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
|
|||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == header.n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]);
|
||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == header.n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]);
|
||||
|
||||
// Don't free ggml context now
|
||||
ggml_guard.ctx = NULL;
|
||||
// Attach ggml context to instance
|
||||
instance.ctx.ctx = ctx;
|
||||
instance.model = std::move(model);
|
||||
instance.scratch = std::move(scratch);
|
||||
return true;
|
||||
}
|
||||
|
||||
struct rwkv_context * rwkv_new_context_impl(std::shared_ptr<struct rwkv_instance> instance, const uint32_t n_threads) {
|
||||
global_last_error = RWKV_ERROR_NONE;
|
||||
|
||||
struct rwkv_file_header & header = instance->model.header;
|
||||
|
||||
rwkv_ctx_size ctx_size;
|
||||
rwkv_ctx_size_add(ctx_size, 1, rwkv_single_graph_size(header.n_vocab, header.n_embed, header.n_layer, instance->ffn_key_size));
|
||||
// And finally to end it all off: the graph work tensor
|
||||
enum ggml_type mul_mat_type = ggml_is_quantized(rwkv_type_to_ggml[header.data_type]) ? GGML_TYPE_Q8_1 : rwkv_type_to_ggml[header.data_type];
|
||||
rwkv_ctx_size_add(ctx_size, 1, rwkv_tensor_size(GGML_TYPE_I8, rwkv_tensor_size(mul_mat_type, instance->ffn_key_size) * n_threads + 64 * (n_threads - 1)));
|
||||
|
||||
std::unique_ptr<uint8_t []> scratch(new(std::nothrow) uint8_t [ctx_size.scratch_size]);
|
||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, scratch.get(), "Failed to allocate graph scratch space (%d)", ctx_size.scratch_size);
|
||||
|
||||
struct ggml_context * ctx = ggml_init({ ctx_size.objects_size + ctx_size.objects_count * GGML_OBJECT_SIZE, NULL, false});
|
||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, ctx, "Failed to create GGML context");
|
||||
rwkv_ggml_guard ggml_guard { ctx };
|
||||
|
||||
ggml_set_scratch(ctx, { 0, ctx_size.scratch_size, scratch.get() });
|
||||
|
||||
// Build graph
|
||||
struct rwkv_graph graph;
|
||||
RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_single_graph(ctx, model, n_threads, graph));
|
||||
RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_single_graph(ctx, instance->model, n_threads, graph));
|
||||
|
||||
std::unique_ptr<struct rwkv_context> rwkv_ctx(new(std::nothrow) struct rwkv_context());
|
||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx.get(), "Failed to allocate context");
|
||||
|
||||
// Don't free ggml context
|
||||
ggml_guard.ctx = NULL;
|
||||
rwkv_ctx->model = std::move(model);
|
||||
|
||||
rwkv_ctx->instance = std::move(instance);
|
||||
rwkv_ctx->ctx = ctx;
|
||||
rwkv_ctx->scratch = std::move(scratch);
|
||||
rwkv_ctx->graph = std::move(graph);
|
||||
|
@ -975,21 +1018,39 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
|
|||
rwkv_ctx->gpu_layers = 0;
|
||||
rwkv_ctx->vram_total = 0;
|
||||
|
||||
ggml_set_scratch(ctx, { 0, 0, NULL });
|
||||
|
||||
return rwkv_ctx.release();
|
||||
}
|
||||
|
||||
struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) {
|
||||
global_last_error = RWKV_ERROR_NONE;
|
||||
|
||||
std::shared_ptr<struct rwkv_instance> instance(new(std::nothrow) struct rwkv_instance);
|
||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, instance.get(), "Failed to allocate instance");
|
||||
RWKV_ENSURE_OR_NULL(rwkv_instance_from_file(file_path, *instance.get()));
|
||||
|
||||
return rwkv_new_context_impl(instance, n_threads);
|
||||
}
|
||||
|
||||
struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads) {
|
||||
struct rwkv_context * clone = rwkv_new_context_impl(ctx->instance, n_threads);
|
||||
|
||||
if (clone) {
|
||||
clone->print_errors = ctx->print_errors;
|
||||
}
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers) {
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
{
|
||||
size_t n_gpu = std::min(n_gpu_layers, ctx->model.header.n_layer);
|
||||
size_t n_gpu = std::min(n_gpu_layers, ctx->instance->model.header.n_layer);
|
||||
|
||||
size_t gpu_layers = ((struct rwkv_context *) ctx)->gpu_layers;
|
||||
size_t vram_total = ((struct rwkv_context *) ctx)->vram_total;
|
||||
|
||||
for (size_t i = 0; i < n_gpu; i++) {
|
||||
const struct rwkv_layer & layer = ctx->model.layers[i];
|
||||
const struct rwkv_layer & layer = ctx->instance->model.layers[i];
|
||||
|
||||
// Use cuBLAS only for heavy matrices; other operations are not supported for GPU at the moment
|
||||
ggml_cuda_transform_tensor(layer.att_key); vram_total += ggml_nbytes(layer.att_key);
|
||||
|
@ -1012,7 +1073,7 @@ bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_g
|
|||
bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) {
|
||||
((struct rwkv_context *) ctx)->last_error = RWKV_ERROR_NONE;
|
||||
|
||||
const struct rwkv_file_header & header = ctx->model.header;
|
||||
const struct rwkv_file_header & header = ctx->instance->model.header;
|
||||
RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < header.n_vocab, "Token is out of range 0..%d", header.n_vocab - 1);
|
||||
|
||||
const struct rwkv_graph & graph = ctx->graph;
|
||||
|
@ -1055,11 +1116,11 @@ bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const floa
|
|||
}
|
||||
|
||||
uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) {
|
||||
return ctx->model.header.n_layer * 5 * ctx->model.header.n_embed;
|
||||
return ctx->instance->model.header.n_layer * 5 * ctx->instance->model.header.n_embed;
|
||||
}
|
||||
|
||||
uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) {
|
||||
return ctx->model.header.n_vocab;
|
||||
return ctx->instance->model.header.n_vocab;
|
||||
}
|
||||
|
||||
void rwkv_free(struct rwkv_context * ctx) {
|
||||
|
|
14
rwkv.h
14
rwkv.h
|
@ -61,6 +61,10 @@ extern "C" {
|
|||
RWKV_ERROR_PARAM_MISSING = 14
|
||||
};
|
||||
|
||||
// RWKV context that can be used for inference.
|
||||
// All functions that operate on rwkv_context are thread-safe.
|
||||
// rwkv_context can be sent to different threads between calls to rwkv_eval.
|
||||
// There is no requirement for rwkv_context to be freed on the creating thread.
|
||||
struct rwkv_context;
|
||||
|
||||
// Sets whether errors are automatically printed to stderr.
|
||||
|
@ -85,11 +89,20 @@ extern "C" {
|
|||
// - n_threads: count of threads to use, must be positive.
|
||||
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads);
|
||||
|
||||
// Creates a new context from an existing one.
|
||||
// This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times.
|
||||
// Each rwkv_context can have one eval running at a time.
|
||||
// Every rwkv_context must be freed using rwkv_free.
|
||||
// - ctx: context to be cloned.
|
||||
// - n_threads: count of threads to use, must be positive.
|
||||
RWKV_API struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads);
|
||||
|
||||
// Offloads specified layers of context onto GPU using cuBLAS, if it is enabled.
|
||||
// If rwkv.cpp was compiled without cuBLAS support, this function is a no-op.
|
||||
RWKV_API bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers);
|
||||
|
||||
// Evaluates the model for a single token.
|
||||
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
|
||||
// Returns false on any error. Error messages would be printed to stderr.
|
||||
// - token: next token index, in range 0 <= token < n_vocab.
|
||||
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
|
||||
|
@ -104,6 +117,7 @@ extern "C" {
|
|||
RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx);
|
||||
|
||||
// Frees all allocated memory and the context.
|
||||
// Does not need to be the same thread that created the rwkv_context.
|
||||
RWKV_API void rwkv_free(struct rwkv_context * ctx);
|
||||
|
||||
// Quantizes FP32 or FP16 model to one of quantized formats.
|
||||
|
|
|
@ -14,3 +14,4 @@ file(COPY expected_logits.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
|
|||
|
||||
rwkv_add_test(test_ggml_basics.c)
|
||||
rwkv_add_test(test_tiny_rwkv.c)
|
||||
rwkv_add_test(test_context_cloning.c)
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
#include <rwkv.h>
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
int main() {
|
||||
struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2);
|
||||
|
||||
if (!ctx) {
|
||||
enum rwkv_error_flags error = rwkv_get_last_error(NULL);
|
||||
fprintf(stderr, "Unexpected error 0x%.8X\n", error);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
float * state = calloc(rwkv_get_state_buffer_element_count(ctx), sizeof(float));
|
||||
float * logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float));
|
||||
|
||||
if (!state || !logits) {
|
||||
fprintf(stderr, "Failed to allocate state/logits\n");
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
// 0xd1 or 209 is space (0x20 or \u0120 in tokenizer)
|
||||
const unsigned char * prompt = "hello\xd1world";
|
||||
|
||||
rwkv_eval(ctx, prompt[0], NULL, state, logits);
|
||||
|
||||
for (const unsigned char * token = prompt + 1; *token != 0; token++) {
|
||||
rwkv_eval(ctx, *token, state, state, logits);
|
||||
}
|
||||
|
||||
float * expected_logits = logits;
|
||||
logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float));
|
||||
|
||||
if (!logits) {
|
||||
fprintf(stderr, "Failed to allocate state/logits\n");
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
struct rwkv_context * ctx2 = rwkv_clone_context(ctx, 2);
|
||||
|
||||
rwkv_eval(ctx, prompt[0], NULL, state, logits);
|
||||
|
||||
for (const unsigned char * token = prompt + 1; *token != 0; token++) {
|
||||
rwkv_eval(ctx, *token, state, state, logits);
|
||||
}
|
||||
|
||||
if (memcmp(expected_logits, logits, rwkv_get_logits_buffer_element_count(ctx) * sizeof(float))) {
|
||||
fprintf(stderr, "results not identical :(\n");
|
||||
return EXIT_FAILURE;
|
||||
} else {
|
||||
fprintf(stdout, "Results identical, success!\n");
|
||||
}
|
||||
|
||||
rwkv_free(ctx);
|
||||
rwkv_free(ctx2);
|
||||
|
||||
free(expected_logits);
|
||||
free(logits);
|
||||
free(state);
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
}
|
Loading…
Reference in New Issue