Add support for Q5_0, Q5_1 and Q8_0 formats; remove Q4_1_O format (#44)

* Remove Q4_3 support

* Add Q5_0, Q5_1, Q8_0 support

* Add more clear message when loading Q4_3 model

* Remove Q4_1_O format

* Fix indentation in .gitmodules

* Simplify sanitizer matrix
This commit is contained in:
Alex 2023-04-29 17:39:11 +05:00 committed by GitHub
parent c736ef5411
commit 1198892888
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 230 additions and 422 deletions

View File

@ -25,7 +25,6 @@ jobs:
matrix: matrix:
sanitizer: [ADDRESS, THREAD, UNDEFINED] sanitizer: [ADDRESS, THREAD, UNDEFINED]
build_type: [Debug, Release] build_type: [Debug, Release]
accelerate: [ON, OFF]
steps: steps:
- name: Clone - name: Clone
@ -45,7 +44,7 @@ jobs:
run: | run: |
mkdir build mkdir build
cd build cd build
cmake .. -DRWKV_SANITIZE_${{ matrix.sanitizer }}=ON -DGGML_SANITIZE_${{ matrix.sanitizer }}=ON -DRWKV_ACCELERATE=${{ matrix.accelerate }} -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} cmake .. -DRWKV_SANITIZE_${{ matrix.sanitizer }}=ON -DGGML_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
cmake --build . --config ${{ matrix.build_type }} cmake --build . --config ${{ matrix.build_type }}
- name: Test - name: Test

1
.gitmodules vendored
View File

@ -1,3 +1,4 @@
[submodule "ggml"] [submodule "ggml"]
path = ggml path = ggml
url = https://github.com/saharNooby/ggml url = https://github.com/saharNooby/ggml
branch = master-2023-04-29

53
FILE_FORMAT.md Normal file
View File

@ -0,0 +1,53 @@
# rwkv.cpp file format
This format is used by `rwkv.cpp` to store RWKV model checkpoints.
Preferred file extension: `.bin`
Specification in C-like pseudocode:
```
RWKVModelFile {
// All ints and floats are in machine byte order.
// Magic is "ggml" string bytes.
int32 magic = 0x67676d66;
int32 version = 100;
int32 n_vocab;
int32 n_embed;
int32 n_layer;
// Data type of most of the parameters. See "Data types" below for possible values.
int32 data_type;
// Read until EOF.
Parameter[] parameters;
}
Parameter {
int32 dim_count;
int32 key_length;
// Data type of the parameter. See "Data types" below for possible values.
int32 data_type;
// Compared to PyTorch's parameter.shape, dimension order is reversed here!
int32[dim_count] shape;
// Keys are like "emb.weight", "block.0.ln1.weight".
uint8[key_length] key_utf8;
// Length of the data array depends on parameter data type:
// - FP32: 4 * element_count
// - FP16: 2 * element_count
// - QX_Y (quantized): element_count / QKX_Y * sizeof(block_qx_y)
// See ggml.c for values of QK and block sizes of specific formats.
byte[] data;
}
```
## Data types
- 0: `FP32`
- 1: `FP16`
- 2: `Q4_0`
- 3: `Q4_1`
- 4: *unused*
- 5: `Q4_2`
- 6: *unused*
- 7: `Q5_0`
- 8: `Q5_1`
- 9: `Q8_0`

View File

