Add rwkv_set_print_errors and rwkv_get_last_error (#68)

* Add rwkv_set_print_errors and rwkv_get_last_error

Fixes #63

This allows retrieving errors from the library without having to
pipe stderr. Also it was annoying that rwkv.cpp assumed control of
the caller process by doing things like calling abort() when it
shouldn't, so I also fixed that.

The basic way this works is:

1. by default, not much is different, except more errors are caught,
   and rwkv.cpp should never abort the process or throw a C++
   exception.

2. the difference comes when you call rwkv_set_print_errors
   (working title):

   1. errors will no longer be printed to stderr automatically
   2. errors will be assigned to a thread-local variable (during
      init/quantization) or a context-local variable (during eval)
   3. the last error can be retrieved using rwkv_get_last_error

I also overhauled the assert macros so more error cases are
handled:

- the file is now closed if rwkv_init_from_file exits early
- the ggml context is freed if rwkv_init_from_file exits early
- if parameters cannot be found an error will be set about it

I also made some optimizations:

- just use fstat instead of opening the file twice
- deduplicated some code / removed edge cases that do not exist
- switched to ggml inplace operations where they exist

test_tiny_rwkv.c seems to run perfectly fine. The Python scripts
also.

The built DLL is perfectly backwards compatible with existing API
consumers like the python library, because it does not remove or
change any functions, only adds some optional ones.

The sad thing is that this will break every PR because the error
handling in this library was terrible and needed to be totally
redone. But I think it is worth it.

* Fix typo

Co-authored-by: Alex <saharNooby@users.noreply.github.com>

* Visual Studio lied and _fileno is incorrect

* Fix trailing comma in assert macros

This was an accident left over from something that didn't pan out,
some compilers do not like when function arguments have a trailing
comma.

* Include header file for fstat

* Remove uses of std::make_unique

* Fix width of format string argument on all platforms

* Use C free for smart pointers

* Revert "Use C free for smart pointers" and try nothrow

* Initialize cgraph to zero

* Fix ggml_cgraph initialization

* Zero-initialize allocations

---------

Co-authored-by: Alex <saharNooby@users.noreply.github.com>
This commit is contained in:
LoganDark 2023-05-24 04:06:52 -07:00 committed by GitHub
parent 1c363e6d5f
commit 9e2a0de843
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 418 additions and 351 deletions

588
rwkv.cpp
View File

@ -12,34 +12,92 @@
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <unordered_map> #include <unordered_map>
#include <memory>
#include <sys/stat.h> // fstat
// --- Error handling ---
thread_local enum rwkv_error_flags global_last_error = RWKV_ERROR_NONE;
thread_local bool global_print_errors = true;
inline enum rwkv_error_flags operator|(enum rwkv_error_flags a, enum rwkv_error_flags b) {
return static_cast<enum rwkv_error_flags>(static_cast<int>(a) | static_cast<int>(b));
}
inline enum rwkv_error_flags operator|=(enum rwkv_error_flags & a, enum rwkv_error_flags b) {
return a = a | b;
}
// If the condition x is false, adds ERR_VAL to the last error, prints a message to stderr, and returns RET_VAL.
#define RWKV_ASSERT_MSG(ERR_VAL, RET_VAL, x, ...) \
if (!(x)) { \
global_last_error |= ERR_VAL; \
if (global_print_errors) { \
fprintf(stderr, __VA_ARGS__); \
fprintf(stderr, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \
} \
return RET_VAL; \
}
// If the condition x is false, adds ERR_VAL to the last error, and returns RET_VAL.
#define RWKV_ASSERT(ERR_VAL, RET_VAL, x) \
if (!(x)) { \
global_last_error |= ERR_VAL; \
return RET_VAL; \
}
// If the condition x is false, adds ERR_VAL to the ctx's last error, prints a message to stderr, and returns RET_VAL.
#define RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, RET_VAL, x, ...) \
if (!(x)) { \
((struct rwkv_context *) ctx)->last_error |= ERR_VAL; \
if (ctx->print_errors) { \
fprintf(stderr, __VA_ARGS__); \
fprintf(stderr, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \
} \
return RET_VAL; \
}
// If the condition x is false, adds ERR_VAL to the ctx's last error, and returns RET_VAL.
#define RWKV_CTX_ASSERT(ctx, ERR_VAL, RET_VAL, x) \
if (!(x)) { \
ctx->last_error |= ERR_VAL; \
return RET_VAL; \
}
#define RWKV_ASSERT_FALSE_MSG(ERR_VAL, x, ...) RWKV_ASSERT_MSG(ERR_VAL, false, x, __VA_ARGS__)
#define RWKV_ASSERT_NULL_MSG(ERR_VAL, x, ...) RWKV_ASSERT_MSG(ERR_VAL, NULL, x, __VA_ARGS__)
#define RWKV_CTX_ASSERT_FALSE_MSG(ctx, ERR_VAL, x, ...) RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, false, x, __VA_ARGS__)
#define RWKV_CTX_ASSERT_NULL_MSG(ctx, ERR_VAL, x, ...) RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, NULL, x, __VA_ARGS__)
#define RWKV_ASSERT_FALSE(ERR_VAL, x) RWKV_ASSERT(ERR_VAL, false, x)
#define RWKV_ASSERT_NULL(ERR_VAL, x) RWKV_ASSERT(ERR_VAL, NULL, x)
#define RWKV_CTX_ASSERT_FALSE(ctx, ERR_VAL, x) RWKV_CTX_ASSERT(ctx, ERR_VAL, false, x)
#define RWKV_CTX_ASSERT_NULL(ctx, ERR_VAL, x) RWKV_CTX_ASSERT(ctx, ERR_VAL, NULL, x)
// --- Utilities --- // --- Utilities ---
// Checks that x is not false. If x is false, prints fancy message to stderr and returns RET_VAL.
#define RWKV_ASSERT(RET_VAL, x, ...) \
{ \
if (!(x)) { \
fprintf(stderr, __VA_ARGS__); \
fprintf(stderr, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \
return RET_VAL; \
} \
}
// Checks that x is not false. If x is false, prints fancy message to stderr and returns false.
#define RWKV_ASSERT_FALSE(x, ...) RWKV_ASSERT(false, x, __VA_ARGS__)
// Checks that x is not false. If x is false, prints fancy message to stderr and returns NULL.
#define RWKV_ASSERT_NULL(x, ...) RWKV_ASSERT(NULL, x, __VA_ARGS__)
// Reads single int32 value from a file. // Reads single int32 value from a file.
bool read_int32(FILE * file, int32_t * dest) { bool read_int32(FILE * file, int32_t * dest, const char * name) {
RWKV_ASSERT_FALSE(fread(dest, sizeof(int32_t), 1, file) == 1, "Failed to read an int32 value from a file"); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, fread(dest, sizeof(int32_t), 1, file) == 1, "Failed to read an int32 value from a file (%s)", name);
return true; return true;
} }
// Reads single uint32 value from a file. // Reads single uint32 value from a file.
bool read_uint32(FILE * file, uint32_t * dest) { bool read_uint32(FILE * file, uint32_t * dest, const char * name) {
RWKV_ASSERT_FALSE(fread(dest, sizeof(uint32_t), 1, file) == 1, "Failed to read a uint32 value from a file"); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, fread(dest, sizeof(uint32_t), 1, file) == 1, "Failed to read a uint32 value from a file (%s)", name);
return true;
}
// Writes single int32 value to a file.
bool write_int32(FILE * file, int32_t value, const char * name) {
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_WRITE, fwrite((void *) &value, sizeof(int32_t), 1, file), "Failed to write an int32 value to a file (%s)", name);
return true;
}
// Writes single uint32 value to a file.
bool write_uint32(FILE * file, uint32_t value, const char * name) {
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_WRITE, fwrite((void *) &value, sizeof(uint32_t), 1, file), "Failed to write a uint32 value to a file (%s)", name);
return true; return true;
} }
@ -123,7 +181,7 @@ struct rwkv_model {
// If the parameter was not found, returns false. // If the parameter was not found, returns false.
bool set_parameter(std::unordered_map<std::string, struct ggml_tensor *> * parameters, std::string key, struct ggml_tensor ** dest) { bool set_parameter(std::unordered_map<std::string, struct ggml_tensor *> * parameters, std::string key, struct ggml_tensor ** dest) {
struct ggml_tensor * parameter = (*parameters)[key]; struct ggml_tensor * parameter = (*parameters)[key];
RWKV_ASSERT_FALSE(parameter != NULL, "Parameter %s not found in model file", key.c_str()); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_PARAM_MISSING, parameter != NULL, "Parameter %s not found in model file", key.c_str());
*dest = parameter; *dest = parameter;
return true; return true;
} }
@ -183,66 +241,82 @@ struct ggml_tensor * rwkv_layer_norm(ggml_context * ctx, struct ggml_tensor * x,
// Looks like ggml_norm does the first part, we only need to apply weight & bias. // Looks like ggml_norm does the first part, we only need to apply weight & bias.
x = ggml_norm(ctx, x); x = ggml_norm(ctx, x);
x = ggml_mul(ctx, x, weight); x = ggml_mul(ctx, x, weight);
x = ggml_add(ctx, x, bias); x = ggml_add_inplace(ctx, x, bias);
return x; return x;
} }
// --- Implementation --- // --- Implementation ---
struct rwkv_context { struct rwkv_context {
struct rwkv_model * model; std::unique_ptr<struct rwkv_model> model;
struct ggml_tensor * token_index; struct ggml_tensor * token_index;
struct ggml_tensor * state; struct ggml_tensor * state;
struct ggml_tensor ** state_parts; struct ggml_tensor ** state_parts;
struct ggml_tensor * logits; struct ggml_tensor * logits;
struct ggml_context * ctx; struct ggml_context * ctx;
struct ggml_cgraph * graph; std::unique_ptr<struct ggml_cgraph> graph;
bool freed; enum rwkv_error_flags last_error;
bool print_errors;
};
void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) {
bool * ptr = ctx ? &ctx->print_errors : &global_print_errors;
*ptr = print_errors;
}
bool rwkv_get_print_errors(struct rwkv_context * ctx) {
return ctx ? ctx->print_errors : global_print_errors;
}
enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) {
enum rwkv_error_flags * ptr = ctx ? &ctx->last_error : &global_last_error;
enum rwkv_error_flags value = *ptr;
*ptr = RWKV_ERROR_NONE;
return value;
}
struct rwkv_file_guard {
FILE * file;
~rwkv_file_guard() { if (file) fclose(file); }
};
struct rwkv_ggml_guard {
struct ggml_context * ctx;
~rwkv_ggml_guard() { if (ctx) ggml_free(ctx); }
}; };
struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) { struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) {
global_last_error = RWKV_ERROR_NONE;
FILE* file = fopen(file_path, "rb"); FILE* file = fopen(file_path, "rb");
RWKV_ASSERT_NULL(file != NULL, "Failed to open file %s", file_path); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file, "Failed to open file %s", file_path);
rwkv_file_guard file_guard { file };
struct stat file_stat;
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(file), &file_stat) == 0, "Failed to stat file %s", file_path);
int32_t magic; int32_t magic;
read_int32(file, &magic); RWKV_ASSERT_NULL(RWKV_ERROR_FILE, read_int32(file, &magic, "magic"));
RWKV_ASSERT_NULL(magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_MAGIC, magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic);
int32_t version; int32_t version;
read_int32(file, &version); RWKV_ASSERT_NULL(RWKV_ERROR_FILE, read_int32(file, &version, "version"));
RWKV_ASSERT_NULL(version == RWKV_FILE_VERSION, "Unsupported file version %d", version); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_VERSION, version == RWKV_FILE_VERSION, "Unsupported file version %d", version);
struct rwkv_model * model = (struct rwkv_model *) calloc(1, sizeof(struct rwkv_model)); std::unique_ptr<rwkv_model> model(new(std::nothrow) struct rwkv_model());
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_ALLOC, model.get(), "Failed to allocate model");
read_uint32(file, &(model->n_vocab)); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_vocab, "n_vocab"));
read_uint32(file, &(model->n_embed)); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_embed, "n_embed"));
read_uint32(file, &(model->n_layer)); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_layer, "n_layer"));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_int32(file, &model->data_type, "data_type"));
read_int32(file, &(model->data_type)); const char* unsupported_msg = "Models in %s format cannot be loaded anymore because the format was removed. You need to quantize the model into another format";
RWKV_ASSERT_NULL(model->data_type >= 0 && model->data_type < FORMAT_TYPE_COUNT, "Unsupported model data type %d", model->data_type); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_DATA_TYPE, model->data_type >= 0 && model->data_type < FORMAT_TYPE_COUNT, "Unsupported model data type %d", model->data_type);
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 4, unsupported_msg, "Q4_1_O");
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 6, unsupported_msg, "Q4_3");
RWKV_ASSERT_NULL( size_t memory_required = file_stat.st_size +
model->data_type != 4,
"Models in Q4_1_O format cannot be loaded anymore because the format was removed. You need to quantize the model into another format"
);
RWKV_ASSERT_NULL(
model->data_type != 6,
"Models in Q4_3 format cannot be loaded anymore because the format was removed. You need to quantize the model into another format"
);
// Parameter tensors would take at least this amount in memory.
size_t file_size;
{
auto fin = std::ifstream(file_path, std::ios::binary);
RWKV_ASSERT_NULL(fin, "Failed to open file %s", file_path);
fin.seekg(0, fin.end);
file_size = fin.tellg();
fin.close();
}
size_t memory_required = file_size +
// Intermediary vectors for calculation; there are around 100 calls to ggml // Intermediary vectors for calculation; there are around 100 calls to ggml
size_t(100) * model->n_embed * sizeof(float) + size_t(100) * model->n_embed * sizeof(float) +
// State, in and out // State, in and out
@ -253,109 +327,91 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
// TODO This is too much for smaller models; need a more proper and robust way of measuring required memory // TODO This is too much for smaller models; need a more proper and robust way of measuring required memory
size_t(256) * 1024 * 1024; size_t(256) * 1024 * 1024;
// Initialize ggml struct ggml_context * ctx = ggml_init({ memory_required, NULL, false });
struct ggml_init_params params; RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_ALLOC, ctx, "Failed to allocate GGML context");
params.mem_size = memory_required; rwkv_ggml_guard ggml_guard { ctx };
params.mem_buffer = NULL;
params.no_alloc = false;
struct ggml_context * ctx = ggml_init(params);
std::unordered_map<std::string, struct ggml_tensor *> parameters; std::unordered_map<std::string, struct ggml_tensor *> parameters;
while (true) { while (true) {
int32_t dim_count; int32_t dim_count, key_length, data_type;
size_t elements_read = fread(&dim_count, 4, 1, file); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_FILE_READ, fread(&dim_count, sizeof(int32_t), 1, file) == 1 || feof(file), "Failed to read an int32 value from a file (dim_count)");
if (feof(file)) break;
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, read_int32(file, &key_length, "key_length"));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, read_int32(file, &data_type, "data_type"));
if (feof(file)) { RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, dim_count == 1 || dim_count == 2, "Unsupported dimension count %d", dim_count);
break; RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_KEY, key_length > 0, "Non-positive key length %d", key_length);
} RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, data_type >= 0 && data_type < FORMAT_TYPE_COUNT, "Unsupported parameter data type %d", data_type);
RWKV_ASSERT_NULL(elements_read == 1, "Failed to read dimension count");
RWKV_ASSERT_NULL(dim_count == 1 || dim_count == 2, "Unsupported dimension count %d", dim_count);
int32_t key_length;
read_int32(file, &key_length);
RWKV_ASSERT_NULL(key_length > 0, "Non-positive key length %d", key_length);
int32_t data_type;
read_int32(file, &data_type);
RWKV_ASSERT_NULL(data_type >= 0 && data_type < FORMAT_TYPE_COUNT, "Unsupported parameter data type %d", data_type);
ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type]; ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type];
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, ggml_data_type != GGML_TYPE_UNKNOWN, "Unsupported parameter data type %d", data_type);
RWKV_ASSERT_NULL(ggml_data_type != GGML_TYPE_UNKNOWN, "Unsupported parameter data type %d", data_type);
struct ggml_tensor * tensor; struct ggml_tensor * tensor;
int32_t x = -1;
int32_t y = -1;
if (dim_count == 1) { if (dim_count == 1) {
read_int32(file, &x); int32_t x;
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, read_int32(file, &x, "x"), "Failed to read parameter length");
tensor = ggml_new_tensor_1d(ctx, ggml_data_type, x); tensor = ggml_new_tensor_1d(ctx, ggml_data_type, x);
} else if (dim_count == 2) {
read_int32(file, &x);
read_int32(file, &y);
tensor = ggml_new_tensor_2d(ctx, ggml_data_type, x, y);
} else { } else {
abort(); int32_t x, y;
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, read_int32(file, &x, "x"), "Failed to read parameter width");
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, read_int32(file, &y, "y"), "Failed to read parameter height");
tensor = ggml_new_tensor_2d(ctx, ggml_data_type, x, y);
} }
std::string key(key_length, 0); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor");
RWKV_ASSERT_NULL(fread(&key[0], 1, key_length, file) == uint32_t(key_length), "Failed to read parameter key");
RWKV_ASSERT_NULL(fread(tensor->data, 1, ggml_nbytes(tensor), file) == ggml_nbytes(tensor), "Failed to read parameter data"); std::string key(key_length, 0);
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_KEY, fread(&key[0], key_length, 1, file) == 1, "Failed to read parameter key");
size_t nbytes = ggml_nbytes(tensor);
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, fread(tensor->data, nbytes, 1, file) == 1, "Failed to read parameter data");
parameters[key] = tensor; parameters[key] = tensor;
} }
fclose(file); file_guard = { NULL }; // close file
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(&parameters, "emb.weight", &model->emb));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(&parameters, "blocks.0.ln0.weight", &model->ln0_weight));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(&parameters, "blocks.0.ln0.bias", &model->ln0_bias));
model->layers.resize(model->n_layer); model->layers.resize(model->n_layer);
set_parameter(&parameters, "emb.weight", &(model->emb));
set_parameter(&parameters, "blocks.0.ln0.weight", &(model->ln0_weight));
set_parameter(&parameters, "blocks.0.ln0.bias", &(model->ln0_bias));
for (uint32_t i = 0; i < model->n_layer; i++) { for (uint32_t i = 0; i < model->n_layer; i++) {
rwkv_layer layer = model->layers[i]; rwkv_layer * layer = &model->layers[i];
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "ln1.weight", &layer->ln1_weight));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "ln1.bias", &layer->ln1_bias));
set_block_parameter(&parameters, i, "ln1.weight", &(layer.ln1_weight)); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "att.time_mix_k", &layer->att_time_mix_k));
set_block_parameter(&parameters, i, "ln1.bias", &(layer.ln1_bias)); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "att.time_mix_v", &layer->att_time_mix_v));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "att.time_mix_r", &layer->att_time_mix_r));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "att.time_first", &layer->att_time_first));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "att.time_decay", &layer->att_time_decay));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "att.key.weight", &layer->att_key));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "att.value.weight", &layer->att_value));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "att.receptance.weight", &layer->att_receptance));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "att.output.weight", &layer->att_output));
set_block_parameter(&parameters, i, "att.time_mix_k", &(layer.att_time_mix_k)); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "ln2.weight", &layer->ln2_weight));
set_block_parameter(&parameters, i, "att.time_mix_v", &(layer.att_time_mix_v)); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "ln2.bias", &layer->ln2_bias));
set_block_parameter(&parameters, i, "att.time_mix_r", &(layer.att_time_mix_r));
set_block_parameter(&parameters, i, "att.time_first", &(layer.att_time_first));
set_block_parameter(&parameters, i, "att.time_decay", &(layer.att_time_decay));
set_block_parameter(&parameters, i, "att.key.weight", &(layer.att_key));
set_block_parameter(&parameters, i, "att.value.weight", &(layer.att_value));
set_block_parameter(&parameters, i, "att.receptance.weight", &(layer.att_receptance));
set_block_parameter(&parameters, i, "att.output.weight", &(layer.att_output));
set_block_parameter(&parameters, i, "ln2.weight", &(layer.ln2_weight)); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "ffn.time_mix_k", &layer->ffn_time_mix_k));
set_block_parameter(&parameters, i, "ln2.bias", &(layer.ln2_bias)); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "ffn.time_mix_r", &layer->ffn_time_mix_r));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "ffn.key.weight", &layer->ffn_key));
set_block_parameter(&parameters, i, "ffn.time_mix_k", &(layer.ffn_time_mix_k)); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "ffn.value.weight", &layer->ffn_value));
set_block_parameter(&parameters, i, "ffn.time_mix_r", &(layer.ffn_time_mix_r)); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "ffn.receptance.weight", &layer->ffn_receptance));
set_block_parameter(&parameters, i, "ffn.key.weight", &(layer.ffn_key));
set_block_parameter(&parameters, i, "ffn.value.weight", &(layer.ffn_value));
set_block_parameter(&parameters, i, "ffn.receptance.weight", &(layer.ffn_receptance));
model->layers[i] = layer;
} }
set_parameter(&parameters, "ln_out.weight", &(model->ln_out_weight)); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(&parameters, "ln_out.weight", &model->ln_out_weight));
set_parameter(&parameters, "ln_out.bias", &(model->ln_out_bias)); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(&parameters, "ln_out.bias", &model->ln_out_bias));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(&parameters, "head.weight", &model->head));
set_parameter(&parameters, "head.weight", &(model->head));
// Verify order of dimensions // Verify order of dimensions
struct ggml_tensor * emb = model->emb; struct ggml_tensor * emb = model->emb;
RWKV_ASSERT_NULL(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims);
RWKV_ASSERT_NULL(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %lld", emb->ne[0]); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]);
RWKV_ASSERT_NULL(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %lld", emb->ne[1]); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]);
uint32_t n_embed = model->n_embed; uint32_t n_embed = model->n_embed;
uint32_t n_layer = model->n_layer; uint32_t n_layer = model->n_layer;
@ -385,17 +441,17 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
// xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
// xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
struct ggml_tensor * xk = ggml_add( struct ggml_tensor * xk = ggml_add_inplace(
ctx, ctx,
ggml_mul(ctx, x0, layer.att_time_mix_k), ggml_mul(ctx, x0, layer.att_time_mix_k),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k))
); );
struct ggml_tensor * xv = ggml_add( struct ggml_tensor * xv = ggml_add_inplace(
ctx, ctx,
ggml_mul(ctx, x0, layer.att_time_mix_v), ggml_mul(ctx, x0, layer.att_time_mix_v),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v))
); );
struct ggml_tensor * xr = ggml_add( struct ggml_tensor * xr = ggml_add_inplace(
ctx, ctx,
ggml_mul(ctx, x0, layer.att_time_mix_r), ggml_mul(ctx, x0, layer.att_time_mix_r),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r))
@ -429,13 +485,13 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
// e2 = torch.exp(ww - qq) // e2 = torch.exp(ww - qq)
struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq));
// a = e1 * aa + e2 * v // a = e1 * aa + e2 * v
struct ggml_tensor * a = ggml_add( struct ggml_tensor * a = ggml_add_inplace(
ctx, ctx,
ggml_mul(ctx, e1, aa), ggml_mul(ctx, e1, aa),
ggml_mul(ctx, e2, v) ggml_mul(ctx, e2, v)
); );
// b = e1 * bb + e2 // b = e1 * bb + e2
struct ggml_tensor * b = ggml_add( struct ggml_tensor * b = ggml_add_inplace(
ctx, ctx,
ggml_mul(ctx, e1, bb), ggml_mul(ctx, e1, bb),
e2 e2
@ -451,13 +507,13 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
// e2 = torch.exp(k - qq) // e2 = torch.exp(k - qq)
e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq)); e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq));
// state[5 * i + 2] = e1 * aa + e2 * v // state[5 * i + 2] = e1 * aa + e2 * v
state_parts[5 * i + 2] = ggml_add( state_parts[5 * i + 2] = ggml_add_inplace(
ctx, ctx,
ggml_mul(ctx, e1, aa), ggml_mul(ctx, e1, aa),
ggml_mul(ctx, e2, v) ggml_mul(ctx, e2, v)
); );
// state[5 * i + 3] = e1 * bb + e2 // state[5 * i + 3] = e1 * bb + e2
state_parts[5 * i + 3] = ggml_add( state_parts[5 * i + 3] = ggml_add_inplace(
ctx, ctx,
ggml_mul(ctx, e1, bb), ggml_mul(ctx, e1, bb),
e2 e2
@ -465,7 +521,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
// state[5 * i + 4] = qq // state[5 * i + 4] = qq
state_parts[5 * i + 4] = qq; state_parts[5 * i + 4] = qq;
// ow @ (r * wkv) // ow @ (r * wkv)
x = ggml_add( x = ggml_add_inplace(
ctx, ctx,
x, x,
ggml_mul_mat( ggml_mul_mat(
@ -484,12 +540,12 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * sizeof(float)); struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * sizeof(float));
// xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
// xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
struct ggml_tensor * xk = ggml_add( struct ggml_tensor * xk = ggml_add_inplace(
ctx, ctx,
ggml_mul(ctx, x0, layer.ffn_time_mix_k), ggml_mul(ctx, x0, layer.ffn_time_mix_k),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k))
); );
struct ggml_tensor * xr = ggml_add( struct ggml_tensor * xr = ggml_add_inplace(
ctx, ctx,
ggml_mul(ctx, x0, layer.ffn_time_mix_r), ggml_mul(ctx, x0, layer.ffn_time_mix_r),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r))
@ -508,7 +564,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
ggml_mul_mat(ctx, layer.ffn_key, xk) ggml_mul_mat(ctx, layer.ffn_key, xk)
)); ));
// r * (vw @ k) // r * (vw @ k)
x = ggml_add( x = ggml_add_inplace(
ctx, ctx,
x, x,
ggml_mul( ggml_mul(
@ -526,25 +582,29 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
// x = (self.w.head.weight @ x).float() // x = (self.w.head.weight @ x).float()
struct ggml_tensor * logits = ggml_mul_mat(ctx, model->head, x); struct ggml_tensor * logits = ggml_mul_mat(ctx, model->head, x);
struct ggml_cgraph * graph = (struct ggml_cgraph *) calloc(1, sizeof(struct ggml_cgraph)); std::unique_ptr<struct ggml_cgraph> graph(new(std::nothrow) struct ggml_cgraph());
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_GRAPH | RWKV_ERROR_ALLOC, graph.get(), "Failed to allocate graph");
*graph = ggml_build_forward(logits); ggml_build_forward_expand(graph.get(), logits);
for (uint32_t i = 0; i < n_layer * 5; i++) { for (uint32_t i = 0; i < n_layer * 5; i++)
ggml_build_forward_expand(graph, state_parts[i]); ggml_build_forward_expand(graph.get(), state_parts[i]);
}
graph->n_threads = n_threads; graph->n_threads = n_threads;
struct rwkv_context * rwkv_ctx = (struct rwkv_context *) calloc(1, sizeof(struct rwkv_context)); std::unique_ptr<struct rwkv_context> rwkv_ctx(new(std::nothrow) struct rwkv_context());
rwkv_ctx->model = model; RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx.get(), "Failed to allocate context");
rwkv_ctx->model = std::move(model);
rwkv_ctx->token_index = token_index; rwkv_ctx->token_index = token_index;
rwkv_ctx->state = state; rwkv_ctx->state = state;
rwkv_ctx->state_parts = state_parts; rwkv_ctx->state_parts = state_parts;
rwkv_ctx->logits = logits; rwkv_ctx->logits = logits;
rwkv_ctx->ctx = ctx; rwkv_ctx->ctx = ctx;
rwkv_ctx->graph = graph; rwkv_ctx->graph = std::move(graph);
return rwkv_ctx; rwkv_ctx->last_error = RWKV_ERROR_NONE;
rwkv_ctx->print_errors = global_print_errors;
ggml_guard.ctx = NULL; // don't free ggml context
return rwkv_ctx.release();
} }
uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) { uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) {
@ -556,21 +616,21 @@ uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) {
} }
bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) { bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) {
RWKV_ASSERT_FALSE(state_out != NULL, "state_out is NULL"); ((struct rwkv_context *) ctx)->last_error = RWKV_ERROR_NONE;
RWKV_ASSERT_FALSE(logits_out != NULL, "logits_out is NULL");
RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, state_out != NULL, "state_out is NULL");
RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, logits_out != NULL, "logits_out is NULL");
RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < ctx->model->n_vocab, "Token is out of range 0..%d", ctx->model->n_vocab - 1);
uint32_t n_layer = ctx->model->n_layer; uint32_t n_layer = ctx->model->n_layer;
uint32_t n_embed = ctx->model->n_embed; uint32_t n_embed = ctx->model->n_embed;
uint32_t n_vocab = ctx->model->n_vocab;
RWKV_ASSERT_FALSE(token < (uint32_t) n_vocab, "Token is out of range 0..%d", n_vocab - 1);
ggml_set_i32_1d(ctx->token_index, 0, token); ggml_set_i32_1d(ctx->token_index, 0, token);
if (state_in == NULL) { if (state_in == NULL) {
ggml_set_f32(ctx->state, 0.0F); ggml_set_f32(ctx->state, 0.0F);
for (uint32_t i = 0; i < n_layer; i++) { for (uint64_t i = 0; i < n_layer; i++) {
// state[5 * i + 4] = -1e30 // state[5 * i + 4] = -1e30
ggml_set_f32( ggml_set_f32(
ggml_view_1d(ctx->ctx, ctx->state, n_embed, (5 * i + 4) * n_embed * sizeof(float)), ggml_view_1d(ctx->ctx, ctx->state, n_embed, (5 * i + 4) * n_embed * sizeof(float)),
@ -581,11 +641,10 @@ bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const floa
memcpy(ctx->state->data, state_in, ctx->state->ne[0] * sizeof(float)); memcpy(ctx->state->data, state_in, ctx->state->ne[0] * sizeof(float));
} }
ggml_graph_compute(ctx->ctx, ctx->graph); ggml_graph_compute(ctx->ctx, ctx->graph.get());
for (uint32_t i = 0; i < n_layer * 5; i++) { for (uint32_t i = 0; i < n_layer * 5; i++) {
struct ggml_tensor * part = ctx->state_parts[i]; struct ggml_tensor * part = ctx->state_parts[i];
memcpy(state_out + i * n_embed, part->data, part->ne[0] * sizeof(float)); memcpy(state_out + i * n_embed, part->data, part->ne[0] * sizeof(float));
} }
@ -595,72 +654,57 @@ bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const floa
} }
void rwkv_free(struct rwkv_context * ctx) { void rwkv_free(struct rwkv_context * ctx) {
ctx->model->layers.~vector(); std::unique_ptr<struct rwkv_context> rwkv_ctx(ctx);
free(ctx->model);
delete[] ctx->state_parts; delete[] ctx->state_parts;
ggml_free(ctx->ctx); ggml_free(ctx->ctx);
free(ctx->graph);
free(ctx);
} }
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name) { bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name) {
int32_t format_type = format_name_to_format_type(format_name); global_last_error = RWKV_ERROR_NONE;
RWKV_ASSERT_FALSE(format_type != -1, "Unsupported format \"%s\"", format_name); int32_t format_data_type = format_name_to_format_type(format_name);
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_DATA_TYPE, format_data_type != -1, "Unsupported format \"%s\"", format_name);
ggml_type type = FORMAT_TYPE_TO_GGML_TYPE[format_type]; ggml_type format_ggml_type = FORMAT_TYPE_TO_GGML_TYPE[format_data_type];
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_DATA_TYPE, format_ggml_type != GGML_TYPE_UNKNOWN, "Unsupported format \"%s\"", format_name);
RWKV_ASSERT_FALSE(type != GGML_TYPE_UNKNOWN, "Unsupported format \"%s\"", format_name);
// Needed to initialize FP16 lookup table // Needed to initialize FP16 lookup table
{ ggml_free(ggml_init({ 0, NULL, false }));
struct ggml_init_params params = { 0, NULL, false };
struct ggml_context * ctx = ggml_init(params);
ggml_free(ctx);
}
printf("Loading model from '%s'\n", model_file_path_in); printf("Loading model from '%s'\n", model_file_path_in);
auto finp = std::ifstream(model_file_path_in, std::ios::binary); FILE * file_in = fopen(model_file_path_in, "rb");
RWKV_ASSERT_FALSE(finp, "Failed to open %s for reading", model_file_path_in); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file_in, "Failed to open %s for reading", model_file_path_in);
FILE * file_out = fopen(model_file_path_out, "wb");
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file_out, "Failed to open %s for writing", model_file_path_out);
auto fout = std::ofstream(model_file_path_out, std::ios::binary); rwkv_file_guard file_in_guard { file_in };
RWKV_ASSERT_FALSE(fout, "Failed to open %s for writing", model_file_path_out); rwkv_file_guard file_out_guard { file_out };
// Process header // Process header
{ {
uint32_t magic; uint32_t magic, version;
finp.read((char *) &magic, sizeof(magic)); int32_t n_vocab, n_embed, n_layer, data_type;
RWKV_ASSERT_FALSE(magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic);
fout.write((char *) &magic, sizeof(magic));
uint32_t format_version; RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_uint32(file_in, &magic, "magic"));
finp.read((char *) &format_version, sizeof(format_version)); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_MAGIC, magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic);
RWKV_ASSERT_FALSE(format_version == RWKV_FILE_VERSION, "Unsupported file version %d", format_version); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_uint32(file_in, &version, "version"));
fout.write((char *) &format_version, sizeof(format_version)); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_VERSION, version == RWKV_FILE_VERSION, "Unsupported file version %d", version);
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_vocab, "n_vocab"));
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_embed, "n_embed"));
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_layer, "n_layer"));
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &data_type, "data_type"));
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_DATA_TYPE, data_type == 0 || data_type == 1, "Unsupported data type %d, only FP32 and FP16 can be quantized", data_type);
int32_t n_vocab; RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_uint32(file_out, magic, "magic"));
int32_t n_embed; RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_uint32(file_out, version, "version"));
int32_t n_layer; RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_vocab, "n_vocab"));
int32_t data_type; RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_embed, "n_embed"));
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_layer, "n_layer"));
finp.read((char *) &n_vocab, sizeof(n_vocab)); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, format_data_type, "data_type"));
finp.read((char *) &n_embed, sizeof(n_embed));
finp.read((char *) &n_layer, sizeof(n_layer));
finp.read((char *) &data_type, sizeof(data_type));
RWKV_ASSERT_FALSE(data_type == 0 || data_type == 1, "Unsupported data type %d, only FP32 and FP16 can be quantized", data_type);
data_type = format_type;
fout.write((char *) &n_vocab, sizeof(n_vocab));
fout.write((char *) &n_embed, sizeof(n_embed));
fout.write((char *) &n_layer, sizeof(n_layer));
fout.write((char *) &data_type, sizeof(data_type));
} }
// Process parameters // Process parameters
{
size_t total_size_orig = 0; size_t total_size_orig = 0;
size_t total_size_new = 0; size_t total_size_new = 0;
@ -673,119 +717,103 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
std::vector<int64_t> hist_all(1 << 4, 0); std::vector<int64_t> hist_all(1 << 4, 0);
while (true) { while (true) {
int32_t n_dims; int32_t n_dims, key_length, parameter_data_type;
int32_t key_length; RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_FILE_READ, fread(&n_dims, sizeof(int32_t), 1, file_in) == 1 || feof(file_in), "Failed to read an int32 value from a file (n_dims)");
int32_t parameter_data_type; if (feof(file_in)) break;
RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &key_length, "key_length"));
RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &parameter_data_type, "parameter_data_type"));
finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims)); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, n_dims == 1 || n_dims == 2, "Unsupported dimension count %d", n_dims);
finp.read(reinterpret_cast<char *>(&key_length), sizeof(key_length)); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, parameter_data_type >= 0 && parameter_data_type < FORMAT_TYPE_COUNT, "Unsupported parameter data type %d", parameter_data_type);
finp.read(reinterpret_cast<char *>(&parameter_data_type), sizeof(parameter_data_type));
if (finp.eof()) {
break;
}
RWKV_ASSERT_FALSE(parameter_data_type >= 0 && parameter_data_type < FORMAT_TYPE_COUNT, "Invalid parameter data type %d", parameter_data_type);
ggml_type parameter_ggml_type = FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type]; ggml_type parameter_ggml_type = FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type];
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, parameter_ggml_type != GGML_TYPE_UNKNOWN, "Unsupported parameter data type %d", parameter_data_type);
RWKV_ASSERT_FALSE(parameter_ggml_type != GGML_TYPE_UNKNOWN, "Invalid parameter data type %d", parameter_data_type); int32_t nelements, x, y;
int32_t nelements = 1; if (n_dims == 1) {
int32_t ne[2] = { 1, 1 }; RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &x, "x"));
for (int i = 0; i < n_dims; ++i) { y = 1;
finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i])); nelements = x;
nelements *= ne[i]; } else {
RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &x, "x"));
RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &y, "y"));
nelements = x * y;
} }
std::string name(key_length, 0); std::string name(key_length, 0);
finp.read(&name[0], key_length); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_KEY, fread(&name[0], key_length, 1, file_in) == 1, "Failed to read parameter key");
{
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ggml_type_name(parameter_ggml_type));
printf("%48s - [%5d, %5d], type = %6s ", name.data(), x, y, ggml_type_name(parameter_ggml_type));
total_size_orig += (size_t) (nelements * ggml_type_sizef(parameter_ggml_type)); total_size_orig += (size_t) (nelements * ggml_type_sizef(parameter_ggml_type));
}
// Quantize only 2D tensors, except embedding and head matrices. // Quantize only 2D tensors, except embedding and head matrices.
// Embedding and head take not too much space, especially in bigger models; // Embedding and head take not too much space, especially in bigger models;
// but they significantly increase perplexity when quantized. // but they significantly increase perplexity when quantized.
bool quantize = n_dims == 2 && bool quantize = n_dims == 2 && name != "emb.weight" && name != "head.weight";
name != std::string("emb.weight") &&
name != std::string("head.weight");
if (quantize) { if (quantize) {
RWKV_ASSERT_FALSE( RWKV_ASSERT_FALSE_MSG(
parameter_data_type == 0 || parameter_data_type == 1, RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA_TYPE,
parameter_ggml_type == GGML_TYPE_F32 || parameter_data_type == GGML_TYPE_F16,
"Unsupported parameter data type %d, only FP32 and FP16 can be quantized", "Unsupported parameter data type %d, only FP32 and FP16 can be quantized",
parameter_data_type parameter_ggml_type
); );
if (parameter_data_type == 1) { data_f32.resize(nelements);
if (parameter_data_type == GGML_TYPE_F16) {
data_f16.resize(nelements); data_f16.resize(nelements);
finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t)); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, fread(data_f16.data(), nelements * sizeof(ggml_fp16_t), 1, file_in) == 1, "Failed to read parameter data");
data_f32.resize(nelements);
for (int i = 0; i < nelements; ++i) { for (int i = 0; i < nelements; ++i)
data_f32[i] = ggml_fp16_to_fp32(data_f16[i]); data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
}
} else { } else {
data_f32.resize(nelements); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, fread(data_f32.data(), nelements * sizeof(float), 1, file_in) == 1, "Failed to read parameter data");
finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
} }
parameter_data_type = format_type; parameter_data_type = format_data_type;
parameter_ggml_type = format_ggml_type;
} else { } else {
const int bytes_per_element = (parameter_data_type == 0) ? sizeof(float) : sizeof(uint16_t); const size_t element_size = ggml_type_size(parameter_ggml_type);
data_u8.resize(nelements * bytes_per_element); data_u8.resize(nelements * element_size);
finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bytes_per_element); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, fread(data_u8.data(), nelements * element_size, 1, file_in) == 1, "Failed to read parameter data");
} }
fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims)); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_dims, "n_dims"));
fout.write(reinterpret_cast<char *>(&key_length), sizeof(key_length)); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, key_length, "key_length"));
fout.write(reinterpret_cast<char *>(&parameter_data_type), sizeof(parameter_data_type)); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, parameter_data_type, "parameter_data_type"));
for (int i = 0; i < n_dims; ++i) { RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, x, "x"));
fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
}
fout.write(&name[0], key_length); if (n_dims == 2)
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, y, "y"));
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_WRITE, fwrite(&name[0], key_length, 1, file_out) == 1, "Failed to write parameter key");
if (quantize) { if (quantize) {
printf("quantizing... "); printf("quantizing... ");
work.resize(nelements); // for quantization work.resize(nelements); // for quantization
size_t cur_size = 0;
// This is a histogramm of some values. If it shows single 1.0, then all 0.0, something went very wrong! // This is a histogramm of some values. If it shows single 1.0, then all 0.0, something went very wrong!
std::vector<int64_t> hist_cur(1 << 4, 0); std::vector<int64_t> hist_cur(1 << 4, 0);
switch (type) { size_t (*f)(const float* src, void* dst, int n, int k, int64_t* hist) =
case GGML_TYPE_Q4_0: format_ggml_type == GGML_TYPE_Q4_0 ? ggml_quantize_q4_0 :
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); format_ggml_type == GGML_TYPE_Q4_1 ? ggml_quantize_q4_1 :
break; format_ggml_type == GGML_TYPE_Q4_2 ? ggml_quantize_q4_2 :
case GGML_TYPE_Q4_1: format_ggml_type == GGML_TYPE_Q5_0 ? ggml_quantize_q5_0 :
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); format_ggml_type == GGML_TYPE_Q5_1 ? ggml_quantize_q5_1 :
break; format_ggml_type == GGML_TYPE_Q8_0 ? ggml_quantize_q8_0 :
case GGML_TYPE_Q4_2: NULL;
cur_size = ggml_quantize_q4_2(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
break;
case GGML_TYPE_Q5_0:
cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
break;
case GGML_TYPE_Q5_1:
cur_size = ggml_quantize_q5_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
break;
case GGML_TYPE_Q8_0:
cur_size = ggml_quantize_q8_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
break;
default: {
fprintf(stderr, "unsupported quantization type %d\n", type);
return false;
}
}
fout.write(reinterpret_cast<char *>(work.data()), cur_size); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_UNSUPPORTED, f, "unsupported quantization type %d\n", format_ggml_type);
size_t cur_size = (*f)(data_f32.data(), work.data(), nelements, x, hist_cur.data());
total_size_new += cur_size; total_size_new += cur_size;
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_WRITE, fwrite(work.data(), cur_size, 1, file_out) == 1, "Failed to write parameter data");
printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float) / 1024.0 / 1024.0, cur_size / 1024.0 / 1024.0); printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float) / 1024.0 / 1024.0, cur_size / 1024.0 / 1024.0);
for (int i = 0; i < (int) hist_cur.size(); ++i) { for (int i = 0; i < (int) hist_cur.size(); ++i) {
@ -799,7 +827,7 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
printf("\n"); printf("\n");
} else { } else {
printf("size = %8.3f MB\n", data_u8.size() / 1024.0 / 1024.0); printf("size = %8.3f MB\n", data_u8.size() / 1024.0 / 1024.0);
fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size()); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_WRITE, fwrite(data_u8.data(), data_u8.size(), 1, file_out) == 1, "Failed to write parameter data");
total_size_new += data_u8.size(); total_size_new += data_u8.size();
} }
} }
@ -808,7 +836,6 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
printf("quantized size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0); printf("quantized size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0);
printf("compression ratio = %8.2f\n", 1.0 * total_size_orig / total_size_new); printf("compression ratio = %8.2f\n", 1.0 * total_size_orig / total_size_new);
{
int64_t sum_all = 0; int64_t sum_all = 0;
for (int i = 0; i < (int) hist_all.size(); ++i) { for (int i = 0; i < (int) hist_all.size(); ++i) {
@ -822,11 +849,6 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
} }
printf("\n"); printf("\n");
}
}
finp.close();
fout.close();
return true; return true;
} }

