diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2049328..d325989 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -8,11 +8,9 @@ on: description: 'Create new release' required: true type: boolean - push: - paths: ['.github/workflows/**', '**/CMakeLists.txt', '**/*.h', '**/*.c', '**/*.cpp'] + push: {} pull_request: types: [opened, synchronize, edited, reopened, review_requested, ready_for_review] - paths: ['**/CMakeLists.txt', '**/*.h', '**/*.c', '**/*.cpp'] env: BRANCH_NAME: ${{ github.head_ref || github.ref_name }} diff --git a/README.md b/README.md index 3e21ced..30511d5 100644 --- a/README.md +++ b/README.md @@ -89,9 +89,13 @@ python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M- Formats available: -- `4`: `Q4_1_O`, OK quality, moderately fast (20% slower than `FP16`). -- `3`: `Q4_1`, worst quality, fast (comparable to `FP16`). -- `2`: `Q4_0`, poor quality, very fast. +- `6`: `Q4_3`, OK quality, fast. +- `5`: `Q4_2`, poor quality, fast. +- `4`: `Q4_1_O`, best quality, slow (20% slower than `FP16`). +- `3`: `Q4_1`, poor quality, very fast. +- `2`: `Q4_0`, worst quality, very fast. + +If you use `rwkv.cpp` for anything serious (just having fun is serious enough!), please [test all available formats for perplexity and latency](rwkv%2Fmeasure_pexplexity.py) on a representative dataset, and decide which trade-off is best for you. ### 4. Run the model diff --git a/ggml b/ggml index 0330904..bfa8d5b 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 03309047d2e65c05ffefbf64c6c4c943e6647c64 +Subproject commit bfa8d5b5ab4ffbae4c5f97525c3890f38619056d diff --git a/rwkv.cpp b/rwkv.cpp index 835d0d8..9441013 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -43,12 +43,14 @@ bool read_int32(FILE * file, int32_t * dest) { return true; } -static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[5] = { +static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[7] = { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, - GGML_TYPE_Q4_1_O + GGML_TYPE_Q4_1_O, + GGML_TYPE_Q4_2, + GGML_TYPE_Q4_3 }; // --- Model definition and loading utilities --- @@ -204,15 +206,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr RWKV_ASSERT_NULL(model->n_layer > 0, "Non-positive n_layer %d", model->n_layer); read_int32(file, &(model->data_type)); - RWKV_ASSERT_NULL( - model->data_type == 0 || - model->data_type == 1 || - model->data_type == 2 || - model->data_type == 3 || - model->data_type == 4, - "Unsupported model data type %d", - model->data_type - ); + RWKV_ASSERT_NULL(model->data_type >= 0 && model->data_type <= 6, "Unsupported model data type %d", model->data_type); // Parameter tensors would take at least this amount in memory. size_t file_size; @@ -262,15 +256,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr int32_t data_type; read_int32(file, &data_type); - RWKV_ASSERT_NULL( - data_type == 0 || - data_type == 1 || - data_type == 2 || - data_type == 3 || - data_type == 4, - "Unsupported parameter data type %d", - data_type - ); + RWKV_ASSERT_NULL(data_type >= 0 && data_type <= 6, "Unsupported parameter data type %d", data_type); ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type]; @@ -581,9 +567,6 @@ bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float memcpy(logits_out, ctx->logits->data, ctx->logits->ne[0] * FP32_SIZE); - // Uncomment to measure used memory for adding the value into get_memory_required_mb. - //fprintf(stderr, "Used mem: %d MB\n", ggml_used_mem(ctx->ctx) / 1024 / 1024); - return true; } @@ -597,7 +580,7 @@ void rwkv_free(struct rwkv_context * ctx) { } bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) { - RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4, "Unsupported quantization type %d", q_type); + RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4 || q_type == 5 || q_type == 6, "Unsupported quantization type %d", q_type); // Needed to initialize FP16 lookup table { @@ -690,7 +673,9 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode "F16", "Q4_0", "Q4_1", - "Q4_1_O" + "Q4_1_O", + "Q4_2", + "Q4_3" }; printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]); @@ -761,6 +746,14 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode { cur_size = ggml_quantize_q4_1_o(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); } break; + case GGML_TYPE_Q4_2: + { + cur_size = ggml_quantize_q4_2(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + } break; + case GGML_TYPE_Q4_3: + { + cur_size = ggml_quantize_q4_3(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + } break; default: { fprintf(stderr, "unsupported quantization type %d\n", type); diff --git a/rwkv.h b/rwkv.h index ca4444e..54538a7 100644 --- a/rwkv.h +++ b/rwkv.h @@ -56,7 +56,7 @@ extern "C" { // Returns false on any error. Error messages would be printed to stderr. // - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16. // - model_file_path_out: quantized model will be written here. - // - q_type: set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1, set to 4 for GGML_TYPE_Q4_1_O. + // - q_type: set to 2 for GGML_TYPE_Q4_0, 3 for GGML_TYPE_Q4_1, 4 for GGML_TYPE_Q4_1_O, 5 for GGML_TYPE_Q4_2, 6 for GGML_TYPE_Q4_3. RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type); // Returns system information string. diff --git a/rwkv/convert_pytorch_to_ggml.py b/rwkv/convert_pytorch_to_ggml.py index f42c316..13f5444 100644 --- a/rwkv/convert_pytorch_to_ggml.py +++ b/rwkv/convert_pytorch_to_ggml.py @@ -12,7 +12,7 @@ # int32 n_vocab; # int32 n_embed; # int32 n_layer; -# // 0 if float32, 1 if float16, 2 if Q4_0, 3 if Q4_1, 4 if Q4_1_O. +# // 0 if float32, 1 if float16, 2 if Q4_0, 3 if Q4_1, 4 if Q4_1_O, 5 if Q4_2, 6 if Q4_3. # int32 data_type; # // Read until EOF. # Parameter[] parameters; @@ -21,7 +21,7 @@ # Parameter { # int32 dim_count; # int32 key_length; -# // 0 if float32, 1 if float16, 2 if Q4_0, 3 if Q4_1, 4 if Q4_1_O. +# // 0 if float32, 1 if float16, 2 if Q4_0, 3 if Q4_1, 4 if Q4_1_O, 5 if Q4_2, 6 if Q4_3. # int32 data_type; # // Compared to PyTorch's tensor.shape, dimension order is reversed here! # int32[dim_count] shape; diff --git a/rwkv/measure_pexplexity.py b/rwkv/measure_pexplexity.py index a2a0e2c..dac6aee 100644 --- a/rwkv/measure_pexplexity.py +++ b/rwkv/measure_pexplexity.py @@ -14,9 +14,10 @@ from typing import List def parse_args(): parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file') - parser.add_argument('model_path', help='Path to model checkpoint file') - parser.add_argument('text_path', help='Path to text file in UTF-8 encoding') - parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int, default=1024) + parser.add_argument('model_path', help='Path to model checkpoint file', type=str) + parser.add_argument('text_path', help='Path to text file in UTF-8 encoding', type=str) + parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int) + parser.add_argument('token_limit', help='How many tokens to process; set to -1 to process all text', nargs='?', type=int, default=-1) return parser.parse_args() args = parse_args() @@ -33,6 +34,15 @@ tokens: List[int] = tokenizer.encode(text).ids token_count: int = len(tokens) print(f'{token_count} tokens in the text') +token_limit: int = args.token_limit + +assert token_limit == -1 or token_limit > 0, 'Invalid token_limit' + +if token_limit != -1 and token_count > token_limit: + tokens = tokens[0:token_limit] + token_count = token_limit + print(f'Text was limited to {token_limit} tokens') + assert token_count - args.ignore_first_n_tokens > 1, 'Need at least 2 tokens for evaluation' # --- @@ -73,7 +83,7 @@ for i in range(run_count): loss_sum += losses loss_count += 1 - if i % 10 == 0: + if run_count <= 5 or i % (run_count // 10) == 0: avg_loss_so_far = loss_sum / loss_count duration: float = time.time() - start @@ -90,11 +100,9 @@ for i in range(run_count): else: print() -print() -print(f'Average latency: {int((time.time() - start) * 1000 / run_count)} ms per token') - print() print(f'Model: {os.path.basename(args.model_path)}, ' f'data: {os.path.basename(args.text_path)} with {token_count} tokens, ' f'skipped {args.ignore_first_n_tokens} tokens, ' - f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}') + f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}, ' + f'latency {int((time.time() - start) * 1000 / run_count)} ms per token') diff --git a/rwkv/quantize.py b/rwkv/quantize.py index 243dc92..68df859 100644 --- a/rwkv/quantize.py +++ b/rwkv/quantize.py @@ -1,4 +1,4 @@ -# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0, Q4_1 or Q4_1_O (recommended). +# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0, Q4_1, Q4_1_O, Q4_2, Q4_3. # Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1_o.bin 4 import argparse @@ -8,20 +8,17 @@ def parse_args(): parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1') parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file') parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten') - parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0), 3 (GGML_TYPE_Q4_1) or 4 (GGML_TYPE_Q4_1_O)', type=int, choices=[2, 3, 4], default=4) + parser.add_argument('data_type', help='Data type, ' + '2 (GGML_TYPE_Q4_0), ' + '3 (GGML_TYPE_Q4_1), ' + '4 (GGML_TYPE_Q4_1_O), ' + '5 (Q4_2), ' + '6 (Q4_3)', type=int, choices=[2, 3, 4, 5, 6], default=4) return parser.parse_args() def main() -> None: args = parse_args() - if args.data_type == 2 or args.data_type == 3: - print() - print('WARNING!') - print('You are using Q4_0 or Q4_1 quantization; it will heavily degrade RWKV quality.') - print('For best quality preservation, it is recommended to use Q4_1_O.') - print('More info at https://github.com/saharNooby/rwkv.cpp/issues/12') - print() - library = rwkv_cpp_shared_library.load_rwkv_shared_library() library.rwkv_quantize_model_file( diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index 49c2873..a508ae3 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -69,17 +69,21 @@ int main(int argc, const char ** argv) { ASSERT(elements_read == N_VOCAB, "Failed to read expected_logits.bin, read %zd elements", elements_read); fclose(file); - float expected_difference_sum[8] = { + float expected_difference_sum[12] = { 0.000000F, -0.005320F, -0.501214F, - -1.092427F, + -0.370606F, -0.268956F, + 0.676837F, + 0.237099F, -0.501073F, - -1.103214F, - -0.244590F + -0.372169F, + -0.244590F, + 0.674874F, + 0.243007F }; test_model("tiny-rwkv-660K-FP32.bin", expected_logits, expected_difference_sum[0]); @@ -88,18 +92,26 @@ int main(int argc, const char ** argv) { rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_0.bin", 2); rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_1.bin", 3); rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_1_O.bin", 4); + rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_2.bin", 5); + rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_3.bin", 6); test_model("tiny-rwkv-660K-FP32-Q4_0.bin", expected_logits, expected_difference_sum[2]); test_model("tiny-rwkv-660K-FP32-Q4_1.bin", expected_logits, expected_difference_sum[3]); test_model("tiny-rwkv-660K-FP32-Q4_1_O.bin", expected_logits, expected_difference_sum[4]); + test_model("tiny-rwkv-660K-FP32-Q4_2.bin", expected_logits, expected_difference_sum[5]); + test_model("tiny-rwkv-660K-FP32-Q4_3.bin", expected_logits, expected_difference_sum[6]); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_0.bin", 2); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_1.bin", 3); rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_1_O.bin", 4); + rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_2.bin", 5); + rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_3.bin", 6); - test_model("tiny-rwkv-660K-FP16-Q4_0.bin", expected_logits, expected_difference_sum[5]); - test_model("tiny-rwkv-660K-FP16-Q4_1.bin", expected_logits, expected_difference_sum[6]); - test_model("tiny-rwkv-660K-FP16-Q4_1_O.bin", expected_logits, expected_difference_sum[7]); + test_model("tiny-rwkv-660K-FP16-Q4_0.bin", expected_logits, expected_difference_sum[7]); + test_model("tiny-rwkv-660K-FP16-Q4_1.bin", expected_logits, expected_difference_sum[8]); + test_model("tiny-rwkv-660K-FP16-Q4_1_O.bin", expected_logits, expected_difference_sum[9]); + test_model("tiny-rwkv-660K-FP16-Q4_2.bin", expected_logits, expected_difference_sum[10]); + test_model("tiny-rwkv-660K-FP16-Q4_3.bin", expected_logits, expected_difference_sum[11]); free(expected_logits);