Merge pull request #16 from saharNooby/outliers-preserving-quantization-PR
Add Q4_1_O quantization format that preserves outliers in weights and does dot in FP32
This commit is contained in:
commit
84e0698f2b
27
README.md
27
README.md
|
@ -10,9 +10,10 @@ This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrap
|
||||||
|
|
||||||
**TODO (contributions welcome!)**:
|
**TODO (contributions welcome!)**:
|
||||||
|
|
||||||
1. Measure latency and perplexity of different model sizes (169M to 14B) and data types (FP32, FP16, Q4_0, Q4_1)
|
1. Optimize AVX2 implementation of `Q4_1_O` matmul — currently, it is as slow as `FP32`
|
||||||
2. Test on Linux (including Colab) and MacOS
|
2. Measure latency and perplexity of different model sizes (169M to 14B) and data types (`FP32`, `FP16`, `Q4_0`, `Q4_1`, `Q4_1_O`)
|
||||||
3. Make required memory calculation more robust (see #4)
|
3. Test on Linux (including Colab) and MacOS
|
||||||
|
4. Make required memory calculation more robust (see [#4](https://github.com/saharNooby/rwkv.cpp/issues/4))
|
||||||
|
|
||||||
## How to use
|
## How to use
|
||||||
|
|
||||||
|
@ -68,7 +69,7 @@ If everything went OK, `librwkv.so` (Linux) or `rwkv.o` (MacOS) file should appe
|
||||||
|
|
||||||
```commandline
|
```commandline
|
||||||
# Windows
|
# Windows
|
||||||
python rwkv\convert_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float16
|
python rwkv\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float16
|
||||||
|
|
||||||
# Linux / MacOS
|
# Linux / MacOS
|
||||||
python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin float16
|
python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin float16
|
||||||
|
@ -80,13 +81,17 @@ To convert the model into INT4 quantized format, run:
|
||||||
|
|
||||||
```commandline
|
```commandline
|
||||||
# Windows
|
# Windows
|
||||||
python rwkv\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q4_1.bin 3
|
python rwkv\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q4_1_O.bin 4
|
||||||
|
|
||||||
# Linux / MacOS
|
# Linux / MacOS
|
||||||
python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q4_1.bin 3
|
python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q4_1_O.bin 4
|
||||||
```
|
```
|
||||||
|
|
||||||
Pass `2` for `Q4_0` format (smaller size, lower quality), `3` for `Q4_1` format (larger size, higher quality).
|
Formats available:
|
||||||
|
|
||||||
|
- `4`: `Q4_1_O`, best quality, very slow (as `FP32`).
|
||||||
|
- `3`: `Q4_1`, poor quality, very fast (as `FP16`).
|
||||||
|
- `2`: `Q4_0`, worst quality, breaks larger models, moderately fast (between `FP16` and `FP32`).
|
||||||
|
|
||||||
### 4. Run the model
|
### 4. Run the model
|
||||||
|
|
||||||
|
@ -98,20 +103,20 @@ To generate some text, run:
|
||||||
|
|
||||||
```commandline
|
```commandline
|
||||||
# Windows
|
# Windows
|
||||||
python rwkv\generate_completions.py C:\rwkv.cpp-169M-Q4_1.bin
|
python rwkv\generate_completions.py C:\rwkv.cpp-169M-Q4_1_O.bin
|
||||||
|
|
||||||
# Linux / MacOS
|
# Linux / MacOS
|
||||||
python rwkv/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q4_1.bin
|
python rwkv/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q4_1_O.bin
|
||||||
```
|
```
|
||||||
|
|
||||||
To chat with a bot, run:
|
To chat with a bot, run:
|
||||||
|
|
||||||
```commandline
|
```commandline
|
||||||
# Windows
|
# Windows
|
||||||
python rwkv\chat_with_bot.py C:\rwkv.cpp-169M-Q4_1.bin
|
python rwkv\chat_with_bot.py C:\rwkv.cpp-169M-Q4_1_O.bin
|
||||||
|
|
||||||
# Linux / MacOS
|
# Linux / MacOS
|
||||||
python rwkv/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q4_1.bin
|
python rwkv/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q4_1_O.bin
|
||||||
```
|
```
|
||||||
|
|
||||||
Edit [generate_completions.py](rwkv%2Fgenerate_completions.py) or [chat_with_bot.py](rwkv%2Fchat_with_bot.py) to change prompts and sampling settings.
|
Edit [generate_completions.py](rwkv%2Fgenerate_completions.py) or [chat_with_bot.py](rwkv%2Fchat_with_bot.py) to change prompts and sampling settings.
|
||||||
|
|
8
ggml.h
8
ggml.h
|
@ -186,7 +186,8 @@
|
||||||
// - to `ggml_compute_forward` and call the forward dispatch function here.
|
// - to `ggml_compute_forward` and call the forward dispatch function here.
|
||||||
// - to `ggml_compute_backward` and add `GGML_ASSERT(false)` here.
|
// - to `ggml_compute_backward` and add `GGML_ASSERT(false)` here.
|
||||||
// - to `ggml_graph_compute` and add `node->n_tasks = 1` here.
|
// - to `ggml_graph_compute` and add `node->n_tasks = 1` here.
|
||||||
// 6. Fix all assertions that check value of `GGML_OP_COUNT`: you've added 1 operator, so increment asserted value by one.
|
// 6. Add operator label to `GGML_OP_LABEL` array and operator symbol to `GGML_OP_SYMBOL` array.
|
||||||
|
// 7. Fix all assertions that check value of `GGML_OP_COUNT`: you've added 1 operator, so increment asserted value by one.
|
||||||
//
|
//
|
||||||
// When in doubt, consult the code of existing operators similar to that you're implementing.
|
// When in doubt, consult the code of existing operators similar to that you're implementing.
|
||||||
// Resulting operator would work for the forward pass, but will lack backward implementation and multi-threading support.
|
// Resulting operator would work for the forward pass, but will lack backward implementation and multi-threading support.
|
||||||
|
@ -225,7 +226,11 @@ struct ggml_context;
|
||||||
|
|
||||||
enum ggml_type {
|
enum ggml_type {
|
||||||
GGML_TYPE_Q4_0,
|
GGML_TYPE_Q4_0,
|
||||||
|
// Stores min and delta per block, does quantized matmul.
|
||||||
GGML_TYPE_Q4_1,
|
GGML_TYPE_Q4_1,
|
||||||
|
// Same as Q4_1, but stores outliers separately, and matmul is done in FP32.
|
||||||
|
// An outlier is the single absmax element in the quantized block.
|
||||||
|
GGML_TYPE_Q4_1_O,
|
||||||
GGML_TYPE_I8,
|
GGML_TYPE_I8,
|
||||||
GGML_TYPE_I16,
|
GGML_TYPE_I16,
|
||||||
GGML_TYPE_I32,
|
GGML_TYPE_I32,
|
||||||
|
@ -806,6 +811,7 @@ enum ggml_opt_result ggml_opt(
|
||||||
|
|
||||||
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
|
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
|
||||||
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
|
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
|
||||||
|
size_t ggml_quantize_q4_1_o(const float * src, void * dst, int n, int k, int64_t * hist);
|
||||||
|
|
||||||
//
|
//
|
||||||
// system info
|
// system info
|
||||||
|
|
75
rwkv.cpp
75
rwkv.cpp
|
@ -43,6 +43,14 @@ bool read_int32(FILE * file, int32_t * dest) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[5] = {
|
||||||
|
GGML_TYPE_F32,
|
||||||
|
GGML_TYPE_F16,
|
||||||
|
GGML_TYPE_Q4_0,
|
||||||
|
GGML_TYPE_Q4_1,
|
||||||
|
GGML_TYPE_Q4_1_O
|
||||||
|
};
|
||||||
|
|
||||||
// --- Model definition and loading utilities ---
|
// --- Model definition and loading utilities ---
|
||||||
|
|
||||||
struct rwkv_layer {
|
struct rwkv_layer {
|
||||||
|
@ -160,7 +168,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
|
||||||
model->data_type == 0 ||
|
model->data_type == 0 ||
|
||||||
model->data_type == 1 ||
|
model->data_type == 1 ||
|
||||||
model->data_type == 2 ||
|
model->data_type == 2 ||
|
||||||
model->data_type == 3,
|
model->data_type == 3 ||
|
||||||
|
model->data_type == 4,
|
||||||
"Unsupported model data type %d",
|
"Unsupported model data type %d",
|
||||||
model->data_type
|
model->data_type
|
||||||
);
|
);
|
||||||
|
@ -216,20 +225,13 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
|
||||||
data_type == 0 ||
|
data_type == 0 ||
|
||||||
data_type == 1 ||
|
data_type == 1 ||
|
||||||
data_type == 2 ||
|
data_type == 2 ||
|
||||||
data_type == 3,
|
data_type == 3 ||
|
||||||
|
data_type == 4,
|
||||||
"Unsupported parameter data type %d",
|
"Unsupported parameter data type %d",
|
||||||
data_type
|
data_type
|
||||||
);
|
);
|
||||||
|
|
||||||
ggml_type ggml_data_type;
|
ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type];
|
||||||
|
|
||||||
switch (data_type) {
|
|
||||||
case 0: ggml_data_type = GGML_TYPE_F32; break;
|
|
||||||
case 1: ggml_data_type = GGML_TYPE_F16; break;
|
|
||||||
case 2: ggml_data_type = GGML_TYPE_Q4_0; break;
|
|
||||||
case 3: ggml_data_type = GGML_TYPE_Q4_1; break;
|
|
||||||
default: return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor * tensor;
|
struct ggml_tensor * tensor;
|
||||||
|
|
||||||
|
@ -553,17 +555,9 @@ void rwkv_free(struct rwkv_context * ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) {
|
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) {
|
||||||
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type);
|
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4, "Unsupported quantization type %d", q_type);
|
||||||
|
|
||||||
ggml_type type;
|
ggml_type type = FORMAT_TYPE_TO_GGML_TYPE[q_type];
|
||||||
|
|
||||||
switch (q_type) {
|
|
||||||
case 2: type = GGML_TYPE_Q4_0; break;
|
|
||||||
case 3: type = GGML_TYPE_Q4_1; break;
|
|
||||||
default: return false;
|
|
||||||
};
|
|
||||||
|
|
||||||
RWKV_ASSERT_FALSE(type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1, "Unsupported data type %d", type);
|
|
||||||
|
|
||||||
printf("Loading model from '%s'\n", model_file_path_in);
|
printf("Loading model from '%s'\n", model_file_path_in);
|
||||||
|
|
||||||
|
@ -643,22 +637,30 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
|
||||||
|
|
||||||
{
|
{
|
||||||
static const char * parameter_data_type_str[] = {
|
static const char * parameter_data_type_str[] = {
|
||||||
"f32",
|
"F32",
|
||||||
"f16",
|
"F16",
|
||||||
"q4_0",
|
"Q4_0",
|
||||||
"q4_1"
|
"Q4_1",
|
||||||
|
"Q4_1_O"
|
||||||
};
|
};
|
||||||
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]);
|
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]);
|
||||||
|
|
||||||
|
total_size_orig += (size_t) (nelements * ggml_type_sizef(FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type]));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Quantize only 2D tensors
|
// Quantize only 2D tensors, except embedding and head matrices.
|
||||||
bool quantize = n_dims == 2;
|
// Embedding and head take not too much space, especially in bigger models;
|
||||||
|
// but they significantly increase perplexity when quantized.
|
||||||
|
bool quantize = n_dims == 2 &&
|
||||||
|
name != std::string("emb.weight") &&
|
||||||
|
name != std::string("head.weight");
|
||||||
|
|
||||||
if (quantize) {
|
if (quantize) {
|
||||||
if (parameter_data_type != 0 && parameter_data_type != 1) {
|
RWKV_ASSERT_FALSE(
|
||||||
fprintf(stderr, "unsupported data type %d for integer quantization\n", parameter_data_type);
|
parameter_data_type == 0 || parameter_data_type == 1,
|
||||||
return false;
|
"Unsupported parameter data type %d, only FP32 and FP16 can be quantized",
|
||||||
}
|
parameter_data_type
|
||||||
|
);
|
||||||
|
|
||||||
if (parameter_data_type == 1) {
|
if (parameter_data_type == 1) {
|
||||||
data_f16.resize(nelements);
|
data_f16.resize(nelements);
|
||||||
|
@ -706,6 +708,10 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
|
||||||
{
|
{
|
||||||
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_Q4_1_O:
|
||||||
|
{
|
||||||
|
cur_size = ggml_quantize_q4_1_o(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
fprintf(stderr, "unsupported quantization type %d\n", type);
|
fprintf(stderr, "unsupported quantization type %d\n", type);
|
||||||
|
@ -732,12 +738,11 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
|
||||||
fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
|
fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
|
||||||
total_size_new += data_u8.size();
|
total_size_new += data_u8.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
total_size_orig += nelements * sizeof(float);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
printf("model size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0);
|
printf("original size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0);
|
||||||
printf("quant size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0);
|
printf("quantized size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0);
|
||||||
|
printf("compression ratio = %8.2f%\n", 1.0 * total_size_orig / total_size_new);
|
||||||
|
|
||||||
{
|
{
|
||||||
int64_t sum_all = 0;
|
int64_t sum_all = 0;
|
||||||
|
|
|
@ -1,102 +0,0 @@
|
||||||
# Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV.
|
|
||||||
# Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch.
|
|
||||||
# Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py
|
|
||||||
# Usage: python compare_with_reference_implementation.py C:\rwkv.cpp-169M.bin
|
|
||||||
|
|
||||||
import os
|
|
||||||
import struct
|
|
||||||
import argparse
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import rwkv_cpp_model
|
|
||||||
import rwkv_cpp_shared_library
|
|
||||||
from typing import List, Tuple, Any
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
|
|
||||||
parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file')
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
# Don't want to depend on tokenizer here.
|
|
||||||
tokens: List[int] = [4, 3631, 4420, 2412, 953, 432, 391, 30567, 87, 15, 14161, 7092, 273, 416, 27767, 55, 342,
|
|
||||||
2412, 953, 432, 3806, 7092, 273, 416, 27767, 55, 15, 187, 4, 19039, 2412, 953, 497, 4561,
|
|
||||||
342, 416, 27767, 55, 14, 21, 14, 49, 587, 14, 17809, 46, 14, 938, 14256, 28950, 14, 1438,
|
|
||||||
1508, 15, 81, 394, 1566, 275, 8462, 22097, 348, 15, 187, 4, 43825, 27, 15548, 7277, 64,
|
|
||||||
3113, 64, 14005, 64, 39595, 15, 4789, 10269, 61, 18992, 61, 7265, 64, 30217, 39297, 15,
|
|
||||||
20963, 330, 27, 190, 30567, 87, 15, 14161, 14, 17809, 46, 15, 4805]
|
|
||||||
|
|
||||||
threshold: float
|
|
||||||
|
|
||||||
with open(args.ggml_model_path, 'rb') as model_file:
|
|
||||||
header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4))
|
|
||||||
data_type: int = header[5]
|
|
||||||
|
|
||||||
assert data_type == 0 or\
|
|
||||||
data_type == 1 or\
|
|
||||||
data_type == 2 or\
|
|
||||||
data_type == 3, f'Unsupported model data type {data_type}'
|
|
||||||
|
|
||||||
if data_type == 0:
|
|
||||||
# FP32, high precision
|
|
||||||
threshold = 0.000005
|
|
||||||
elif data_type == 1:
|
|
||||||
# FP16, lower precision, so higher threshold
|
|
||||||
threshold = 0.0032
|
|
||||||
elif data_type == 2:
|
|
||||||
# INT4 quantized, even lower precision, so even higher threshold
|
|
||||||
# This threshold will let some bugs pass
|
|
||||||
threshold = 4.0
|
|
||||||
elif data_type == 3:
|
|
||||||
# This format stores more data, so error would be lower
|
|
||||||
threshold = 1.2
|
|
||||||
|
|
||||||
model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path)
|
|
||||||
|
|
||||||
def compare_logits(tokens_subset: List[int]) -> None:
|
|
||||||
token_count: int = len(tokens_subset)
|
|
||||||
|
|
||||||
logits, state = None, None
|
|
||||||
|
|
||||||
for i in range(token_count):
|
|
||||||
token: int = tokens_subset[i]
|
|
||||||
|
|
||||||
if token_count <= 10 or i % (token_count // 10) == 0:
|
|
||||||
print(f'{i + 1}/{token_count}')
|
|
||||||
|
|
||||||
logits, state = model.eval(token, state, state, logits)
|
|
||||||
|
|
||||||
actual_logits = logits
|
|
||||||
|
|
||||||
# ---
|
|
||||||
|
|
||||||
expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin'
|
|
||||||
|
|
||||||
if not os.path.isfile(expected_logits_path):
|
|
||||||
expected_logits_path = 'rwkv/' + expected_logits_path
|
|
||||||
|
|
||||||
with open(expected_logits_path, 'rb') as logits_file:
|
|
||||||
expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
|
||||||
|
|
||||||
# ---
|
|
||||||
|
|
||||||
difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item()
|
|
||||||
|
|
||||||
print(f'Reference logits: {expected_logits}')
|
|
||||||
print(f'Actual logits: {actual_logits}')
|
|
||||||
print('Difference per token: %.8f' % (difference,))
|
|
||||||
|
|
||||||
assert abs(difference) <= threshold, 'Difference is too big'
|
|
||||||
|
|
||||||
compare_logits(tokens)
|
|
||||||
|
|
||||||
print()
|
|
||||||
print('Test passes')
|
|
||||||
|
|
||||||
if model is not None:
|
|
||||||
model.free()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
Binary file not shown.
|
@ -0,0 +1,100 @@
|
||||||
|
# Measures perplexity and per-token latency of an RWKV model on a given text file.
|
||||||
|
# Perplexity is defined here as exp() of average cross-entropy loss.
|
||||||
|
# Usage: python measure_pexplexity.py C:\rwkv.cpp-169M.bin C:\text.txt 1024
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import pathlib
|
||||||
|
import argparse
|
||||||
|
import tokenizers
|
||||||
|
import torch
|
||||||
|
import rwkv_cpp_model
|
||||||
|
import rwkv_cpp_shared_library
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file')
|
||||||
|
parser.add_argument('model_path', help='Path to model checkpoint file')
|
||||||
|
parser.add_argument('text_path', help='Path to text file in UTF-8 encoding')
|
||||||
|
parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int, default=1024)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# ---
|
||||||
|
|
||||||
|
print('Loading 20B tokenizer')
|
||||||
|
tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
|
||||||
|
tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
|
||||||
|
|
||||||
|
print('Loading text')
|
||||||
|
text: str = open(args.text_path, encoding='utf-8').read()
|
||||||
|
tokens: List[int] = tokenizer.encode(text).ids
|
||||||
|
token_count: int = len(tokens)
|
||||||
|
print(f'{token_count} tokens in the text')
|
||||||
|
|
||||||
|
assert token_count - args.ignore_first_n_tokens > 1, 'Need at least 2 tokens for evaluation'
|
||||||
|
|
||||||
|
# ---
|
||||||
|
|
||||||
|
def format_loss(loss: torch.Tensor) -> str:
|
||||||
|
return str(['%.3f' % (loss[i].item(),) for i in range(len(loss))]).replace('\'', '')[1:-1]
|
||||||
|
|
||||||
|
def format_loss_with_perplexity(loss: torch.Tensor) -> str:
|
||||||
|
return f'loss [{format_loss(loss)}], perplexity {"%.3f" % (torch.exp(loss[0]).item(),)}'
|
||||||
|
|
||||||
|
# ---
|
||||||
|
|
||||||
|
model: rwkv_cpp_model.RWKVModel = rwkv_cpp_model.RWKVModel(
|
||||||
|
rwkv_cpp_shared_library.load_rwkv_shared_library(),
|
||||||
|
args.model_path
|
||||||
|
)
|
||||||
|
|
||||||
|
logits, state = None, None
|
||||||
|
|
||||||
|
loss_sum: torch.Tensor = torch.tensor([0.0])
|
||||||
|
loss_count: int = 0
|
||||||
|
|
||||||
|
start: float = time.time()
|
||||||
|
|
||||||
|
run_count: int = token_count - 1
|
||||||
|
|
||||||
|
for i in range(run_count):
|
||||||
|
token: int = tokens[i]
|
||||||
|
target: int = tokens[i + 1]
|
||||||
|
|
||||||
|
logits, state = model.eval(token, state, state, logits)
|
||||||
|
|
||||||
|
if args.ignore_first_n_tokens == 0 or i + 1 >= args.ignore_first_n_tokens:
|
||||||
|
losses = torch.tensor([
|
||||||
|
torch.nn.functional.cross_entropy(logits, torch.tensor(target, dtype=torch.long), reduction='none').item()
|
||||||
|
])
|
||||||
|
|
||||||
|
loss_sum += losses
|
||||||
|
loss_count += 1
|
||||||
|
|
||||||
|
if i % 10 == 0:
|
||||||
|
avg_loss_so_far = loss_sum / loss_count
|
||||||
|
|
||||||
|
duration: float = time.time() - start
|
||||||
|
duration_per_token: float = duration / (i + 1)
|
||||||
|
runs_remaining: int = run_count - i - 1
|
||||||
|
duration_remaining: int = int(runs_remaining * duration_per_token)
|
||||||
|
|
||||||
|
print(f'Token #{i}/{token_count}, '
|
||||||
|
f'{int(100.0 * i / token_count)}%, '
|
||||||
|
f'ETA {duration_remaining // 60} m {duration_remaining % 60} s', end='')
|
||||||
|
|
||||||
|
if loss_count > 0:
|
||||||
|
print(f', averages so far: {format_loss_with_perplexity(avg_loss_so_far)}')
|
||||||
|
else:
|
||||||
|
print()
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(f'Average latency: {int((time.time() - start) * 1000 / run_count)} ms per token')
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(f'Model: {os.path.basename(args.model_path)}, '
|
||||||
|
f'data: {os.path.basename(args.text_path)} with {token_count} tokens, '
|
||||||
|
f'skipped {args.ignore_first_n_tokens} tokens, '
|
||||||
|
f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}')
|
|
@ -1,5 +1,5 @@
|
||||||
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1.
|
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0, Q4_1 or Q4_1_O (recommended).
|
||||||
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3
|
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1_o.bin 4
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import rwkv_cpp_shared_library
|
import rwkv_cpp_shared_library
|
||||||
|
@ -8,12 +8,20 @@ def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1')
|
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1')
|
||||||
parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
|
parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
|
||||||
parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
|
parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
|
||||||
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3)
|
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0), 3 (GGML_TYPE_Q4_1) or 4 (GGML_TYPE_Q4_1_O)', type=int, choices=[2, 3, 4], default=4)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
|
if args.data_type == 2 or args.data_type == 3:
|
||||||
|
print()
|
||||||
|
print('WARNING!')
|
||||||
|
print('You are using Q4_0 or Q4_1 quantization; it will heavily degrade RWKV quality.')
|
||||||
|
print('For best quality preservation, it is recommended to use Q4_1_O.')
|
||||||
|
print('More info at https://github.com/saharNooby/rwkv.cpp/issues/12')
|
||||||
|
print()
|
||||||
|
|
||||||
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||||
|
|
||||||
library.rwkv_quantize_model_file(
|
library.rwkv_quantize_model_file(
|
||||||
|
|
|
@ -32,14 +32,14 @@ class RWKVModel:
|
||||||
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
||||||
assert thread_count > 0, 'Thread count must be positive'
|
assert thread_count > 0, 'Thread count must be positive'
|
||||||
|
|
||||||
self.library = shared_library
|
self._library = shared_library
|
||||||
|
|
||||||
self.ctx = self.library.rwkv_init_from_file(model_path, thread_count)
|
self._ctx = self._library.rwkv_init_from_file(model_path, thread_count)
|
||||||
|
|
||||||
self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx)
|
self._state_buffer_element_count = self._library.rwkv_get_state_buffer_element_count(self._ctx)
|
||||||
self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx)
|
self._logits_buffer_element_count = self._library.rwkv_get_logits_buffer_element_count(self._ctx)
|
||||||
|
|
||||||
self.valid = True
|
self._valid = True
|
||||||
|
|
||||||
def eval(
|
def eval(
|
||||||
self,
|
self,
|
||||||
|
@ -69,7 +69,7 @@ class RWKVModel:
|
||||||
Logits vector of shape (n_vocab); state for the next step.
|
Logits vector of shape (n_vocab); state for the next step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert self.valid, 'Model was freed'
|
assert self._valid, 'Model was freed'
|
||||||
|
|
||||||
def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None:
|
def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None:
|
||||||
assert buf.dtype == torch.float32, f'{name} is not of type float32'
|
assert buf.dtype == torch.float32, f'{name} is not of type float32'
|
||||||
|
@ -77,24 +77,24 @@ class RWKVModel:
|
||||||
assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})'
|
assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})'
|
||||||
|
|
||||||
if state_in is not None:
|
if state_in is not None:
|
||||||
validate_buffer(state_in, 'state_in', self.state_buffer_element_count)
|
validate_buffer(state_in, 'state_in', self._state_buffer_element_count)
|
||||||
|
|
||||||
state_in_ptr = state_in.storage().data_ptr()
|
state_in_ptr = state_in.storage().data_ptr()
|
||||||
else:
|
else:
|
||||||
state_in_ptr = 0
|
state_in_ptr = 0
|
||||||
|
|
||||||
if state_out is not None:
|
if state_out is not None:
|
||||||
validate_buffer(state_out, 'state_out', self.state_buffer_element_count)
|
validate_buffer(state_out, 'state_out', self._state_buffer_element_count)
|
||||||
else:
|
else:
|
||||||
state_out = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu')
|
state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||||
|
|
||||||
if logits_out is not None:
|
if logits_out is not None:
|
||||||
validate_buffer(logits_out, 'logits_out', self.logits_buffer_element_count)
|
validate_buffer(logits_out, 'logits_out', self._logits_buffer_element_count)
|
||||||
else:
|
else:
|
||||||
logits_out = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu')
|
logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||||
|
|
||||||
self.library.rwkv_eval(
|
self._library.rwkv_eval(
|
||||||
self.ctx,
|
self._ctx,
|
||||||
token,
|
token,
|
||||||
state_in_ptr,
|
state_in_ptr,
|
||||||
state_out.storage().data_ptr(),
|
state_out.storage().data_ptr(),
|
||||||
|
@ -110,8 +110,13 @@ class RWKVModel:
|
||||||
The object must not be used anymore after calling this method.
|
The object must not be used anymore after calling this method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert self.valid, 'Already freed'
|
assert self._valid, 'Already freed'
|
||||||
|
|
||||||
self.valid = False
|
self._valid = False
|
||||||
|
|
||||||
self.library.rwkv_free(self.ctx)
|
self._library.rwkv_free(self._ctx)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
# Free the context on GC in case user forgot to call free() explicitly.
|
||||||
|
if hasattr(self, '_valid') and self._valid:
|
||||||
|
self.free()
|
||||||
|
|
|
@ -192,13 +192,17 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
||||||
else:
|
else:
|
||||||
file_name = 'librwkv.so'
|
file_name = 'librwkv.so'
|
||||||
|
|
||||||
|
repo_root_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent.parent
|
||||||
|
|
||||||
paths = [
|
paths = [
|
||||||
# If we are in "rwkv" directory
|
# If we are in "rwkv" directory
|
||||||
f'../bin/Release/{file_name}',
|
f'../bin/Release/{file_name}',
|
||||||
# If we are in repo root directory
|
# If we are in repo root directory
|
||||||
f'bin/Release/{file_name}',
|
f'bin/Release/{file_name}',
|
||||||
|
# Search relative to this file
|
||||||
|
str(repo_root_dir / 'bin' / 'Release' / file_name),
|
||||||
# Fallback
|
# Fallback
|
||||||
pathlib.Path(os.path.abspath(__file__)).parent.parent / file_name
|
str(repo_root_dir / file_name)
|
||||||
]
|
]
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
|
|
Loading…
Reference in New Issue