44
rwkv.h
View File

@ -27,6 +27,50 @@
extern "C" { extern "C" {
#endif #endif
// Represents an error encountered during a function call.
// These are flags, so an actual value might contain multiple errors.
enum rwkv_error_flags {
RWKV_ERROR_NONE = 0,
RWKV_ERROR_ARGS = 1 << 4,
RWKV_ERROR_FILE = 2 << 4,
RWKV_ERROR_MODEL = 3 << 4,
RWKV_ERROR_MODEL_PARAMS = 4 << 4,
RWKV_ERROR_GRAPH = 5 << 4,
RWKV_ERROR_CTX = 6 << 4,
RWKV_ERROR_ALLOC = 1,
RWKV_ERROR_FILE_OPEN = 2,
RWKV_ERROR_FILE_STAT = 3,
RWKV_ERROR_FILE_READ = 4,
RWKV_ERROR_FILE_WRITE = 5,
RWKV_ERROR_FILE_MAGIC = 6,
RWKV_ERROR_FILE_VERSION = 7,
RWKV_ERROR_DATA_TYPE = 8,
RWKV_ERROR_UNSUPPORTED = 9,
RWKV_ERROR_SHAPE = 10,
RWKV_ERROR_DIMENSION = 11,
RWKV_ERROR_KEY = 12,
RWKV_ERROR_DATA = 13,
RWKV_ERROR_PARAM_MISSING = 14
};
// Sets whether errors are automatically printed to stderr.
// If this is set to false, you are responsible for calling rwkv_last_error manually if an operation fails.
// - ctx: the context to suppress error messages for.
// If NULL, affects model load (rwkv_init_from_file) and quantization (rwkv_quantize_model_file) errors,
// as well as the default for new context.
// - print_errors: whether error messages should be automatically printed.
RWKV_API void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors);
// Gets whether errors are automatically printed to stderr.
// - ctx: the context to retrieve the setting for, or NULL for the global setting.
RWKV_API bool rwkv_get_print_errors(struct rwkv_context * ctx);
// Retrieves and clears the error flags.
// - ctx: the context the retrieve the error for, or NULL for the global error.
RWKV_API enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx);
struct rwkv_context; struct rwkv_context;
// Loads the model from a file and prepares it for inference. // Loads the model from a file and prepares it for inference.

View File

@ -26,6 +26,7 @@ void test_model(const char * model_path, const float * expected_logits, const fl
fprintf(stderr, "Testing %s\n", model_path); fprintf(stderr, "Testing %s\n", model_path);
struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS); struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS);
enum rwkv_error_flags error = rwkv_get_last_error(NULL);
uint32_t n_vocab = rwkv_get_logits_buffer_element_count(model); uint32_t n_vocab = rwkv_get_logits_buffer_element_count(model);