Sync ggml with upstream (#38)

* Sync ggml with upstream

* Remove file filters from Actions triggers

* Update ggml

* Add Q4_2 and Q4_3 support

* Improve output of perplexity measuring script

* Add tests for new formats

* Add token limit argument to perplexity measuring script

* Update README

* Update README

* Update ggml

* Use master branch of ggml
This commit is contained in:
Alex 2023-04-22 20:25:29 +05:00 committed by GitHub
parent ac663631e1
commit 3587ff9e58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 72 additions and 60 deletions

View File

@ -8,11 +8,9 @@ on:
description: 'Create new release'
required: true
type: boolean
push:
paths: ['.github/workflows/**', '**/CMakeLists.txt', '**/*.h', '**/*.c', '**/*.cpp']
push: {}
pull_request:
types: [opened, synchronize, edited, reopened, review_requested, ready_for_review]
paths: ['**/CMakeLists.txt', '**/*.h', '**/*.c', '**/*.cpp']
env:
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}

View File

@ -89,9 +89,13 @@ python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-
Formats available:
- `4`: `Q4_1_O`, OK quality, moderately fast (20% slower than `FP16`).
- `3`: `Q4_1`, worst quality, fast (comparable to `FP16`).
- `2`: `Q4_0`, poor quality, very fast.
- `6`: `Q4_3`, OK quality, fast.
- `5`: `Q4_2`, poor quality, fast.
- `4`: `Q4_1_O`, best quality, slow (20% slower than `FP16`).
- `3`: `Q4_1`, poor quality, very fast.
- `2`: `Q4_0`, worst quality, very fast.
If you use `rwkv.cpp` for anything serious (just having fun is serious enough!), please [test all available formats for perplexity and latency](rwkv%2Fmeasure_pexplexity.py) on a representative dataset, and decide which trade-off is best for you.
### 4. Run the model

2
ggml

@ -1 +1 @@
Subproject commit 03309047d2e65c05ffefbf64c6c4c943e6647c64
Subproject commit bfa8d5b5ab4ffbae4c5f97525c3890f38619056d

View File

