Allow creating multiple contexts per model (#83)
* Allow creating multiple contexts per model This allows for parallel inference and I am preparing to support sequence mode using a method similar to this * Fix cuBLAS * Update rwkv.h Co-authored-by: Alex <saharNooby@users.noreply.github.com> * Update rwkv.cpp Co-authored-by: Alex <saharNooby@users.noreply.github.com> * Inherit print_errors from parent ctx when cloning * Add context cloning test * Free * Free ggml context when last rwkv_context is freed * Free before exit * int main * add explanation of ffn_key_size * Update rwkv_instance and rwkv_context comments * Thread safety notes --------- Co-authored-by: Alex <saharNooby@users.noreply.github.com>
This commit is contained in:
parent
363dfb1a06
commit
3f8bb2c080
119
rwkv.cpp
119
rwkv.cpp
|
@ -476,8 +476,33 @@ struct rwkv_graph {
|
||||||
std::unique_ptr<struct ggml_cgraph> cgraph;
|
std::unique_ptr<struct ggml_cgraph> cgraph;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct rwkv_context {
|
struct rwkv_ggml_guard {
|
||||||
|
struct ggml_context * ctx;
|
||||||
|
~rwkv_ggml_guard() { if (ctx) { ggml_free(ctx); } }
|
||||||
|
};
|
||||||
|
|
||||||
|
// An instance of an RWKV model loaded into memory:
|
||||||
|
// Contains all the model weights.
|
||||||
|
// Shared by one or more contexts.
|
||||||
|
struct rwkv_instance {
|
||||||
struct rwkv_model model;
|
struct rwkv_model model;
|
||||||
|
struct rwkv_ggml_guard ctx;
|
||||||
|
std::unique_ptr<uint8_t []> scratch;
|
||||||
|
|
||||||
|
// TODO come up with a better solution to estimate "work tensor" size.
|
||||||
|
// The ggml_cgraph allocates a "work tensor" the first time it is used.
|
||||||
|
// Currently, the height of blocks.0.ffn.key.weight is the bottleneck in our implementation of RWKV.
|
||||||
|
// Since it is the largest dimension used in any matrix multiply, it is the size used for the "work tensor".
|
||||||
|
// However, if ggml changes its implementation, or rwkv.cpp changes its own implementation, at any point,
|
||||||
|
// this may become outdated. We need to find a way not to hardcode a specific tensor, but to calculate accurately.
|
||||||
|
// This may come out of a ggml issue: https://github.com/ggerganov/ggml/issues/214
|
||||||
|
size_t ffn_key_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
// RWKV context for a specific instance.
|
||||||
|
// Contains the computation graph and is used for inference.
|
||||||
|
struct rwkv_context {
|
||||||
|
std::shared_ptr<struct rwkv_instance> instance;
|
||||||
struct ggml_context * ctx;
|
struct ggml_context * ctx;
|
||||||
std::unique_ptr<uint8_t []> scratch;
|
std::unique_ptr<uint8_t []> scratch;
|
||||||
struct rwkv_graph graph;
|
struct rwkv_graph graph;
|
||||||
|
@ -860,11 +885,6 @@ struct rwkv_file_guard {
|
||||||
~rwkv_file_guard() { if (file) { fclose(file); } }
|
~rwkv_file_guard() { if (file) { fclose(file); } }
|
||||||
};
|
};
|
||||||
|
|
||||||
struct rwkv_ggml_guard {
|
|
||||||
struct ggml_context * ctx;
|
|
||||||
~rwkv_ggml_guard() { if (ctx) { ggml_free(ctx); } }
|
|
||||||
};
|
|
||||||
|
|
||||||
void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) {
|
void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) {
|
||||||
bool * ptr = ctx ? &ctx->print_errors : &global_print_errors;
|
bool * ptr = ctx ? &ctx->print_errors : &global_print_errors;
|
||||||
*ptr = print_errors;
|
*ptr = print_errors;
|
||||||
|
@ -881,14 +901,12 @@ enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) {
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) {
|
bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & instance) {
|
||||||
global_last_error = RWKV_ERROR_NONE;
|
|
||||||
|
|
||||||
FILE * file = fopen(file_path, "rb");
|
FILE * file = fopen(file_path, "rb");
|
||||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file, "Failed to open file %s", file_path);
|
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file, "Failed to open file %s", file_path);
|
||||||
rwkv_file_guard file_guard { file };
|
rwkv_file_guard file_guard { file };
|
||||||
|
|
||||||
// Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to the get file length.
|
// Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to get the file length.
|
||||||
struct stat file_stat;
|
struct stat file_stat;
|
||||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(file), &file_stat) == 0, "Failed to stat file %s", file_path);
|
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(file), &file_stat) == 0, "Failed to stat file %s", file_path);
|
||||||
|
|
||||||
|
@ -897,9 +915,10 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
|
||||||
|
|
||||||
size_t tensors_start = ftell(file);
|
size_t tensors_start = ftell(file);
|
||||||
struct rwkv_ctx_size ctx_size;
|
struct rwkv_ctx_size ctx_size;
|
||||||
size_t ffn_key = 0;
|
|
||||||
|
|
||||||
std::string name;
|
std::string name;
|
||||||
|
instance.ffn_key_size = 0;
|
||||||
|
|
||||||
while ((size_t) ftell(file) < (size_t) file_stat.st_size) {
|
while ((size_t) ftell(file) < (size_t) file_stat.st_size) {
|
||||||
struct rwkv_tensor_header header;
|
struct rwkv_tensor_header header;
|
||||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_tensor_header(file, header), "Invalid tensor header");
|
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_tensor_header(file, header), "Invalid tensor header");
|
||||||
|
@ -907,18 +926,12 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
|
||||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file, rwkv_tensor_size(header), SEEK_CUR) == 0, "Failed to read tensor data");
|
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file, rwkv_tensor_size(header), SEEK_CUR) == 0, "Failed to read tensor data");
|
||||||
rwkv_ctx_size_add_tensor(ctx_size, 1, 0, header);
|
rwkv_ctx_size_add_tensor(ctx_size, 1, 0, header);
|
||||||
|
|
||||||
if (ffn_key == 0 && name == "blocks.0.ffn.key.weight") {
|
if (instance.ffn_key_size == 0 && name == "blocks.0.ffn.key.weight") {
|
||||||
ffn_key = header.height;
|
instance.ffn_key_size = header.height;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, ffn_key, "Model is missing parameter blocks.0.ffn.key.weight");
|
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, instance.ffn_key_size, "Model is missing parameter blocks.0.ffn.key.weight");
|
||||||
|
|
||||||
rwkv_ctx_size_add(ctx_size, 1, rwkv_single_graph_size(header.n_vocab, header.n_embed, header.n_layer, ffn_key));
|
|
||||||
// And finally to end it all off: the graph work tensor
|
|
||||||
enum ggml_type mul_mat_type = ggml_is_quantized(rwkv_type_to_ggml[header.data_type]) ? GGML_TYPE_Q8_1 : rwkv_type_to_ggml[header.data_type];
|
|
||||||
rwkv_ctx_size_add_objects(ctx_size, 1, sizeof(struct ggml_tensor) + rwkv_tensor_size(GGML_TYPE_I8, rwkv_tensor_size(mul_mat_type, ffn_key) * n_threads + 64 * (n_threads - 1)));
|
|
||||||
|
|
||||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file, tensors_start, SEEK_SET) == 0, "Failed to seek in file");
|
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file, tensors_start, SEEK_SET) == 0, "Failed to seek in file");
|
||||||
|
|
||||||
std::unique_ptr<uint8_t []> scratch(new(std::nothrow) uint8_t [ctx_size.scratch_size]);
|
std::unique_ptr<uint8_t []> scratch(new(std::nothrow) uint8_t [ctx_size.scratch_size]);
|
||||||
|
@ -957,16 +970,46 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
|
||||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == header.n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]);
|
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == header.n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]);
|
||||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == header.n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]);
|
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == header.n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]);
|
||||||
|
|
||||||
|
// Don't free ggml context now
|
||||||
|
ggml_guard.ctx = NULL;
|
||||||
|
// Attach ggml context to instance
|
||||||
|
instance.ctx.ctx = ctx;
|
||||||
|
instance.model = std::move(model);
|
||||||
|
instance.scratch = std::move(scratch);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct rwkv_context * rwkv_new_context_impl(std::shared_ptr<struct rwkv_instance> instance, const uint32_t n_threads) {
|
||||||
|
global_last_error = RWKV_ERROR_NONE;
|
||||||
|
|
||||||
|
struct rwkv_file_header & header = instance->model.header;
|
||||||
|
|
||||||
|
rwkv_ctx_size ctx_size;
|
||||||
|
rwkv_ctx_size_add(ctx_size, 1, rwkv_single_graph_size(header.n_vocab, header.n_embed, header.n_layer, instance->ffn_key_size));
|
||||||
|
// And finally to end it all off: the graph work tensor
|
||||||
|
enum ggml_type mul_mat_type = ggml_is_quantized(rwkv_type_to_ggml[header.data_type]) ? GGML_TYPE_Q8_1 : rwkv_type_to_ggml[header.data_type];
|
||||||
|
rwkv_ctx_size_add(ctx_size, 1, rwkv_tensor_size(GGML_TYPE_I8, rwkv_tensor_size(mul_mat_type, instance->ffn_key_size) * n_threads + 64 * (n_threads - 1)));
|
||||||
|
|
||||||
|
std::unique_ptr<uint8_t []> scratch(new(std::nothrow) uint8_t [ctx_size.scratch_size]);
|
||||||
|
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, scratch.get(), "Failed to allocate graph scratch space (%d)", ctx_size.scratch_size);
|
||||||
|
|
||||||
|
struct ggml_context * ctx = ggml_init({ ctx_size.objects_size + ctx_size.objects_count * GGML_OBJECT_SIZE, NULL, false});
|
||||||
|
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, ctx, "Failed to create GGML context");
|
||||||
|
rwkv_ggml_guard ggml_guard { ctx };
|
||||||
|
|
||||||
|
ggml_set_scratch(ctx, { 0, ctx_size.scratch_size, scratch.get() });
|
||||||
|
|
||||||
// Build graph
|
// Build graph
|
||||||
struct rwkv_graph graph;
|
struct rwkv_graph graph;
|
||||||
RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_single_graph(ctx, model, n_threads, graph));
|
RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_single_graph(ctx, instance->model, n_threads, graph));
|
||||||
|
|
||||||
std::unique_ptr<struct rwkv_context> rwkv_ctx(new(std::nothrow) struct rwkv_context());
|
std::unique_ptr<struct rwkv_context> rwkv_ctx(new(std::nothrow) struct rwkv_context());
|
||||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx.get(), "Failed to allocate context");
|
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx.get(), "Failed to allocate context");
|
||||||
|
|
||||||
// Don't free ggml context
|
// Don't free ggml context
|
||||||
ggml_guard.ctx = NULL;
|
ggml_guard.ctx = NULL;
|
||||||
rwkv_ctx->model = std::move(model);
|
|
||||||
|
rwkv_ctx->instance = std::move(instance);
|
||||||
rwkv_ctx->ctx = ctx;
|
rwkv_ctx->ctx = ctx;
|
||||||
rwkv_ctx->scratch = std::move(scratch);
|
rwkv_ctx->scratch = std::move(scratch);
|
||||||
rwkv_ctx->graph = std::move(graph);
|
rwkv_ctx->graph = std::move(graph);
|
||||||
|
@ -975,21 +1018,39 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
|
||||||
rwkv_ctx->gpu_layers = 0;
|
rwkv_ctx->gpu_layers = 0;
|
||||||
rwkv_ctx->vram_total = 0;
|
rwkv_ctx->vram_total = 0;
|
||||||
|
|
||||||
ggml_set_scratch(ctx, { 0, 0, NULL });
|
|
||||||
|
|
||||||
return rwkv_ctx.release();
|
return rwkv_ctx.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) {
|
||||||
|
global_last_error = RWKV_ERROR_NONE;
|
||||||
|
|
||||||
|
std::shared_ptr<struct rwkv_instance> instance(new(std::nothrow) struct rwkv_instance);
|
||||||
|
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, instance.get(), "Failed to allocate instance");
|
||||||
|
RWKV_ENSURE_OR_NULL(rwkv_instance_from_file(file_path, *instance.get()));
|
||||||
|
|
||||||
|
return rwkv_new_context_impl(instance, n_threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads) {
|
||||||
|
struct rwkv_context * clone = rwkv_new_context_impl(ctx->instance, n_threads);
|
||||||
|
|
||||||
|
if (clone) {
|
||||||
|
clone->print_errors = ctx->print_errors;
|
||||||
|
}
|
||||||
|
|
||||||
|
return clone;
|
||||||
|
}
|
||||||
|
|
||||||
bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers) {
|
bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers) {
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
{
|
{
|
||||||
size_t n_gpu = std::min(n_gpu_layers, ctx->model.header.n_layer);
|
size_t n_gpu = std::min(n_gpu_layers, ctx->instance->model.header.n_layer);
|
||||||
|
|
||||||
size_t gpu_layers = ((struct rwkv_context *) ctx)->gpu_layers;
|
size_t gpu_layers = ((struct rwkv_context *) ctx)->gpu_layers;
|
||||||
size_t vram_total = ((struct rwkv_context *) ctx)->vram_total;
|
size_t vram_total = ((struct rwkv_context *) ctx)->vram_total;
|
||||||
|
|
||||||
for (size_t i = 0; i < n_gpu; i++) {
|
for (size_t i = 0; i < n_gpu; i++) {
|
||||||
const struct rwkv_layer & layer = ctx->model.layers[i];
|
const struct rwkv_layer & layer = ctx->instance->model.layers[i];
|
||||||
|
|
||||||
// Use cuBLAS only for heavy matrices; other operations are not supported for GPU at the moment
|
// Use cuBLAS only for heavy matrices; other operations are not supported for GPU at the moment
|
||||||
ggml_cuda_transform_tensor(layer.att_key); vram_total += ggml_nbytes(layer.att_key);
|
ggml_cuda_transform_tensor(layer.att_key); vram_total += ggml_nbytes(layer.att_key);
|
||||||
|
@ -1012,7 +1073,7 @@ bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_g
|
||||||
bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) {
|
bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) {
|
||||||
((struct rwkv_context *) ctx)->last_error = RWKV_ERROR_NONE;
|
((struct rwkv_context *) ctx)->last_error = RWKV_ERROR_NONE;
|
||||||
|
|
||||||
const struct rwkv_file_header & header = ctx->model.header;
|
const struct rwkv_file_header & header = ctx->instance->model.header;
|
||||||
RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < header.n_vocab, "Token is out of range 0..%d", header.n_vocab - 1);
|
RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < header.n_vocab, "Token is out of range 0..%d", header.n_vocab - 1);
|
||||||
|
|
||||||
const struct rwkv_graph & graph = ctx->graph;
|
const struct rwkv_graph & graph = ctx->graph;
|
||||||
|
@ -1055,11 +1116,11 @@ bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const floa
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) {
|
uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) {
|
||||||
return ctx->model.header.n_layer * 5 * ctx->model.header.n_embed;
|
return ctx->instance->model.header.n_layer * 5 * ctx->instance->model.header.n_embed;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) {
|
uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) {
|
||||||
return ctx->model.header.n_vocab;
|
return ctx->instance->model.header.n_vocab;
|
||||||
}
|
}
|
||||||
|
|
||||||
void rwkv_free(struct rwkv_context * ctx) {
|
void rwkv_free(struct rwkv_context * ctx) {
|
||||||
|
|
14
rwkv.h
14
rwkv.h
|
@ -61,6 +61,10 @@ extern "C" {
|
||||||
RWKV_ERROR_PARAM_MISSING = 14
|
RWKV_ERROR_PARAM_MISSING = 14
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// RWKV context that can be used for inference.
|
||||||
|
// All functions that operate on rwkv_context are thread-safe.
|
||||||
|
// rwkv_context can be sent to different threads between calls to rwkv_eval.
|
||||||
|
// There is no requirement for rwkv_context to be freed on the creating thread.
|
||||||
struct rwkv_context;
|
struct rwkv_context;
|
||||||
|
|
||||||
// Sets whether errors are automatically printed to stderr.
|
// Sets whether errors are automatically printed to stderr.
|
||||||
|
@ -85,11 +89,20 @@ extern "C" {
|
||||||
// - n_threads: count of threads to use, must be positive.
|
// - n_threads: count of threads to use, must be positive.
|
||||||
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads);
|
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads);
|
||||||
|
|
||||||
|
// Creates a new context from an existing one.
|
||||||
|
// This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times.
|
||||||
|
// Each rwkv_context can have one eval running at a time.
|
||||||
|
// Every rwkv_context must be freed using rwkv_free.
|
||||||
|
// - ctx: context to be cloned.
|
||||||
|
// - n_threads: count of threads to use, must be positive.
|
||||||
|
RWKV_API struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads);
|
||||||
|
|
||||||
// Offloads specified layers of context onto GPU using cuBLAS, if it is enabled.
|
// Offloads specified layers of context onto GPU using cuBLAS, if it is enabled.
|
||||||
// If rwkv.cpp was compiled without cuBLAS support, this function is a no-op.
|
// If rwkv.cpp was compiled without cuBLAS support, this function is a no-op.
|
||||||
RWKV_API bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers);
|
RWKV_API bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers);
|
||||||
|
|
||||||
// Evaluates the model for a single token.
|
// Evaluates the model for a single token.
|
||||||
|
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
|
||||||
// Returns false on any error. Error messages would be printed to stderr.
|
// Returns false on any error. Error messages would be printed to stderr.
|
||||||
// - token: next token index, in range 0 <= token < n_vocab.
|
// - token: next token index, in range 0 <= token < n_vocab.
|
||||||
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
|
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
|
||||||
|
@ -104,6 +117,7 @@ extern "C" {
|
||||||
RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx);
|
RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx);
|
||||||
|
|
||||||
// Frees all allocated memory and the context.
|
// Frees all allocated memory and the context.
|
||||||
|
// Does not need to be the same thread that created the rwkv_context.
|
||||||
RWKV_API void rwkv_free(struct rwkv_context * ctx);
|
RWKV_API void rwkv_free(struct rwkv_context * ctx);
|
||||||
|
|
||||||
// Quantizes FP32 or FP16 model to one of quantized formats.
|
// Quantizes FP32 or FP16 model to one of quantized formats.
|
||||||
|
|
|
@ -14,3 +14,4 @@ file(COPY expected_logits.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
|
||||||
|
|
||||||
rwkv_add_test(test_ggml_basics.c)
|
rwkv_add_test(test_ggml_basics.c)
|
||||||
rwkv_add_test(test_tiny_rwkv.c)
|
rwkv_add_test(test_tiny_rwkv.c)
|
||||||
|
rwkv_add_test(test_context_cloning.c)
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
#include <rwkv.h>
|
||||||
|
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2);
|
||||||
|
|
||||||
|
if (!ctx) {
|
||||||
|
enum rwkv_error_flags error = rwkv_get_last_error(NULL);
|
||||||
|
fprintf(stderr, "Unexpected error 0x%.8X\n", error);
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
|
||||||
|
float * state = calloc(rwkv_get_state_buffer_element_count(ctx), sizeof(float));
|
||||||
|
float * logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float));
|
||||||
|
|
||||||
|
if (!state || !logits) {
|
||||||
|
fprintf(stderr, "Failed to allocate state/logits\n");
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 0xd1 or 209 is space (0x20 or \u0120 in tokenizer)
|
||||||
|
const unsigned char * prompt = "hello\xd1world";
|
||||||
|
|
||||||
|
rwkv_eval(ctx, prompt[0], NULL, state, logits);
|
||||||
|
|
||||||
|
for (const unsigned char * token = prompt + 1; *token != 0; token++) {
|
||||||
|
rwkv_eval(ctx, *token, state, state, logits);
|
||||||
|
}
|
||||||
|
|
||||||
|
float * expected_logits = logits;
|
||||||
|
logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float));
|
||||||
|
|
||||||
|
if (!logits) {
|
||||||
|
fprintf(stderr, "Failed to allocate state/logits\n");
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct rwkv_context * ctx2 = rwkv_clone_context(ctx, 2);
|
||||||
|
|
||||||
|
rwkv_eval(ctx, prompt[0], NULL, state, logits);
|
||||||
|
|
||||||
|
for (const unsigned char * token = prompt + 1; *token != 0; token++) {
|
||||||
|
rwkv_eval(ctx, *token, state, state, logits);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (memcmp(expected_logits, logits, rwkv_get_logits_buffer_element_count(ctx) * sizeof(float))) {
|
||||||
|
fprintf(stderr, "results not identical :(\n");
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
} else {
|
||||||
|
fprintf(stdout, "Results identical, success!\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
rwkv_free(ctx);
|
||||||
|
rwkv_free(ctx2);
|
||||||
|
|
||||||
|
free(expected_logits);
|
||||||
|
free(logits);
|
||||||
|
free(state);
|
||||||
|
|
||||||
|
return EXIT_SUCCESS;
|
||||||
|
}
|
Loading…
Reference in New Issue