Various improvements & upgrade ggml (#75)

* Use types from typing for better compatibility with older Python versions

* Split last double end of line token as per BlinkDL's suggestion

* Fix MSVC warnings

* Drop Q4_2 support

* Update ggml

* Bump file format version for quantization changes

* Apply suggestions
This commit is contained in:
Alex 2023-05-27 16:02:24 +05:00 committed by GitHub
parent 3ca9c7f785
commit dea929f8ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 230 additions and 77 deletions

34
CODE_STYLE.md Normal file
View File

@ -0,0 +1,34 @@
# Code Style
Please follow this code style when contributing to `rwkv.cpp`.
This list is not complete.
## General
Overall, keep code in similar style as it was before.
- Keep lines at 180 characters or shorter.
- Separate logically grouped pieces of code with empty lines.
- Surround `if`, `for`, `while`, `do` and other similar statements with empty lines.
- Write documentation for public functions indended for outside use.
- Place single-line comments on the line before, not right after the code line.
- Start comments with a capital letter, use correct grammar and punctuation.
## C/C++
- Use 4 spaces for indentation.
- Use [The One True Brace Style](https://en.wikipedia.org/wiki/Indentation_style#Variant:_1TBS_(OTBS)):
- Place braces on the same line as the statement.
- Always add braces to `if`, `for`, `while`, `do` and other similar statements.
## Python
- Use 2 spaces for indentation.
- Specify types for functions and parameters.
- For `void` functions, specify `-> None`.
- Specifying types for local variables:
- required, if they are global
- required, if they are compound (lists, dicts, optionals, etc.)
- optional otherwise.
- Use types from `typing` (`List`, `Dict`) instead of built-in (`list`, `dict`).

View File

@ -11,7 +11,8 @@ RWKVModelFile {
// All ints and floats are in machine byte order. // All ints and floats are in machine byte order.
// Magic is "ggml" string bytes. // Magic is "ggml" string bytes.
int32 magic = 0x67676d66; int32 magic = 0x67676d66;
int32 version = 100; // Can be either 100 or 101. See "File versions" section below for details.
int32 version = 101;
int32 n_vocab; int32 n_vocab;
int32 n_embed; int32 n_embed;
int32 n_layer; int32 n_layer;
@ -39,6 +40,20 @@ Parameter {
} }
``` ```
## File versions
### `100`
Original version number, chosen to not interfere with `llama.cpp` file version number of `1`.
### `101`
Introduced on 2023-05-27, as `ggml` was updated to commit [00b49ec](https://github.com/ggerganov/ggml/commit/00b49ec707d73df0176e21630a6e23c2aa0e938c).
All quantized formats (`QX_Y`) were changed in a backwards-incompatible way: new version of `ggml` can not handle loading version `100` quantized models.
`FP32` and `FP16` remain the same.
## Data types ## Data types
- 0: `FP32` - 0: `FP32`
@ -46,7 +61,7 @@ Parameter {
- 2: `Q4_0` - 2: `Q4_0`
- 3: `Q4_1` - 3: `Q4_1`
- 4: *unused* - 4: *unused*
- 5: `Q4_2` - 5: *unused*
- 6: *unused* - 6: *unused*
- 7: `Q5_0` - 7: `Q5_0`
- 8: `Q5_1` - 8: `Q5_1`

View File

@ -2,7 +2,7 @@
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml).
Besides the usual **FP32**, it supports **FP16**, **quantized INT4** and **quantized INT8** inference. This project is **CPU only**. Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT8** inference. This project is **CPU only**.
This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](rwkv%2Frwkv_cpp_model.py) for it. This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](rwkv%2Frwkv_cpp_model.py) for it.
@ -20,7 +20,6 @@ Below table is for reference only. Measurements were made on 4C/8T x86 CPU with
|-----------|-------------------|--------------------|----------------------| |-----------|-------------------|--------------------|----------------------|
| `Q4_0` | 17.507 | *76* | **1.53** | | `Q4_0` | 17.507 | *76* | **1.53** |
| `Q4_1` | 17.187 | **72** | 1.68 | | `Q4_1` | 17.187 | **72** | 1.68 |
| `Q4_2` | 17.060 | 85 | **1.53** |
| `Q5_0` | 16.194 | 78 | *1.60* | | `Q5_0` | 16.194 | 78 | *1.60* |
| `Q5_1` | 15.851 | 81 | 1.68 | | `Q5_1` | 15.851 | 81 | 1.68 |
| `Q8_0` | *15.652* | 89 | 2.13 | | `Q8_0` | *15.652* | 89 | 2.13 |
@ -105,10 +104,10 @@ python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-802
```commandline ```commandline
# Windows # Windows
python rwkv\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q4_2.bin Q4_2 python rwkv\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q5_1.bin Q5_1
# Linux / MacOS # Linux / MacOS
python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q4_2.bin Q4_2 python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q5_1.bin Q5_1
``` ```
### 4. Run the model ### 4. Run the model
@ -121,20 +120,20 @@ To generate some text, run:
```commandline ```commandline
# Windows # Windows
python rwkv\generate_completions.py C:\rwkv.cpp-169M-Q4_2.bin python rwkv\generate_completions.py C:\rwkv.cpp-169M-Q5_1.bin
# Linux / MacOS # Linux / MacOS
python rwkv/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q4_2.bin python rwkv/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q5_1.bin
``` ```
To chat with a bot, run: To chat with a bot, run:
```commandline ```commandline
# Windows # Windows
python rwkv\chat_with_bot.py C:\rwkv.cpp-169M-Q4_2.bin python rwkv\chat_with_bot.py C:\rwkv.cpp-169M-Q5_1.bin
# Linux / MacOS # Linux / MacOS
python rwkv/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q4_2.bin python rwkv/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q5_1.bin
``` ```
Edit [generate_completions.py](rwkv%2Fgenerate_completions.py) or [chat_with_bot.py](rwkv%2Fchat_with_bot.py) to change prompts and sampling settings. Edit [generate_completions.py](rwkv%2Fgenerate_completions.py) or [chat_with_bot.py](rwkv%2Fchat_with_bot.py) to change prompts and sampling settings.
@ -167,3 +166,22 @@ for token in [1, 2, 3]:
model.free() model.free()
``` ```
## Compatibility
`ggml` moves fast, and can occasionally break compatibility with older file formats.
`rwkv.cpp` will attempt it's best to explain why a model file can't be loaded and what next steps are available to the user.
For reference only, here is a list of latest versions of `rwkv.cpp` that have supported older formats. **No support will be provided for these versions**.
- `Q4_2`, old layout of quantized formats
- [commit 3ca9c7f](https://github.com/saharNooby/rwkv.cpp/commit/3ca9c7f7857a4b9f3de616ec938e71249cfb3f3f), [release with prebuilt binaries](https://github.com/saharNooby/rwkv.cpp/releases/tag/master-3ca9c7f)
- `Q4_3`, `Q4_1_O`
- [commit c736ef5](https://github.com/saharNooby/rwkv.cpp/commit/c736ef5411606b529d3a74c139ee111ef1a28bb9), [release with prebuilt binaries](https://github.com/saharNooby/rwkv.cpp/releases/tag/master-1c363e6)
See also [FILE_FORMAT.md](FILE_FORMAT.md) for version numbers of `rwkv.cpp` model files and their changelog.
## Contributing
There is no complete contributor guide yet; but we have [CODE_STYLE.md](CODE_STYLE.md).

2
ggml

@ -1 +1 @@
Subproject commit ff6e03cbcd9bf6e9fa41d49f2495c042efae4dc6 Subproject commit 00b49ec707d73df0176e21630a6e23c2aa0e938c

137
rwkv.cpp
View File

@ -116,17 +116,28 @@ static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[FORMAT_TYPE_COUNT] = {
GGML_TYPE_Q4_0, GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1, GGML_TYPE_Q4_1,
GGML_TYPE_UNKNOWN, // Unused GGML_TYPE_UNKNOWN, // Unused
GGML_TYPE_Q4_2, GGML_TYPE_UNKNOWN, // Unused
GGML_TYPE_UNKNOWN, // Unused GGML_TYPE_UNKNOWN, // Unused
GGML_TYPE_Q5_0, GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0 GGML_TYPE_Q8_0
}; };
static bool is_non_quantized_format_type(int32_t format_type) {
return format_type == 0 || format_type == 1;
}
static bool is_quantized_format_type(int32_t format_type) {
return format_type == 2 ||
format_type == 3 ||
format_type == 7 ||
format_type == 8 ||
format_type == 9;
}
static int32_t format_name_to_format_type(const char * format_name) { static int32_t format_name_to_format_type(const char * format_name) {
if (strcmp(format_name, "Q4_0") == 0) return 2; if (strcmp(format_name, "Q4_0") == 0) return 2;
if (strcmp(format_name, "Q4_1") == 0) return 3; if (strcmp(format_name, "Q4_1") == 0) return 3;
if (strcmp(format_name, "Q4_2") == 0) return 5;
if (strcmp(format_name, "Q5_0") == 0) return 7; if (strcmp(format_name, "Q5_0") == 0) return 7;
if (strcmp(format_name, "Q5_1") == 0) return 8; if (strcmp(format_name, "Q5_1") == 0) return 8;
if (strcmp(format_name, "Q8_0") == 0) return 9; if (strcmp(format_name, "Q8_0") == 0) return 9;
@ -371,7 +382,6 @@ bool rwkv_build_graph(struct ggml_context * ctx, struct rwkv_model * model, cons
// state[5 * i + 2] = e1 * aa + e2 * v // state[5 * i + 2] = e1 * aa + e2 * v
// state[5 * i + 3] = e1 * bb + e2 // state[5 * i + 3] = e1 * bb + e2
// state[5 * i + 4] = qq // state[5 * i + 4] = qq
state_parts[part_index + 1] = x0; state_parts[part_index + 1] = x0;
state_parts[part_index + 2] = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v)); state_parts[part_index + 2] = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v));
state_parts[part_index + 3] = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2); state_parts[part_index + 3] = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2);
@ -428,8 +438,9 @@ bool rwkv_build_graph(struct ggml_context * ctx, struct rwkv_model * model, cons
ggml_build_forward_expand(cgraph.get(), logits); ggml_build_forward_expand(cgraph.get(), logits);
for (uint32_t i = 0; i < n_layer * 5; i++) for (uint32_t i = 0; i < n_layer * 5; i++) {
ggml_build_forward_expand(cgraph.get(), state_parts[i]); ggml_build_forward_expand(cgraph.get(), state_parts[i]);
}
out->state = state; out->state = state;
out->state_parts = std::move(state_parts); out->state_parts = std::move(state_parts);
@ -456,6 +467,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file, "Failed to open file %s", file_path); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file, "Failed to open file %s", file_path);
rwkv_file_guard file_guard { file }; rwkv_file_guard file_guard { file };
// Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to the get file length.
struct stat64 file_stat; struct stat64 file_stat;
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat64(fileno(file), &file_stat) == 0, "Failed to stat file %s", file_path); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat64(fileno(file), &file_stat) == 0, "Failed to stat file %s", file_path);
@ -465,7 +477,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
int32_t version; int32_t version;
RWKV_ASSERT_NULL(RWKV_ERROR_FILE, read_int32(file, &version, "version")); RWKV_ASSERT_NULL(RWKV_ERROR_FILE, read_int32(file, &version, "version"));
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_VERSION, version == RWKV_FILE_VERSION, "Unsupported file version %d", version); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_VERSION, version >= RWKV_FILE_VERSION_MIN && version <= RWKV_FILE_VERSION_MAX, "Unsupported file version %d", version);
std::unique_ptr<rwkv_model> model(new(std::nothrow) struct rwkv_model()); std::unique_ptr<rwkv_model> model(new(std::nothrow) struct rwkv_model());
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_ALLOC, model.get(), "Failed to allocate model"); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_ALLOC, model.get(), "Failed to allocate model");
@ -475,10 +487,22 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_layer, "n_layer")); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_layer, "n_layer"));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_int32(file, &model->data_type, "data_type")); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_int32(file, &model->data_type, "data_type"));
const char * unsupported_msg = "Models in %s format cannot be loaded anymore because the format was removed. You need to quantize the model into another format";
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_DATA_TYPE, model->data_type >= 0 && model->data_type < FORMAT_TYPE_COUNT, "Unsupported model data type %d", model->data_type); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_DATA_TYPE, model->data_type >= 0 && model->data_type < FORMAT_TYPE_COUNT, "Unsupported model data type %d", model->data_type);
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 4, unsupported_msg, "Q4_1_O");
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 6, unsupported_msg, "Q4_3"); const char * unsupported_type_msg = "Models in %s format cannot be loaded anymore because the format was removed.\n"
"You need to quantize the model into another format or use an older version of rwkv.cpp.\n"
"See https://github.com/saharNooby/rwkv.cpp#compatibility for more info";
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 4, unsupported_type_msg, "Q4_1_O");
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 5, unsupported_type_msg, "Q4_2");
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 6, unsupported_type_msg, "Q4_3");
RWKV_ASSERT_NULL_MSG(
RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED,
!is_quantized_format_type(model->data_type) || version >= RWKV_FILE_VERSION_1,
"The quantized model file was created with an old version of rwkv.cpp and can not be loaded anymore.\n"
"You need to requantize the model or use an older version of rwkv.cpp.\n"
"See https://github.com/saharNooby/rwkv.cpp#compatibility for more info"
);
size_t memory_required = file_stat.st_size + size_t memory_required = file_stat.st_size +
// Intermediary vectors for calculation; there are around 100 calls to ggml // Intermediary vectors for calculation; there are around 100 calls to ggml
@ -499,8 +523,16 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
while (true) { while (true) {
int32_t dim_count, key_length, data_type; int32_t dim_count, key_length, data_type;
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_FILE_READ, fread(&dim_count, sizeof(int32_t), 1, file) == 1 || feof(file), "Failed to read an int32 value from a file (dim_count)"); RWKV_ASSERT_NULL_MSG(
if (feof(file)) break; RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_FILE_READ,
fread(&dim_count, sizeof(int32_t), 1, file) == 1 || feof(file),
"Failed to read an int32 value from a file (dim_count)"
);
if (feof(file)) {
break;
}
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, read_int32(file, &key_length, "key_length")); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, read_int32(file, &key_length, "key_length"));
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, read_int32(file, &data_type, "data_type")); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, read_int32(file, &data_type, "data_type"));
@ -542,6 +574,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(&parameters, "blocks.0.ln0.bias", &model->ln0_bias)); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_parameter(&parameters, "blocks.0.ln0.bias", &model->ln0_bias));
model->layers.resize(model->n_layer); model->layers.resize(model->n_layer);
for (uint32_t i = 0; i < model->n_layer; i++) { for (uint32_t i = 0; i < model->n_layer; i++) {
rwkv_layer * layer = &model->layers[i]; rwkv_layer * layer = &model->layers[i];
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "ln1.weight", &layer->ln1_weight)); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS, set_block_parameter(&parameters, i, "ln1.weight", &layer->ln1_weight));
@ -577,9 +610,6 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]);
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]);
size_t n_embed = model->n_embed;
size_t n_layer = model->n_layer;
// Build graph // Build graph
struct rwkv_graph graph; struct rwkv_graph graph;
RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_graph(ctx, model.get(), n_threads, &graph)); RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_graph(ctx, model.get(), n_threads, &graph));
@ -591,7 +621,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
rwkv_ctx->graph = std::move(graph); rwkv_ctx->graph = std::move(graph);
rwkv_ctx->last_error = RWKV_ERROR_NONE; rwkv_ctx->last_error = RWKV_ERROR_NONE;
rwkv_ctx->print_errors = global_print_errors; rwkv_ctx->print_errors = global_print_errors;
ggml_guard.ctx = NULL; // don't free ggml context // Don't free ggml context
ggml_guard.ctx = NULL;
return rwkv_ctx.release(); return rwkv_ctx.release();
} }
@ -676,16 +707,30 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_uint32(file_in, &magic, "magic")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_uint32(file_in, &magic, "magic"));
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_MAGIC, magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_MAGIC, magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic);
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_uint32(file_in, &version, "version")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_uint32(file_in, &version, "version"));
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_VERSION, version == RWKV_FILE_VERSION, "Unsupported file version %d", version); RWKV_ASSERT_FALSE_MSG(
RWKV_ERROR_FILE | RWKV_ERROR_FILE_VERSION,
version >= RWKV_FILE_VERSION_MIN && version <= RWKV_FILE_VERSION_MAX,
"Unsupported file version %d",
version
);
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_vocab, "n_vocab")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_vocab, "n_vocab"));
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_embed, "n_embed")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_embed, "n_embed"));
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_layer, "n_layer")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &n_layer, "n_layer"));
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &data_type, "data_type")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, read_int32(file_in, &data_type, "data_type"));
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_DATA_TYPE, data_type == 0 || data_type == 1, "Unsupported data type %d, only FP32 and FP16 can be quantized", data_type); RWKV_ASSERT_FALSE_MSG(
RWKV_ERROR_FILE | RWKV_ERROR_DATA_TYPE,
is_non_quantized_format_type(data_type),
"Unsupported data type %d, only FP32 and FP16 can be quantized",
data_type
);
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_uint32(file_out, magic, "magic")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_uint32(file_out, magic, "magic"));
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_uint32(file_out, version, "version")); // Always write latest version number when saving files
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_uint32(file_out, RWKV_FILE_VERSION_MAX, "version"));
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_vocab, "n_vocab")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_vocab, "n_vocab"));
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_embed, "n_embed")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_embed, "n_embed"));
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_layer, "n_layer")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_layer, "n_layer"));
@ -706,16 +751,34 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
while (true) { while (true) {
int32_t n_dims, key_length, parameter_data_type; int32_t n_dims, key_length, parameter_data_type;
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_FILE_READ, fread(&n_dims, sizeof(int32_t), 1, file_in) == 1 || feof(file_in), "Failed to read an int32 value from a file (n_dims)"); RWKV_ASSERT_FALSE_MSG(
if (feof(file_in)) break; RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_FILE_READ,
fread(&n_dims, sizeof(int32_t), 1, file_in) == 1 || feof(file_in),
"Failed to read an int32 value from a file (n_dims)"
);
if (feof(file_in)) {
break;
}
RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &key_length, "key_length")); RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &key_length, "key_length"));
RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &parameter_data_type, "parameter_data_type")); RWKV_ASSERT_FALSE(RWKV_ERROR_MODEL_PARAMS, read_int32(file_in, &parameter_data_type, "parameter_data_type"));
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, n_dims == 1 || n_dims == 2, "Unsupported dimension count %d", n_dims); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, n_dims == 1 || n_dims == 2, "Unsupported dimension count %d", n_dims);
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, parameter_data_type >= 0 && parameter_data_type < FORMAT_TYPE_COUNT, "Unsupported parameter data type %d", parameter_data_type); RWKV_ASSERT_FALSE_MSG(
RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED,
parameter_data_type >= 0 && parameter_data_type < FORMAT_TYPE_COUNT,
"Unsupported parameter data type %d",
parameter_data_type
);
ggml_type parameter_ggml_type = FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type]; ggml_type parameter_ggml_type = FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type];
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED, parameter_ggml_type != GGML_TYPE_UNKNOWN, "Unsupported parameter data type %d", parameter_data_type); RWKV_ASSERT_FALSE_MSG(
RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_UNSUPPORTED,
parameter_ggml_type != GGML_TYPE_UNKNOWN,
"Unsupported parameter data type %d",
parameter_data_type
);
int32_t nelements, x, y; int32_t nelements, x, y;
@ -752,12 +815,21 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
if (parameter_data_type == GGML_TYPE_F16) { if (parameter_data_type == GGML_TYPE_F16) {
data_f16.resize(nelements); data_f16.resize(nelements);
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, fread(data_f16.data(), nelements * sizeof(ggml_fp16_t), 1, file_in) == 1, "Failed to read parameter data"); RWKV_ASSERT_FALSE_MSG(
RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA,
fread(data_f16.data(), nelements * sizeof(ggml_fp16_t), 1, file_in) == 1,
"Failed to read parameter data"
);
for (int i = 0; i < nelements; ++i) for (int i = 0; i < nelements; ++i) {
data_f32[i] = ggml_fp16_to_fp32(data_f16[i]); data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
}
} else { } else {
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, fread(data_f32.data(), nelements * sizeof(float), 1, file_in) == 1, "Failed to read parameter data"); RWKV_ASSERT_FALSE_MSG(
RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA,
fread(data_f32.data(), nelements * sizeof(float), 1, file_in) == 1,
"Failed to read parameter data"
);
} }
parameter_data_type = format_data_type; parameter_data_type = format_data_type;
@ -765,7 +837,11 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
} else { } else {
const size_t element_size = ggml_type_size(parameter_ggml_type); const size_t element_size = ggml_type_size(parameter_ggml_type);
data_u8.resize(nelements * element_size); data_u8.resize(nelements * element_size);
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA, fread(data_u8.data(), nelements * element_size, 1, file_in) == 1, "Failed to read parameter data"); RWKV_ASSERT_FALSE_MSG(
RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DATA,
fread(data_u8.data(), nelements * element_size, 1, file_in) == 1,
"Failed to read parameter data"
);
} }
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_dims, "n_dims")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, n_dims, "n_dims"));
@ -774,28 +850,29 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, x, "x")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, x, "x"));
if (n_dims == 2) if (n_dims == 2) {
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, y, "y")); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, write_int32(file_out, y, "y"));
}
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_WRITE, fwrite(&name[0], key_length, 1, file_out) == 1, "Failed to write parameter key"); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_WRITE, fwrite(&name[0], key_length, 1, file_out) == 1, "Failed to write parameter key");
if (quantize) { if (quantize) {
printf("quantizing... "); printf("quantizing... ");
work.resize(nelements); // for quantization // For quantization
work.resize(nelements);
// This is a histogramm of some values. If it shows single 1.0, then all 0.0, something went very wrong! // This is a histogram of quantized values. If it shows single 1.0, then all 0.0, something went very wrong!
std::vector<int64_t> hist_cur(1 << 4, 0); std::vector<int64_t> hist_cur(1 << 4, 0);
size_t (*f)(const float * src, void * dst, int n, int k, int64_t * hist) = size_t (*f)(const float * src, void * dst, int n, int k, int64_t * hist) =
format_ggml_type == GGML_TYPE_Q4_0 ? ggml_quantize_q4_0 : format_ggml_type == GGML_TYPE_Q4_0 ? ggml_quantize_q4_0 :
format_ggml_type == GGML_TYPE_Q4_1 ? ggml_quantize_q4_1 : format_ggml_type == GGML_TYPE_Q4_1 ? ggml_quantize_q4_1 :
format_ggml_type == GGML_TYPE_Q4_2 ? ggml_quantize_q4_2 :
format_ggml_type == GGML_TYPE_Q5_0 ? ggml_quantize_q5_0 : format_ggml_type == GGML_TYPE_Q5_0 ? ggml_quantize_q5_0 :
format_ggml_type == GGML_TYPE_Q5_1 ? ggml_quantize_q5_1 : format_ggml_type == GGML_TYPE_Q5_1 ? ggml_quantize_q5_1 :
format_ggml_type == GGML_TYPE_Q8_0 ? ggml_quantize_q8_0 : format_ggml_type == GGML_TYPE_Q8_0 ? ggml_quantize_q8_0 :
NULL; NULL;
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_UNSUPPORTED, f, "unsupported quantization type %d\n", format_ggml_type); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_UNSUPPORTED, f, "Unsupported quantization type %d\n", format_ggml_type);
size_t cur_size = (*f)(data_f32.data(), work.data(), nelements, x, hist_cur.data()); size_t cur_size = (*f)(data_f32.data(), work.data(), nelements, x, hist_cur.data());
total_size_new += cur_size; total_size_new += cur_size;

