From dea929f8cad90b7cf2f820c5a3d6653cfdd58c4e Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 27 May 2023 16:02:24 +0500 Subject: [PATCH] Various improvements & upgrade ggml (#75) * Use types from typing for better compatibility with older Python versions * Split last double end of line token as per BlinkDL's suggestion * Fix MSVC warnings * Drop Q4_2 support * Update ggml * Bump file format version for quantization changes * Apply suggestions --- CODE_STYLE.md | 34 +++++++ FILE_FORMAT.md | 19 +++- README.md | 34 +++++-- ggml | 2 +- rwkv.cpp | 137 +++++++++++++++++++++------ rwkv.h | 13 ++- rwkv/chat_with_bot.py | 23 +++-- rwkv/convert_pytorch_to_ggml.py | 3 +- rwkv/convert_pytorch_to_ggml.test.py | 2 +- rwkv/merge_lora_into_ggml.py | 11 ++- rwkv/quantize.py | 4 +- rwkv/rwkv_cpp_shared_library.py | 1 - tests/test_tiny_rwkv.c | 24 ++--- 13 files changed, 230 insertions(+), 77 deletions(-) create mode 100644 CODE_STYLE.md diff --git a/CODE_STYLE.md b/CODE_STYLE.md new file mode 100644 index 0000000..6ad8d90 --- /dev/null +++ b/CODE_STYLE.md @@ -0,0 +1,34 @@ +# Code Style + +Please follow this code style when contributing to `rwkv.cpp`. + +This list is not complete. + +## General + +Overall, keep code in similar style as it was before. + +- Keep lines at 180 characters or shorter. +- Separate logically grouped pieces of code with empty lines. +- Surround `if`, `for`, `while`, `do` and other similar statements with empty lines. +- Write documentation for public functions indended for outside use. +- Place single-line comments on the line before, not right after the code line. +- Start comments with a capital letter, use correct grammar and punctuation. + +## C/C++ + +- Use 4 spaces for indentation. +- Use [The One True Brace Style](https://en.wikipedia.org/wiki/Indentation_style#Variant:_1TBS_(OTBS)): + - Place braces on the same line as the statement. + - Always add braces to `if`, `for`, `while`, `do` and other similar statements. + +## Python + +- Use 2 spaces for indentation. +- Specify types for functions and parameters. + - For `void` functions, specify `-> None`. +- Specifying types for local variables: + - required, if they are global + - required, if they are compound (lists, dicts, optionals, etc.) + - optional otherwise. +- Use types from `typing` (`List`, `Dict`) instead of built-in (`list`, `dict`). diff --git a/FILE_FORMAT.md b/FILE_FORMAT.md index 3d16d06..4a0f8d1 100644 --- a/FILE_FORMAT.md +++ b/FILE_FORMAT.md @@ -11,7 +11,8 @@ RWKVModelFile { // All ints and floats are in machine byte order. // Magic is "ggml" string bytes. int32 magic = 0x67676d66; - int32 version = 100; + // Can be either 100 or 101. See "File versions" section below for details. + int32 version = 101; int32 n_vocab; int32 n_embed; int32 n_layer; @@ -39,6 +40,20 @@ Parameter { } ``` +## File versions + +### `100` + +Original version number, chosen to not interfere with `llama.cpp` file version number of `1`. + +### `101` + +Introduced on 2023-05-27, as `ggml` was updated to commit [00b49ec](https://github.com/ggerganov/ggml/commit/00b49ec707d73df0176e21630a6e23c2aa0e938c). + +All quantized formats (`QX_Y`) were changed in a backwards-incompatible way: new version of `ggml` can not handle loading version `100` quantized models. + +`FP32` and `FP16` remain the same. + ## Data types - 0: `FP32` @@ -46,7 +61,7 @@ Parameter { - 2: `Q4_0` - 3: `Q4_1` - 4: *unused* -- 5: `Q4_2` +- 5: *unused* - 6: *unused* - 7: `Q5_0` - 8: `Q5_1` diff --git a/README.md b/README.md index b5c09f4..ad60d29 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). -Besides the usual **FP32**, it supports **FP16**, **quantized INT4** and **quantized INT8** inference. This project is **CPU only**. +Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT8** inference. This project is **CPU only**. This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](rwkv%2Frwkv_cpp_model.py) for it. @@ -20,7 +20,6 @@ Below table is for reference only. Measurements were made on 4C/8T x86 CPU with |-----------|-------------------|--------------------|----------------------| | `Q4_0` | 17.507 | *76* | **1.53** | | `Q4_1` | 17.187 | **72** | 1.68 | -| `Q4_2` | 17.060 | 85 | **1.53** | | `Q5_0` | 16.194 | 78 | *1.60* | | `Q5_1` | 15.851 | 81 | 1.68 | | `Q8_0` | *15.652* | 89 | 2.13 | @@ -105,10 +104,10 @@ python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-802 ```commandline # Windows -python rwkv\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q4_2.bin Q4_2 +python rwkv\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q5_1.bin Q5_1 # Linux / MacOS -python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q4_2.bin Q4_2 +python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q5_1.bin Q5_1 ``` ### 4. Run the model @@ -121,20 +120,20 @@ To generate some text, run: ```commandline # Windows -python rwkv\generate_completions.py C:\rwkv.cpp-169M-Q4_2.bin +python rwkv\generate_completions.py C:\rwkv.cpp-169M-Q5_1.bin # Linux / MacOS -python rwkv/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q4_2.bin +python rwkv/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q5_1.bin ``` To chat with a bot, run: ```commandline # Windows -python rwkv\chat_with_bot.py C:\rwkv.cpp-169M-Q4_2.bin +python rwkv\chat_with_bot.py C:\rwkv.cpp-169M-Q5_1.bin # Linux / MacOS -python rwkv/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q4_2.bin +python rwkv/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q5_1.bin ``` Edit [generate_completions.py](rwkv%2Fgenerate_completions.py) or [chat_with_bot.py](rwkv%2Fchat_with_bot.py) to change prompts and sampling settings. @@ -167,3 +166,22 @@ for token in [1, 2, 3]: model.free() ``` + +## Compatibility + +`ggml` moves fast, and can occasionally break compatibility with older file formats. + +`rwkv.cpp` will attempt it's best to explain why a model file can't be loaded and what next steps are available to the user. + +For reference only, here is a list of latest versions of `rwkv.cpp` that have supported older formats. **No support will be provided for these versions**. + +- `Q4_2`, old layout of quantized formats + - [commit 3ca9c7f](https://github.com/saharNooby/rwkv.cpp/commit/3ca9c7f7857a4b9f3de616ec938e71249cfb3f3f), [release with prebuilt binaries](https://github.com/saharNooby/rwkv.cpp/releases/tag/master-3ca9c7f) +- `Q4_3`, `Q4_1_O` + - [commit c736ef5](https://github.com/saharNooby/rwkv.cpp/commit/c736ef5411606b529d3a74c139ee111ef1a28bb9), [release with prebuilt binaries](https://github.com/saharNooby/rwkv.cpp/releases/tag/master-1c363e6) + +See also [FILE_FORMAT.md](FILE_FORMAT.md) for version numbers of `rwkv.cpp` model files and their changelog. + +## Contributing + +There is no complete contributor guide yet; but we have [CODE_STYLE.md](CODE_STYLE.md). diff --git a/ggml b/ggml index ff6e03c..00b49ec 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit ff6e03cbcd9bf6e9fa41d49f2495c042efae4dc6 +Subproject commit 00b49ec707d73df0176e21630a6e23c2aa0e938c diff --git a/rwkv.cpp b/rwkv.cpp index 27055fc..4268992 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -116,17 +116,28 @@ static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[FORMAT_TYPE_COUNT] = { GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_UNKNOWN, // Unused - GGML_TYPE_Q4_2, + GGML_TYPE_UNKNOWN, // Unused GGML_TYPE_UNKNOWN, // Unused GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0 }; +static bool is_non_quantized_format_type(int32_t format_type) { + return format_type == 0 || format_type == 1; +} + +static bool is_quantized_format_type(int32_t format_type) { + return format_type == 2 || + format_type == 3 || + format_type == 7 || + format_type == 8 || + format_type == 9; +} + static int32_t format_name_to_format_type(const char * format_name) { if (strcmp(format_name, "Q4_0") == 0) return 2; if (strcmp(format_name, "Q4_1") == 0) return 3; - if (strcmp(format_name, "Q4_2") == 0) return 5; if (strcmp(format_name, "Q5_0") == 0) return 7; if (strcmp(format_name, "Q5_1") == 0) return 8; if (strcmp(format_name, "Q8_0") == 0) return 9; @@ -371,7 +382,6 @@ bool rwkv_build_graph(struct ggml_context * ctx, struct rwkv_model * model, cons // state[5 * i + 2] = e1 * aa + e2 * v // state[5 * i + 3] = e1 * bb + e2 // state[5 * i + 4] = qq - state_parts[part_index + 1] = x0; state_parts[part_index + 2] = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v)); state_parts[part_index + 3] = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2); @@ -428,8 +438,9 @@ bool rwkv_build_graph(struct ggml_context * ctx, struct rwkv_model * model, cons ggml_build_forward_expand(cgraph.get(), logits); - for (uint32_t i = 0; i < n_layer * 5; i++) + for (uint32_t i = 0; i < n_layer * 5; i++) { ggml_build_forward_expand(cgraph.get(), state_parts[i]); + } out->state = state; out->state_parts = std::move(state_parts); @@ -456,6 +467,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file, "Failed to open file %s", file_path); rwkv_file_guard file_guard { file }; + // Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to the get file length. struct stat64 file_stat; RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat64(fileno(file), &file_stat) == 0, "Failed to stat file %s", file_path); @@ -465,7 +477,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t int32_t version; RWKV_ASSERT_NULL(RWKV_ERROR_FILE, read_int32(file, &version, "version")); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_VERSION, version == RWKV_FILE_VERSION, "Unsupported file version %d", version); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_VERSION, version >= RWKV_FILE_VERSION_MIN && version <= RWKV_FILE_VERSION_MAX, "Unsupported file version %d", version); std::unique_ptr model(new(std::nothrow) struct rwkv_model()); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_ALLOC, model.get(), "Failed to allocate model"); @@ -475,10 +487,22 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_layer, "n_layer")); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_int32(file, &model->data_type, "data_type")); - const char * unsupported_msg = "Models in %s format cannot be loaded anymore because the format was removed. You need to quantize the model into another format"; RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_DATA_TYPE, model->data_type >= 0 && model->data_type < FORMAT_TYPE_COUNT, "Unsupported model data type %d", model->data_type); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 4, unsupported_msg, "Q4_1_O"); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 6, unsupported_msg, "Q4_3"); + + const char * unsupported_type_msg = "Models in %s format cannot be loaded anymore because the format was removed.\n" + "You need to quantize the model into another format or use an older version of rwkv.cpp.\n" + "See https://github.com/saharNooby/rwkv.cpp#compatibility for more info"; + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 4, unsupported_type_msg, "Q4_1_O"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 5, unsupported_type_msg, "Q4_2"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 6, unsupported_type_msg, "Q4_3"); + + RWKV_ASSERT_NULL_MSG( + RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, + !is_quantized_format_type(model->data_type) || version >= RWKV_FILE_VERSION_1, + "The quantized model file was created with an old version of rwkv.cpp and can not be loaded anymore.\n" + "You need to requantize the model or use an older version of rwkv.cpp.\n" + "See https://github.com/saharNooby/rwkv.cpp#compatibility for more info" + ); size_t memory_required = file_stat.st_size + // Intermediary vectors for calculation; there are around 100 calls to ggml @@ -499,8 +523,16 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t while (true) { int32_t dim_count, key_length, data_type; - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_FILE_READ, fread(&dim_count, sizeof(int32_t), 1, file) == 1 || feof(file), "Failed to read an int32 value from a file (dim_count)"); - if (feof(file)) break; + RWKV_ASSERT_NULL_MSG( + RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_FILE_READ, + fread(&dim_count, sizeof(int32_t), 1, file) == 1 || feof(file), + "Failed to read an int32 value from a file (dim_count)" + ); + + if (feof(file)) { + break; + } + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, read_int32(file, &key_length, "key_length")); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, read_int32(file, &data_type, "data_type")); @@ -542,6 +574,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(¶meters, "blocks.0.ln0.bias", &model->ln0_bias)); model->layers.resize(model->n_layer); + for (uint32_t i = 0; i < model->n_layer; i++) { rwkv_layer * layer = &model->layers[i]; RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(¶meters, i, "ln1.weight", &layer->ln1_weight)); @@ -577,9 +610,6 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]); - size_t n_embed = model->n_embed; - size_t n_layer = model->n_layer; - // Build graph struct rwkv_graph graph; RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_graph(ctx, model.get(), n_threads, &graph)); @@ -591,7 +621,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t rwkv_ctx->graph = std::move(graph); rwkv_ctx->last_error = RWKV_ERROR_NONE; rwkv_ctx->print_errors = global_print_errors; - ggml_guard.ctx = NULL; // don't free ggml context + // Don't free ggml context + ggml_guard.ctx = NULL; return rwkv_ctx.release(); } @@ -676,16 +707,30 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_uint32(file_in, &magic, "magic")); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_MAGIC, magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_uint32(file_in, &version, "version")); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_VERSION, version == RWKV_FILE_VERSION, "Unsupported file version %d", version); + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_FILE | RWKV_ERROR_FILE_VERSION, + version >= RWKV_FILE_VERSION_MIN && version <= RWKV_FILE_VERSION_MAX, + "Unsupported file version %d", + version + ); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_vocab, "n_vocab")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_embed, "n_embed")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_layer, "n_layer")); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &data_type, "data_type")); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_DATA_TYPE, data_type == 0 || data_type == 1, "Unsupported data type %d, only FP32 and FP16 can be quantized", data_type); + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_FILE | RWKV_ERROR_DATA_TYPE, + is_non_quantized_format_type(data_type), + "Unsupported data type %d, only FP32 and FP16 can be quantized", + data_type + ); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_uint32(file_out, magic, "magic")); - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_uint32(file_out, version, "version")); + // Always write latest version number when saving files + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_uint32(file_out, RWKV_FILE_VERSION_MAX, "version")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_vocab, "n_vocab")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_embed, "n_embed")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_layer, "n_layer")); @@ -706,16 +751,34 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode while (true) { int32_t n_dims, key_length, parameter_data_type; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_FILE_READ, fread(&n_dims, sizeof(int32_t), 1, file_in) == 1 || feof(file_in), "Failed to read an int32 value from a file (n_dims)"); - if (feof(file_in)) break; + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_FILE_READ, + fread(&n_dims, sizeof(int32_t), 1, file_in) == 1 || feof(file_in), + "Failed to read an int32 value from a file (n_dims)" + ); + + if (feof(file_in)) { + break; + } + RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &key_length, "key_length")); RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, ¶meter_data_type, "parameter_data_type")); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, n_dims == 1 || n_dims == 2, "Unsupported dimension count %d", n_dims); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, parameter_data_type >= 0 && parameter_data_type < FORMAT_TYPE_COUNT, "Unsupported parameter data type %d", parameter_data_type); + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, + parameter_data_type >= 0 && parameter_data_type < FORMAT_TYPE_COUNT, + "Unsupported parameter data type %d", + parameter_data_type + ); ggml_type parameter_ggml_type = FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type]; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, parameter_ggml_type != GGML_TYPE_UNKNOWN, "Unsupported parameter data type %d", parameter_data_type); + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, + parameter_ggml_type != GGML_TYPE_UNKNOWN, + "Unsupported parameter data type %d", + parameter_data_type + ); int32_t nelements, x, y; @@ -752,12 +815,21 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode if (parameter_data_type == GGML_TYPE_F16) { data_f16.resize(nelements); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, fread(data_f16.data(), nelements * sizeof(ggml_fp16_t), 1, file_in) == 1, "Failed to read parameter data"); + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, + fread(data_f16.data(), nelements * sizeof(ggml_fp16_t), 1, file_in) == 1, + "Failed to read parameter data" + ); - for (int i = 0; i < nelements; ++i) + for (int i = 0; i < nelements; ++i) { data_f32[i] = ggml_fp16_to_fp32(data_f16[i]); + } } else { - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, fread(data_f32.data(), nelements * sizeof(float), 1, file_in) == 1, "Failed to read parameter data"); + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, + fread(data_f32.data(), nelements * sizeof(float), 1, file_in) == 1, + "Failed to read parameter data" + ); } parameter_data_type = format_data_type; @@ -765,7 +837,11 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode } else { const size_t element_size = ggml_type_size(parameter_ggml_type); data_u8.resize(nelements * element_size); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, fread(data_u8.data(), nelements * element_size, 1, file_in) == 1, "Failed to read parameter data"); + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, + fread(data_u8.data(), nelements * element_size, 1, file_in) == 1, + "Failed to read parameter data" + ); } RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_dims, "n_dims")); @@ -774,28 +850,29 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, x, "x")); - if (n_dims == 2) + if (n_dims == 2) { RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, y, "y")); + } RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_WRITE, fwrite(&name[0], key_length, 1, file_out) == 1, "Failed to write parameter key"); if (quantize) { printf("quantizing... "); - work.resize(nelements); // for quantization + // For quantization + work.resize(nelements); - // This is a histogramm of some values. If it shows single 1.0, then all 0.0, something went very wrong! + // This is a histogram of quantized values. If it shows single 1.0, then all 0.0, something went very wrong! std::vector hist_cur(1 << 4, 0); size_t (*f)(const float * src, void * dst, int n, int k, int64_t * hist) = format_ggml_type == GGML_TYPE_Q4_0 ? ggml_quantize_q4_0 : format_ggml_type == GGML_TYPE_Q4_1 ? ggml_quantize_q4_1 : - format_ggml_type == GGML_TYPE_Q4_2 ? ggml_quantize_q4_2 : format_ggml_type == GGML_TYPE_Q5_0 ? ggml_quantize_q5_0 : format_ggml_type == GGML_TYPE_Q5_1 ? ggml_quantize_q5_1 : format_ggml_type == GGML_TYPE_Q8_0 ? ggml_quantize_q8_0 : NULL; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_UNSUPPORTED, f, "unsupported quantization type %d\n", format_ggml_type); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_UNSUPPORTED, f, "Unsupported quantization type %d\n", format_ggml_type); size_t cur_size = (*f)(data_f32.data(), work.data(), nelements, x, hist_cur.data()); total_size_new += cur_size; diff --git a/rwkv.h b/rwkv.h index b43cf77..539e655 100644 --- a/rwkv.h +++ b/rwkv.h @@ -21,7 +21,13 @@ // 'ggmf' in hex. #define RWKV_FILE_MAGIC 0x67676d66 -#define RWKV_FILE_VERSION 100 + +#define RWKV_FILE_VERSION_0 100 +#define RWKV_FILE_VERSION_1 101 +#define RWKV_FILE_VERSION_MIN RWKV_FILE_VERSION_0 +#define RWKV_FILE_VERSION_MAX RWKV_FILE_VERSION_1 +// Default file version is the latest version. +#define RWKV_FILE_VERSION RWKV_FILE_VERSION_MAX #ifdef __cplusplus extern "C" { @@ -55,6 +61,8 @@ extern "C" { RWKV_ERROR_PARAM_MISSING = 14 }; + struct rwkv_context; + // Sets whether errors are automatically printed to stderr. // If this is set to false, you are responsible for calling rwkv_last_error manually if an operation fails. // - ctx: the context to suppress error messages for. @@ -71,8 +79,6 @@ extern "C" { // - ctx: the context the retrieve the error for, or NULL for the global error. RWKV_API enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx); - struct rwkv_context; - // Loads the model from a file and prepares it for inference. // Returns NULL on any error. Error messages would be printed to stderr. // - model_file_path: path to model file in ggml format. @@ -104,7 +110,6 @@ extern "C" { // Available format names: // - Q4_0 // - Q4_1 - // - Q4_2 // - Q5_0 // - Q5_1 // - Q8_0 diff --git a/rwkv/chat_with_bot.py b/rwkv/chat_with_bot.py index 6ae0ce0..3ebef41 100644 --- a/rwkv/chat_with_bot.py +++ b/rwkv/chat_with_bot.py @@ -12,7 +12,7 @@ import tokenizers import rwkv_cpp_model import rwkv_cpp_shared_library import json -from typing import Optional +from typing import List, Dict, Optional # ======================================== Script settings ======================================== @@ -34,6 +34,7 @@ PRESENCE_PENALTY: float = 0.2 FREQUENCY_PENALTY: float = 0.2 END_OF_LINE_TOKEN: int = 187 +DOUBLE_END_OF_LINE_TOKEN: int = 535 END_OF_TEXT_TOKEN: int = 0 # ================================================================================================= @@ -66,11 +67,11 @@ prompt_token_count = len(prompt_tokens) # ================================================================================================= -processed_tokens: list[int] = [] +processed_tokens: List[int] = [] logits: Optional[torch.Tensor] = None state: Optional[torch.Tensor] = None -def process_tokens(_tokens: list[int], new_line_logit_bias: float = 0.0) -> None: +def process_tokens(_tokens: List[int], new_line_logit_bias: float = 0.0) -> None: global processed_tokens, logits, state processed_tokens += _tokens @@ -80,7 +81,7 @@ def process_tokens(_tokens: list[int], new_line_logit_bias: float = 0.0) -> None logits[END_OF_LINE_TOKEN] += new_line_logit_bias -state_by_thread: dict[str, dict] = {} +state_by_thread: Dict[str, Dict] = {} def save_thread_state(_thread: str) -> None: state_by_thread[_thread] = { @@ -98,11 +99,19 @@ def load_thread_state(_thread: str) -> None: logits = copy.deepcopy(thread_state['logits']) state = copy.deepcopy(thread_state['state']) +# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end. +# See https://github.com/BlinkDL/ChatRWKV/pull/110/files +def split_last_end_of_line(tokens): + if len(tokens) > 0 and tokens[-1] == DOUBLE_END_OF_LINE_TOKEN: + tokens = tokens[:-1] + [END_OF_LINE_TOKEN, END_OF_LINE_TOKEN] + + return tokens + # ================================================================================================= print(f'Processing {prompt_token_count} prompt tokens, may take a while') -process_tokens(tokenizer.encode(init_prompt).ids) +process_tokens(split_last_end_of_line(tokenizer.encode(init_prompt).ids)) save_thread_state('chat_init') save_thread_state('chat') @@ -226,8 +235,8 @@ Below is an instruction that describes a task. Write a response that appropriate print(f'> {bot}{separator}', end='') start_index: int = len(processed_tokens) - accumulated_tokens: list[int] = [] - token_counts: dict[int, int] = {} + accumulated_tokens: List[int] = [] + token_counts: Dict[int, int] = {} for i in range(MAX_GENERATION_LENGTH): for n in token_counts: diff --git a/rwkv/convert_pytorch_to_ggml.py b/rwkv/convert_pytorch_to_ggml.py index e132fd5..2ea4a48 100644 --- a/rwkv/convert_pytorch_to_ggml.py +++ b/rwkv/convert_pytorch_to_ggml.py @@ -38,8 +38,7 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t '=iiiiii', # Magic: 'ggmf' in hex 0x67676d66, - # llama.cpp uses file versions 1+, let's use 100+ for rwkv.cpp - 100, + 101, n_vocab, n_embed, n_layer, diff --git a/rwkv/convert_pytorch_to_ggml.test.py b/rwkv/convert_pytorch_to_ggml.test.py index 5578506..9ced1d0 100644 --- a/rwkv/convert_pytorch_to_ggml.test.py +++ b/rwkv/convert_pytorch_to_ggml.test.py @@ -21,7 +21,7 @@ def test() -> None: expected_bytes: bytes = struct.pack( '=iiiiii' + 'iiiii10sffffff' + 'iiii19sf', 0x67676d66, - 100, + 101, 3, 2, 1, diff --git a/rwkv/merge_lora_into_ggml.py b/rwkv/merge_lora_into_ggml.py index e7d9523..6f8ce2d 100644 --- a/rwkv/merge_lora_into_ggml.py +++ b/rwkv/merge_lora_into_ggml.py @@ -8,6 +8,7 @@ import argparse import struct import torch import numpy as np +from typing import List, Dict, Tuple def parse_args(): parser = argparse.ArgumentParser(description='Merge a PyTorch LoRA checkpoint (.pth) into an rwkv.cpp model file') @@ -45,16 +46,16 @@ def main() -> None: print(f'Reading {args.lora_path}') - lora_state_dict: dict[str, torch.Tensor] = torch.load(args.lora_path, map_location='cpu') + lora_state_dict: Dict[str, torch.Tensor] = torch.load(args.lora_path, map_location='cpu') print(f'Merging') with open(args.src_path, 'rb') as in_file, open(args.dest_path, 'wb') as out_file: # noinspection PyTypeChecker - header: tuple[int, int, int, int, int, int] = struct.unpack('=iiiiii', in_file.read(6 * 4)) + header: Tuple[int, int, int, int, int, int] = struct.unpack('=iiiiii', in_file.read(6 * 4)) assert header[0] == 0x67676d66, 'Invalid magic value' - assert header[1] == 100, 'Invalid version number' + assert 100 <= header[1] <= 101, 'Invalid version number' assert header[5] == 0 or header[5] == 1, 'Only FP32 and FP16 models are supported' out_file.write(struct.pack('=iiiiii', *header)) @@ -68,9 +69,9 @@ def main() -> None: dim_count, key_length, data_type = struct.unpack('=iii', parameter_header_bytes) # noinspection PyTypeChecker - shape: tuple[int] = struct.unpack('=' + 'i' * dim_count, in_file.read(dim_count * 4)) + shape: Tuple[int] = struct.unpack('=' + 'i' * dim_count, in_file.read(dim_count * 4)) # ggml order to PyTorch - shape: list[int] = [d for d in reversed(shape)] + shape: List[int] = [d for d in reversed(shape)] key: str = in_file.read(key_length).decode('utf-8') diff --git a/rwkv/quantize.py b/rwkv/quantize.py index 239e576..305573b 100644 --- a/rwkv/quantize.py +++ b/rwkv/quantize.py @@ -1,6 +1,6 @@ # Quantizes rwkv.cpp model file from FP32 or FP16. # Available format names are in rwkv_cpp_shared_library.QUANTIZED_FORMAT_NAMES -# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-FP32.bin C:\rwkv.cpp-169M-Q4_2.bin Q4_2 +# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-FP32.bin C:\rwkv.cpp-169M-Q5_1.bin Q5_1 import argparse import rwkv_cpp_shared_library @@ -11,7 +11,7 @@ def parse_args(): parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16') parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file') parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten') - parser.add_argument('format_name', help='Format name, one of ' + ', '.join(format_names), type=str, choices=format_names, default='Q4_2') + parser.add_argument('format_name', help='Format name, one of ' + ', '.join(format_names), type=str, choices=format_names, default='Q5_1') return parser.parse_args() def main() -> None: diff --git a/rwkv/rwkv_cpp_shared_library.py b/rwkv/rwkv_cpp_shared_library.py index 1414710..2004361 100644 --- a/rwkv/rwkv_cpp_shared_library.py +++ b/rwkv/rwkv_cpp_shared_library.py @@ -7,7 +7,6 @@ from typing import Optional QUANTIZED_FORMAT_NAMES = ( 'Q4_0', 'Q4_1', - 'Q4_2', 'Q5_0', 'Q5_1', 'Q8_0' diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index 750b1ed..1244877 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -26,7 +26,9 @@ void test_model(const char * model_path, const float * expected_logits, const fl fprintf(stderr, "Testing %s\n", model_path); struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS); + enum rwkv_error_flags error = rwkv_get_last_error(NULL); + ASSERT(error == 0, "Unexpected error %d", error); uint32_t n_vocab = rwkv_get_logits_buffer_element_count(model); @@ -76,14 +78,12 @@ int main(void) { -0.160030F, -0.370606F, - 0.661480F, -0.170404F, 0.278034F, 0.071216F, 0.154614F, -0.372169F, - 0.658310F, -0.170043F, 0.294953F, 0.065571F, @@ -94,31 +94,27 @@ int main(void) { rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_0.bin", "Q4_0"); rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_1.bin", "Q4_1"); - rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_2.bin", "Q4_2"); rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q5_0.bin", "Q5_0"); rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q5_1.bin", "Q5_1"); rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q8_0.bin", "Q8_0"); test_model("tiny-rwkv-660K-FP32-Q4_0.bin", expected_logits, expected_difference_sum[2]); test_model("tiny-rwkv-660K-FP32-Q4_1.bin", expected_logits, expected_difference_sum[3]); - test_model("tiny-rwkv-660K-FP32-Q4_2.bin", expected_logits, expected_difference_sum[4]); - test_model("tiny-rwkv-660K-FP32-Q5_0.bin", expected_logits, expected_difference_sum[5]); - test_model("tiny-rwkv-660K-FP32-Q5_1.bin", expected_logits, expected_difference_sum[6]); - test_model("tiny-rwkv-660K-FP32-Q8_0.bin", expected_logits, expected_difference_sum[7]); + test_model("tiny-rwkv-660K-FP32-Q5_0.bin", expected_logits, expected_difference_sum[4]); + test_model("tiny-rwkv-660K-FP32-Q5_1.bin", expected_logits, expected_difference_sum[5]); + test_model("tiny-rwkv-660K-FP32-Q8_0.bin", expected_logits, expected_difference_sum[6]); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_0.bin", "Q4_0"); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_1.bin", "Q4_1"); - rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_2.bin", "Q4_2"); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q5_0.bin", "Q5_0"); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q5_1.bin", "Q5_1"); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q8_0.bin", "Q8_0"); - test_model("tiny-rwkv-660K-FP16-Q4_0.bin", expected_logits, expected_difference_sum[8]); - test_model("tiny-rwkv-660K-FP16-Q4_1.bin", expected_logits, expected_difference_sum[9]); - test_model("tiny-rwkv-660K-FP16-Q4_2.bin", expected_logits, expected_difference_sum[10]); - test_model("tiny-rwkv-660K-FP16-Q5_0.bin", expected_logits, expected_difference_sum[11]); - test_model("tiny-rwkv-660K-FP16-Q5_1.bin", expected_logits, expected_difference_sum[12]); - test_model("tiny-rwkv-660K-FP16-Q8_0.bin", expected_logits, expected_difference_sum[13]); + test_model("tiny-rwkv-660K-FP16-Q4_0.bin", expected_logits, expected_difference_sum[7]); + test_model("tiny-rwkv-660K-FP16-Q4_1.bin", expected_logits, expected_difference_sum[8]); + test_model("tiny-rwkv-660K-FP16-Q5_0.bin", expected_logits, expected_difference_sum[9]); + test_model("tiny-rwkv-660K-FP16-Q5_1.bin", expected_logits, expected_difference_sum[10]); + test_model("tiny-rwkv-660K-FP16-Q8_0.bin", expected_logits, expected_difference_sum[11]); free(expected_logits);