diff --git a/rwkv.cpp b/rwkv.cpp index a7c2ee4..9119d13 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -12,34 +12,92 @@ #include #include #include +#include + +#include // fstat + +// --- Error handling --- + +thread_local enum rwkv_error_flags global_last_error = RWKV_ERROR_NONE; +thread_local bool global_print_errors = true; + +inline enum rwkv_error_flags operator|(enum rwkv_error_flags a, enum rwkv_error_flags b) { + return static_cast(static_cast(a) | static_cast(b)); +} + +inline enum rwkv_error_flags operator|=(enum rwkv_error_flags & a, enum rwkv_error_flags b) { + return a = a | b; +} + +// If the condition x is false, adds ERR_VAL to the last error, prints a message to stderr, and returns RET_VAL. +#define RWKV_ASSERT_MSG(ERR_VAL, RET_VAL, x, ...) \ + if (!(x)) { \ + global_last_error |= ERR_VAL; \ + if (global_print_errors) { \ + fprintf(stderr, __VA_ARGS__); \ + fprintf(stderr, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + } \ + return RET_VAL; \ + } + +// If the condition x is false, adds ERR_VAL to the last error, and returns RET_VAL. +#define RWKV_ASSERT(ERR_VAL, RET_VAL, x) \ + if (!(x)) { \ + global_last_error |= ERR_VAL; \ + return RET_VAL; \ + } + +// If the condition x is false, adds ERR_VAL to the ctx's last error, prints a message to stderr, and returns RET_VAL. +#define RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, RET_VAL, x, ...) \ + if (!(x)) { \ + ((struct rwkv_context *) ctx)->last_error |= ERR_VAL; \ + if (ctx->print_errors) { \ + fprintf(stderr, __VA_ARGS__); \ + fprintf(stderr, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + } \ + return RET_VAL; \ + } + +// If the condition x is false, adds ERR_VAL to the ctx's last error, and returns RET_VAL. +#define RWKV_CTX_ASSERT(ctx, ERR_VAL, RET_VAL, x) \ + if (!(x)) { \ + ctx->last_error |= ERR_VAL; \ + return RET_VAL; \ + } + +#define RWKV_ASSERT_FALSE_MSG(ERR_VAL, x, ...) RWKV_ASSERT_MSG(ERR_VAL, false, x, __VA_ARGS__) +#define RWKV_ASSERT_NULL_MSG(ERR_VAL, x, ...) RWKV_ASSERT_MSG(ERR_VAL, NULL, x, __VA_ARGS__) +#define RWKV_CTX_ASSERT_FALSE_MSG(ctx, ERR_VAL, x, ...) RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, false, x, __VA_ARGS__) +#define RWKV_CTX_ASSERT_NULL_MSG(ctx, ERR_VAL, x, ...) RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, NULL, x, __VA_ARGS__) + +#define RWKV_ASSERT_FALSE(ERR_VAL, x) RWKV_ASSERT(ERR_VAL, false, x) +#define RWKV_ASSERT_NULL(ERR_VAL, x) RWKV_ASSERT(ERR_VAL, NULL, x) +#define RWKV_CTX_ASSERT_FALSE(ctx, ERR_VAL, x) RWKV_CTX_ASSERT(ctx, ERR_VAL, false, x) +#define RWKV_CTX_ASSERT_NULL(ctx, ERR_VAL, x) RWKV_CTX_ASSERT(ctx, ERR_VAL, NULL, x) // --- Utilities --- -// Checks that x is not false. If x is false, prints fancy message to stderr and returns RET_VAL. -#define RWKV_ASSERT(RET_VAL, x, ...) \ - { \ - if (!(x)) { \ - fprintf(stderr, __VA_ARGS__); \ - fprintf(stderr, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ - return RET_VAL; \ - } \ - } - -// Checks that x is not false. If x is false, prints fancy message to stderr and returns false. -#define RWKV_ASSERT_FALSE(x, ...) RWKV_ASSERT(false, x, __VA_ARGS__) - -// Checks that x is not false. If x is false, prints fancy message to stderr and returns NULL. -#define RWKV_ASSERT_NULL(x, ...) RWKV_ASSERT(NULL, x, __VA_ARGS__) - // Reads single int32 value from a file. -bool read_int32(FILE * file, int32_t * dest) { - RWKV_ASSERT_FALSE(fread(dest, sizeof(int32_t), 1, file) == 1, "Failed to read an int32 value from a file"); +bool read_int32(FILE * file, int32_t * dest, const char * name) { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, fread(dest, sizeof(int32_t), 1, file) == 1, "Failed to read an int32 value from a file (%s)", name); return true; } // Reads single uint32 value from a file. -bool read_uint32(FILE * file, uint32_t * dest) { - RWKV_ASSERT_FALSE(fread(dest, sizeof(uint32_t), 1, file) == 1, "Failed to read a uint32 value from a file"); +bool read_uint32(FILE * file, uint32_t * dest, const char * name) { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, fread(dest, sizeof(uint32_t), 1, file) == 1, "Failed to read a uint32 value from a file (%s)", name); + return true; +} + +// Writes single int32 value to a file. +bool write_int32(FILE * file, int32_t value, const char * name) { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_WRITE, fwrite((void *) &value, sizeof(int32_t), 1, file), "Failed to write an int32 value to a file (%s)", name); + return true; +} + +// Writes single uint32 value to a file. +bool write_uint32(FILE * file, uint32_t value, const char * name) { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_WRITE, fwrite((void *) &value, sizeof(uint32_t), 1, file), "Failed to write a uint32 value to a file (%s)", name); return true; } @@ -123,7 +181,7 @@ struct rwkv_model { // If the parameter was not found, returns false. bool set_parameter(std::unordered_map * parameters, std::string key, struct ggml_tensor ** dest) { struct ggml_tensor * parameter = (*parameters)[key]; - RWKV_ASSERT_FALSE(parameter != NULL, "Parameter %s not found in model file", key.c_str()); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_PARAM_MISSING, parameter != NULL, "Parameter %s not found in model file", key.c_str()); *dest = parameter; return true; } @@ -183,66 +241,82 @@ struct ggml_tensor * rwkv_layer_norm(ggml_context * ctx, struct ggml_tensor * x, // Looks like ggml_norm does the first part, we only need to apply weight & bias. x = ggml_norm(ctx, x); x = ggml_mul(ctx, x, weight); - x = ggml_add(ctx, x, bias); + x = ggml_add_inplace(ctx, x, bias); return x; } // --- Implementation --- struct rwkv_context { - struct rwkv_model * model; + std::unique_ptr model; struct ggml_tensor * token_index; struct ggml_tensor * state; struct ggml_tensor ** state_parts; struct ggml_tensor * logits; struct ggml_context * ctx; - struct ggml_cgraph * graph; - bool freed; + std::unique_ptr graph; + enum rwkv_error_flags last_error; + bool print_errors; +}; + +void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) { + bool * ptr = ctx ? &ctx->print_errors : &global_print_errors; + *ptr = print_errors; +} + +bool rwkv_get_print_errors(struct rwkv_context * ctx) { + return ctx ? ctx->print_errors : global_print_errors; +} + +enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) { + enum rwkv_error_flags * ptr = ctx ? &ctx->last_error : &global_last_error; + enum rwkv_error_flags value = *ptr; + *ptr = RWKV_ERROR_NONE; + return value; +} + +struct rwkv_file_guard { + FILE * file; + ~rwkv_file_guard() { if (file) fclose(file); } +}; + +struct rwkv_ggml_guard { + struct ggml_context * ctx; + ~rwkv_ggml_guard() { if (ctx) ggml_free(ctx); } }; struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) { - FILE * file = fopen(file_path, "rb"); - RWKV_ASSERT_NULL(file != NULL, "Failed to open file %s", file_path); + global_last_error = RWKV_ERROR_NONE; + + FILE* file = fopen(file_path, "rb"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file, "Failed to open file %s", file_path); + rwkv_file_guard file_guard { file }; + + struct stat file_stat; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(file), &file_stat) == 0, "Failed to stat file %s", file_path); int32_t magic; - read_int32(file, &magic); - RWKV_ASSERT_NULL(magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic); + RWKV_ASSERT_NULL(RWKV_ERROR_FILE, read_int32(file, &magic, "magic")); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_MAGIC, magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic); int32_t version; - read_int32(file, &version); - RWKV_ASSERT_NULL(version == RWKV_FILE_VERSION, "Unsupported file version %d", version); + RWKV_ASSERT_NULL(RWKV_ERROR_FILE, read_int32(file, &version, "version")); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_VERSION, version == RWKV_FILE_VERSION, "Unsupported file version %d", version); - struct rwkv_model * model = (struct rwkv_model *) calloc(1, sizeof(struct rwkv_model)); + std::unique_ptr model(new(std::nothrow) struct rwkv_model()); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_ALLOC, model.get(), "Failed to allocate model"); + + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_vocab, "n_vocab")); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_embed, "n_embed")); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_layer, "n_layer")); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_int32(file, &model->data_type, "data_type")); - read_uint32(file, &(model->n_vocab)); - read_uint32(file, &(model->n_embed)); - read_uint32(file, &(model->n_layer)); + const char* unsupported_msg = "Models in %s format cannot be loaded anymore because the format was removed. You need to quantize the model into another format"; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_DATA_TYPE, model->data_type >= 0 && model->data_type < FORMAT_TYPE_COUNT, "Unsupported model data type %d", model->data_type); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 4, unsupported_msg, "Q4_1_O"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 6, unsupported_msg, "Q4_3"); - read_int32(file, &(model->data_type)); - RWKV_ASSERT_NULL(model->data_type >= 0 && model->data_type < FORMAT_TYPE_COUNT, "Unsupported model data type %d", model->data_type); - - RWKV_ASSERT_NULL( - model->data_type != 4, - "Models in Q4_1_O format cannot be loaded anymore because the format was removed. You need to quantize the model into another format" - ); - - RWKV_ASSERT_NULL( - model->data_type != 6, - "Models in Q4_3 format cannot be loaded anymore because the format was removed. You need to quantize the model into another format" - ); - - // Parameter tensors would take at least this amount in memory. - size_t file_size; - - { - auto fin = std::ifstream(file_path, std::ios::binary); - RWKV_ASSERT_NULL(fin, "Failed to open file %s", file_path); - fin.seekg(0, fin.end); - file_size = fin.tellg(); - fin.close(); - } - - size_t memory_required = file_size + + size_t memory_required = file_stat.st_size + // Intermediary vectors for calculation; there are around 100 calls to ggml size_t(100) * model->n_embed * sizeof(float) + // State, in and out @@ -253,109 +327,91 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t // TODO This is too much for smaller models; need a more proper and robust way of measuring required memory size_t(256) * 1024 * 1024; - // Initialize ggml - struct ggml_init_params params; - params.mem_size = memory_required; - params.mem_buffer = NULL; - params.no_alloc = false; - struct ggml_context * ctx = ggml_init(params); + struct ggml_context * ctx = ggml_init({ memory_required, NULL, false }); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_ALLOC, ctx, "Failed to allocate GGML context"); + rwkv_ggml_guard ggml_guard { ctx }; std::unordered_map parameters; while (true) { - int32_t dim_count; - size_t elements_read = fread(&dim_count, 4, 1, file); + int32_t dim_count, key_length, data_type; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_FILE_READ, fread(&dim_count, sizeof(int32_t), 1, file) == 1 || feof(file), "Failed to read an int32 value from a file (dim_count)"); + if (feof(file)) break; + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, read_int32(file, &key_length, "key_length")); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, read_int32(file, &data_type, "data_type")); - if (feof(file)) { - break; - } - - RWKV_ASSERT_NULL(elements_read == 1, "Failed to read dimension count"); - RWKV_ASSERT_NULL(dim_count == 1 || dim_count == 2, "Unsupported dimension count %d", dim_count); - - int32_t key_length; - read_int32(file, &key_length); - RWKV_ASSERT_NULL(key_length > 0, "Non-positive key length %d", key_length); - - int32_t data_type; - read_int32(file, &data_type); - RWKV_ASSERT_NULL(data_type >= 0 && data_type < FORMAT_TYPE_COUNT, "Unsupported parameter data type %d", data_type); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, dim_count == 1 || dim_count == 2, "Unsupported dimension count %d", dim_count); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_KEY, key_length > 0, "Non-positive key length %d", key_length); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, data_type >= 0 && data_type < FORMAT_TYPE_COUNT, "Unsupported parameter data type %d", data_type); ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type]; - - RWKV_ASSERT_NULL(ggml_data_type != GGML_TYPE_UNKNOWN, "Unsupported parameter data type %d", data_type); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, ggml_data_type != GGML_TYPE_UNKNOWN, "Unsupported parameter data type %d", data_type); struct ggml_tensor * tensor; - int32_t x = -1; - int32_t y = -1; - if (dim_count == 1) { - read_int32(file, &x); + int32_t x; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, read_int32(file, &x, "x"), "Failed to read parameter length"); tensor = ggml_new_tensor_1d(ctx, ggml_data_type, x); - } else if (dim_count == 2) { - read_int32(file, &x); - read_int32(file, &y); - tensor = ggml_new_tensor_2d(ctx, ggml_data_type, x, y); } else { - abort(); + int32_t x, y; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, read_int32(file, &x, "x"), "Failed to read parameter width"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, read_int32(file, &y, "y"), "Failed to read parameter height"); + tensor = ggml_new_tensor_2d(ctx, ggml_data_type, x, y); } - std::string key(key_length, 0); - RWKV_ASSERT_NULL(fread(&key[0], 1, key_length, file) == uint32_t(key_length), "Failed to read parameter key"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); - RWKV_ASSERT_NULL(fread(tensor->data, 1, ggml_nbytes(tensor), file) == ggml_nbytes(tensor), "Failed to read parameter data"); + std::string key(key_length, 0); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_KEY, fread(&key[0], key_length, 1, file) == 1, "Failed to read parameter key"); + + size_t nbytes = ggml_nbytes(tensor); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, fread(tensor->data, nbytes, 1, file) == 1, "Failed to read parameter data"); parameters[key] = tensor; } - fclose(file); + file_guard = { NULL }; // close file + + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "emb.weight", &model->emb)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "blocks.0.ln0.weight", &model->ln0_weight)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "blocks.0.ln0.bias", &model->ln0_bias)); model->layers.resize(model->n_layer); - - set_parameter(¶meters, "emb.weight", &(model->emb)); - - set_parameter(¶meters, "blocks.0.ln0.weight", &(model->ln0_weight)); - set_parameter(¶meters, "blocks.0.ln0.bias", &(model->ln0_bias)); - for (uint32_t i = 0; i < model->n_layer; i++) { - rwkv_layer layer = model->layers[i]; + rwkv_layer * layer = &model->layers[i]; + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ln1.weight", &layer->ln1_weight)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ln1.bias", &layer->ln1_bias)); - set_block_parameter(¶meters, i, "ln1.weight", &(layer.ln1_weight)); - set_block_parameter(¶meters, i, "ln1.bias", &(layer.ln1_bias)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.time_mix_k", &layer->att_time_mix_k)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.time_mix_v", &layer->att_time_mix_v)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.time_mix_r", &layer->att_time_mix_r)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.time_first", &layer->att_time_first)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.time_decay", &layer->att_time_decay)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.key.weight", &layer->att_key)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.value.weight", &layer->att_value)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.receptance.weight", &layer->att_receptance)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.output.weight", &layer->att_output)); - set_block_parameter(¶meters, i, "att.time_mix_k", &(layer.att_time_mix_k)); - set_block_parameter(¶meters, i, "att.time_mix_v", &(layer.att_time_mix_v)); - set_block_parameter(¶meters, i, "att.time_mix_r", &(layer.att_time_mix_r)); - set_block_parameter(¶meters, i, "att.time_first", &(layer.att_time_first)); - set_block_parameter(¶meters, i, "att.time_decay", &(layer.att_time_decay)); - set_block_parameter(¶meters, i, "att.key.weight", &(layer.att_key)); - set_block_parameter(¶meters, i, "att.value.weight", &(layer.att_value)); - set_block_parameter(¶meters, i, "att.receptance.weight", &(layer.att_receptance)); - set_block_parameter(¶meters, i, "att.output.weight", &(layer.att_output)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ln2.weight", &layer->ln2_weight)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ln2.bias", &layer->ln2_bias)); - set_block_parameter(¶meters, i, "ln2.weight", &(layer.ln2_weight)); - set_block_parameter(¶meters, i, "ln2.bias", &(layer.ln2_bias)); - - set_block_parameter(¶meters, i, "ffn.time_mix_k", &(layer.ffn_time_mix_k)); - set_block_parameter(¶meters, i, "ffn.time_mix_r", &(layer.ffn_time_mix_r)); - set_block_parameter(¶meters, i, "ffn.key.weight", &(layer.ffn_key)); - set_block_parameter(¶meters, i, "ffn.value.weight", &(layer.ffn_value)); - set_block_parameter(¶meters, i, "ffn.receptance.weight", &(layer.ffn_receptance)); - - model->layers[i] = layer; + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.time_mix_k", &layer->ffn_time_mix_k)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.time_mix_r", &layer->ffn_time_mix_r)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.key.weight", &layer->ffn_key)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.value.weight", &layer->ffn_value)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.receptance.weight", &layer->ffn_receptance)); } - set_parameter(¶meters, "ln_out.weight", &(model->ln_out_weight)); - set_parameter(¶meters, "ln_out.bias", &(model->ln_out_bias)); - - set_parameter(¶meters, "head.weight", &(model->head)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "ln_out.weight", &model->ln_out_weight)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "ln_out.bias", &model->ln_out_bias)); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "head.weight", &model->head)); // Verify order of dimensions struct ggml_tensor * emb = model->emb; - RWKV_ASSERT_NULL(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); - RWKV_ASSERT_NULL(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %lld", emb->ne[0]); - RWKV_ASSERT_NULL(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %lld", emb->ne[1]); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]); uint32_t n_embed = model->n_embed; uint32_t n_layer = model->n_layer; @@ -385,17 +441,17 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) - struct ggml_tensor * xk = ggml_add( + struct ggml_tensor * xk = ggml_add_inplace( ctx, ggml_mul(ctx, x0, layer.att_time_mix_k), ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) ); - struct ggml_tensor * xv = ggml_add( + struct ggml_tensor * xv = ggml_add_inplace( ctx, ggml_mul(ctx, x0, layer.att_time_mix_v), ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) ); - struct ggml_tensor * xr = ggml_add( + struct ggml_tensor * xr = ggml_add_inplace( ctx, ggml_mul(ctx, x0, layer.att_time_mix_r), ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) @@ -429,13 +485,13 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t // e2 = torch.exp(ww - qq) struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); // a = e1 * aa + e2 * v - struct ggml_tensor * a = ggml_add( + struct ggml_tensor * a = ggml_add_inplace( ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v) ); // b = e1 * bb + e2 - struct ggml_tensor * b = ggml_add( + struct ggml_tensor * b = ggml_add_inplace( ctx, ggml_mul(ctx, e1, bb), e2 @@ -451,13 +507,13 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t // e2 = torch.exp(k - qq) e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq)); // state[5 * i + 2] = e1 * aa + e2 * v - state_parts[5 * i + 2] = ggml_add( + state_parts[5 * i + 2] = ggml_add_inplace( ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v) ); // state[5 * i + 3] = e1 * bb + e2 - state_parts[5 * i + 3] = ggml_add( + state_parts[5 * i + 3] = ggml_add_inplace( ctx, ggml_mul(ctx, e1, bb), e2 @@ -465,7 +521,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t // state[5 * i + 4] = qq state_parts[5 * i + 4] = qq; // ow @ (r * wkv) - x = ggml_add( + x = ggml_add_inplace( ctx, x, ggml_mul_mat( @@ -484,12 +540,12 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * sizeof(float)); // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) - struct ggml_tensor * xk = ggml_add( + struct ggml_tensor * xk = ggml_add_inplace( ctx, ggml_mul(ctx, x0, layer.ffn_time_mix_k), ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) ); - struct ggml_tensor * xr = ggml_add( + struct ggml_tensor * xr = ggml_add_inplace( ctx, ggml_mul(ctx, x0, layer.ffn_time_mix_r), ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) @@ -508,7 +564,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t ggml_mul_mat(ctx, layer.ffn_key, xk) )); // r * (vw @ k) - x = ggml_add( + x = ggml_add_inplace( ctx, x, ggml_mul( @@ -526,25 +582,29 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t // x = (self.w.head.weight @ x).float() struct ggml_tensor * logits = ggml_mul_mat(ctx, model->head, x); - struct ggml_cgraph * graph = (struct ggml_cgraph *) calloc(1, sizeof(struct ggml_cgraph)); + std::unique_ptr graph(new(std::nothrow) struct ggml_cgraph()); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_GRAPH | RWKV_ERROR_ALLOC, graph.get(), "Failed to allocate graph"); - *graph = ggml_build_forward(logits); + ggml_build_forward_expand(graph.get(), logits); - for (uint32_t i = 0; i < n_layer * 5; i++) { - ggml_build_forward_expand(graph, state_parts[i]); - } + for (uint32_t i = 0; i < n_layer * 5; i++) + ggml_build_forward_expand(graph.get(), state_parts[i]); graph->n_threads = n_threads; - struct rwkv_context * rwkv_ctx = (struct rwkv_context *) calloc(1, sizeof(struct rwkv_context)); - rwkv_ctx->model = model; + std::unique_ptr rwkv_ctx(new(std::nothrow) struct rwkv_context()); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx.get(), "Failed to allocate context"); + rwkv_ctx->model = std::move(model); rwkv_ctx->token_index = token_index; rwkv_ctx->state = state; rwkv_ctx->state_parts = state_parts; rwkv_ctx->logits = logits; rwkv_ctx->ctx = ctx; - rwkv_ctx->graph = graph; - return rwkv_ctx; + rwkv_ctx->graph = std::move(graph); + rwkv_ctx->last_error = RWKV_ERROR_NONE; + rwkv_ctx->print_errors = global_print_errors; + ggml_guard.ctx = NULL; // don't free ggml context + return rwkv_ctx.release(); } uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) { @@ -556,21 +616,21 @@ uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) { } bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) { - RWKV_ASSERT_FALSE(state_out != NULL, "state_out is NULL"); - RWKV_ASSERT_FALSE(logits_out != NULL, "logits_out is NULL"); + ((struct rwkv_context *) ctx)->last_error = RWKV_ERROR_NONE; + + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, state_out != NULL, "state_out is NULL"); + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, logits_out != NULL, "logits_out is NULL"); + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < ctx->model->n_vocab, "Token is out of range 0..%d", ctx->model->n_vocab - 1); uint32_t n_layer = ctx->model->n_layer; uint32_t n_embed = ctx->model->n_embed; - uint32_t n_vocab = ctx->model->n_vocab; - - RWKV_ASSERT_FALSE(token < (uint32_t) n_vocab, "Token is out of range 0..%d", n_vocab - 1); ggml_set_i32_1d(ctx->token_index, 0, token); if (state_in == NULL) { ggml_set_f32(ctx->state, 0.0F); - for (uint32_t i = 0; i < n_layer; i++) { + for (uint64_t i = 0; i < n_layer; i++) { // state[5 * i + 4] = -1e30 ggml_set_f32( ggml_view_1d(ctx->ctx, ctx->state, n_embed, (5 * i + 4) * n_embed * sizeof(float)), @@ -581,11 +641,10 @@ bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const floa memcpy(ctx->state->data, state_in, ctx->state->ne[0] * sizeof(float)); } - ggml_graph_compute(ctx->ctx, ctx->graph); + ggml_graph_compute(ctx->ctx, ctx->graph.get()); for (uint32_t i = 0; i < n_layer * 5; i++) { struct ggml_tensor * part = ctx->state_parts[i]; - memcpy(state_out + i * n_embed, part->data, part->ne[0] * sizeof(float)); } @@ -595,238 +654,201 @@ bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const floa } void rwkv_free(struct rwkv_context * ctx) { - ctx->model->layers.~vector(); - free(ctx->model); + std::unique_ptr rwkv_ctx(ctx); delete[] ctx->state_parts; ggml_free(ctx->ctx); - free(ctx->graph); - free(ctx); } bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name) { - int32_t format_type = format_name_to_format_type(format_name); + global_last_error = RWKV_ERROR_NONE; - RWKV_ASSERT_FALSE(format_type != -1, "Unsupported format \"%s\"", format_name); + int32_t format_data_type = format_name_to_format_type(format_name); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_DATA_TYPE, format_data_type != -1, "Unsupported format \"%s\"", format_name); - ggml_type type = FORMAT_TYPE_TO_GGML_TYPE[format_type]; - - RWKV_ASSERT_FALSE(type != GGML_TYPE_UNKNOWN, "Unsupported format \"%s\"", format_name); + ggml_type format_ggml_type = FORMAT_TYPE_TO_GGML_TYPE[format_data_type]; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_DATA_TYPE, format_ggml_type != GGML_TYPE_UNKNOWN, "Unsupported format \"%s\"", format_name); // Needed to initialize FP16 lookup table - { - struct ggml_init_params params = { 0, NULL, false }; - struct ggml_context * ctx = ggml_init(params); - ggml_free(ctx); - } + ggml_free(ggml_init({ 0, NULL, false })); printf("Loading model from '%s'\n", model_file_path_in); - auto finp = std::ifstream(model_file_path_in, std::ios::binary); - RWKV_ASSERT_FALSE(finp, "Failed to open %s for reading", model_file_path_in); + FILE * file_in = fopen(model_file_path_in, "rb"); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file_in, "Failed to open %s for reading", model_file_path_in); + FILE * file_out = fopen(model_file_path_out, "wb"); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file_out, "Failed to open %s for writing", model_file_path_out); - auto fout = std::ofstream(model_file_path_out, std::ios::binary); - RWKV_ASSERT_FALSE(fout, "Failed to open %s for writing", model_file_path_out); + rwkv_file_guard file_in_guard { file_in }; + rwkv_file_guard file_out_guard { file_out }; // Process header { - uint32_t magic; - finp.read((char *) &magic, sizeof(magic)); - RWKV_ASSERT_FALSE(magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic); - fout.write((char *) &magic, sizeof(magic)); + uint32_t magic, version; + int32_t n_vocab, n_embed, n_layer, data_type; - uint32_t format_version; - finp.read((char *) &format_version, sizeof(format_version)); - RWKV_ASSERT_FALSE(format_version == RWKV_FILE_VERSION, "Unsupported file version %d", format_version); - fout.write((char *) &format_version, sizeof(format_version)); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_uint32(file_in, &magic, "magic")); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_MAGIC, magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_uint32(file_in, &version, "version")); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_VERSION, version == RWKV_FILE_VERSION, "Unsupported file version %d", version); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_vocab, "n_vocab")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_embed, "n_embed")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_layer, "n_layer")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &data_type, "data_type")); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_DATA_TYPE, data_type == 0 || data_type == 1, "Unsupported data type %d, only FP32 and FP16 can be quantized", data_type); - int32_t n_vocab; - int32_t n_embed; - int32_t n_layer; - int32_t data_type; - - finp.read((char *) &n_vocab, sizeof(n_vocab)); - finp.read((char *) &n_embed, sizeof(n_embed)); - finp.read((char *) &n_layer, sizeof(n_layer)); - finp.read((char *) &data_type, sizeof(data_type)); - - RWKV_ASSERT_FALSE(data_type == 0 || data_type == 1, "Unsupported data type %d, only FP32 and FP16 can be quantized", data_type); - - data_type = format_type; - - fout.write((char *) &n_vocab, sizeof(n_vocab)); - fout.write((char *) &n_embed, sizeof(n_embed)); - fout.write((char *) &n_layer, sizeof(n_layer)); - fout.write((char *) &data_type, sizeof(data_type)); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_uint32(file_out, magic, "magic")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_uint32(file_out, version, "version")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_vocab, "n_vocab")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_embed, "n_embed")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_layer, "n_layer")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, format_data_type, "data_type")); } // Process parameters - { - size_t total_size_orig = 0; - size_t total_size_new = 0; + size_t total_size_orig = 0; + size_t total_size_new = 0; - std::vector work; + std::vector work; - std::vector data_u8; - std::vector data_f16; - std::vector data_f32; + std::vector data_u8; + std::vector data_f16; + std::vector data_f32; - std::vector hist_all(1 << 4, 0); + std::vector hist_all(1 << 4, 0); - while (true) { - int32_t n_dims; - int32_t key_length; - int32_t parameter_data_type; + while (true) { + int32_t n_dims, key_length, parameter_data_type; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_FILE_READ, fread(&n_dims, sizeof(int32_t), 1, file_in) == 1 || feof(file_in), "Failed to read an int32 value from a file (n_dims)"); + if (feof(file_in)) break; + RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &key_length, "key_length")); + RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, ¶meter_data_type, "parameter_data_type")); - finp.read(reinterpret_cast(&n_dims), sizeof(n_dims)); - finp.read(reinterpret_cast(&key_length), sizeof(key_length)); - finp.read(reinterpret_cast(¶meter_data_type), sizeof(parameter_data_type)); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, n_dims == 1 || n_dims == 2, "Unsupported dimension count %d", n_dims); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, parameter_data_type >= 0 && parameter_data_type < FORMAT_TYPE_COUNT, "Unsupported parameter data type %d", parameter_data_type); - if (finp.eof()) { - break; - } + ggml_type parameter_ggml_type = FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type]; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, parameter_ggml_type != GGML_TYPE_UNKNOWN, "Unsupported parameter data type %d", parameter_data_type); - RWKV_ASSERT_FALSE(parameter_data_type >= 0 && parameter_data_type < FORMAT_TYPE_COUNT, "Invalid parameter data type %d", parameter_data_type); + int32_t nelements, x, y; - ggml_type parameter_ggml_type = FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type]; - - RWKV_ASSERT_FALSE(parameter_ggml_type != GGML_TYPE_UNKNOWN, "Invalid parameter data type %d", parameter_data_type); - - int32_t nelements = 1; - int32_t ne[2] = { 1, 1 }; - for (int i = 0; i < n_dims; ++i) { - finp.read (reinterpret_cast(&ne[i]), sizeof(ne[i])); - nelements *= ne[i]; - } - - std::string name(key_length, 0); - finp.read(&name[0], key_length); - - { - printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ggml_type_name(parameter_ggml_type)); - - total_size_orig += (size_t) (nelements * ggml_type_sizef(parameter_ggml_type)); - } - - // Quantize only 2D tensors, except embedding and head matrices. - // Embedding and head take not too much space, especially in bigger models; - // but they significantly increase perplexity when quantized. - bool quantize = n_dims == 2 && - name != std::string("emb.weight") && - name != std::string("head.weight"); - - if (quantize) { - RWKV_ASSERT_FALSE( - parameter_data_type == 0 || parameter_data_type == 1, - "Unsupported parameter data type %d, only FP32 and FP16 can be quantized", - parameter_data_type - ); - - if (parameter_data_type == 1) { - data_f16.resize(nelements); - finp.read(reinterpret_cast(data_f16.data()), nelements * sizeof(ggml_fp16_t)); - data_f32.resize(nelements); - for (int i = 0; i < nelements; ++i) { - data_f32[i] = ggml_fp16_to_fp32(data_f16[i]); - } - } else { - data_f32.resize(nelements); - finp.read(reinterpret_cast(data_f32.data()), nelements * sizeof(float)); - } - - parameter_data_type = format_type; - } else { - const int bytes_per_element = (parameter_data_type == 0) ? sizeof(float) : sizeof(uint16_t); - data_u8.resize(nelements * bytes_per_element); - finp.read(reinterpret_cast(data_u8.data()), nelements * bytes_per_element); - } - - fout.write(reinterpret_cast(&n_dims), sizeof(n_dims)); - fout.write(reinterpret_cast(&key_length), sizeof(key_length)); - fout.write(reinterpret_cast(¶meter_data_type), sizeof(parameter_data_type)); - - for (int i = 0; i < n_dims; ++i) { - fout.write(reinterpret_cast(&ne[i]), sizeof(ne[i])); - } - - fout.write(&name[0], key_length); - - if (quantize) { - printf("quantizing... "); - work.resize(nelements); // for quantization - - size_t cur_size = 0; - // This is a histogramm of some values. If it shows single 1.0, then all 0.0, something went very wrong! - std::vector hist_cur(1 << 4, 0); - - switch (type) { - case GGML_TYPE_Q4_0: - cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); - break; - case GGML_TYPE_Q4_1: - cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); - break; - case GGML_TYPE_Q4_2: - cur_size = ggml_quantize_q4_2(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); - break; - case GGML_TYPE_Q5_0: - cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); - break; - case GGML_TYPE_Q5_1: - cur_size = ggml_quantize_q5_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); - break; - case GGML_TYPE_Q8_0: - cur_size = ggml_quantize_q8_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); - break; - default: { - fprintf(stderr, "unsupported quantization type %d\n", type); - return false; - } - } - - fout.write(reinterpret_cast(work.data()), cur_size); - total_size_new += cur_size; - - printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float) / 1024.0 / 1024.0, cur_size / 1024.0 / 1024.0); - - for (int i = 0; i < (int) hist_cur.size(); ++i) { - hist_all[i] += hist_cur[i]; - } - - for (int i = 0; i < (int) hist_cur.size(); ++i) { - printf("%5.3f ", hist_cur[i] / float(nelements)); - } - - printf("\n"); - } else { - printf("size = %8.3f MB\n", data_u8.size() / 1024.0 / 1024.0); - fout.write(reinterpret_cast(data_u8.data()), data_u8.size()); - total_size_new += data_u8.size(); - } + if (n_dims == 1) { + RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &x, "x")); + y = 1; + nelements = x; + } else { + RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &x, "x")); + RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &y, "y")); + nelements = x * y; } - printf("original size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0); - printf("quantized size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0); - printf("compression ratio = %8.2f\n", 1.0 * total_size_orig / total_size_new); + std::string name(key_length, 0); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_KEY, fread(&name[0], key_length, 1, file_in) == 1, "Failed to read parameter key"); - { - int64_t sum_all = 0; + printf("%48s - [%5d, %5d], type = %6s ", name.data(), x, y, ggml_type_name(parameter_ggml_type)); + total_size_orig += (size_t) (nelements * ggml_type_sizef(parameter_ggml_type)); - for (int i = 0; i < (int) hist_all.size(); ++i) { - sum_all += hist_all[i]; + // Quantize only 2D tensors, except embedding and head matrices. + // Embedding and head take not too much space, especially in bigger models; + // but they significantly increase perplexity when quantized. + bool quantize = n_dims == 2 && name != "emb.weight" && name != "head.weight"; + + if (quantize) { + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA_TYPE, + parameter_ggml_type == GGML_TYPE_F32 || parameter_data_type == GGML_TYPE_F16, + "Unsupported parameter data type %d, only FP32 and FP16 can be quantized", + parameter_ggml_type + ); + + data_f32.resize(nelements); + + if (parameter_data_type == GGML_TYPE_F16) { + data_f16.resize(nelements); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, fread(data_f16.data(), nelements * sizeof(ggml_fp16_t), 1, file_in) == 1, "Failed to read parameter data"); + + for (int i = 0; i < nelements; ++i) + data_f32[i] = ggml_fp16_to_fp32(data_f16[i]); + } else { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, fread(data_f32.data(), nelements * sizeof(float), 1, file_in) == 1, "Failed to read parameter data"); } - printf("hist: "); + parameter_data_type = format_data_type; + parameter_ggml_type = format_ggml_type; + } else { + const size_t element_size = ggml_type_size(parameter_ggml_type); + data_u8.resize(nelements * element_size); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, fread(data_u8.data(), nelements * element_size, 1, file_in) == 1, "Failed to read parameter data"); + } - for (int i = 0; i < (int) hist_all.size(); ++i) { - printf("%5.3f ", hist_all[i] / float(sum_all)); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_dims, "n_dims")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, key_length, "key_length")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, parameter_data_type, "parameter_data_type")); + + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, x, "x")); + + if (n_dims == 2) + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, y, "y")); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_WRITE, fwrite(&name[0], key_length, 1, file_out) == 1, "Failed to write parameter key"); + + if (quantize) { + printf("quantizing... "); + work.resize(nelements); // for quantization + + // This is a histogramm of some values. If it shows single 1.0, then all 0.0, something went very wrong! + std::vector hist_cur(1 << 4, 0); + + size_t (*f)(const float* src, void* dst, int n, int k, int64_t* hist) = + format_ggml_type == GGML_TYPE_Q4_0 ? ggml_quantize_q4_0 : + format_ggml_type == GGML_TYPE_Q4_1 ? ggml_quantize_q4_1 : + format_ggml_type == GGML_TYPE_Q4_2 ? ggml_quantize_q4_2 : + format_ggml_type == GGML_TYPE_Q5_0 ? ggml_quantize_q5_0 : + format_ggml_type == GGML_TYPE_Q5_1 ? ggml_quantize_q5_1 : + format_ggml_type == GGML_TYPE_Q8_0 ? ggml_quantize_q8_0 : + NULL; + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_UNSUPPORTED, f, "unsupported quantization type %d\n", format_ggml_type); + + size_t cur_size = (*f)(data_f32.data(), work.data(), nelements, x, hist_cur.data()); + total_size_new += cur_size; + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_WRITE, fwrite(work.data(), cur_size, 1, file_out) == 1, "Failed to write parameter data"); + + printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float) / 1024.0 / 1024.0, cur_size / 1024.0 / 1024.0); + + for (int i = 0; i < (int) hist_cur.size(); ++i) { + hist_all[i] += hist_cur[i]; + } + + for (int i = 0; i < (int) hist_cur.size(); ++i) { + printf("%5.3f ", hist_cur[i] / float(nelements)); } printf("\n"); + } else { + printf("size = %8.3f MB\n", data_u8.size() / 1024.0 / 1024.0); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_WRITE, fwrite(data_u8.data(), data_u8.size(), 1, file_out) == 1, "Failed to write parameter data"); + total_size_new += data_u8.size(); } } - finp.close(); - fout.close(); + printf("original size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0); + printf("quantized size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0); + printf("compression ratio = %8.2f\n", 1.0 * total_size_orig / total_size_new); + + int64_t sum_all = 0; + + for (int i = 0; i < (int) hist_all.size(); ++i) { + sum_all += hist_all[i]; + } + + printf("hist: "); + + for (int i = 0; i < (int) hist_all.size(); ++i) { + printf("%5.3f ", hist_all[i] / float(sum_all)); + } + + printf("\n"); return true; } @@ -849,4 +871,4 @@ const char * rwkv_get_system_info_string(void) { s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; return s.c_str(); -} +} \ No newline at end of file diff --git a/rwkv.h b/rwkv.h index 46abb61..beb6a81 100644 --- a/rwkv.h +++ b/rwkv.h @@ -27,6 +27,50 @@ extern "C" { #endif + // Represents an error encountered during a function call. + // These are flags, so an actual value might contain multiple errors. + enum rwkv_error_flags { + RWKV_ERROR_NONE = 0, + + RWKV_ERROR_ARGS = 1 << 4, + RWKV_ERROR_FILE = 2 << 4, + RWKV_ERROR_MODEL = 3 << 4, + RWKV_ERROR_MODEL_PARAMS = 4 << 4, + RWKV_ERROR_GRAPH = 5 << 4, + RWKV_ERROR_CTX = 6 << 4, + + RWKV_ERROR_ALLOC = 1, + RWKV_ERROR_FILE_OPEN = 2, + RWKV_ERROR_FILE_STAT = 3, + RWKV_ERROR_FILE_READ = 4, + RWKV_ERROR_FILE_WRITE = 5, + RWKV_ERROR_FILE_MAGIC = 6, + RWKV_ERROR_FILE_VERSION = 7, + RWKV_ERROR_DATA_TYPE = 8, + RWKV_ERROR_UNSUPPORTED = 9, + RWKV_ERROR_SHAPE = 10, + RWKV_ERROR_DIMENSION = 11, + RWKV_ERROR_KEY = 12, + RWKV_ERROR_DATA = 13, + RWKV_ERROR_PARAM_MISSING = 14 + }; + + // Sets whether errors are automatically printed to stderr. + // If this is set to false, you are responsible for calling rwkv_last_error manually if an operation fails. + // - ctx: the context to suppress error messages for. + // If NULL, affects model load (rwkv_init_from_file) and quantization (rwkv_quantize_model_file) errors, + // as well as the default for new context. + // - print_errors: whether error messages should be automatically printed. + RWKV_API void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors); + + // Gets whether errors are automatically printed to stderr. + // - ctx: the context to retrieve the setting for, or NULL for the global setting. + RWKV_API bool rwkv_get_print_errors(struct rwkv_context * ctx); + + // Retrieves and clears the error flags. + // - ctx: the context the retrieve the error for, or NULL for the global error. + RWKV_API enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx); + struct rwkv_context; // Loads the model from a file and prepares it for inference. diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index bfb726e..750b1ed 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -26,6 +26,7 @@ void test_model(const char * model_path, const float * expected_logits, const fl fprintf(stderr, "Testing %s\n", model_path); struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS); + enum rwkv_error_flags error = rwkv_get_last_error(NULL); uint32_t n_vocab = rwkv_get_logits_buffer_element_count(model);