@ -2,18 +2,30 @@
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml).
Besides the usual **FP32**, it supports **FP16** and **quantized INT4** inference on CPU. This project is **CPU only**. Besides the usual **FP32**, it supports **FP16**, **quantized INT4** and **quantized INT8** inference. This project is **CPU only**.
RWKV is a novel large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts.
This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](rwkv%2Frwkv_cpp_model.py) for it. This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](rwkv%2Frwkv_cpp_model.py) for it.
RWKV is a novel large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts.
Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py). Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py).
**TODO (contributions welcome!)**: ### Quality and performance
1. Measure latency and perplexity of different model sizes (169M to 14B) and data types (`FP32`, `FP16`, `Q4_0`, `Q4_1`, `Q4_1_O`) If you use `rwkv.cpp` for anything serious, please [test all available formats for perplexity and latency](rwkv%2Fmeasure_pexplexity.py) on a representative dataset, and decide which trade-off is best for you.
2. Make required memory calculation more robust (see [#4](https://github.com/saharNooby/rwkv.cpp/issues/4))
Below table is for reference only. Measurements were made on 4C/8T x86 CPU with AVX2, 4 threads.
| Format | Perplexity (169M) | Latency, ms (1.5B) | File size, GB (1.5B) |
|-----------|-------------------|--------------------|----------------------|
| `Q4_0` | 17.507 | *76* | **1.53** |
| `Q4_1` | 17.187 | **72** | 1.68 |
| `Q4_2` | 17.060 | 85 | **1.53** |
| `Q5_0` | 16.194 | 78 | *1.60* |
| `Q5_1` | 15.851 | 81 | 1.68 |
| `Q8_0` | *15.652* | 89 | 2.13 |
| `FP16` | **15.623** | 117 | 2.82 |
| `FP32` | **15.623** | 198 | 5.64 |
## How to use ## How to use
@ -77,26 +89,16 @@ python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-802
#### 3.1. Optionally, quantize the model #### 3.1. Optionally, quantize the model
To convert the model into INT4 quantized format, run: To convert the model into one of quantized formats from the table above, run:
```commandline ```commandline
# Windows # Windows
python rwkv\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q4_1_O.bin 4 python rwkv\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q4_2.bin Q4_2
# Linux / MacOS # Linux / MacOS
python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q4_1_O.bin 4 python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q4_2.bin Q4_2
``` ```
Formats available:
- `6`: `Q4_3`, OK quality, fast.
- `5`: `Q4_2`, poor quality, fast.
- `4`: `Q4_1_O`, best quality, slow (20% slower than `FP16`).
- `3`: `Q4_1`, poor quality, very fast.
- `2`: `Q4_0`, worst quality, very fast.
If you use `rwkv.cpp` for anything serious (just having fun is serious enough!), please [test all available formats for perplexity and latency](rwkv%2Fmeasure_pexplexity.py) on a representative dataset, and decide which trade-off is best for you.
### 4. Run the model ### 4. Run the model
**Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/) and [tokenizers](https://pypi.org/project/tokenizers/). **Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/) and [tokenizers](https://pypi.org/project/tokenizers/).
@ -107,20 +109,20 @@ To generate some text, run:
```commandline ```commandline
# Windows # Windows
python rwkv\generate_completions.py C:\rwkv.cpp-169M-Q4_1_O.bin python rwkv\generate_completions.py C:\rwkv.cpp-169M-Q4_2.bin
# Linux / MacOS # Linux / MacOS
python rwkv/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q4_1_O.bin python rwkv/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q4_2.bin
``` ```
To chat with a bot, run: To chat with a bot, run:
```commandline ```commandline
# Windows # Windows
python rwkv\chat_with_bot.py C:\rwkv.cpp-169M-Q4_1_O.bin python rwkv\chat_with_bot.py C:\rwkv.cpp-169M-Q4_2.bin
# Linux / MacOS # Linux / MacOS
python rwkv/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q4_1_O.bin python rwkv/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q4_2.bin
``` ```
Edit [generate_completions.py](rwkv%2Fgenerate_completions.py) or [chat_with_bot.py](rwkv%2Fchat_with_bot.py) to change prompts and sampling settings. Edit [generate_completions.py](rwkv%2Fgenerate_completions.py) or [chat_with_bot.py](rwkv%2Fchat_with_bot.py) to change prompts and sampling settings.

2
ggml

@ -1 +1 @@
Subproject commit bfa8d5b5ab4ffbae4c5f97525c3890f38619056d Subproject commit a0687a3a3c4b31811219d7a61adfb66230b09201

124
rwkv.cpp
View File

@ -15,8 +15,6 @@
// --- Utilities --- // --- Utilities ---
#define FP32_SIZE 4
// Checks that x is not false. If x is false, prints fancy message to stderr and returns 0. // Checks that x is not false. If x is false, prints fancy message to stderr and returns 0.
#define RWKV_ASSERT_FALSE(x, ...) \ #define RWKV_ASSERT_FALSE(x, ...) \
do { \ do { \
@ -43,16 +41,34 @@ bool read_int32(FILE * file, int32_t * dest) {
return true; return true;
} }
static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[7] = { #define GGML_TYPE_UNKNOWN GGML_TYPE_COUNT
#define FORMAT_TYPE_COUNT 10
static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[FORMAT_TYPE_COUNT] = {
GGML_TYPE_F32, GGML_TYPE_F32,
GGML_TYPE_F16, GGML_TYPE_F16,
GGML_TYPE_Q4_0, GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1, GGML_TYPE_Q4_1,
GGML_TYPE_Q4_1_O, GGML_TYPE_UNKNOWN, // Unused
GGML_TYPE_Q4_2, GGML_TYPE_Q4_2,
GGML_TYPE_Q4_3 GGML_TYPE_UNKNOWN, // Unused
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0
}; };
static int32_t format_name_to_format_type(const char * format_name) {
if (strcmp(format_name, "Q4_0") == 0) return 2;
if (strcmp(format_name, "Q4_1") == 0) return 3;
if (strcmp(format_name, "Q4_2") == 0) return 5;
if (strcmp(format_name, "Q5_0") == 0) return 7;
if (strcmp(format_name, "Q5_1") == 0) return 8;
if (strcmp(format_name, "Q8_0") == 0) return 9;
return -1;
}
// --- Model definition and loading utilities --- // --- Model definition and loading utilities ---
struct rwkv_layer { struct rwkv_layer {
@ -206,7 +222,17 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
RWKV_ASSERT_NULL(model->n_layer > 0, "Non-positive n_layer %d", model->n_layer); RWKV_ASSERT_NULL(model->n_layer > 0, "Non-positive n_layer %d", model->n_layer);
read_int32(file, &(model->data_type)); read_int32(file, &(model->data_type));
RWKV_ASSERT_NULL(model->data_type >= 0 && model->data_type <= 6, "Unsupported model data type %d", model->data_type); RWKV_ASSERT_NULL(model->data_type >= 0 && model->data_type < FORMAT_TYPE_COUNT, "Unsupported model data type %d", model->data_type);
RWKV_ASSERT_NULL(
model->data_type != 4,
"Models in Q4_1_O format cannot be loaded anymore because the format was removed. You need to quantize the model into another format"
);
RWKV_ASSERT_NULL(
model->data_type != 6,
"Models in Q4_3 format cannot be loaded anymore because the format was removed. You need to quantize the model into another format"
);
// Parameter tensors would take at least this amount in memory. // Parameter tensors would take at least this amount in memory.
size_t file_size; size_t file_size;
@ -256,10 +282,12 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
int32_t data_type; int32_t data_type;
read_int32(file, &data_type); read_int32(file, &data_type);
RWKV_ASSERT_NULL(data_type >= 0 && data_type <= 6, "Unsupported parameter data type %d", data_type); RWKV_ASSERT_NULL(data_type >= 0 && data_type < FORMAT_TYPE_COUNT, "Unsupported parameter data type %d", data_type);
ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type]; ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type];
RWKV_ASSERT_NULL(ggml_data_type != GGML_TYPE_UNKNOWN, "Unsupported parameter data type %d", data_type);
struct ggml_tensor * tensor; struct ggml_tensor * tensor;
int32_t x = -1; int32_t x = -1;
@ -356,7 +384,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
// self.layer_norm(x, self.w.blocks[i].ln1) // self.layer_norm(x, self.w.blocks[i].ln1)
struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);
// state[5 * i + 1] // state[5 * i + 1]
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * FP32_SIZE); struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * sizeof(float));
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
// xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
// xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
@ -391,9 +419,9 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
// aa = state[5 * i + 2] // aa = state[5 * i + 2]
// bb = state[5 * i + 3] // bb = state[5 * i + 3]
// pp = state[5 * i + 4] // pp = state[5 * i + 4]
struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * FP32_SIZE); struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * sizeof(float));
struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * FP32_SIZE); struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * sizeof(float));
struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * FP32_SIZE); struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * sizeof(float));
// ww = time_first + k // ww = time_first + k
struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k); struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k);
@ -456,7 +484,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
// self.layer_norm(x, self.w.blocks[i].ln2) // self.layer_norm(x, self.w.blocks[i].ln2)
struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias);
// state[5 * i + 0] // state[5 * i + 0]
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * FP32_SIZE); struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * sizeof(float));
// xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
// xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
struct ggml_tensor * xk = ggml_add( struct ggml_tensor * xk = ggml_add(
@ -549,12 +577,12 @@ bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float
for (int i = 0; i < n_layer; i++) { for (int i = 0; i < n_layer; i++) {
// state[5 * i + 4] = -1e30 // state[5 * i + 4] = -1e30
ggml_set_f32( ggml_set_f32(
ggml_view_1d(ctx->ctx, ctx->state, n_embed, (5 * i + 4) * n_embed * FP32_SIZE), ggml_view_1d(ctx->ctx, ctx->state, n_embed, (5 * i + 4) * n_embed * sizeof(float)),
-1e30F -1e30F
); );
} }
} else { } else {
memcpy(ctx->state->data, state_in, ctx->state->ne[0] * FP32_SIZE); memcpy(ctx->state->data, state_in, ctx->state->ne[0] * sizeof(float));
} }
ggml_graph_compute(ctx->ctx, ctx->graph); ggml_graph_compute(ctx->ctx, ctx->graph);
@ -562,10 +590,10 @@ bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float
for (size_t i = 0; i < size_t(n_layer * 5); i++) { for (size_t i = 0; i < size_t(n_layer * 5); i++) {
struct ggml_tensor * part = ctx->state_parts[i]; struct ggml_tensor * part = ctx->state_parts[i];
memcpy(state_out + i * n_embed, part->data, part->ne[0] * FP32_SIZE); memcpy(state_out + i * n_embed, part->data, part->ne[0] * sizeof(float));
} }
memcpy(logits_out, ctx->logits->data, ctx->logits->ne[0] * FP32_SIZE); memcpy(logits_out, ctx->logits->data, ctx->logits->ne[0] * sizeof(float));
return true; return true;
} }
@ -579,8 +607,14 @@ void rwkv_free(struct rwkv_context * ctx) {
free(ctx); free(ctx);
} }
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) { bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name) {
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4 || q_type == 5 || q_type == 6, "Unsupported quantization type %d", q_type); int32_t format_type = format_name_to_format_type(format_name);
RWKV_ASSERT_FALSE(format_type != -1, "Unsupported format \"%s\"", format_name);
ggml_type type = FORMAT_TYPE_TO_GGML_TYPE[format_type];
RWKV_ASSERT_FALSE(type != GGML_TYPE_UNKNOWN, "Unsupported format \"%s\"", format_name);
// Needed to initialize FP16 lookup table // Needed to initialize FP16 lookup table
{ {
@ -589,8 +623,6 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
ggml_free(ctx); ggml_free(ctx);
} }
ggml_type type = FORMAT_TYPE_TO_GGML_TYPE[q_type];
printf("Loading model from '%s'\n", model_file_path_in); printf("Loading model from '%s'\n", model_file_path_in);
auto finp = std::ifstream(model_file_path_in, std::ios::binary); auto finp = std::ifstream(model_file_path_in, std::ios::binary);
@ -623,7 +655,7 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
RWKV_ASSERT_FALSE(data_type == 0 || data_type == 1, "Unsupported data type %d, only FP32 and FP16 can be quantized", data_type); RWKV_ASSERT_FALSE(data_type == 0 || data_type == 1, "Unsupported data type %d, only FP32 and FP16 can be quantized", data_type);
data_type = q_type; data_type = format_type;
fout.write((char *) &n_vocab, sizeof(n_vocab)); fout.write((char *) &n_vocab, sizeof(n_vocab));
fout.write((char *) &n_embed, sizeof(n_embed)); fout.write((char *) &n_embed, sizeof(n_embed));
@ -657,6 +689,12 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
break; break;
} }
RWKV_ASSERT_FALSE(parameter_data_type >= 0 && parameter_data_type < FORMAT_TYPE_COUNT, "Invalid parameter data type %d", parameter_data_type);
ggml_type parameter_ggml_type = FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type];
RWKV_ASSERT_FALSE(parameter_ggml_type != GGML_TYPE_UNKNOWN, "Invalid parameter data type %d", parameter_data_type);
int32_t nelements = 1; int32_t nelements = 1;
int32_t ne[2] = { 1, 1 }; int32_t ne[2] = { 1, 1 };
for (int i = 0; i < n_dims; ++i) { for (int i = 0; i < n_dims; ++i) {
@ -668,18 +706,9 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
finp.read(&name[0], key_length); finp.read(&name[0], key_length);
{ {
static const char * parameter_data_type_str[] = { printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ggml_type_name(parameter_ggml_type));
"F32",
"F16",
"Q4_0",
"Q4_1",
"Q4_1_O",
"Q4_2",
"Q4_3"
};
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]);
total_size_orig += (size_t) (nelements * ggml_type_sizef(FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type])); total_size_orig += (size_t) (nelements * ggml_type_sizef(parameter_ggml_type));
} }
// Quantize only 2D tensors, except embedding and head matrices. // Quantize only 2D tensors, except embedding and head matrices.
@ -708,7 +737,7 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float)); finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
} }
parameter_data_type = q_type; parameter_data_type = format_type;
} else { } else {
const int bytes_per_element = (parameter_data_type == 0) ? sizeof(float) : sizeof(uint16_t); const int bytes_per_element = (parameter_data_type == 0) ? sizeof(float) : sizeof(uint16_t);
data_u8.resize(nelements * bytes_per_element); data_u8.resize(nelements * bytes_per_element);
@ -735,27 +764,24 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
switch (type) { switch (type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
{
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break; break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
{
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break; break;
case GGML_TYPE_Q4_1_O:
{
cur_size = ggml_quantize_q4_1_o(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_2:
{
cur_size = ggml_quantize_q4_2(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); cur_size = ggml_quantize_q4_2(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break; break;
case GGML_TYPE_Q4_3: case GGML_TYPE_Q5_0:
{ cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
cur_size = ggml_quantize_q4_3(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); break;
} break; case GGML_TYPE_Q5_1:
default: cur_size = ggml_quantize_q5_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
{ break;
case GGML_TYPE_Q8_0:
cur_size = ggml_quantize_q8_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
break;
default: {
fprintf(stderr, "unsupported quantization type %d\n", type); fprintf(stderr, "unsupported quantization type %d\n", type);
return false; return false;
} }

13
rwkv.h
View File

@ -52,12 +52,19 @@ extern "C" {
// Frees all allocated memory and the context. // Frees all allocated memory and the context.
RWKV_API void rwkv_free(struct rwkv_context * ctx); RWKV_API void rwkv_free(struct rwkv_context * ctx);
// Quantizes FP32 or FP16 model to one of INT4 formats. // Quantizes FP32 or FP16 model to one of quantized formats.
// Returns false on any error. Error messages would be printed to stderr. // Returns false on any error. Error messages would be printed to stderr.
// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16. // - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
// - model_file_path_out: quantized model will be written here. // - model_file_path_out: quantized model will be written here.
// - q_type: set to 2 for GGML_TYPE_Q4_0, 3 for GGML_TYPE_Q4_1, 4 for GGML_TYPE_Q4_1_O, 5 for GGML_TYPE_Q4_2, 6 for GGML_TYPE_Q4_3. // - format_name: must be one of available format names below.
RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type); // Available format names:
// - Q4_0
// - Q4_1
// - Q4_2
// - Q5_0
// - Q5_1
// - Q8_0
RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name);
// Returns system information string. // Returns system information string.
RWKV_API const char * rwkv_get_system_info_string(void); RWKV_API const char * rwkv_get_system_info_string(void);

View File

@ -1,39 +1,7 @@
# Converts an RWKV model checkpoint to an rwkv.cpp compatible file. # Converts an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file.
# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32 # Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32
# Get model checkpoints from https://huggingface.co/BlinkDL # Get model checkpoints from https://huggingface.co/BlinkDL
# See FILE_FORMAT.md for the documentation on the file format.
# File format:
#
# RWKVModelFile {
# // All ints and floats are in machine byte order.
# // Magic is "ggml" string bytes.
# int32 magic = 0x67676d66;
# int32 version = 100;
# int32 n_vocab;
# int32 n_embed;
# int32 n_layer;
# // 0 if float32, 1 if float16, 2 if Q4_0, 3 if Q4_1, 4 if Q4_1_O, 5 if Q4_2, 6 if Q4_3.
# int32 data_type;
# // Read until EOF.
# Parameter[] parameters;
# }
#
# Parameter {
# int32 dim_count;
# int32 key_length;
# // 0 if float32, 1 if float16, 2 if Q4_0, 3 if Q4_1, 4 if Q4_1_O, 5 if Q4_2, 6 if Q4_3.
# int32 data_type;
# // Compared to PyTorch's tensor.shape, dimension order is reversed here!
# int32[dim_count] shape;
# // Keys are like "emb.weight", "block.0.ln1.weight".
# uint8[key_length] key_utf8;
# // float32: 4 * element_count bytes.
# // float16: 2 * element_count bytes.
# // Q4_0: element_count / 32 * 20 bytes.
# // Q4_1: element_count / 32 * 24 bytes.
# // Q4_1_O: element_count / 32 * 24 bytes.
# byte[] data;
# }
import os import os
import argparse import argparse
@ -42,7 +10,7 @@ import torch
from typing import Dict from typing import Dict
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Convert an RWKV model checkpoint to an rwkv.cpp compatible file') parser = argparse.ArgumentParser(description='Convert an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file')
parser.add_argument('src_path', help='Path to PyTorch checkpoint file') parser.add_argument('src_path', help='Path to PyTorch checkpoint file')
parser.add_argument('dest_path', help='Path to rwkv.cpp checkpoint file, will be overwritten') parser.add_argument('dest_path', help='Path to rwkv.cpp checkpoint file, will be overwritten')
parser.add_argument('data_type', help='Data type, float16 or float32', type=str, choices=['float16', 'float32'], default='float32') parser.add_argument('data_type', help='Data type, float16 or float32', type=str, choices=['float16', 'float32'], default='float32')

View File

@ -1,19 +1,17 @@
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0, Q4_1, Q4_1_O, Q4_2, Q4_3. # Quantizes rwkv.cpp model file from FP32 or FP16.
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1_o.bin 4 # Available format names are in rwkv_cpp_shared_library.QUANTIZED_FORMAT_NAMES
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-FP32.bin C:\rwkv.cpp-169M-Q4_2.bin Q4_2
import argparse import argparse
import rwkv_cpp_shared_library import rwkv_cpp_shared_library
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1') format_names = rwkv_cpp_shared_library.QUANTIZED_FORMAT_NAMES
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16')
parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file') parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten') parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
parser.add_argument('data_type', help='Data type, ' parser.add_argument('format_name', help='Format name, one of ' + ', '.join(format_names), type=str, choices=format_names, default='Q4_2')
'2 (GGML_TYPE_Q4_0), '
'3 (GGML_TYPE_Q4_1), '
'4 (GGML_TYPE_Q4_1_O), '
'5 (Q4_2), '
'6 (Q4_3)', type=int, choices=[2, 3, 4, 5, 6], default=4)
return parser.parse_args() return parser.parse_args()
def main() -> None: def main() -> None:
@ -24,7 +22,7 @@ def main() -> None:
library.rwkv_quantize_model_file( library.rwkv_quantize_model_file(
args.src_path, args.src_path,
args.dest_path, args.dest_path,
args.data_type args.format_name
) )
print('Done') print('Done')

View File

@ -4,6 +4,14 @@ import ctypes
import pathlib import pathlib
from typing import Optional from typing import Optional
QUANTIZED_FORMAT_NAMES = (
'Q4_0',
'Q4_1',
'Q4_2',
'Q5_0',
'Q5_1',
'Q8_0'
)
P_FLOAT = ctypes.POINTER(ctypes.c_float) P_FLOAT = ctypes.POINTER(ctypes.c_float)
@ -54,7 +62,7 @@ class RWKVSharedLibrary:
self.library.rwkv_free.argtypes = [ctypes.c_void_p] self.library.rwkv_free.argtypes = [ctypes.c_void_p]
self.library.rwkv_free.restype = None self.library.rwkv_free.restype = None
self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_uint32] self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p]
self.library.rwkv_quantize_model_file.restype = ctypes.c_bool self.library.rwkv_quantize_model_file.restype = ctypes.c_bool
self.library.rwkv_get_system_info_string.argtypes = [] self.library.rwkv_get_system_info_string.argtypes = []
@ -149,7 +157,7 @@ class RWKVSharedLibrary:
ctx.ptr = ctypes.cast(0, ctypes.c_void_p) ctx.ptr = ctypes.cast(0, ctypes.c_void_p)
def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, q_type: int) -> None: def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, format_name: str) -> None:
""" """
Quantizes FP32 or FP16 model to one of INT4 formats. Quantizes FP32 or FP16 model to one of INT4 formats.
Throws an exception in case of any error. Error messages would be printed to stderr. Throws an exception in case of any error. Error messages would be printed to stderr.
@ -160,14 +168,16 @@ class RWKVSharedLibrary:
Path to model file in ggml format, must be either FP32 or FP16. Path to model file in ggml format, must be either FP32 or FP16.
model_file_path_out : str model_file_path_out : str
Quantized model will be written here. Quantized model will be written here.
q_type : int format_name : str
Set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1. One of QUANTIZED_FORMAT_NAMES.
""" """
assert format_name in QUANTIZED_FORMAT_NAMES, f'Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}'
assert self.library.rwkv_quantize_model_file( assert self.library.rwkv_quantize_model_file(
model_file_path_in.encode('utf-8'), model_file_path_in.encode('utf-8'),
model_file_path_out.encode('utf-8'), model_file_path_out.encode('utf-8'),
ctypes.c_uint32(q_type) format_name.encode('utf-8')
), 'rwkv_quantize_model_file failed, check stderr' ), 'rwkv_quantize_model_file failed, check stderr'
def rwkv_get_system_info_string(self) -> str: def rwkv_get_system_info_string(self) -> str:

View File

@ -10,6 +10,4 @@ file(COPY tiny-rwkv-660K-FP16.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
file(COPY expected_logits.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) file(COPY expected_logits.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
rwkv_add_test(test_ggml_basics.c) rwkv_add_test(test_ggml_basics.c)
rwkv_add_test(test_Q4_1_O.c)
rwkv_add_test(test_Q4_1_O_large_matmul.c)
rwkv_add_test(test_tiny_rwkv.c) rwkv_add_test(test_tiny_rwkv.c)

View File

@ -1,174 +0,0 @@
// Tests that Q4_1_O basics (quantization, dequantization, matmul) work.
#include "ggml.h"
#include "rwkv.h"
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#define GET_ELEMENT_F32(tensor, i) (((float *) tensor->data)[i])
#define SET_ELEMENT_F32(tensor, i, value) ((float *) tensor->data)[i] = value
#define ASSERT(x, ...) {\
if (!(x)) {\
fprintf(stderr, "*** Assertion failed ***\n");\
fprintf(stderr, __VA_ARGS__);\
fprintf(stderr, "\n%s:%d\n", __FILE__, __LINE__);\
abort();\
}\
}
// ---
#define QK 32
// Copied from ggml.c
typedef struct {
ggml_fp16_t d;
ggml_fp16_t m;
uint16_t outlier_index;
ggml_fp16_t outlier_value;
uint8_t qs[QK / 2];
} block_q4_1_o;
int main(int argc, const char ** argv) {
ASSERT(sizeof(block_q4_1_o) == 8 + QK / 2, "Wrong q4_1_o block size/padding");
// Needed to initialize FP16 lookup table
{
struct ggml_init_params params = { 0, NULL, false };
struct ggml_context * ctx = ggml_init(params);
ggml_free(ctx);
}
fprintf(stderr, "System info: %s\n", rwkv_get_system_info_string());
quantize_fns_t quantize_fns = ggml_internal_get_quantize_fn(GGML_TYPE_Q4_1_O);
float src[QK];
uint8_t dest[24];
// 1..32
for (int i = 0; i < QK; i++) {
src[i] = (float) (i + 1);
}
// --- Quantization ---
(quantize_fns.quantize_row_q)(src, dest, QK);
float delta_result = ggml_fp16_to_fp32(((block_q4_1_o *) dest)->d);
float delta_expected = (src[30] - src[0]) / ((1 << 4) - 1);
ASSERT(delta_result == delta_expected, "%f, %f", delta_result, delta_expected);
float min_result = ggml_fp16_to_fp32(((block_q4_1_o *) dest)->m);
float min_expected = src[0];
ASSERT(min_result == min_expected, "%f, %f", min_result, min_expected);
uint16_t outlier_index = ((block_q4_1_o *) dest)->outlier_index;
uint16_t outlier_index_expected = 31;
ASSERT(outlier_index == outlier_index_expected, "%d, %d", outlier_index, outlier_index_expected);
float outlier_value_result = ggml_fp16_to_fp32(((block_q4_1_o *) dest)->outlier_value);
float outlier_value_expected = src[31];
ASSERT(outlier_value_result == outlier_value_expected, "%f, %f", outlier_value_result, outlier_value_expected);
for (int i = 0; i < QK - 1; i++) {
uint8_t q4_result = (i % 2) ? (dest[sizeof(float) * 2 + i / 2] >> 4) : (dest[sizeof(float) * 2 + i / 2] & 0xF);
uint8_t q4_expected = roundf((src[i] - min_expected) / delta_expected);
ASSERT(q4_result == q4_expected, "%d: %d, %d", i, q4_result, q4_expected);
}
// --- Dequantization ---
float dequantized[QK];
(quantize_fns.dequantize_row_q)(dest, dequantized, QK);
for (int i = 0; i < QK; i++) {
float actual = dequantized[i];
float expected = src[i];
float diff = fabsf(actual - expected);
// Difference looks huge, but the range is 0..31 -- compared to the range, it is not that huge
ASSERT(diff <= 1.0F, "%d: %f, %f", i, actual, expected);
}
// --- Matmul ---
struct ggml_init_params params = {
.mem_size = 16 * 1024,
.mem_buffer = NULL,
.no_alloc = false,
};
struct ggml_context * ctx = ggml_init(params);
struct ggml_tensor * mat = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, QK, 4);
// Note rare outlier values: -88, -83, etc.
float mat_values[QK * 4] = {
-1.371795F, -88.901100F, -0.412088F, -0.486081F, 1.280220F, -1.067033F, 1.371795F, 1.099267F, 1.079487F, -0.204029F, 1.237729F, -0.563736F,
-0.633333F, 0.700000F, 0.211355F, 0.510989F, -0.981319F, -0.456777F, 0.011355F, 0.911722F, -0.976191F, 0.078022F, -0.757143F, -0.744689F,
-0.768865F, 0.656777F, 0.141026F, -0.038462F, 1.023810F, 1.221612F, -0.393773F, 1.135165F, -1.341758F, -83.113556F, 1.291209F, 0.313187F,
1.032601F, -0.401099F, 1.482418F, 0.823077F, 0.619414F, -0.583516F, 0.527106F, 1.489011F, 1.327839F, 0.846520F, -1.437729F, 0.461172F,
1.031136F, 0.293407F, 0.284615F, -1.102198F, -1.481685F, 0.602564F, -0.480952F, -0.745421F, -1.376190F, -1.319780F, 1.338828F, -1.062637F,
1.266300F, 0.360073F, 1.472894F, 1.063370F, -0.833333F, 49.047626F, -1.229670F, 1.079487F, -0.004762F, -0.696337F, -0.541758F, 0.993773F,
-1.323443F, 0.908059F, -1.059707F, 0.965201F, -0.376923F, 1.158608F, -1.100000F, -1.002564F, -0.355678F, 1.157143F, 0.450916F, -0.497802F,
1.270696F, 0.028205F, 1.075092F, 1.462637F, 0.252381F, -0.579121F, -0.880220F, -0.041392F, -1.017949F, -0.754945F, 0.582784F, -1.193773F,
-1.411355F, 122.014656F, -1.053114F, -0.949084F, 0.448718F, 0.209890F, 0.815751F, 0.071429F, -0.125641F, -0.600366F, -0.914652F, -0.956410F,
-0.278755F, 0.235531F, -0.573260F, -1.484615F, -0.327839F, -0.297070F, -1.195238F, -1.160073F, 0.932967F, -0.606960F, 0.798901F, 0.212088F,
0.113187F, -0.116117F, -0.532967F, 0.077289F, 0.016484F, 1.352747F, -1.487546F, -1.363736F
};
for (int i = 0; i < QK * 4; i++) {
SET_ELEMENT_F32(mat, i, mat_values[i]);
}
struct ggml_tensor * quantized_mat = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_1_O, QK, 4);
int64_t histogram[16];
ggml_quantize_q4_1_o(mat->data, quantized_mat->data, QK * 4, QK, histogram);
struct ggml_tensor * vec = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, QK);
float vec_values[] = {
-0.578388F, -0.770330F, -0.183516F, 0.264103F, 0.585714F, -0.226740F, 1.319048F, 0.652381F,
-1.161538F, 0.428205F, -0.907326F, -0.837729F, 0.673626F, 0.248718F, 0.392308F, -0.225275F,
0.910989F, 0.483150F, -0.669963F, -0.412088F, 0.954945F, 0.826007F, 0.113919F, 0.095604F,
-1.042125F, -1.094872F, 0.589377F, -0.426007F, 0.669231F, -0.243590F, -0.179121F, 0.325641F
};
for (int i = 0; i < QK; i++) {
SET_ELEMENT_F32(vec, i, vec_values[i]);
}
struct ggml_tensor * expected_result = ggml_mul_mat(ctx, mat, vec);
struct ggml_tensor * quantized_result = ggml_mul_mat(ctx, quantized_mat, vec);
struct ggml_cgraph graph = ggml_build_forward(expected_result);
ggml_build_forward_expand(&graph, quantized_result);
graph.n_threads = 2;
ggml_graph_compute(ctx, &graph);
float diff_sum = 0.0F;
for (int i = 0; i < 4; i++) {
fprintf(
stderr,
"[%d] expected %f, actual %f\n",
i,
GET_ELEMENT_F32(expected_result, i),
GET_ELEMENT_F32(quantized_result, i)
);
diff_sum += fabsf(GET_ELEMENT_F32(expected_result, i) - GET_ELEMENT_F32(quantized_result, i));
}
float diff_average = diff_sum / 4;
// If Q4_1_O format works correctly, difference should be this or lower
ASSERT(diff_average <= 0.112F, "Unexpected average difference value %f", diff_average);
ggml_free(ctx);
return 0;
}

View File

@ -1,86 +0,0 @@
// Tests that Q4_1_O matmul on a large matrix works (does not crash, etc.)
#include "ggml.h"
#include "rwkv.h"
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#define GET_ELEMENT_F32(tensor, i) (((float *) tensor->data)[i])
#define SET_ELEMENT_F32(tensor, i, value) ((float *) tensor->data)[i] = value
#define ASSERT(x, ...) {\
if (!(x)) {\
fprintf(stderr, "*** Assertion failed ***\n");\
fprintf(stderr, __VA_ARGS__);\
fprintf(stderr, "\n%s:%d\n", __FILE__, __LINE__);\
abort();\
}\
}
#define RANDOM_FLOAT() (((rand() & 0xFFF) / ((float) 0xFFF) - 0.5F) * 3.0F)
// ---
#define QK 32
#define MATRIX_SIZE 1024
int main(int argc, const char ** argv) {
srand(42);
struct ggml_init_params params = {
.mem_size = 8 * 1024 * 1024,
.mem_buffer = NULL,
.no_alloc = false,
};
struct ggml_context * ctx = ggml_init(params);
struct ggml_tensor * mat = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, MATRIX_SIZE, MATRIX_SIZE);
for (int i = 0; i < MATRIX_SIZE * MATRIX_SIZE; i++) {
SET_ELEMENT_F32(mat, i, RANDOM_FLOAT());
}
// Add some outliers
for (int i = 0; i < MATRIX_SIZE; i++) {
SET_ELEMENT_F32(mat, i * MATRIX_SIZE + 1, RANDOM_FLOAT() * 100.0F);
}
struct ggml_tensor * quantized_mat = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_1_O, MATRIX_SIZE, MATRIX_SIZE);
int64_t histogram[16];
ggml_quantize_q4_1_o(mat->data, quantized_mat->data, MATRIX_SIZE * MATRIX_SIZE, QK, histogram);
struct ggml_tensor * vec = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, MATRIX_SIZE);
for (int i = 0; i < MATRIX_SIZE; i++) {
SET_ELEMENT_F32(vec, i, RANDOM_FLOAT());
}
struct ggml_tensor * expected_result = ggml_mul_mat(ctx, mat, vec);
struct ggml_tensor * quantized_result = ggml_mul_mat(ctx, quantized_mat, vec);
struct ggml_cgraph graph = ggml_build_forward(expected_result);
ggml_build_forward_expand(&graph, quantized_result);
graph.n_threads = 4;
ggml_graph_compute(ctx, &graph);
float diff_sum = 0.0F;
for (int i = 0; i < MATRIX_SIZE; i++) {
diff_sum += fabsf(GET_ELEMENT_F32(expected_result, i) - GET_ELEMENT_F32(quantized_result, i));
}
float diff_average = diff_sum / MATRIX_SIZE;
// More strict test is in test_Q4_1_O.c, here we just do sanity check
ASSERT(diff_average <= 2.0F, "Unexpected average difference value %f", diff_average);
ggml_free(ctx);
return 0;
}