@ -43,12 +43,14 @@ bool read_int32(FILE * file, int32_t * dest) {
return true;
}
static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[5] = {
static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[7] = {
GGML_TYPE_F32,
GGML_TYPE_F16,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1,
GGML_TYPE_Q4_1_O
GGML_TYPE_Q4_1_O,
GGML_TYPE_Q4_2,
GGML_TYPE_Q4_3
};
// --- Model definition and loading utilities ---
@ -204,15 +206,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
RWKV_ASSERT_NULL(model->n_layer > 0, "Non-positive n_layer %d", model->n_layer);
read_int32(file, &(model->data_type));
RWKV_ASSERT_NULL(
model->data_type == 0 ||
model->data_type == 1 ||
model->data_type == 2 ||
model->data_type == 3 ||
model->data_type == 4,
"Unsupported model data type %d",
model->data_type
);
RWKV_ASSERT_NULL(model->data_type >= 0 && model->data_type <= 6, "Unsupported model data type %d", model->data_type);
// Parameter tensors would take at least this amount in memory.
size_t file_size;
@ -262,15 +256,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
int32_t data_type;
read_int32(file, &data_type);
RWKV_ASSERT_NULL(
data_type == 0 ||
data_type == 1 ||
data_type == 2 ||
data_type == 3 ||
data_type == 4,
"Unsupported parameter data type %d",
data_type
);
RWKV_ASSERT_NULL(data_type >= 0 && data_type <= 6, "Unsupported parameter data type %d", data_type);
ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type];
@ -581,9 +567,6 @@ bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float
memcpy(logits_out, ctx->logits->data, ctx->logits->ne[0] * FP32_SIZE);
// Uncomment to measure used memory for adding the value into get_memory_required_mb.
//fprintf(stderr, "Used mem: %d MB\n", ggml_used_mem(ctx->ctx) / 1024 / 1024);
return true;
}
@ -597,7 +580,7 @@ void rwkv_free(struct rwkv_context * ctx) {
}
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) {
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4, "Unsupported quantization type %d", q_type);
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4 || q_type == 5 || q_type == 6, "Unsupported quantization type %d", q_type);
// Needed to initialize FP16 lookup table
{
@ -690,7 +673,9 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
"F16",
"Q4_0",
"Q4_1",
"Q4_1_O"
"Q4_1_O",
"Q4_2",
"Q4_3"
};
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]);
@ -761,6 +746,14 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
{
cur_size = ggml_quantize_q4_1_o(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q4_2:
{
cur_size = ggml_quantize_q4_2(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q4_3:
{
cur_size = ggml_quantize_q4_3(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
default:
{
fprintf(stderr, "unsupported quantization type %d\n", type);

2
rwkv.h
View File

@ -56,7 +56,7 @@ extern "C" {
// Returns false on any error. Error messages would be printed to stderr.
// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
// - model_file_path_out: quantized model will be written here.
// - q_type: set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1, set to 4 for GGML_TYPE_Q4_1_O.
// - q_type: set to 2 for GGML_TYPE_Q4_0, 3 for GGML_TYPE_Q4_1, 4 for GGML_TYPE_Q4_1_O, 5 for GGML_TYPE_Q4_2, 6 for GGML_TYPE_Q4_3.
RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type);
// Returns system information string.

View File

@ -12,7 +12,7 @@
# int32 n_vocab;
# int32 n_embed;
# int32 n_layer;
# // 0 if float32, 1 if float16, 2 if Q4_0, 3 if Q4_1, 4 if Q4_1_O.
# // 0 if float32, 1 if float16, 2 if Q4_0, 3 if Q4_1, 4 if Q4_1_O, 5 if Q4_2, 6 if Q4_3.
# int32 data_type;
# // Read until EOF.
# Parameter[] parameters;
@ -21,7 +21,7 @@
# Parameter {
# int32 dim_count;
# int32 key_length;
# // 0 if float32, 1 if float16, 2 if Q4_0, 3 if Q4_1, 4 if Q4_1_O.
# // 0 if float32, 1 if float16, 2 if Q4_0, 3 if Q4_1, 4 if Q4_1_O, 5 if Q4_2, 6 if Q4_3.
# int32 data_type;
# // Compared to PyTorch's tensor.shape, dimension order is reversed here!
# int32[dim_count] shape;

View File

@ -14,9 +14,10 @@ from typing import List
def parse_args():
parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file')
parser.add_argument('model_path', help='Path to model checkpoint file')
parser.add_argument('text_path', help='Path to text file in UTF-8 encoding')
parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int, default=1024)
parser.add_argument('model_path', help='Path to model checkpoint file', type=str)
parser.add_argument('text_path', help='Path to text file in UTF-8 encoding', type=str)
parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int)
parser.add_argument('token_limit', help='How many tokens to process; set to -1 to process all text', nargs='?', type=int, default=-1)
return parser.parse_args()
args = parse_args()
@ -33,6 +34,15 @@ tokens: List[int] = tokenizer.encode(text).ids
token_count: int = len(tokens)
print(f'{token_count} tokens in the text')
token_limit: int = args.token_limit
assert token_limit == -1 or token_limit > 0, 'Invalid token_limit'
if token_limit != -1 and token_count > token_limit:
tokens = tokens[0:token_limit]
token_count = token_limit
print(f'Text was limited to {token_limit} tokens')
assert token_count - args.ignore_first_n_tokens > 1, 'Need at least 2 tokens for evaluation'
# ---
@ -73,7 +83,7 @@ for i in range(run_count):
loss_sum += losses
loss_count += 1
if i % 10 == 0:
if run_count <= 5 or i % (run_count // 10) == 0:
avg_loss_so_far = loss_sum / loss_count
duration: float = time.time() - start
@ -90,11 +100,9 @@ for i in range(run_count):
else:
print()
print()
print(f'Average latency: {int((time.time() - start) * 1000 / run_count)} ms per token')
print()
print(f'Model: {os.path.basename(args.model_path)}, '
f'data: {os.path.basename(args.text_path)} with {token_count} tokens, '
f'skipped {args.ignore_first_n_tokens} tokens, '
f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}')
f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}, '
f'latency {int((time.time() - start) * 1000 / run_count)} ms per token')

View File

@ -1,4 +1,4 @@
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0, Q4_1 or Q4_1_O (recommended).
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0, Q4_1, Q4_1_O, Q4_2, Q4_3.
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1_o.bin 4
import argparse
@ -8,20 +8,17 @@ def parse_args():
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1')
parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0), 3 (GGML_TYPE_Q4_1) or 4 (GGML_TYPE_Q4_1_O)', type=int, choices=[2, 3, 4], default=4)
parser.add_argument('data_type', help='Data type, '
'2 (GGML_TYPE_Q4_0), '
'3 (GGML_TYPE_Q4_1), '
'4 (GGML_TYPE_Q4_1_O), '
'5 (Q4_2), '
'6 (Q4_3)', type=int, choices=[2, 3, 4, 5, 6], default=4)
return parser.parse_args()
def main() -> None:
args = parse_args()
if args.data_type == 2 or args.data_type == 3:
print()
print('WARNING!')
print('You are using Q4_0 or Q4_1 quantization; it will heavily degrade RWKV quality.')
print('For best quality preservation, it is recommended to use Q4_1_O.')
print('More info at https://github.com/saharNooby/rwkv.cpp/issues/12')
print()
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
library.rwkv_quantize_model_file(

View File

@ -69,17 +69,21 @@ int main(int argc, const char ** argv) {
ASSERT(elements_read == N_VOCAB, "Failed to read expected_logits.bin, read %zd elements", elements_read);
fclose(file);
float expected_difference_sum[8] = {
float expected_difference_sum[12] = {
0.000000F,
-0.005320F,
-0.501214F,
-1.092427F,
-0.370606F,
-0.268956F,
0.676837F,
0.237099F,
-0.501073F,
-1.103214F,
-0.244590F
-0.372169F,
-0.244590F,
0.674874F,
0.243007F
};
test_model("tiny-rwkv-660K-FP32.bin", expected_logits, expected_difference_sum[0]);
@ -88,18 +92,26 @@ int main(int argc, const char ** argv) {
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_0.bin", 2);
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_1.bin", 3);
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_1_O.bin", 4);
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_2.bin", 5);
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_3.bin", 6);
test_model("tiny-rwkv-660K-FP32-Q4_0.bin", expected_logits, expected_difference_sum[2]);
test_model("tiny-rwkv-660K-FP32-Q4_1.bin", expected_logits, expected_difference_sum[3]);
test_model("tiny-rwkv-660K-FP32-Q4_1_O.bin", expected_logits, expected_difference_sum[4]);
test_model("tiny-rwkv-660K-FP32-Q4_2.bin", expected_logits, expected_difference_sum[5]);
test_model("tiny-rwkv-660K-FP32-Q4_3.bin", expected_logits, expected_difference_sum[6]);
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_0.bin", 2);
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_1.bin", 3);
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_1_O.bin", 4);
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_2.bin", 5);
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_3.bin", 6);
test_model("tiny-rwkv-660K-FP16-Q4_0.bin", expected_logits, expected_difference_sum[5]);
test_model("tiny-rwkv-660K-FP16-Q4_1.bin", expected_logits, expected_difference_sum[6]);
test_model("tiny-rwkv-660K-FP16-Q4_1_O.bin", expected_logits, expected_difference_sum[7]);
test_model("tiny-rwkv-660K-FP16-Q4_0.bin", expected_logits, expected_difference_sum[7]);
test_model("tiny-rwkv-660K-FP16-Q4_1.bin", expected_logits, expected_difference_sum[8]);
test_model("tiny-rwkv-660K-FP16-Q4_1_O.bin", expected_logits, expected_difference_sum[9]);
test_model("tiny-rwkv-660K-FP16-Q4_2.bin", expected_logits, expected_difference_sum[10]);
test_model("tiny-rwkv-660K-FP16-Q4_3.bin", expected_logits, expected_difference_sum[11]);
free(expected_logits);