diff --git a/rwkv.cpp b/rwkv.cpp index 7e76abf..eff1807 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -7,23 +7,36 @@ #include #include -#include -#include +#include #include #include -#include -#include #include -#include #include #include -#include // fstat +#define _FILE_OFFSET_BITS 64 +#define RWKV_MAYBE_BREAK -#ifdef WIN32 -#define stat64 _stat64 -#define fstat64 _fstat64 +#ifdef _MSC_BUILD +#define stat _stat64 +#define fstat _fstat64 +#define ftell _ftelli64 +#define fseek _fseeki64 + +#ifndef NDEBUG +#include +#define RWKV_MAYBE_BREAK __debugbreak() #endif +#else +#include +#if !defined(__APPLE__) +#define ftell ftello +#define fseek fseeko +#endif +#endif + +static_assert(sizeof(stat::st_size) >= 8, "File offsets should be 64-bit or else rwkv.cpp will not be able to load model files over 2GB"); +static_assert(sizeof(decltype(ftell(NULL))) >= 8, "File offsets should be 64-bit or else rwkv.cpp will not be able to load model files over 2GB"); // --- Error handling --- @@ -38,41 +51,73 @@ inline enum rwkv_error_flags operator|=(enum rwkv_error_flags & a, enum rwkv_err return a = a | b; } -// If the condition x is false, adds ERR_VAL to the last error, prints a message to stderr, and returns RET_VAL. -#define RWKV_ASSERT_MSG(ERR_VAL, RET_VAL, x, ...) \ - if (!(x)) { \ - global_last_error |= ERR_VAL; \ - if (global_print_errors) { \ - fprintf(stderr, __VA_ARGS__); \ - fprintf(stderr, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ - } \ - return RET_VAL; \ - } +#define RWKV_MSG(...) do { if (global_print_errors) fprintf(stderr, __VA_ARGS__); } while (0) +#define RWKV_CTX_MSG(ctx, ...) do { if (ctx->print_errors) fprintf(stderr, __VA_ARGS__); } while (0) // If the condition x is false, adds ERR_VAL to the last error, and returns RET_VAL. -#define RWKV_ASSERT(ERR_VAL, RET_VAL, x) \ +#define RWKV_ASSERT(ERR_VAL, RET_VAL, x) do { \ if (!(x)) { \ global_last_error |= ERR_VAL; \ + RWKV_MSG("\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + RWKV_MAYBE_BREAK; \ return RET_VAL; \ - } + } } while (0) + +// If the condition x is false, adds ERR_VAL to the last error, prints a message to stderr, and returns RET_VAL. +#define RWKV_ASSERT_MSG(ERR_VAL, RET_VAL, x, ...) do { \ + if (!(x)) { \ + global_last_error |= ERR_VAL; \ + RWKV_MSG(__VA_ARGS__); \ + RWKV_MSG("\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + RWKV_MAYBE_BREAK; \ + return RET_VAL; \ + } } while (0) // If the condition x is false, adds ERR_VAL to the ctx's last error, prints a message to stderr, and returns RET_VAL. -#define RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, RET_VAL, x, ...) \ +#define RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, RET_VAL, x, ...) do { \ if (!(x)) { \ ((struct rwkv_context *) ctx)->last_error |= ERR_VAL; \ - if (ctx->print_errors) { \ - fprintf(stderr, __VA_ARGS__); \ - fprintf(stderr, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ - } \ + RWKV_CTX_MSG(ctx, __VA_ARGS__); \ + RWKV_CTX_MSG(ctx, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + RWKV_MAYBE_BREAK; \ return RET_VAL; \ - } + } } while (0) // If the condition x is false, adds ERR_VAL to the ctx's last error, and returns RET_VAL. -#define RWKV_CTX_ASSERT(ctx, ERR_VAL, RET_VAL, x) \ +#define RWKV_CTX_ASSERT(ctx, ERR_VAL, RET_VAL, x) do { \ if (!(x)) { \ - ctx->last_error |= ERR_VAL; \ + ((struct rwkv_context *) ctx)->last_error |= ERR_VAL; \ + RWKV_CTX_MSG(ctx, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + RWKV_MAYBE_BREAK; \ return RET_VAL; \ - } + } } while (0) + +// If the condition x is false, returns RET_VAL. +#define RWKV_ENSURE(RET_VAL, x) do { \ + if (!(x)) { \ + RWKV_MSG("\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + RWKV_MAYBE_BREAK; \ + return RET_VAL; \ + } } while (0) + +// If the condition x is false, prints a message to stderr, and returns RET_VAL. +#define RWKV_ENSURE_MSG(RET_VAL, x, ...) do { \ + if (!(x)) { \ + RWKV_MSG(__VA_ARGS__); \ + RWKV_MSG("\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + RWKV_MAYBE_BREAK; \ + return RET_VAL; \ + } } while (0) + +// If the condition x is false, prints a message to stderr, and returns RET_VAL. +#define RWKV_CTX_ENSURE_MSG(ctx, RET_VAL, x, ...) do { \ + if (!(x)) { \ + ((struct rwkv_context *) ctx)->last_error |= ERR_VAL; \ + RWKV_CTX_MSG(ctx, __VA_ARGS__); \ + RWKV_CTX_MSG(ctx, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + RWKV_MAYBE_BREAK; \ + return RET_VAL; \ + } } while (0) #define RWKV_ASSERT_FALSE_MSG(ERR_VAL, x, ...) RWKV_ASSERT_MSG(ERR_VAL, false, x, __VA_ARGS__) #define RWKV_ASSERT_NULL_MSG(ERR_VAL, x, ...) RWKV_ASSERT_MSG(ERR_VAL, NULL, x, __VA_ARGS__) @@ -84,72 +129,243 @@ inline enum rwkv_error_flags operator|=(enum rwkv_error_flags & a, enum rwkv_err #define RWKV_CTX_ASSERT_FALSE(ctx, ERR_VAL, x) RWKV_CTX_ASSERT(ctx, ERR_VAL, false, x) #define RWKV_CTX_ASSERT_NULL(ctx, ERR_VAL, x) RWKV_CTX_ASSERT(ctx, ERR_VAL, NULL, x) +#define RWKV_ENSURE_OR_FALSE(x) RWKV_ENSURE(false, x) +#define RWKV_ENSURE_OR_NULL(x) RWKV_ENSURE(NULL, x) +#define RWKV_ENSURE_OR_FALSE_MSG(x, ...) RWKV_ENSURE_MSG(false, x, __VA_ARGS__) +#define RWKV_ENSURE_OR_NULL_MSG(x, ...) RWKV_ENSURE_MSG(NULL, x, __VA_ARGS__) +#define RWKV_CTX_ENSURE_OR_FALSE_MSG(ctx, x, ...) RWKV_CTX_ENSURE_MSG(ctx, false, x, __VA_ARGS__) +#define RWKV_CTX_ENSURE_OR_NULL_MSG(ctx, x, ...) RWKV_CTX_ENSURE_MSG(ctx, NULL, x, __VA_ARGS__) + // --- Utilities --- -// Reads single int32 value from a file. -bool read_int32(FILE * file, int32_t * dest, const char * name) { - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, fread(dest, sizeof(int32_t), 1, file) == 1, "Failed to read an int32 value from a file (%s)", name); - return true; +// Reads a single uint32 value from a file. +bool rwkv_fread_uint32(FILE * file, uint32_t & dest) { + return fread((void *) &dest, sizeof(uint32_t), 1, file) == 1; } -// Reads single uint32 value from a file. -bool read_uint32(FILE * file, uint32_t * dest, const char * name) { - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, fread(dest, sizeof(uint32_t), 1, file) == 1, "Failed to read a uint32 value from a file (%s)", name); - return true; +// Reads a single string value from a file. +bool rwkv_fread_string(FILE * file, size_t length, std::string & dest) { + dest.resize(length); + return fread((void *) dest.data(), length, 1, file) == 1; } -// Writes single int32 value to a file. -bool write_int32(FILE * file, int32_t value, const char * name) { - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_WRITE, fwrite((void *) &value, sizeof(int32_t), 1, file), "Failed to write an int32 value to a file (%s)", name); - return true; +// Reads a single data buffer from a file. +bool rwkv_fread_data(FILE * file, size_t length, void * dest) { + return fread(dest, length, 1, file) == 1; } -// Writes single uint32 value to a file. -bool write_uint32(FILE * file, uint32_t value, const char * name) { - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_WRITE, fwrite((void *) &value, sizeof(uint32_t), 1, file), "Failed to write a uint32 value to a file (%s)", name); - return true; +// Writes a single uint32 value to a file. +bool rwkv_fwrite_uint32(FILE * file, const uint32_t value) { + return fwrite((const void *) &value, sizeof(uint32_t), 1, file); } +// Writes a single string value to a file. +bool rwkv_fwrite_string(FILE * file, const std::string & value) { + return fwrite((const void *) value.data(), value.length(), 1, file) == 1; +} + +// Writes a single data buffer to a file. +bool rwkv_fwrite_data(FILE * file, const void * data, const size_t length) { + return fwrite(data, length, 1, file) == 1; +} + +// --- File data structures --- + +#define TYPE_UNKNOWN TYPE_COUNT + +enum rwkv_type { + TYPE_F32, + TYPE_F16, + TYPE_Q4_0, + TYPE_Q4_1, + TYPE_Q4_1_O, // Unsupported + TYPE_Q4_2, // Unsupported + TYPE_Q4_3, // Unsupported + TYPE_Q5_0, + TYPE_Q5_1, + TYPE_Q8_0, + TYPE_COUNT +}; + #define GGML_TYPE_UNKNOWN GGML_TYPE_COUNT -#define FORMAT_TYPE_COUNT 10 - -static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[FORMAT_TYPE_COUNT] = { - GGML_TYPE_F32, - GGML_TYPE_F16, - GGML_TYPE_Q4_0, - GGML_TYPE_Q4_1, - GGML_TYPE_UNKNOWN, // Unused - GGML_TYPE_UNKNOWN, // Unused - GGML_TYPE_UNKNOWN, // Unused - GGML_TYPE_Q5_0, - GGML_TYPE_Q5_1, - GGML_TYPE_Q8_0 +extern const enum ggml_type rwkv_type_to_ggml[TYPE_COUNT + 1] = { + GGML_TYPE_F32, /* F32 */ + GGML_TYPE_F16, /* F16 */ + GGML_TYPE_Q4_0, /* Q4_0 */ + GGML_TYPE_Q4_1, /* Q4_1 */ + GGML_TYPE_UNKNOWN, /* Q4_1_O */ + GGML_TYPE_UNKNOWN, /* Q4_2 */ + GGML_TYPE_UNKNOWN, /* Q4_3 */ + GGML_TYPE_Q5_0, /* Q5_0 */ + GGML_TYPE_Q5_1, /* Q5_1 */ + GGML_TYPE_Q8_0, /* Q8_0 */ + GGML_TYPE_COUNT /* COUNT */ }; -static bool is_non_quantized_format_type(int32_t format_type) { - return format_type == 0 || format_type == 1; +extern const enum rwkv_type rwkv_type_from_ggml[GGML_TYPE_COUNT + 1] = { + TYPE_F32, /* F32 */ + TYPE_F16, /* F16 */ + TYPE_Q4_0, /* Q4_0 */ + TYPE_Q4_1, /* Q4_1 */ + TYPE_Q4_2, /* Q4_2 */ + TYPE_Q4_3, /* Q4_3 */ + TYPE_Q5_0, /* Q5_0 */ + TYPE_Q5_1, /* Q5_1 */ + TYPE_Q8_0, /* Q8_0 */ + TYPE_COUNT, /* Q8_1 */ + TYPE_COUNT, /* I8 */ + TYPE_COUNT, /* I16 */ + TYPE_COUNT, /* I32 */ + TYPE_COUNT, /* COUNT */ +}; + +extern const char * rwkv_type_to_string[TYPE_COUNT + 1] = {"float32", "float16", "Q4_0", "Q4_1", "Q4_1_O", "Q4_2", "Q4_3", "Q5_0", "Q5_1", "Q8_0", "unknown"}; + +enum rwkv_type rwkv_type_from_string(const char * str) { + for (int ord = 0; ord < TYPE_COUNT; ord++) { + if (strcmp(str, rwkv_type_to_string[ord]) == 0) { + return (enum rwkv_type) ord; + } + } + + return TYPE_UNKNOWN; } -static bool is_quantized_format_type(int32_t format_type) { - return format_type == 2 || - format_type == 3 || - format_type == 7 || - format_type == 8 || - format_type == 9; +struct rwkv_file_header { + uint32_t magic; + uint32_t version; + uint32_t n_vocab; + uint32_t n_embed; + uint32_t n_layer; + uint32_t data_type; +}; + +bool rwkv_is_file_version_in_range(uint32_t version) { + return version >= RWKV_FILE_VERSION_MIN && version <= RWKV_FILE_VERSION_MAX; } -static int32_t format_name_to_format_type(const char * format_name) { - if (strcmp(format_name, "Q4_0") == 0) return 2; - if (strcmp(format_name, "Q4_1") == 0) return 3; - if (strcmp(format_name, "Q5_0") == 0) return 7; - if (strcmp(format_name, "Q5_1") == 0) return 8; - if (strcmp(format_name, "Q8_0") == 0) return 9; +bool rwkv_fread_file_header(FILE * file, struct rwkv_file_header & header, bool verify_data_type = true) { + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, sizeof(struct rwkv_file_header), &header)); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_MAGIC, header.magic == RWKV_FILE_MAGIC); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_VERSION, rwkv_is_file_version_in_range(header.version), "Unsupported file version %" PRId32, header.version); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_DATA_TYPE, header.data_type < TYPE_COUNT, "Model data type out of range (%" PRId32 " > %" PRId32 ")", header.data_type, TYPE_COUNT - 1); - return -1; + if (verify_data_type) { + enum ggml_type ggml_type = rwkv_type_to_ggml[header.data_type]; + + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_DATA_TYPE, + ggml_type != GGML_TYPE_UNKNOWN, + "Models in %s format cannot be loaded anymore because the format was removed.\n" + "You need to quantize the model into another format or use an older version of rwkv.cpp.\n" + "See https://github.com/saharNooby/rwkv.cpp#compatibility for more info", + rwkv_type_to_string[header.data_type] + ); + + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_DATA_TYPE, + (!ggml_is_quantized(ggml_type) || header.version == RWKV_FILE_VERSION_1), + "The quantized model file in %s format was created with an old version of rwkv.cpp and can not be loaded anymore.\n" + "You need to requantize the model or use an older version of rwkv.cpp.\n" + "See https://github.com/saharNooby/rwkv.cpp#compatibility for more info", + rwkv_type_to_string[header.data_type] + ); + } + + return true; } -// --- Model definition and loading utilities --- +bool rwkv_fwrite_file_header(FILE * file, const struct rwkv_file_header & header) { + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_data(file, &header, sizeof(struct rwkv_file_header))); + return true; +} + +struct rwkv_tensor_header { + uint32_t dim_count; + uint32_t key_length; + uint32_t data_type; + uint32_t width; + uint32_t height; +}; + +struct rwkv_tensor { + struct rwkv_tensor_header header; + std::string name; + uint8_t * data; +}; + +bool rwkv_fread_tensor_header(FILE * file, struct rwkv_tensor_header & header) { + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, sizeof(struct rwkv_tensor_header) - sizeof(uint32_t), &header)); + header.height = 1; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_SHAPE, header.dim_count == 1 || header.dim_count == 2, "Tensor has an invalid shape (%" PRId32 " dimensions)", header.dim_count); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_DATA_TYPE, header.data_type < TYPE_COUNT, "Tensor data type out of range (%" PRId32 " > %" PRId32 ")", header.data_type, TYPE_COUNT - 1); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_DATA_TYPE, rwkv_type_to_ggml[header.data_type] != GGML_TYPE_UNKNOWN, "Tensor data type (%s) is no longer supported", rwkv_type_to_string[header.data_type]); + + if (header.dim_count == 2) { + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_uint32(file, header.height)); + } + + return true; +} + +bool rwkv_fwrite_tensor_header(FILE * file, const struct rwkv_tensor_header & header) { + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_data(file, &header, sizeof(struct rwkv_tensor_header) - (header.dim_count == 1 ? sizeof(uint32_t) : 0))); + return true; +} + +size_t rwkv_tensor_size(enum ggml_type type, const int64_t width, const int64_t height = 1) { + struct ggml_tensor decoy {}; + decoy.type = type; + decoy.ne[0] = width; + decoy.ne[1] = height; + decoy.ne[2] = 1; + decoy.ne[3] = 1; + return ggml_nbytes(&decoy); +} + +size_t rwkv_tensor_size(const struct rwkv_tensor_header & header) { + return rwkv_tensor_size(rwkv_type_to_ggml[header.data_type], header.width, header.height); +} + +bool rwkv_fskip_tensor_data(FILE * file, const struct rwkv_tensor_header & header) { + return fseek(file, header.key_length + rwkv_tensor_size(header), SEEK_CUR) == 0; +} + +bool rwkv_fread_tensor_header_and_skip(FILE * file, struct rwkv_tensor_header & header) { + RWKV_ENSURE_OR_FALSE(rwkv_fread_tensor_header(file, header)); + RWKV_ASSERT_FALSE(RWKV_ERROR_DATA, rwkv_fskip_tensor_data(file, header)); + return true; +} + +bool rwkv_fread_tensor_data(FILE * file, struct rwkv_tensor & output, void * buffer = NULL) { + size_t data_size = rwkv_tensor_size(output.header); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_string(file, output.header.key_length, output.name)); + + if (buffer) { + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, data_size, buffer)); + } else { + output.data = NULL; + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fskip_tensor_data(file, output.header)); + } + + return true; +} + +bool rwkv_fread_tensor(FILE * file, struct rwkv_tensor & output, void * buffer = NULL) { + RWKV_ENSURE_OR_FALSE(rwkv_fread_tensor_header(file, output.header)); + RWKV_ENSURE_OR_FALSE(rwkv_fread_tensor_data(file, output, buffer)); + return true; +} + +bool rwkv_fwrite_tensor(FILE * file, const struct rwkv_tensor & tensor) { + RWKV_ENSURE_OR_FALSE(rwkv_fwrite_tensor_header(file, tensor.header)); + RWKV_ENSURE_OR_FALSE(rwkv_fwrite_string(file, tensor.name)); + RWKV_ENSURE_OR_FALSE(rwkv_fwrite_data(file, tensor.data, rwkv_tensor_size(tensor.header))); + return true; +} + +// --- Model definition --- struct rwkv_layer { struct ggml_tensor * ln1_weight; @@ -178,18 +394,14 @@ struct rwkv_layer { }; struct rwkv_model { - uint32_t n_vocab; - uint32_t n_layer; - uint32_t n_embed; - // 0 for float32, 1 for float16. - int32_t data_type; + struct rwkv_file_header header; struct ggml_tensor * emb; struct ggml_tensor * ln0_weight; struct ggml_tensor * ln0_bias; - std::vector layers; + std::unique_ptr layers; struct ggml_tensor * ln_out_weight; struct ggml_tensor * ln_out_bias; @@ -197,23 +409,6 @@ struct rwkv_model { struct ggml_tensor * head; }; -// Finds model parameter by key and sets it into dest. -// If the parameter was not found, returns false. -bool set_parameter(std::unordered_map * parameters, std::string key, struct ggml_tensor ** dest) { - struct ggml_tensor * parameter = (*parameters)[key]; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_PARAM_MISSING, parameter != NULL, "Parameter %s not found in model file", key.c_str()); - *dest = parameter; - return true; -} - -// Finds block parameter by block index and key and sets it into dest. -// If the parameter was not found, returns false. -bool set_block_parameter(std::unordered_map * parameters, int32_t block_index, std::string key, struct ggml_tensor ** dest) { - char full_key[128]; - sprintf(full_key, "blocks.%d.%s", block_index, key.c_str()); - return set_parameter(parameters, full_key, dest); -} - // --- Operators --- void rwkv_exp_impl(const int n_cols, float * dest, const float * src) { @@ -264,22 +459,410 @@ struct ggml_tensor * rwkv_layer_norm(ggml_context * ctx, struct ggml_tensor * x, // --- Implementation --- +struct rwkv_layer_state { + struct ggml_tensor * ffn_xx; + struct ggml_tensor * att_xx; + struct ggml_tensor * att_aa; + struct ggml_tensor * att_bb; + struct ggml_tensor * att_pp; +}; + struct rwkv_graph { - struct ggml_tensor * state; - std::unique_ptr state_parts; + struct ggml_tensor * input_state; + std::unique_ptr input_layers; + std::unique_ptr output_layers; struct ggml_tensor * token_index; struct ggml_tensor * logits; std::unique_ptr cgraph; }; struct rwkv_context { - std::unique_ptr model; + struct rwkv_model model; struct ggml_context * ctx; + std::unique_ptr scratch; struct rwkv_graph graph; enum rwkv_error_flags last_error; bool print_errors; + size_t gpu_layers; size_t vram_total; - int gpu_layers; +}; + +bool rwkv_fread_ggml_tensor_data(FILE * file, const struct rwkv_tensor_header & header, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_string(file, header.key_length, name), "Failed to read tensor name"); + + enum ggml_type ggml_type = rwkv_type_to_ggml[header.data_type]; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_UNSUPPORTED, ggml_type != GGML_TYPE_UNKNOWN, "Unsupported tensor data type %s from %s", rwkv_type_to_string[header.data_type], name.c_str()); + + tensor = header.dim_count == 1 + ? ggml_new_tensor_1d(ctx, ggml_type, header.width) + : ggml_new_tensor_2d(ctx, ggml_type, header.width, header.height); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); + ggml_set_name(tensor, name.c_str()); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, ggml_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str()); + return true; +} + +bool rwkv_fread_ggml_tensor(FILE * file, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) { + struct rwkv_tensor_header header; + RWKV_ENSURE_OR_FALSE_MSG(rwkv_fread_tensor_header(file, header), "Invalid tensor header"); + return rwkv_fread_ggml_tensor_data(file, header, ctx, name, tensor); +} + +template // https://stackoverflow.com/a/6458689 +bool rwkv_set_params(struct rwkv_model & model, F callback) { + RWKV_ENSURE_OR_FALSE(callback("emb.weight", model.emb)); + RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.weight", model.ln0_weight)); + RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.bias", model.ln0_bias)); + + uint32_t n_layer = model.header.n_layer; + std::unique_ptr layers(new(std::nothrow) struct rwkv_layer [n_layer]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, layers.get(), "Failed to allocate model layers"); + model.layers = std::move(layers); + + for (uint32_t i = 0; i < n_layer; i++) { + char buffer[128]; + size_t offset = sprintf(buffer, "blocks.%" PRId32 ".", i); + + rwkv_layer & layer = model.layers[i]; + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.weight"), buffer), layer.ln1_weight)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.bias"), buffer), layer.ln1_bias)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_r"), buffer), layer.att_time_mix_r)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_first"), buffer), layer.att_time_first)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay"), buffer), layer.att_time_decay)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.weight"), buffer), layer.ln2_weight)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.bias"), buffer), layer.ln2_bias)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_k"), buffer), layer.ffn_time_mix_k)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_r"), buffer), layer.ffn_time_mix_r)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.key.weight"), buffer), layer.ffn_key)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.value.weight"), buffer), layer.ffn_value)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.receptance.weight"), buffer), layer.ffn_receptance)); + } + + RWKV_ENSURE_OR_FALSE(callback("ln_out.weight", model.ln_out_weight)); + RWKV_ENSURE_OR_FALSE(callback("ln_out.bias", model.ln_out_bias)); + RWKV_ENSURE_OR_FALSE(callback("head.weight", model.head)); + return true; +} + +struct rwkv_ctx_size { + size_t objects_count = 0; + size_t objects_size = 0; + size_t scratch_size = 0; +}; + +void rwkv_ctx_size_add_objects(struct rwkv_ctx_size & ctx_size, size_t objects, size_t object_size = sizeof(struct ggml_tensor)) { + ctx_size.objects_count += objects; + ctx_size.objects_size += ((object_size + 15) & ~15) * objects; +} + +void rwkv_ctx_size_add_scratch(struct rwkv_ctx_size & ctx_size, size_t length, size_t count = 1) { + ctx_size.scratch_size += ((length + 15) & ~15) * count; +} + +void rwkv_ctx_size_add(struct rwkv_ctx_size & ctx_size, size_t objects, size_t scratch = 0, size_t scratches = 1) { + rwkv_ctx_size_add_objects(ctx_size, objects); + rwkv_ctx_size_add_scratch(ctx_size, scratch, scratches); +} + +void rwkv_ctx_size_add(struct rwkv_ctx_size & ctx_size, size_t count, const struct rwkv_ctx_size & other) { + ctx_size.objects_count += other.objects_count * count; + ctx_size.objects_size += other.objects_size * count; + ctx_size.scratch_size += other.scratch_size * count; +} + +void rwkv_ctx_size_add_tensor(struct rwkv_ctx_size & ctx_size, const uint64_t tensors, const uint64_t views, const enum ggml_type type, const uint64_t width, const uint64_t height = 1) { + rwkv_ctx_size_add_objects(ctx_size, tensors + views); + rwkv_ctx_size_add_scratch(ctx_size, rwkv_tensor_size(type, width, height), tensors); +} + +void rwkv_ctx_size_add_tensor(struct rwkv_ctx_size & size, const uint64_t tensors, const uint64_t views, const struct rwkv_tensor_header & header) { + rwkv_ctx_size_add_tensor(size, tensors, views, rwkv_type_to_ggml[header.data_type], header.width, header.height); +} + +struct rwkv_ctx_size rwkv_single_att_size(const size_t n_embed = 0) { + size_t ptr_nelem = sizeof(void *) / sizeof(uint32_t); + + struct rwkv_ctx_size ctx_size; + + /* x0 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); + + /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 3, 1, GGML_TYPE_F32, n_embed); + /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); + /* xv */ rwkv_ctx_size_add_tensor(ctx_size, 3, 1, GGML_TYPE_F32, n_embed); + /* xv */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); + /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 3, 1, GGML_TYPE_F32, n_embed); + /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); + + /* r */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); + /* r */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); + /* k */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); + /* v */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); + + /* ww */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); + /* qq */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); + /* qq */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); + /* e1 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); + /* e1 */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); + /* e2 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); + /* e2 */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); + + /* a */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); + /* b */ rwkv_ctx_size_add_tensor(ctx_size, 1, 1, GGML_TYPE_F32, n_embed); + + /* ww */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); + /* qq */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); + /* qq */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); + /* e1 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); + /* e1 */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); + /* e2 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); + /* e2 */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); + + /* xx */ rwkv_ctx_size_add_tensor(ctx_size, 0, 0, GGML_TYPE_F32, n_embed); + /* aa */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); + /* bb */ rwkv_ctx_size_add_tensor(ctx_size, 1, 1, GGML_TYPE_F32, n_embed); + /* pp */ rwkv_ctx_size_add_tensor(ctx_size, 0, 0, GGML_TYPE_F32, n_embed); + + /* wkv */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); + /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); + + return ctx_size; +} + +struct ggml_tensor * rwkv_single_att(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer & layer, struct rwkv_layer_state & state) { + // self.layer_norm(x, self.w.blocks[i].ln1) + struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); + + // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) + struct ggml_tensor * xk = ggml_add_inplace(ctx, + ggml_mul(ctx, x0, layer.att_time_mix_k), + ggml_mul(ctx, state.att_xx, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) + ); + + // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) + struct ggml_tensor * xv = ggml_add_inplace(ctx, + ggml_mul(ctx, x0, layer.att_time_mix_v), + ggml_mul(ctx, state.att_xx, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) + ); + + // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) + struct ggml_tensor * xr = ggml_add_inplace(ctx, + ggml_mul(ctx, x0, layer.att_time_mix_r), + ggml_mul(ctx, state.att_xx, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) + ); + + // r = torch.sigmoid(rw @ xr) + struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr)); + // k = kw @ xk + struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk); + // v = vw @ xv + struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv); + + // ww = time_first + k + struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k); + // qq = torch.maximum(pp, ww) + struct ggml_tensor * qq = rwkv_max(ctx, state.att_pp, ww); + // e1 = torch.exp(pp - qq) + struct ggml_tensor * e1 = rwkv_exp(ctx, ggml_sub(ctx, state.att_pp, qq)); + // e2 = torch.exp(ww - qq) + struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); + + // a = e1 * aa + e2 * v + struct ggml_tensor * a = ggml_add_inplace(ctx, ggml_mul(ctx, e1, state.att_aa), ggml_mul(ctx, e2, v)); + // b = e1 * bb + e2 + struct ggml_tensor * b = ggml_add_inplace(ctx, ggml_mul(ctx, e1, state.att_bb), e2); + + // ww = pp + time_decay + ww = ggml_add(ctx, state.att_pp, layer.att_time_decay); + // qq = torch.maximum(ww, k) + qq = rwkv_max(ctx, ww, k); + // e1 = torch.exp(ww - qq) + e1 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); + // e2 = torch.exp(k - qq) + e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq)); + + // state[5 * i + 1] = x0 + // state[5 * i + 2] = e1 * aa + e2 * v + // state[5 * i + 3] = e1 * bb + e2 + // state[5 * i + 4] = qq + state.att_xx = x0; + state.att_aa = ggml_add_inplace(ctx, ggml_mul(ctx, e1, state.att_aa), ggml_mul(ctx, e2, v)); + state.att_bb = ggml_add_inplace(ctx, ggml_mul(ctx, e1, state.att_bb), e2); + state.att_pp = qq; + + // wkv = a / b + struct ggml_tensor * wkv = ggml_div(ctx, a, b); + + // ow @ (r * wkv) + return ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, wkv))); +} + +struct rwkv_ctx_size rwkv_single_ffn_size(const size_t n_embed = 0, const size_t ffn_key = 0) { + size_t ptr_nelem = sizeof(void *) / sizeof(uint32_t); + + struct rwkv_ctx_size ctx_size; + + /* x0 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); + + /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 3, 1, GGML_TYPE_F32, n_embed); + /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); + /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 3, 1, GGML_TYPE_F32, n_embed); + /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); + + /* xx */ rwkv_ctx_size_add_tensor(ctx_size, 0, 0, GGML_TYPE_F32, n_embed); + + /* r */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); + /* r */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); + /* k */ rwkv_ctx_size_add_tensor(ctx_size, 3, 0, GGML_TYPE_F32, ffn_key); + + /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); + + return ctx_size; +} + +struct ggml_tensor * rwkv_single_ffn(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer & layer, struct rwkv_layer_state & state) { + // self.layer_norm(x, self.w.blocks[i].ln2) + struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); + + // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) + struct ggml_tensor * xk = ggml_add_inplace( + ctx, + ggml_mul(ctx, x0, layer.ffn_time_mix_k), + ggml_mul(ctx, state.ffn_xx, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) + ); + + // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) + struct ggml_tensor * xr = ggml_add_inplace( + ctx, + ggml_mul(ctx, x0, layer.ffn_time_mix_r), + ggml_mul(ctx, state.ffn_xx, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) + ); + + // state[5 * i + 0] = x + state.ffn_xx = x0; + + // r = torch.sigmoid(rw @ xr) + struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr)); + + // k = torch.square(torch.relu(kw @ xk)) + struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); + + // r * (vw @ k) + return ggml_add_inplace(ctx, x, ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k))); +} + +struct rwkv_ctx_size rwkv_single_graph_size(const size_t n_vocab = 0, const size_t n_embed = 0, const size_t n_layer = 0, const size_t ffn_key = 0) { + size_t ptr_nelem = sizeof(void *) / sizeof(uint32_t); + + struct rwkv_ctx_size ctx_size; + + /* state */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_layer * 5 * n_embed); + /* token */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, 1); + /* x */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); + /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); + + /* ffn_xx */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed); + /* att_xx */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed); + /* att_aa */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed); + /* att_bb */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed); + /* att_pp */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed); + + /* att */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_single_att_size(n_embed)); + /* ffn */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_single_ffn_size(n_embed, ffn_key)); + + /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); + /* logits */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_vocab); + + return ctx_size; +} + +bool rwkv_single_graph(struct ggml_context * ctx, struct rwkv_model & model, const uint32_t n_threads, struct rwkv_graph & out) { + std::unique_ptr cgraph(new(std::nothrow) struct ggml_cgraph()); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, cgraph.get(), "Failed to allocate graph"); + cgraph->n_threads = n_threads; + + size_t n_embed = model.header.n_embed; + size_t n_layer = model.header.n_layer; + + struct ggml_tensor * input_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_layer * 5 * n_embed); + size_t output_part_size = n_embed * sizeof(float); + + // We collect parts of input state here. Each part is (n_embed) vector. + std::unique_ptr input_layers(new(std::nothrow) struct rwkv_layer_state [n_layer]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, input_layers.get(), "Failed to allocate input state parts"); + + // We collect parts of output state here. Each part is (n_embed) vector. + std::unique_ptr output_layers(new(std::nothrow) struct rwkv_layer_state [n_layer]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, output_layers.get(), "Failed to allocate output state parts"); + + // x = self.w.emb.weight[token] + struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, token_index); + + // x = self.layer_norm(x, self.w.blocks[0].ln0) + x = rwkv_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias); + + for (size_t i = 0; i < n_layer; i++) { + struct rwkv_layer & layer = model.layers[i]; + struct rwkv_layer_state & input_layer = input_layers[i]; + struct rwkv_layer_state & output_layer = output_layers[i]; + + size_t state_index = i * 5; + input_layer.ffn_xx = ggml_view_1d(ctx, input_state, n_embed, output_part_size * (state_index + 0)); + input_layer.att_xx = ggml_view_1d(ctx, input_state, n_embed, output_part_size * (state_index + 1)); + input_layer.att_aa = ggml_view_1d(ctx, input_state, n_embed, output_part_size * (state_index + 2)); + input_layer.att_bb = ggml_view_1d(ctx, input_state, n_embed, output_part_size * (state_index + 3)); + input_layer.att_pp = ggml_view_1d(ctx, input_state, n_embed, output_part_size * (state_index + 4)); + output_layer = input_layer; + + x = rwkv_single_att(ctx, x, layer, output_layer); + x = rwkv_single_ffn(ctx, x, layer, output_layer); + } + + // x = self.layer_norm(x, self.w.ln_out) + x = rwkv_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias); + + // x = (self.w.head.weight @ x).float() + struct ggml_tensor * logits = ggml_mul_mat(ctx, model.head, x); + + ggml_build_forward_expand(cgraph.get(), logits); + + for (uint32_t i = 0; i < n_layer; i++) { + struct rwkv_layer_state & output_layer = output_layers[i]; + ggml_build_forward_expand(cgraph.get(), output_layer.ffn_xx); + ggml_build_forward_expand(cgraph.get(), output_layer.att_xx); + ggml_build_forward_expand(cgraph.get(), output_layer.att_aa); + ggml_build_forward_expand(cgraph.get(), output_layer.att_bb); + ggml_build_forward_expand(cgraph.get(), output_layer.att_pp); + } + + out.input_state = input_state; + out.input_layers = std::move(input_layers); + out.output_layers = std::move(output_layers); + out.token_index = token_index; + out.logits = logits; + out.cgraph = std::move(cgraph); + return true; +} + +struct rwkv_file_guard { + FILE * file; + ~rwkv_file_guard() { if (file) { fclose(file); } } +}; + +struct rwkv_ggml_guard { + struct ggml_context * ctx; + ~rwkv_ggml_guard() { if (ctx) { ggml_free(ctx); } } }; void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) { @@ -298,179 +881,7 @@ enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) { return value; } -bool rwkv_build_graph(struct ggml_context * ctx, struct rwkv_model * model, const uint32_t n_threads, struct rwkv_graph * out) { - std::unique_ptr cgraph(new(std::nothrow) struct ggml_cgraph()); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, cgraph.get(), "Failed to allocate graph"); - cgraph->n_threads = n_threads; - - size_t n_embed = model->n_embed, n_layer = model->n_layer; - struct ggml_tensor * state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_layer * 5 * n_embed); - - // We collect parts of new state here. Each part is (n_embed) vector. - std::unique_ptr state_parts(new(std::nothrow) ggml_tensor * [n_layer * 5]); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, state_parts.get(), "Failed to allocate state parts"); - - // x = self.w.emb.weight[token] - struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); - struct ggml_tensor * x = ggml_get_rows(ctx, model->emb, token_index); - - // x = self.layer_norm(x, self.w.blocks[0].ln0) - x = rwkv_layer_norm(ctx, x, model->ln0_weight, model->ln0_bias); - - for (size_t i = 0; i < n_layer; i++) { - struct rwkv_layer layer = model->layers[i]; - size_t part_index = i * 5; - size_t state_part_size = n_embed * sizeof(float); - - // RWKV/time mixing - { - // self.layer_norm(x, self.w.blocks[i].ln1) - struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); - - // x0 = state[5 * i + 1] - struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (part_index + 1) * state_part_size); - // aa = state[5 * i + 2] - struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (part_index + 2) * state_part_size); - // bb = state[5 * i + 3] - struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (part_index + 3) * state_part_size); - // pp = state[5 * i + 4] - struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (part_index + 4) * state_part_size); - - // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) - struct ggml_tensor * xk = ggml_add_inplace(ctx, - ggml_mul(ctx, x0, layer.att_time_mix_k), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) - ); - - // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) - struct ggml_tensor * xv = ggml_add_inplace(ctx, - ggml_mul(ctx, x0, layer.att_time_mix_v), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) - ); - - // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) - struct ggml_tensor * xr = ggml_add_inplace(ctx, - ggml_mul(ctx, x0, layer.att_time_mix_r), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) - ); - - // r = torch.sigmoid(rw @ xr) - struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr)); - // k = kw @ xk - struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk); - // v = vw @ xv - struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv); - - // ww = time_first + k - struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k); - // qq = torch.maximum(pp, ww) - struct ggml_tensor * qq = rwkv_max(ctx, pp, ww); - // e1 = torch.exp(pp - qq) - struct ggml_tensor * e1 = rwkv_exp(ctx, ggml_sub(ctx, pp, qq)); - // e2 = torch.exp(ww - qq) - struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); - - // a = e1 * aa + e2 * v - struct ggml_tensor * a = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v)); - // b = e1 * bb + e2 - struct ggml_tensor * b = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2); - - // ww = pp + time_decay - ww = ggml_add_inplace(ctx, pp, layer.att_time_decay); - // qq = torch.maximum(ww, k) - qq = rwkv_max(ctx, ww, k); - // e1 = torch.exp(ww - qq) - e1 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); - // e2 = torch.exp(k - qq) - e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq)); - - // state[5 * i + 1] = x0 - // state[5 * i + 2] = e1 * aa + e2 * v - // state[5 * i + 3] = e1 * bb + e2 - // state[5 * i + 4] = qq - state_parts[part_index + 1] = x0; - state_parts[part_index + 2] = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v)); - state_parts[part_index + 3] = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2); - state_parts[part_index + 4] = qq; - - // wkv = a / b - struct ggml_tensor * wkv = ggml_div(ctx, a, b); - - // ow @ (r * wkv) - x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, wkv))); - } - - // FFN/channel mixing - { - // self.layer_norm(x, self.w.blocks[i].ln2) - struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); - - // x_prev = state[5 * i + 0] - struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, part_index * state_part_size); - - // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) - struct ggml_tensor * xk = ggml_add_inplace( - ctx, - ggml_mul(ctx, x0, layer.ffn_time_mix_k), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) - ); - - // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) - struct ggml_tensor * xr = ggml_add_inplace( - ctx, - ggml_mul(ctx, x0, layer.ffn_time_mix_r), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) - ); - - // state[5 * i + 0] = x - state_parts[part_index] = x0; - - // r = torch.sigmoid(rw @ xr) - struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr)); - - // k = torch.square(torch.relu(kw @ xk)) - struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); - - // r * (vw @ k) - x = ggml_add_inplace(ctx, x, ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k))); - } - } - - // x = self.layer_norm(x, self.w.ln_out) - x = rwkv_layer_norm(ctx, x, model->ln_out_weight, model->ln_out_bias); - - // x = (self.w.head.weight @ x).float() - struct ggml_tensor * logits = ggml_mul_mat(ctx, model->head, x); - - ggml_build_forward_expand(cgraph.get(), logits); - - for (uint32_t i = 0; i < n_layer * 5; i++) { - ggml_build_forward_expand(cgraph.get(), state_parts[i]); - } - - out->state = state; - out->state_parts = std::move(state_parts); - out->token_index = token_index; - out->logits = logits; - out->cgraph = std::move(cgraph); - return true; -} - -struct rwkv_file_guard { - FILE * file; - ~rwkv_file_guard() { if (file) fclose(file); } -}; - -struct rwkv_ggml_guard { - struct ggml_context * ctx; - ~rwkv_ggml_guard() { if (ctx) ggml_free(ctx); } -}; - struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) { - return rwkv_init_from_file(file_path, n_threads, 0); -} - -struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads, const uint32_t n_gpu_layers) { global_last_error = RWKV_ERROR_NONE; FILE * file = fopen(file_path, "rb"); @@ -478,147 +889,107 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t rwkv_file_guard file_guard { file }; // Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to the get file length. - struct stat64 file_stat; - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat64(fileno(file), &file_stat) == 0, "Failed to stat file %s", file_path); + struct stat file_stat; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(file), &file_stat) == 0, "Failed to stat file %s", file_path); - int32_t magic; - RWKV_ASSERT_NULL(RWKV_ERROR_FILE, read_int32(file, &magic, "magic")); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_MAGIC, magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic); + struct rwkv_file_header header; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE, rwkv_fread_file_header(file, header), "Invalid file header"); - int32_t version; - RWKV_ASSERT_NULL(RWKV_ERROR_FILE, read_int32(file, &version, "version")); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_VERSION, version >= RWKV_FILE_VERSION_MIN && version <= RWKV_FILE_VERSION_MAX, "Unsupported file version %d", version); + size_t tensors_start = ftell(file); + struct rwkv_ctx_size ctx_size; + size_t ffn_key = 0; - std::unique_ptr model(new(std::nothrow) struct rwkv_model()); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_ALLOC, model.get(), "Failed to allocate model"); + std::string name; + while ((size_t) ftell(file) < (size_t) file_stat.st_size) { + struct rwkv_tensor_header header; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_tensor_header(file, header), "Invalid tensor header"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(file, header.key_length, name), "Failed to read tensor name"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file, rwkv_tensor_size(header), SEEK_CUR) == 0, "Failed to read tensor data"); + rwkv_ctx_size_add_tensor(ctx_size, 1, 0, header); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_vocab, "n_vocab")); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_embed, "n_embed")); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_layer, "n_layer")); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_int32(file, &model->data_type, "data_type")); + if (ffn_key == 0 && name == "blocks.0.ffn.key.weight") { + ffn_key = header.height; + } + } - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_DATA_TYPE, model->data_type >= 0 && model->data_type < FORMAT_TYPE_COUNT, "Unsupported model data type %d", model->data_type); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, ffn_key, "Model is missing parameter blocks.0.ffn.key.weight"); - const char * unsupported_type_msg = "Models in %s format cannot be loaded anymore because the format was removed.\n" - "You need to quantize the model into another format or use an older version of rwkv.cpp.\n" - "See https://github.com/saharNooby/rwkv.cpp#compatibility for more info"; - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 4, unsupported_type_msg, "Q4_1_O"); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 5, unsupported_type_msg, "Q4_2"); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 6, unsupported_type_msg, "Q4_3"); + rwkv_ctx_size_add(ctx_size, 1, rwkv_single_graph_size(header.n_vocab, header.n_embed, header.n_layer, ffn_key)); + // And finally to end it all off: the graph work tensor + enum ggml_type mul_mat_type = ggml_is_quantized(rwkv_type_to_ggml[header.data_type]) ? GGML_TYPE_Q8_1 : rwkv_type_to_ggml[header.data_type]; + rwkv_ctx_size_add_objects(ctx_size, 1, sizeof(struct ggml_tensor) + rwkv_tensor_size(GGML_TYPE_I8, rwkv_tensor_size(mul_mat_type, ffn_key) * n_threads + 64 * (n_threads - 1))); - RWKV_ASSERT_NULL_MSG( - RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, - !is_quantized_format_type(model->data_type) || version >= RWKV_FILE_VERSION_1, - "The quantized model file was created with an old version of rwkv.cpp and can not be loaded anymore.\n" - "You need to requantize the model or use an older version of rwkv.cpp.\n" - "See https://github.com/saharNooby/rwkv.cpp#compatibility for more info" - ); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file, tensors_start, SEEK_SET) == 0, "Failed to seek in file"); - size_t memory_required = file_stat.st_size + - // Intermediary vectors for calculation; there are around 100 calls to ggml - size_t(100) * model->n_embed * sizeof(float) + - // State, in and out - size_t(2) * 5 * model->n_layer * model->n_embed * sizeof(float) + - // Logits - size_t(model->n_vocab) * sizeof(float) + - // +256 MB just for any overhead - // TODO This is too much for smaller models; need a more proper and robust way of measuring required memory - size_t(256) * 1024 * 1024; + std::unique_ptr scratch(new(std::nothrow) uint8_t [ctx_size.scratch_size]); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, scratch.get(), "Failed to allocate model scratch space"); - struct ggml_context * ctx = ggml_init({ memory_required, NULL, false }); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_ALLOC, ctx, "Failed to allocate GGML context"); + struct ggml_context * ctx = ggml_init({ ctx_size.objects_size + ctx_size.objects_count * GGML_OBJECT_SIZE, NULL, false}); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, ctx, "Failed to create GGML context"); rwkv_ggml_guard ggml_guard { ctx }; std::unordered_map parameters; + ggml_set_scratch(ctx, { 0, ctx_size.scratch_size, scratch.get() }); - while (true) { - int32_t dim_count, key_length, data_type; - RWKV_ASSERT_NULL_MSG( - RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_FILE_READ, - fread(&dim_count, sizeof(int32_t), 1, file) == 1 || feof(file), - "Failed to read an int32 value from a file (dim_count)" - ); - - if (feof(file)) { - break; - } - - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, read_int32(file, &key_length, "key_length")); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, read_int32(file, &data_type, "data_type")); - - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, dim_count == 1 || dim_count == 2, "Unsupported dimension count %d", dim_count); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_KEY, key_length > 0, "Non-positive key length %d", key_length); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, data_type >= 0 && data_type < FORMAT_TYPE_COUNT, "Unsupported parameter data type %d", data_type); - - ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type]; - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, ggml_data_type != GGML_TYPE_UNKNOWN, "Unsupported parameter data type %d", data_type); - + while ((size_t) ftell(file) < (size_t) file_stat.st_size) { + std::string name; struct ggml_tensor * tensor; - - if (dim_count == 1) { - int32_t x; - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, read_int32(file, &x, "x"), "Failed to read parameter length"); - tensor = ggml_new_tensor_1d(ctx, ggml_data_type, x); - } else { - int32_t x, y; - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, read_int32(file, &x, "x"), "Failed to read parameter width"); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, read_int32(file, &y, "y"), "Failed to read parameter height"); - tensor = ggml_new_tensor_2d(ctx, ggml_data_type, x, y); - } - - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); - - std::string key(key_length, 0); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_KEY, fread(&key[0], key_length, 1, file) == 1, "Failed to read parameter key"); - - size_t nbytes = ggml_nbytes(tensor); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, fread(tensor->data, nbytes, 1, file) == 1, "Failed to read parameter data"); - - parameters[key] = tensor; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_ggml_tensor(file, ctx, name, tensor), "Failed to read model params"); + parameters[std::move(name)] = tensor; } - file_guard = { NULL }; // close file + file = NULL; + file_guard = { NULL }; - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "emb.weight", &model->emb)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "blocks.0.ln0.weight", &model->ln0_weight)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "blocks.0.ln0.bias", &model->ln0_bias)); + struct rwkv_model model { header }; - model->layers.resize(model->n_layer); + std::unordered_map & parameters_ref = parameters; + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, rwkv_set_params(model, [&](const char * key, struct ggml_tensor *& dest) { + struct ggml_tensor * tensor = parameters_ref[key]; + RWKV_ENSURE_OR_FALSE_MSG(tensor, "Model parameter %s not found", key); + dest = tensor; + return true; + })); - for (uint32_t i = 0; i < model->n_layer; i++) { - rwkv_layer * layer = &model->layers[i]; - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ln1.weight", &layer->ln1_weight)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ln1.bias", &layer->ln1_bias)); + // Verify order of dimensions + struct ggml_tensor * emb = model.emb; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == header.n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == header.n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.time_mix_k", &layer->att_time_mix_k)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.time_mix_v", &layer->att_time_mix_v)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.time_mix_r", &layer->att_time_mix_r)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.time_first", &layer->att_time_first)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.time_decay", &layer->att_time_decay)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.key.weight", &layer->att_key)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.value.weight", &layer->att_value)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.receptance.weight", &layer->att_receptance)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "att.output.weight", &layer->att_output)); + // Build graph + struct rwkv_graph graph; + RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_single_graph(ctx, model, n_threads, graph)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ln2.weight", &layer->ln2_weight)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ln2.bias", &layer->ln2_bias)); + std::unique_ptr rwkv_ctx(new(std::nothrow) struct rwkv_context()); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx.get(), "Failed to allocate context"); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.time_mix_k", &layer->ffn_time_mix_k)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.time_mix_r", &layer->ffn_time_mix_r)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.key.weight", &layer->ffn_key)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.value.weight", &layer->ffn_value)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ffn.receptance.weight", &layer->ffn_receptance)); - } + // Don't free ggml context + ggml_guard.ctx = NULL; + rwkv_ctx->model = std::move(model); + rwkv_ctx->ctx = ctx; + rwkv_ctx->scratch = std::move(scratch); + rwkv_ctx->graph = std::move(graph); + rwkv_ctx->last_error = RWKV_ERROR_NONE; + rwkv_ctx->print_errors = global_print_errors; + rwkv_ctx->gpu_layers = 0; + rwkv_ctx->vram_total = 0; - int n_gpu = 0; - size_t vram_total = 0; + ggml_set_scratch(ctx, { 0, 0, NULL }); + return rwkv_ctx.release(); +} + +bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers) { #ifdef GGML_USE_CUBLAS { - n_gpu = std::min(n_gpu_layers, model->n_layer); + size_t n_gpu = std::min(n_gpu_layers, ctx->model.header.n_layer); - for (int i = 0; i < n_gpu; ++i) { - const auto & layer = model->layers[i]; + size_t gpu_layers = ((struct rwkv_context *) ctx)->gpu_layers; + size_t vram_total = ((struct rwkv_context *) ctx)->vram_total; + + for (size_t i = 0; i < n_gpu; i++) { + const struct rwkv_layer & layer = ctx->model.layers[i]; // Use cuBLAS only for heavy matrices; other operations are not supported for GPU at the moment ggml_cuda_transform_tensor(layer.att_key); vram_total += ggml_nbytes(layer.att_key); @@ -629,326 +1000,229 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t ggml_cuda_transform_tensor(layer.ffn_key); vram_total += ggml_nbytes(layer.ffn_key); ggml_cuda_transform_tensor(layer.ffn_value); vram_total += ggml_nbytes(layer.ffn_value); ggml_cuda_transform_tensor(layer.ffn_receptance); vram_total += ggml_nbytes(layer.ffn_receptance); + + gpu_layers++; } } #endif - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "ln_out.weight", &model->ln_out_weight)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "ln_out.bias", &model->ln_out_bias)); - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "head.weight", &model->head)); - - // Verify order of dimensions - struct ggml_tensor * emb = model->emb; - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]); - - // Build graph - struct rwkv_graph graph; - RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_graph(ctx, model.get(), n_threads, &graph)); - - std::unique_ptr rwkv_ctx(new(std::nothrow) struct rwkv_context()); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx.get(), "Failed to allocate context"); - rwkv_ctx->model = std::move(model); - rwkv_ctx->ctx = ctx; - rwkv_ctx->graph = std::move(graph); - rwkv_ctx->last_error = RWKV_ERROR_NONE; - rwkv_ctx->print_errors = global_print_errors; - rwkv_ctx->gpu_layers = n_gpu; - rwkv_ctx->vram_total = vram_total; - // Don't free ggml context - ggml_guard.ctx = NULL; - return rwkv_ctx.release(); -} - -uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) { - return ctx->model->n_layer * 5 * ctx->model->n_embed; -} - -uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) { - return ctx->model->n_vocab; + return true; } bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) { ((struct rwkv_context *) ctx)->last_error = RWKV_ERROR_NONE; - RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, state_out != NULL, "state_out is NULL"); - RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, logits_out != NULL, "logits_out is NULL"); - RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < ctx->model->n_vocab, "Token is out of range 0..%d", ctx->model->n_vocab - 1); + const struct rwkv_file_header & header = ctx->model.header; + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < header.n_vocab, "Token is out of range 0..%d", header.n_vocab - 1); - const struct rwkv_graph * graph = &ctx->graph; - size_t n_layer = ctx->model->n_layer; - size_t n_embed = ctx->model->n_embed; - - ggml_set_i32_1d(graph->token_index, 0, token); + const struct rwkv_graph & graph = ctx->graph; + ggml_set_i32_1d(graph.token_index, 0, token); if (state_in == NULL) { - ggml_set_f32(graph->state, 0.0F); - - for (size_t i = 0; i < n_layer; i++) { - // state[5 * i + 4] = -1e30 - ggml_set_f32( - ggml_view_1d(ctx->ctx, graph->state, n_embed, (5 * i + 4) * n_embed * sizeof(float)), - -1e30F - ); + for (size_t i = 0; i < header.n_layer; i++) { + struct rwkv_layer_state & layer = graph.input_layers[i]; + ggml_set_f32(layer.ffn_xx, 0.0F); + ggml_set_f32(layer.att_xx, 0.0F); + ggml_set_f32(layer.att_aa, 0.0F); + ggml_set_f32(layer.att_bb, 0.0F); + ggml_set_f32(layer.att_pp, -1e30F); } } else { - memcpy(graph->state->data, state_in, graph->state->ne[0] * sizeof(float)); + memcpy(graph.input_state->data, state_in, ggml_nbytes(graph.input_state)); } - ggml_graph_compute(ctx->ctx, graph->cgraph.get()); + ggml_graph_compute(ctx->ctx, graph.cgraph.get()); - for (size_t i = 0; i < n_layer * 5; i++) { - struct ggml_tensor * part = graph->state_parts[i]; - memcpy(state_out + i * n_embed, part->data, part->ne[0] * sizeof(float)); + if (state_out) { + size_t part_size = rwkv_tensor_size(GGML_TYPE_F32, header.n_embed); + for (size_t i = 0; i < header.n_layer; i++) { + struct rwkv_layer_state & layer = graph.output_layers[i]; + + float * dest = state_out + i * header.n_embed * 5; + memcpy(dest + header.n_embed * 0, layer.ffn_xx->data, part_size); + memcpy(dest + header.n_embed * 1, layer.att_xx->data, part_size); + memcpy(dest + header.n_embed * 2, layer.att_aa->data, part_size); + memcpy(dest + header.n_embed * 3, layer.att_bb->data, part_size); + memcpy(dest + header.n_embed * 4, layer.att_pp->data, part_size); + } } - memcpy(logits_out, graph->logits->data, graph->logits->ne[0] * sizeof(float)); + if (logits_out) { + memcpy(logits_out, graph.logits->data, ggml_nbytes(graph.logits)); + } return true; } +uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) { + return ctx->model.header.n_layer * 5 * ctx->model.header.n_embed; +} + +uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) { + return ctx->model.header.n_vocab; +} + void rwkv_free(struct rwkv_context * ctx) { std::unique_ptr rwkv_ctx(ctx); ggml_free(ctx->ctx); } -bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name) { +bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const char * type_name) { global_last_error = RWKV_ERROR_NONE; - int32_t format_data_type = format_name_to_format_type(format_name); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_DATA_TYPE, format_data_type != -1, "Unsupported format \"%s\"", format_name); + enum ggml_type out_type = rwkv_type_to_ggml[rwkv_type_from_string(type_name)]; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_DATA_TYPE, ggml_is_quantized(out_type), "Unsupported output data type (%s)", rwkv_type_to_string[rwkv_type_from_ggml[out_type]]); - ggml_type format_ggml_type = FORMAT_TYPE_TO_GGML_TYPE[format_data_type]; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_DATA_TYPE, format_ggml_type != GGML_TYPE_UNKNOWN, "Unsupported format \"%s\"", format_name); + RWKV_MSG("Loading model from '%s'\n", in_path); - // Needed to initialize FP16 lookup table - ggml_free(ggml_init({ 0, NULL, false })); + struct stat in_stat; + FILE * in_file = fopen(in_path, "rb"); + rwkv_file_guard in_guard { in_file }; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, in_file, "Failed to open %s for reading", in_path); - printf("Loading model from '%s'\n", model_file_path_in); + FILE * out_file = fopen(out_path, "wb"); + rwkv_file_guard out_guard { out_file }; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, out_file, "Failed to open %s for writing", out_path); - FILE * file_in = fopen(model_file_path_in, "rb"); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file_in, "Failed to open %s for reading", model_file_path_in); - FILE * file_out = fopen(model_file_path_out, "wb"); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file_out, "Failed to open %s for writing", model_file_path_out); + // Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to the get file length. + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(in_file), &in_stat) == 0, "failed to stat file %s", in_path); - rwkv_file_guard file_in_guard { file_in }; - rwkv_file_guard file_out_guard { file_out }; + struct rwkv_file_header in_header; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, rwkv_fread_file_header(in_file, in_header), "Invalid file header"); - // Process header - { - uint32_t magic, version; - int32_t n_vocab, n_embed, n_layer, data_type; + enum ggml_type in_type = rwkv_type_to_ggml[in_header.data_type]; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, in_type == GGML_TYPE_F32 || in_type == GGML_TYPE_F16, "Unsupported input data type (%s); needs to be f32 or f16", rwkv_type_to_string[rwkv_type_from_ggml[in_type]]); - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_uint32(file_in, &magic, "magic")); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_MAGIC, magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic); - - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_uint32(file_in, &version, "version")); - RWKV_ASSERT_FALSE_MSG( - RWKV_ERROR_FILE | RWKV_ERROR_FILE_VERSION, - version >= RWKV_FILE_VERSION_MIN && version <= RWKV_FILE_VERSION_MAX, - "Unsupported file version %d", - version - ); - - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_vocab, "n_vocab")); - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_embed, "n_embed")); - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_layer, "n_layer")); - - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &data_type, "data_type")); - RWKV_ASSERT_FALSE_MSG( - RWKV_ERROR_FILE | RWKV_ERROR_DATA_TYPE, - is_non_quantized_format_type(data_type), - "Unsupported data type %d, only FP32 and FP16 can be quantized", - data_type - ); - - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_uint32(file_out, magic, "magic")); - // Always write latest version number when saving files - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_uint32(file_out, RWKV_FILE_VERSION_MAX, "version")); - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_vocab, "n_vocab")); - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_embed, "n_embed")); - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_layer, "n_layer")); - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, format_data_type, "data_type")); - } + struct rwkv_file_header out_header = in_header; + out_header.version = RWKV_FILE_VERSION; + out_header.data_type = rwkv_type_from_ggml[out_type]; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, rwkv_fwrite_file_header(out_file, out_header), "Failed to write file header"); // Process parameters - size_t total_size_orig = 0; - size_t total_size_new = 0; + size_t orig_total_size = 0; + size_t new_total_size = 0; - std::vector work; + // Required to init the fp16 tables + // Doesn't crash if ggml_init fails + ggml_free(ggml_init({ 0, NULL, true })); - std::vector data_u8; - std::vector data_f16; - std::vector data_f32; + size_t max_in_size = 0; + size_t max_out_size = 0; + size_t max_key_length = 0; - std::vector hist_all(1 << 4, 0); + while (ftell(in_file) < in_stat.st_size) { + struct rwkv_tensor_header header; + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, rwkv_fread_tensor_header_and_skip(in_file, header)); - while (true) { - int32_t n_dims, key_length, parameter_data_type; - RWKV_ASSERT_FALSE_MSG( - RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_FILE_READ, - fread(&n_dims, sizeof(int32_t), 1, file_in) == 1 || feof(file_in), - "Failed to read an int32 value from a file (n_dims)" - ); + size_t in_size = rwkv_tensor_size(header); - if (feof(file_in)) { - break; + if (in_size > max_in_size) { + max_in_size = in_size; } - RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &key_length, "key_length")); - RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, ¶meter_data_type, "parameter_data_type")); + // f16 type tensors get relocated to out and then converted into f32 at in + if (header.data_type == TYPE_F16) { + if (in_size > max_out_size) { + max_out_size = in_size; + } - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, n_dims == 1 || n_dims == 2, "Unsupported dimension count %d", n_dims); - RWKV_ASSERT_FALSE_MSG( - RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, - parameter_data_type >= 0 && parameter_data_type < FORMAT_TYPE_COUNT, - "Unsupported parameter data type %d", - parameter_data_type - ); + size_t f32_size = rwkv_tensor_size(GGML_TYPE_F32, header.width, header.height); - ggml_type parameter_ggml_type = FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type]; - RWKV_ASSERT_FALSE_MSG( - RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, - parameter_ggml_type != GGML_TYPE_UNKNOWN, - "Unsupported parameter data type %d", - parameter_data_type - ); - - int32_t nelements, x, y; - - if (n_dims == 1) { - RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &x, "x")); - y = 1; - nelements = x; - } else { - RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &x, "x")); - RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &y, "y")); - nelements = x * y; + if (f32_size > max_in_size) { + max_in_size = f32_size; + } } - std::string name(key_length, 0); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_KEY, fread(&name[0], key_length, 1, file_in) == 1, "Failed to read parameter key"); + size_t out_size = rwkv_tensor_size(out_type, header.width, header.height); - printf("%48s - [%5d, %5d], type = %6s ", name.data(), x, y, ggml_type_name(parameter_ggml_type)); - total_size_orig += (size_t) (nelements * ggml_type_sizef(parameter_ggml_type)); + if (out_size > max_out_size) { + max_out_size = out_size; + } + + if (header.key_length > max_key_length) { + max_key_length = header.key_length; + } + } + + rewind(in_file); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(in_file, sizeof(struct rwkv_file_header), SEEK_CUR) == 0); + + // This is a histogram of quantized values. If it shows single 1.0, then all 0.0, something went very wrong! + int64_t hist_all[16] {}; + + std::unique_ptr scratch(new(std::nothrow) uint8_t [max_in_size + max_out_size]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, scratch.get(), "Failed to allocate buffer"); + + uint8_t * in_buf = scratch.get(); + uint8_t * out_buf = in_buf + max_in_size; + + struct rwkv_tensor tensor; + struct rwkv_tensor_header & header = tensor.header; + std::string & name = tensor.name; + uint8_t *& data = tensor.data; + + while (ftell(in_file) < in_stat.st_size) { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_tensor_header(in_file, header), "Failed to read tensor header"); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(in_file, header.key_length, name), "Failed to read tensor name"); + + const char * name_str = name.c_str(); + RWKV_MSG("%*s - [%5" PRId32 ", %5" PRId32 "], type = %6s ", (int) max_key_length, name_str, header.width, header.height, rwkv_type_to_string[header.data_type]); + + data = header.data_type == TYPE_F16 ? out_buf : in_buf; + size_t orig_size = rwkv_tensor_size(header), new_size = orig_size; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_data(in_file, orig_size, data), "\nFailed to read tensor data of %s", name_str); // Quantize only 2D tensors, except embedding and head matrices. // Embedding and head take not too much space, especially in bigger models; // but they significantly increase perplexity when quantized. - bool quantize = n_dims == 2 && name != "emb.weight" && name != "head.weight"; + if ((header.data_type == TYPE_F32 || header.data_type == TYPE_F16) && header.dim_count == 2 && name != "emb.weight" && name != "head.weight") { + RWKV_MSG("quantizing... "); - if (quantize) { - RWKV_ASSERT_FALSE_MSG( - RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA_TYPE, - parameter_ggml_type == GGML_TYPE_F32 || parameter_data_type == GGML_TYPE_F16, - "Unsupported parameter data type %d, only FP32 and FP16 can be quantized", - parameter_ggml_type - ); + size_t nelements = (size_t) header.width * (size_t) header.height; - data_f32.resize(nelements); - - if (parameter_data_type == GGML_TYPE_F16) { - data_f16.resize(nelements); - RWKV_ASSERT_FALSE_MSG( - RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, - fread(data_f16.data(), nelements * sizeof(ggml_fp16_t), 1, file_in) == 1, - "Failed to read parameter data" - ); - - for (int i = 0; i < nelements; ++i) { - data_f32[i] = ggml_fp16_to_fp32(data_f16[i]); - } - } else { - RWKV_ASSERT_FALSE_MSG( - RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, - fread(data_f32.data(), nelements * sizeof(float), 1, file_in) == 1, - "Failed to read parameter data" - ); + if (header.data_type == TYPE_F16) { + ggml_fp16_to_fp32_row((const ggml_fp16_t *) out_buf, (float *) in_buf, nelements); } - parameter_data_type = format_data_type; - parameter_ggml_type = format_ggml_type; - } else { - const size_t element_size = ggml_type_size(parameter_ggml_type); - data_u8.resize(nelements * element_size); - RWKV_ASSERT_FALSE_MSG( - RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, - fread(data_u8.data(), nelements * element_size, 1, file_in) == 1, - "Failed to read parameter data" - ); - } + int64_t hist_cur[16] {}; + new_size = ggml_quantize_chunk(out_type, (const float *) in_buf, out_buf, 0, nelements, hist_cur); + header.data_type = rwkv_type_from_ggml[out_type]; + data = out_buf; - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_dims, "n_dims")); - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, key_length, "key_length")); - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, parameter_data_type, "parameter_data_type")); - - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, x, "x")); + RWKV_MSG("size = %8.2f MB -> %8.2f MB | hist: ", orig_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); - if (n_dims == 2) { - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, y, "y")); - } - - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_WRITE, fwrite(&name[0], key_length, 1, file_out) == 1, "Failed to write parameter key"); - - if (quantize) { - printf("quantizing... "); - // For quantization - work.resize(nelements); - - // This is a histogram of quantized values. If it shows single 1.0, then all 0.0, something went very wrong! - std::vector hist_cur(1 << 4, 0); - - size_t (*f)(const float * src, void * dst, int n, int k, int64_t * hist) = - format_ggml_type == GGML_TYPE_Q4_0 ? ggml_quantize_q4_0 : - format_ggml_type == GGML_TYPE_Q4_1 ? ggml_quantize_q4_1 : - format_ggml_type == GGML_TYPE_Q5_0 ? ggml_quantize_q5_0 : - format_ggml_type == GGML_TYPE_Q5_1 ? ggml_quantize_q5_1 : - format_ggml_type == GGML_TYPE_Q8_0 ? ggml_quantize_q8_0 : - NULL; - - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_UNSUPPORTED, f, "Unsupported quantization type %d\n", format_ggml_type); - - size_t cur_size = (*f)(data_f32.data(), work.data(), nelements, x, hist_cur.data()); - total_size_new += cur_size; - - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_WRITE, fwrite(work.data(), cur_size, 1, file_out) == 1, "Failed to write parameter data"); - - printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float) / 1024.0 / 1024.0, cur_size / 1024.0 / 1024.0); - - for (int i = 0; i < (int) hist_cur.size(); ++i) { + for (int i = 0; i < 16; i++) { + RWKV_MSG("%5.3f ", hist_cur[i] / (float) nelements); hist_all[i] += hist_cur[i]; } - for (int i = 0; i < (int) hist_cur.size(); ++i) { - printf("%5.3f ", hist_cur[i] / float(nelements)); - } - - printf("\n"); + RWKV_MSG("\n"); } else { - printf("size = %8.3f MB\n", data_u8.size() / 1024.0 / 1024.0); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_WRITE, fwrite(data_u8.data(), data_u8.size(), 1, file_out) == 1, "Failed to write parameter data"); - total_size_new += data_u8.size(); + RWKV_MSG("size = %8.3f MB\n", orig_size / 1024.0 / 1024.0); } + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_tensor(out_file, tensor), "Failed to write tensor %s", name_str); + orig_total_size += orig_size; + new_total_size += orig_size; } - printf("original size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0); - printf("quantized size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0); - printf("compression ratio = %8.2f\n", 1.0 * total_size_orig / total_size_new); + RWKV_MSG("original size = %8.2f MB\n", orig_total_size / 1024.0 / 1024.0); + RWKV_MSG("quantized size = %8.2f MB\n", new_total_size / 1024.0 / 1024.0); + RWKV_MSG("compression ratio = %8.2f\n", orig_total_size / float(new_total_size)); int64_t sum_all = 0; - for (int i = 0; i < (int) hist_all.size(); ++i) { + for (int i = 0; i < 16; i++) { sum_all += hist_all[i]; } - printf("hist: "); + RWKV_MSG("hist: "); - for (int i = 0; i < (int) hist_all.size(); ++i) { + for (int i = 0; i < 16; ++i) { printf("%5.3f ", hist_all[i] / float(sum_all)); } - printf("\n"); + RWKV_MSG("\n"); return true; } @@ -957,18 +1231,18 @@ const char * rwkv_get_system_info_string(void) { static std::string s; s = ""; - s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; - s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; - s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; - s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; - s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; - s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; - s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; - s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; - s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; - s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; - s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; - s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; + s += "AVX=" + std::to_string(ggml_cpu_has_avx()) + " "; + s += "AVX2=" + std::to_string(ggml_cpu_has_avx2()) + " "; + s += "AVX512=" + std::to_string(ggml_cpu_has_avx512()) + " "; + s += "FMA=" + std::to_string(ggml_cpu_has_fma()) + " "; + s += "NEON=" + std::to_string(ggml_cpu_has_neon()) + " "; + s += "ARM_FMA=" + std::to_string(ggml_cpu_has_arm_fma()) + " "; + s += "F16C=" + std::to_string(ggml_cpu_has_f16c()) + " "; + s += "FP16_VA=" + std::to_string(ggml_cpu_has_fp16_va()) + " "; + s += "WASM_SIMD=" + std::to_string(ggml_cpu_has_wasm_simd()) + " "; + s += "BLAS=" + std::to_string(ggml_cpu_has_blas()) + " "; + s += "SSE3=" + std::to_string(ggml_cpu_has_sse3()) + " "; + s += "VSX=" + std::to_string(ggml_cpu_has_vsx()); return s.c_str(); } diff --git a/rwkv.h b/rwkv.h index ff50447..5e8c756 100644 --- a/rwkv.h +++ b/rwkv.h @@ -83,8 +83,11 @@ extern "C" { // Returns NULL on any error. Error messages would be printed to stderr. // - model_file_path: path to model file in ggml format. // - n_threads: count of threads to use, must be positive. - // - n_gpu_layer: count of layers need to load to gpu (only works when cuBLAS is on) - RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads, const uint32_t n_gpu_layers); + RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads); + + // Offloads specified layers of context onto GPU using cuBLAS, if it is enabled. + // If rwkv.cpp was compiled without cuBLAS support, this function is a no-op. + RWKV_API bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers); // Evaluates the model for a single token. // Returns false on any error. Error messages would be printed to stderr. diff --git a/rwkv/rwkv_cpp_model.py b/rwkv/rwkv_cpp_model.py index c108c8a..dd195ef 100644 --- a/rwkv/rwkv_cpp_model.py +++ b/rwkv/rwkv_cpp_model.py @@ -32,11 +32,14 @@ class RWKVModel: assert os.path.isfile(model_path), f'{model_path} is not a file' assert thread_count > 0, 'Thread count must be positive' - assert gpu_layers_count > 0, 'GPU layers count must be positive' + assert gpu_layers_count >= 0, 'GPU layers count must be >= 0' self._library = shared_library - self._ctx = self._library.rwkv_init_from_file(model_path, thread_count, gpu_layers_count) + self._ctx = self._library.rwkv_init_from_file(model_path, thread_count) + + if gpu_layers_count > 0: + self._library.rwkv_gpu_offload_layers(self._ctx, gpu_layers_count) self._state_buffer_element_count = self._library.rwkv_get_state_buffer_element_count(self._ctx) self._logits_buffer_element_count = self._library.rwkv_get_logits_buffer_element_count(self._ctx) diff --git a/rwkv/rwkv_cpp_shared_library.py b/rwkv/rwkv_cpp_shared_library.py index 56e4afb..0641ec4 100644 --- a/rwkv/rwkv_cpp_shared_library.py +++ b/rwkv/rwkv_cpp_shared_library.py @@ -37,9 +37,12 @@ class RWKVSharedLibrary: self.library = ctypes.cdll.LoadLibrary(shared_library_path) - self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32, ctypes.c_uint32] + self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32] self.library.rwkv_init_from_file.restype = ctypes.c_void_p + self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32] + self.library.rwkv_gpu_offload_layers.restype = ctypes.c_bool + self.library.rwkv_eval.argtypes = [ ctypes.c_void_p, # ctx ctypes.c_int32, # token @@ -67,7 +70,7 @@ class RWKVSharedLibrary: self.library.rwkv_get_system_info_string.argtypes = [] self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p - def rwkv_init_from_file(self, model_file_path: str, thread_count: int, gpu_layers_count: int) -> RWKVContext: + def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext: """ Loads the model from a file and prepares it for inference. Throws an exception in case of any error. Error messages would be printed to stderr. @@ -83,11 +86,23 @@ class RWKVSharedLibrary: """ ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), - ctypes.c_uint32(thread_count), - ctypes.c_uint32(gpu_layers_count)) + ctypes.c_uint32(thread_count)) assert ptr is not None, 'rwkv_init_from_file failed, check stderr' return RWKVContext(ptr) + def rwkv_gpu_offload_layers(self, ctx: RWKVContext, gpu_layers_count: int) -> None: + """ + Offloads specified layers of context onto GPU using cuBLAS, if it is enabled. + If rwkv.cpp was compiled without cuBLAS support, this function is a no-op. + + Parameters + ---------- + gpu_layers_count : int + Count of layers to load onto gpu, must be >= 0, only enabled with cuBLAS. + """ + + assert self.library.rwkv_gpu_offload_layers(ctx.ptr, ctypes.c_uint32(gpu_layers_count)), 'rwkv_gpu_offload_layers failed, check stderr' + def rwkv_eval( self, ctx: RWKVContext, diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index e8085df..286e528 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -26,9 +26,12 @@ void test_model(const char * model_path, const float * expected_logits, const float max_diff) { fprintf(stderr, "Testing %s\n", model_path); - struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS, N_GPU_LAYERS); + struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS); enum rwkv_error_flags error = rwkv_get_last_error(NULL); ASSERT(error == 0, "Unexpected error %d", error); +#ifdef GGML_USE_CUBLAS + ASSERT(rwkv_gpu_offload_layers(model, N_GPU_LAYERS), "Unexpected error %d", rwkv_get_last_error(model)); +#endif uint32_t n_vocab = rwkv_get_logits_buffer_element_count(model);