13
rwkv.h
View File

@ -21,7 +21,13 @@
// 'ggmf' in hex. // 'ggmf' in hex.
#define RWKV_FILE_MAGIC 0x67676d66 #define RWKV_FILE_MAGIC 0x67676d66
#define RWKV_FILE_VERSION 100
#define RWKV_FILE_VERSION_0 100
#define RWKV_FILE_VERSION_1 101
#define RWKV_FILE_VERSION_MIN RWKV_FILE_VERSION_0
#define RWKV_FILE_VERSION_MAX RWKV_FILE_VERSION_1
// Default file version is the latest version.
#define RWKV_FILE_VERSION RWKV_FILE_VERSION_MAX
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
@ -55,6 +61,8 @@ extern "C" {
RWKV_ERROR_PARAM_MISSING = 14 RWKV_ERROR_PARAM_MISSING = 14
}; };
struct rwkv_context;
// Sets whether errors are automatically printed to stderr. // Sets whether errors are automatically printed to stderr.
// If this is set to false, you are responsible for calling rwkv_last_error manually if an operation fails. // If this is set to false, you are responsible for calling rwkv_last_error manually if an operation fails.
// - ctx: the context to suppress error messages for. // - ctx: the context to suppress error messages for.
@ -71,8 +79,6 @@ extern "C" {
// - ctx: the context the retrieve the error for, or NULL for the global error. // - ctx: the context the retrieve the error for, or NULL for the global error.
RWKV_API enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx); RWKV_API enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx);
struct rwkv_context;
// Loads the model from a file and prepares it for inference. // Loads the model from a file and prepares it for inference.
// Returns NULL on any error. Error messages would be printed to stderr. // Returns NULL on any error. Error messages would be printed to stderr.
// - model_file_path: path to model file in ggml format. // - model_file_path: path to model file in ggml format.
@ -104,7 +110,6 @@ extern "C" {
// Available format names: // Available format names:
// - Q4_0 // - Q4_0
// - Q4_1 // - Q4_1
// - Q4_2
// - Q5_0 // - Q5_0
// - Q5_1 // - Q5_1
// - Q8_0 // - Q8_0

View File

@ -12,7 +12,7 @@ import tokenizers
import rwkv_cpp_model import rwkv_cpp_model
import rwkv_cpp_shared_library import rwkv_cpp_shared_library
import json import json
from typing import Optional from typing import List, Dict, Optional
# ======================================== Script settings ======================================== # ======================================== Script settings ========================================
@ -34,6 +34,7 @@ PRESENCE_PENALTY: float = 0.2
FREQUENCY_PENALTY: float = 0.2 FREQUENCY_PENALTY: float = 0.2
END_OF_LINE_TOKEN: int = 187 END_OF_LINE_TOKEN: int = 187
DOUBLE_END_OF_LINE_TOKEN: int = 535
END_OF_TEXT_TOKEN: int = 0 END_OF_TEXT_TOKEN: int = 0
# ================================================================================================= # =================================================================================================
@ -66,11 +67,11 @@ prompt_token_count = len(prompt_tokens)
# ================================================================================================= # =================================================================================================
processed_tokens: list[int] = [] processed_tokens: List[int] = []
logits: Optional[torch.Tensor] = None logits: Optional[torch.Tensor] = None
state: Optional[torch.Tensor] = None state: Optional[torch.Tensor] = None
def process_tokens(_tokens: list[int], new_line_logit_bias: float = 0.0) -> None: def process_tokens(_tokens: List[int], new_line_logit_bias: float = 0.0) -> None:
global processed_tokens, logits, state global processed_tokens, logits, state
processed_tokens += _tokens processed_tokens += _tokens
@ -80,7 +81,7 @@ def process_tokens(_tokens: list[int], new_line_logit_bias: float = 0.0) -> None
logits[END_OF_LINE_TOKEN] += new_line_logit_bias logits[END_OF_LINE_TOKEN] += new_line_logit_bias
state_by_thread: dict[str, dict] = {} state_by_thread: Dict[str, Dict] = {}
def save_thread_state(_thread: str) -> None: def save_thread_state(_thread: str) -> None:
state_by_thread[_thread] = { state_by_thread[_thread] = {
@ -98,11 +99,19 @@ def load_thread_state(_thread: str) -> None:
logits = copy.deepcopy(thread_state['logits']) logits = copy.deepcopy(thread_state['logits'])
state = copy.deepcopy(thread_state['state']) state = copy.deepcopy(thread_state['state'])
# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end.
# See https://github.com/BlinkDL/ChatRWKV/pull/110/files
def split_last_end_of_line(tokens):
if len(tokens) > 0 and tokens[-1] == DOUBLE_END_OF_LINE_TOKEN:
tokens = tokens[:-1] + [END_OF_LINE_TOKEN, END_OF_LINE_TOKEN]
return tokens
# ================================================================================================= # =================================================================================================
print(f'Processing {prompt_token_count} prompt tokens, may take a while') print(f'Processing {prompt_token_count} prompt tokens, may take a while')
process_tokens(tokenizer.encode(init_prompt).ids) process_tokens(split_last_end_of_line(tokenizer.encode(init_prompt).ids))
save_thread_state('chat_init') save_thread_state('chat_init')
save_thread_state('chat') save_thread_state('chat')
@ -226,8 +235,8 @@ Below is an instruction that describes a task. Write a response that appropriate
print(f'> {bot}{separator}', end='') print(f'> {bot}{separator}', end='')
start_index: int = len(processed_tokens) start_index: int = len(processed_tokens)
accumulated_tokens: list[int] = [] accumulated_tokens: List[int] = []
token_counts: dict[int, int] = {} token_counts: Dict[int, int] = {}
for i in range(MAX_GENERATION_LENGTH): for i in range(MAX_GENERATION_LENGTH):
for n in token_counts: for n in token_counts:

View File

@ -38,8 +38,7 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
'=iiiiii', '=iiiiii',
# Magic: 'ggmf' in hex # Magic: 'ggmf' in hex
0x67676d66, 0x67676d66,
# llama.cpp uses file versions 1+, let's use 100+ for rwkv.cpp 101,
100,
n_vocab, n_vocab,
n_embed, n_embed,
n_layer, n_layer,

View File

@ -21,7 +21,7 @@ def test() -> None:
expected_bytes: bytes = struct.pack( expected_bytes: bytes = struct.pack(
'=iiiiii' + 'iiiii10sffffff' + 'iiii19sf', '=iiiiii' + 'iiiii10sffffff' + 'iiii19sf',
0x67676d66, 0x67676d66,
100, 101,
3, 3,
2, 2,
1, 1,

View File

@ -8,6 +8,7 @@ import argparse
import struct import struct
import torch import torch
import numpy as np import numpy as np
from typing import List, Dict, Tuple
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Merge a PyTorch LoRA checkpoint (.pth) into an rwkv.cpp model file') parser = argparse.ArgumentParser(description='Merge a PyTorch LoRA checkpoint (.pth) into an rwkv.cpp model file')
@ -45,16 +46,16 @@ def main() -> None:
print(f'Reading {args.lora_path}') print(f'Reading {args.lora_path}')
lora_state_dict: dict[str, torch.Tensor] = torch.load(args.lora_path, map_location='cpu') lora_state_dict: Dict[str, torch.Tensor] = torch.load(args.lora_path, map_location='cpu')
print(f'Merging') print(f'Merging')
with open(args.src_path, 'rb') as in_file, open(args.dest_path, 'wb') as out_file: with open(args.src_path, 'rb') as in_file, open(args.dest_path, 'wb') as out_file:
# noinspection PyTypeChecker # noinspection PyTypeChecker
header: tuple[int, int, int, int, int, int] = struct.unpack('=iiiiii', in_file.read(6 * 4)) header: Tuple[int, int, int, int, int, int] = struct.unpack('=iiiiii', in_file.read(6 * 4))
assert header[0] == 0x67676d66, 'Invalid magic value' assert header[0] == 0x67676d66, 'Invalid magic value'
assert header[1] == 100, 'Invalid version number' assert 100 <= header[1] <= 101, 'Invalid version number'
assert header[5] == 0 or header[5] == 1, 'Only FP32 and FP16 models are supported' assert header[5] == 0 or header[5] == 1, 'Only FP32 and FP16 models are supported'
out_file.write(struct.pack('=iiiiii', *header)) out_file.write(struct.pack('=iiiiii', *header))
@ -68,9 +69,9 @@ def main() -> None:
dim_count, key_length, data_type = struct.unpack('=iii', parameter_header_bytes) dim_count, key_length, data_type = struct.unpack('=iii', parameter_header_bytes)
# noinspection PyTypeChecker # noinspection PyTypeChecker
shape: tuple[int] = struct.unpack('=' + 'i' * dim_count, in_file.read(dim_count * 4)) shape: Tuple[int] = struct.unpack('=' + 'i' * dim_count, in_file.read(dim_count * 4))
# ggml order to PyTorch # ggml order to PyTorch
shape: list[int] = [d for d in reversed(shape)] shape: List[int] = [d for d in reversed(shape)]
key: str = in_file.read(key_length).decode('utf-8') key: str = in_file.read(key_length).decode('utf-8')

View File

@ -1,6 +1,6 @@
# Quantizes rwkv.cpp model file from FP32 or FP16. # Quantizes rwkv.cpp model file from FP32 or FP16.
# Available format names are in rwkv_cpp_shared_library.QUANTIZED_FORMAT_NAMES # Available format names are in rwkv_cpp_shared_library.QUANTIZED_FORMAT_NAMES
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-FP32.bin C:\rwkv.cpp-169M-Q4_2.bin Q4_2 # Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-FP32.bin C:\rwkv.cpp-169M-Q5_1.bin Q5_1
import argparse import argparse
import rwkv_cpp_shared_library import rwkv_cpp_shared_library
@ -11,7 +11,7 @@ def parse_args():
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16') parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16')
parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file') parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten') parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
parser.add_argument('format_name', help='Format name, one of ' + ', '.join(format_names), type=str, choices=format_names, default='Q4_2') parser.add_argument('format_name', help='Format name, one of ' + ', '.join(format_names), type=str, choices=format_names, default='Q5_1')
return parser.parse_args() return parser.parse_args()
def main() -> None: def main() -> None:

View File

@ -7,7 +7,6 @@ from typing import Optional
QUANTIZED_FORMAT_NAMES = ( QUANTIZED_FORMAT_NAMES = (
'Q4_0', 'Q4_0',
'Q4_1', 'Q4_1',
'Q4_2',
'Q5_0', 'Q5_0',
'Q5_1', 'Q5_1',
'Q8_0' 'Q8_0'

View File

@ -26,7 +26,9 @@ void test_model(const char * model_path, const float * expected_logits, const fl
fprintf(stderr, "Testing %s\n", model_path); fprintf(stderr, "Testing %s\n", model_path);
struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS); struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS);
enum rwkv_error_flags error = rwkv_get_last_error(NULL); enum rwkv_error_flags error = rwkv_get_last_error(NULL);
ASSERT(error == 0, "Unexpected error %d", error);
uint32_t n_vocab = rwkv_get_logits_buffer_element_count(model); uint32_t n_vocab = rwkv_get_logits_buffer_element_count(model);
@ -76,14 +78,12 @@ int main(void) {
-0.160030F, -0.160030F,
-0.370606F, -0.370606F,
0.661480F,
-0.170404F, -0.170404F,
0.278034F, 0.278034F,
0.071216F, 0.071216F,
0.154614F, 0.154614F,
-0.372169F, -0.372169F,
0.658310F,
-0.170043F, -0.170043F,
0.294953F, 0.294953F,
0.065571F, 0.065571F,
@ -94,31 +94,27 @@ int main(void) {
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_0.bin", "Q4_0"); rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_0.bin", "Q4_0");
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_1.bin", "Q4_1"); rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_1.bin", "Q4_1");
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_2.bin", "Q4_2");
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q5_0.bin", "Q5_0"); rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q5_0.bin", "Q5_0");
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q5_1.bin", "Q5_1"); rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q5_1.bin", "Q5_1");
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q8_0.bin", "Q8_0"); rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q8_0.bin", "Q8_0");
test_model("tiny-rwkv-660K-FP32-Q4_0.bin", expected_logits, expected_difference_sum[2]); test_model("tiny-rwkv-660K-FP32-Q4_0.bin", expected_logits, expected_difference_sum[2]);
test_model("tiny-rwkv-660K-FP32-Q4_1.bin", expected_logits, expected_difference_sum[3]); test_model("tiny-rwkv-660K-FP32-Q4_1.bin", expected_logits, expected_difference_sum[3]);
test_model("tiny-rwkv-660K-FP32-Q4_2.bin", expected_logits, expected_difference_sum[4]); test_model("tiny-rwkv-660K-FP32-Q5_0.bin", expected_logits, expected_difference_sum[4]);
test_model("tiny-rwkv-660K-FP32-Q5_0.bin", expected_logits, expected_difference_sum[5]); test_model("tiny-rwkv-660K-FP32-Q5_1.bin", expected_logits, expected_difference_sum[5]);
test_model("tiny-rwkv-660K-FP32-Q5_1.bin", expected_logits, expected_difference_sum[6]); test_model("tiny-rwkv-660K-FP32-Q8_0.bin", expected_logits, expected_difference_sum[6]);
test_model("tiny-rwkv-660K-FP32-Q8_0.bin", expected_logits, expected_difference_sum[7]);
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_0.bin", "Q4_0"); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_0.bin", "Q4_0");
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_1.bin", "Q4_1"); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_1.bin", "Q4_1");
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_2.bin", "Q4_2");
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q5_0.bin", "Q5_0"); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q5_0.bin", "Q5_0");
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q5_1.bin", "Q5_1"); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q5_1.bin", "Q5_1");
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q8_0.bin", "Q8_0"); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q8_0.bin", "Q8_0");
test_model("tiny-rwkv-660K-FP16-Q4_0.bin", expected_logits, expected_difference_sum[8]); test_model("tiny-rwkv-660K-FP16-Q4_0.bin", expected_logits, expected_difference_sum[7]);
test_model("tiny-rwkv-660K-FP16-Q4_1.bin", expected_logits, expected_difference_sum[9]); test_model("tiny-rwkv-660K-FP16-Q4_1.bin", expected_logits, expected_difference_sum[8]);
test_model("tiny-rwkv-660K-FP16-Q4_2.bin", expected_logits, expected_difference_sum[10]); test_model("tiny-rwkv-660K-FP16-Q5_0.bin", expected_logits, expected_difference_sum[9]);
test_model("tiny-rwkv-660K-FP16-Q5_0.bin", expected_logits, expected_difference_sum[11]); test_model("tiny-rwkv-660K-FP16-Q5_1.bin", expected_logits, expected_difference_sum[10]);
test_model("tiny-rwkv-660K-FP16-Q5_1.bin", expected_logits, expected_difference_sum[12]); test_model("tiny-rwkv-660K-FP16-Q8_0.bin", expected_logits, expected_difference_sum[11]);
test_model("tiny-rwkv-660K-FP16-Q8_0.bin", expected_logits, expected_difference_sum[13]);
free(expected_logits); free(expected_logits);