View File

@ -69,49 +69,55 @@ int main(int argc, const char ** argv) {
ASSERT(elements_read == N_VOCAB, "Failed to read expected_logits.bin, read %zd elements", elements_read); ASSERT(elements_read == N_VOCAB, "Failed to read expected_logits.bin, read %zd elements", elements_read);
fclose(file); fclose(file);
float expected_difference_sum[12] = { float expected_difference_sum[14] = {
0.000000F, 0.000000F,
-0.005320F, -0.005320F,
-0.501214F, -0.160030F,
-0.370606F, -0.370606F,
-0.268956F, 0.661480F,
0.676837F, -0.170404F,
0.237099F, 0.278034F,
0.071216F,
-0.501073F, 0.154614F,
-0.372169F, -0.372169F,
-0.244590F, 0.658310F,
0.674874F, -0.170043F,
0.243007F 0.294953F,
0.065571F,
}; };
test_model("tiny-rwkv-660K-FP32.bin", expected_logits, expected_difference_sum[0]); test_model("tiny-rwkv-660K-FP32.bin", expected_logits, expected_difference_sum[0]);
test_model("tiny-rwkv-660K-FP16.bin", expected_logits, expected_difference_sum[1]); test_model("tiny-rwkv-660K-FP16.bin", expected_logits, expected_difference_sum[1]);
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_0.bin", 2); rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_0.bin", "Q4_0");
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_1.bin", 3); rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_1.bin", "Q4_1");
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_1_O.bin", 4); rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_2.bin", "Q4_2");
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_2.bin", 5); rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q5_0.bin", "Q5_0");
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_3.bin", 6); rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q5_1.bin", "Q5_1");
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q8_0.bin", "Q8_0");
test_model("tiny-rwkv-660K-FP32-Q4_0.bin", expected_logits, expected_difference_sum[2]); test_model("tiny-rwkv-660K-FP32-Q4_0.bin", expected_logits, expected_difference_sum[2]);
test_model("tiny-rwkv-660K-FP32-Q4_1.bin", expected_logits, expected_difference_sum[3]); test_model("tiny-rwkv-660K-FP32-Q4_1.bin", expected_logits, expected_difference_sum[3]);
test_model("tiny-rwkv-660K-FP32-Q4_1_O.bin", expected_logits, expected_difference_sum[4]); test_model("tiny-rwkv-660K-FP32-Q4_2.bin", expected_logits, expected_difference_sum[4]);
test_model("tiny-rwkv-660K-FP32-Q4_2.bin", expected_logits, expected_difference_sum[5]); test_model("tiny-rwkv-660K-FP32-Q5_0.bin", expected_logits, expected_difference_sum[5]);
test_model("tiny-rwkv-660K-FP32-Q4_3.bin", expected_logits, expected_difference_sum[6]); test_model("tiny-rwkv-660K-FP32-Q5_1.bin", expected_logits, expected_difference_sum[6]);
test_model("tiny-rwkv-660K-FP32-Q8_0.bin", expected_logits, expected_difference_sum[7]);
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_0.bin", 2); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_0.bin", "Q4_0");
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_1.bin", 3); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_1.bin", "Q4_1");
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_1_O.bin", 4); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_2.bin", "Q4_2");
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_2.bin", 5); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q5_0.bin", "Q5_0");
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_3.bin", 6); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q5_1.bin", "Q5_1");
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q8_0.bin", "Q8_0");
test_model("tiny-rwkv-660K-FP16-Q4_0.bin", expected_logits, expected_difference_sum[7]); test_model("tiny-rwkv-660K-FP16-Q4_0.bin", expected_logits, expected_difference_sum[8]);
test_model("tiny-rwkv-660K-FP16-Q4_1.bin", expected_logits, expected_difference_sum[8]); test_model("tiny-rwkv-660K-FP16-Q4_1.bin", expected_logits, expected_difference_sum[9]);
test_model("tiny-rwkv-660K-FP16-Q4_1_O.bin", expected_logits, expected_difference_sum[9]);
test_model("tiny-rwkv-660K-FP16-Q4_2.bin", expected_logits, expected_difference_sum[10]); test_model("tiny-rwkv-660K-FP16-Q4_2.bin", expected_logits, expected_difference_sum[10]);
test_model("tiny-rwkv-660K-FP16-Q4_3.bin", expected_logits, expected_difference_sum[11]); test_model("tiny-rwkv-660K-FP16-Q5_0.bin", expected_logits, expected_difference_sum[11]);
test_model("tiny-rwkv-660K-FP16-Q5_1.bin", expected_logits, expected_difference_sum[12]);
test_model("tiny-rwkv-660K-FP16-Q8_0.bin", expected_logits, expected_difference_sum[13]);
free(expected_logits); free(expected_logits);