Merge pull request #16 from saharNooby/outliers-preserving-quantization-PR
Add Q4_1_O quantization format that preserves outliers in weights and does dot in FP32
This commit is contained in:
commit
84e0698f2b
27
README.md
27
README.md
|
@ -10,9 +10,10 @@ This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrap
|
|||
|
||||
**TODO (contributions welcome!)**:
|
||||
|
||||
1. Measure latency and perplexity of different model sizes (169M to 14B) and data types (FP32, FP16, Q4_0, Q4_1)
|
||||
2. Test on Linux (including Colab) and MacOS
|
||||
3. Make required memory calculation more robust (see #4)
|
||||
1. Optimize AVX2 implementation of `Q4_1_O` matmul — currently, it is as slow as `FP32`
|
||||
2. Measure latency and perplexity of different model sizes (169M to 14B) and data types (`FP32`, `FP16`, `Q4_0`, `Q4_1`, `Q4_1_O`)
|
||||
3. Test on Linux (including Colab) and MacOS
|
||||
4. Make required memory calculation more robust (see [#4](https://github.com/saharNooby/rwkv.cpp/issues/4))
|
||||
|
||||
## How to use
|
||||
|
||||
|
@ -68,7 +69,7 @@ If everything went OK, `librwkv.so` (Linux) or `rwkv.o` (MacOS) file should appe
|
|||
|
||||
```commandline
|
||||
# Windows
|
||||
python rwkv\convert_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float16
|
||||
python rwkv\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float16
|
||||
|
||||
# Linux / MacOS
|
||||
python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin float16
|
||||
|
@ -80,13 +81,17 @@ To convert the model into INT4 quantized format, run:
|
|||
|
||||
```commandline
|
||||
# Windows
|
||||
python rwkv\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q4_1.bin 3
|
||||
python rwkv\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q4_1_O.bin 4
|
||||
|
||||
# Linux / MacOS
|
||||
python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q4_1.bin 3
|
||||
python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q4_1_O.bin 4
|
||||
```
|
||||
|
||||
Pass `2` for `Q4_0` format (smaller size, lower quality), `3` for `Q4_1` format (larger size, higher quality).
|
||||
Formats available:
|
||||
|
||||
- `4`: `Q4_1_O`, best quality, very slow (as `FP32`).
|
||||
- `3`: `Q4_1`, poor quality, very fast (as `FP16`).
|
||||
- `2`: `Q4_0`, worst quality, breaks larger models, moderately fast (between `FP16` and `FP32`).
|
||||
|
||||
### 4. Run the model
|
||||
|
||||
|
@ -98,20 +103,20 @@ To generate some text, run:
|
|||
|
||||
```commandline
|
||||
# Windows
|
||||
python rwkv\generate_completions.py C:\rwkv.cpp-169M-Q4_1.bin
|
||||
python rwkv\generate_completions.py C:\rwkv.cpp-169M-Q4_1_O.bin
|
||||
|
||||
# Linux / MacOS
|
||||
python rwkv/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q4_1.bin
|
||||
python rwkv/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q4_1_O.bin
|
||||
```
|
||||
|
||||
To chat with a bot, run:
|
||||
|
||||
```commandline
|
||||
# Windows
|
||||
python rwkv\chat_with_bot.py C:\rwkv.cpp-169M-Q4_1.bin
|
||||
python rwkv\chat_with_bot.py C:\rwkv.cpp-169M-Q4_1_O.bin
|
||||
|
||||
# Linux / MacOS
|
||||
python rwkv/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q4_1.bin
|
||||
python rwkv/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q4_1_O.bin
|
||||
```
|
||||
|
||||
Edit [generate_completions.py](rwkv%2Fgenerate_completions.py) or [chat_with_bot.py](rwkv%2Fchat_with_bot.py) to change prompts and sampling settings.
|
||||
|
|
8
ggml.h
8
ggml.h
|
@ -186,7 +186,8 @@
|
|||
// - to `ggml_compute_forward` and call the forward dispatch function here.
|
||||
// - to `ggml_compute_backward` and add `GGML_ASSERT(false)` here.
|
||||
// - to `ggml_graph_compute` and add `node->n_tasks = 1` here.
|
||||
// 6. Fix all assertions that check value of `GGML_OP_COUNT`: you've added 1 operator, so increment asserted value by one.
|
||||
// 6. Add operator label to `GGML_OP_LABEL` array and operator symbol to `GGML_OP_SYMBOL` array.
|
||||
// 7. Fix all assertions that check value of `GGML_OP_COUNT`: you've added 1 operator, so increment asserted value by one.
|
||||
//
|
||||
// When in doubt, consult the code of existing operators similar to that you're implementing.
|
||||
// Resulting operator would work for the forward pass, but will lack backward implementation and multi-threading support.
|
||||
|
@ -225,7 +226,11 @@ struct ggml_context;
|
|||
|
||||
enum ggml_type {
|
||||
GGML_TYPE_Q4_0,
|
||||
// Stores min and delta per block, does quantized matmul.
|
||||
GGML_TYPE_Q4_1,
|
||||
// Same as Q4_1, but stores outliers separately, and matmul is done in FP32.
|
||||
// An outlier is the single absmax element in the quantized block.
|
||||
GGML_TYPE_Q4_1_O,
|
||||
GGML_TYPE_I8,
|
||||
GGML_TYPE_I16,
|
||||
GGML_TYPE_I32,
|
||||
|
@ -806,6 +811,7 @@ enum ggml_opt_result ggml_opt(
|
|||
|
||||
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
|
||||
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
|
||||
size_t ggml_quantize_q4_1_o(const float * src, void * dst, int n, int k, int64_t * hist);
|
||||
|
||||
//
|
||||
// system info
|
||||
|
|
75
rwkv.cpp
75
rwkv.cpp
|
@ -43,6 +43,14 @@ bool read_int32(FILE * file, int32_t * dest) {
|
|||
return true;
|
||||
}
|
||||
|
||||
static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[5] = {
|
||||
GGML_TYPE_F32,
|
||||
GGML_TYPE_F16,
|
||||
GGML_TYPE_Q4_0,
|
||||
GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_Q4_1_O
|
||||
};
|
||||
|
||||
// --- Model definition and loading utilities ---
|
||||
|
||||
struct rwkv_layer {
|
||||
|
@ -160,7 +168,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
|
|||
model->data_type == 0 ||
|
||||
model->data_type == 1 ||
|
||||
model->data_type == 2 ||
|
||||
model->data_type == 3,
|
||||
model->data_type == 3 ||
|
||||
model->data_type == 4,
|
||||
"Unsupported model data type %d",
|
||||
model->data_type
|
||||
);
|
||||
|
@ -216,20 +225,13 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
|
|||
data_type == 0 ||
|
||||
data_type == 1 ||
|
||||
data_type == 2 ||
|
||||
data_type == 3,
|
||||
data_type == 3 ||
|
||||
data_type == 4,
|
||||
"Unsupported parameter data type %d",
|
||||
data_type
|
||||
);
|
||||
|
||||
ggml_type ggml_data_type;
|
||||
|
||||
switch (data_type) {
|
||||
case 0: ggml_data_type = GGML_TYPE_F32; break;
|
||||
case 1: ggml_data_type = GGML_TYPE_F16; break;
|
||||
case 2: ggml_data_type = GGML_TYPE_Q4_0; break;
|
||||
case 3: ggml_data_type = GGML_TYPE_Q4_1; break;
|
||||
default: return NULL;
|
||||
}
|
||||
ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type];
|
||||
|
||||
struct ggml_tensor * tensor;
|
||||
|
||||
|
@ -553,17 +555,9 @@ void rwkv_free(struct rwkv_context * ctx) {
|
|||
}
|
||||
|
||||
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) {
|
||||
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type);
|
||||
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4, "Unsupported quantization type %d", q_type);
|
||||
|
||||
ggml_type type;
|
||||
|
||||
switch (q_type) {
|
||||
case 2: type = GGML_TYPE_Q4_0; break;
|
||||
case 3: type = GGML_TYPE_Q4_1; break;
|
||||
default: return false;
|
||||
};
|
||||
|
||||
RWKV_ASSERT_FALSE(type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1, "Unsupported data type %d", type);
|
||||
ggml_type type = FORMAT_TYPE_TO_GGML_TYPE[q_type];
|
||||
|
||||
printf("Loading model from '%s'\n", model_file_path_in);
|
||||
|
||||
|
@ -643,22 +637,30 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
|
|||
|
||||
{
|
||||
static const char * parameter_data_type_str[] = {
|
||||
"f32",
|
||||
"f16",
|
||||
"q4_0",
|
||||
"q4_1"
|
||||
"F32",
|
||||
"F16",
|
||||
"Q4_0",
|
||||
"Q4_1",
|
||||
"Q4_1_O"
|
||||
};
|
||||
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]);
|
||||
|
||||
total_size_orig += (size_t) (nelements * ggml_type_sizef(FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type]));
|
||||
}
|
||||
|
||||
// Quantize only 2D tensors
|
||||
bool quantize = n_dims == 2;
|
||||
// Quantize only 2D tensors, except embedding and head matrices.
|
||||
// Embedding and head take not too much space, especially in bigger models;
|
||||
// but they significantly increase perplexity when quantized.
|
||||
bool quantize = n_dims == 2 &&
|
||||
name != std::string("emb.weight") &&
|
||||
name != std::string("head.weight");
|
||||
|
||||
if (quantize) {
|
||||
if (parameter_data_type != 0 && parameter_data_type != 1) {
|
||||
fprintf(stderr, "unsupported data type %d for integer quantization\n", parameter_data_type);
|
||||
return false;
|
||||
}
|
||||
RWKV_ASSERT_FALSE(
|
||||
parameter_data_type == 0 || parameter_data_type == 1,
|
||||
"Unsupported parameter data type %d, only FP32 and FP16 can be quantized",
|
||||
parameter_data_type
|
||||
);
|
||||
|
||||
if (parameter_data_type == 1) {
|
||||
data_f16.resize(nelements);
|
||||
|
@ -706,6 +708,10 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
|
|||
{
|
||||
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
{
|
||||
cur_size = ggml_quantize_q4_1_o(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "unsupported quantization type %d\n", type);
|
||||
|
@ -732,12 +738,11 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
|
|||
fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
|
||||
total_size_new += data_u8.size();
|
||||
}
|
||||
|
||||
total_size_orig += nelements * sizeof(float);
|
||||
}
|
||||
|
||||
printf("model size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0);
|
||||
printf("quant size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0);
|
||||
printf("original size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0);
|
||||
printf("quantized size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0);
|
||||
printf("compression ratio = %8.2f%\n", 1.0 * total_size_orig / total_size_new);
|
||||
|
||||
{
|
||||
int64_t sum_all = 0;
|
||||
|
|
|
@ -1,102 +0,0 @@
|
|||
# Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV.
|
||||
# Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch.
|
||||
# Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py
|
||||
# Usage: python compare_with_reference_implementation.py C:\rwkv.cpp-169M.bin
|
||||
|
||||
import os
|
||||
import struct
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
import rwkv_cpp_model
|
||||
import rwkv_cpp_shared_library
|
||||
from typing import List, Tuple, Any
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
|
||||
parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file')
|
||||
return parser.parse_args()
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
# Don't want to depend on tokenizer here.
|
||||
tokens: List[int] = [4, 3631, 4420, 2412, 953, 432, 391, 30567, 87, 15, 14161, 7092, 273, 416, 27767, 55, 342,
|
||||
2412, 953, 432, 3806, 7092, 273, 416, 27767, 55, 15, 187, 4, 19039, 2412, 953, 497, 4561,
|
||||
342, 416, 27767, 55, 14, 21, 14, 49, 587, 14, 17809, 46, 14, 938, 14256, 28950, 14, 1438,
|
||||
1508, 15, 81, 394, 1566, 275, 8462, 22097, 348, 15, 187, 4, 43825, 27, 15548, 7277, 64,
|
||||
3113, 64, 14005, 64, 39595, 15, 4789, 10269, 61, 18992, 61, 7265, 64, 30217, 39297, 15,
|
||||
20963, 330, 27, 190, 30567, 87, 15, 14161, 14, 17809, 46, 15, 4805]
|
||||
|
||||
threshold: float
|
||||
|
||||
with open(args.ggml_model_path, 'rb') as model_file:
|
||||
header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4))
|
||||
data_type: int = header[5]
|
||||
|
||||
assert data_type == 0 or\
|
||||
data_type == 1 or\
|
||||
data_type == 2 or\
|
||||
data_type == 3, f'Unsupported model data type {data_type}'
|
||||
|
||||
if data_type == 0:
|
||||
# FP32, high precision
|
||||
threshold = 0.000005
|
||||
elif data_type == 1:
|
||||
# FP16, lower precision, so higher threshold
|
||||
threshold = 0.0032
|
||||
elif data_type == 2:
|
||||
# INT4 quantized, even lower precision, so even higher threshold
|
||||
# This threshold will let some bugs pass
|
||||
threshold = 4.0
|
||||
elif data_type == 3:
|
||||
# This format stores more data, so error would be lower
|
||||
threshold = 1.2
|
||||
|
||||
model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path)
|
||||
|
||||
def compare_logits(tokens_subset: List[int]) -> None:
|
||||
token_count: int = len(tokens_subset)
|
||||
|
||||
logits, state = None, None
|
||||
|
||||
for i in range(token_count):
|
||||
token: int = tokens_subset[i]
|
||||
|
||||
if token_count <= 10 or i % (token_count // 10) == 0:
|
||||
print(f'{i + 1}/{token_count}')
|
||||
|
||||
logits, state = model.eval(token, state, state, logits)
|
||||
|
||||
actual_logits = logits
|
||||
|
||||
# ---
|
||||
|
||||
expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin'
|
||||
|
||||
if not os.path.isfile(expected_logits_path):
|
||||
expected_logits_path = 'rwkv/' + expected_logits_path
|
||||
|
||||
with open(expected_logits_path, 'rb') as logits_file:
|
||||
expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
||||
|
||||
# ---
|
||||
|
||||
difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item()
|
||||
|
||||
print(f'Reference logits: {expected_logits}')
|
||||
print(f'Actual logits: {actual_logits}')
|
||||
print('Difference per token: %.8f' % (difference,))
|
||||
|
||||
assert abs(difference) <= threshold, 'Difference is too big'
|
||||
|
||||
compare_logits(tokens)
|
||||
|
||||
print()
|
||||
print('Test passes')
|
||||
|
||||
if model is not None:
|
||||
model.free()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Binary file not shown.
|
@ -0,0 +1,100 @@
|
|||
# Measures perplexity and per-token latency of an RWKV model on a given text file.
|
||||
# Perplexity is defined here as exp() of average cross-entropy loss.
|
||||
# Usage: python measure_pexplexity.py C:\rwkv.cpp-169M.bin C:\text.txt 1024
|
||||
|
||||
import os
|
||||
import time
|
||||
import pathlib
|
||||
import argparse
|
||||
import tokenizers
|
||||
import torch
|
||||
import rwkv_cpp_model
|
||||
import rwkv_cpp_shared_library
|
||||
from typing import List
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file')
|
||||
parser.add_argument('model_path', help='Path to model checkpoint file')
|
||||
parser.add_argument('text_path', help='Path to text file in UTF-8 encoding')
|
||||
parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int, default=1024)
|
||||
return parser.parse_args()
|
||||
|
||||
args = parse_args()
|
||||
|
||||
# ---
|
||||
|
||||
print('Loading 20B tokenizer')
|
||||
tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
|
||||
tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
|
||||
|
||||
print('Loading text')
|
||||
text: str = open(args.text_path, encoding='utf-8').read()
|
||||
tokens: List[int] = tokenizer.encode(text).ids
|
||||
token_count: int = len(tokens)
|
||||
print(f'{token_count} tokens in the text')
|
||||
|
||||
assert token_count - args.ignore_first_n_tokens > 1, 'Need at least 2 tokens for evaluation'
|
||||
|
||||
# ---
|
||||
|
||||
def format_loss(loss: torch.Tensor) -> str:
|
||||
return str(['%.3f' % (loss[i].item(),) for i in range(len(loss))]).replace('\'', '')[1:-1]
|
||||
|
||||
def format_loss_with_perplexity(loss: torch.Tensor) -> str:
|
||||
return f'loss [{format_loss(loss)}], perplexity {"%.3f" % (torch.exp(loss[0]).item(),)}'
|
||||
|
||||
# ---
|
||||
|
||||
model: rwkv_cpp_model.RWKVModel = rwkv_cpp_model.RWKVModel(
|
||||
rwkv_cpp_shared_library.load_rwkv_shared_library(),
|
||||
args.model_path
|
||||
)
|
||||
|
||||
logits, state = None, None
|
||||
|
||||
loss_sum: torch.Tensor = torch.tensor([0.0])
|
||||
loss_count: int = 0
|
||||
|
||||
start: float = time.time()
|
||||
|
||||
run_count: int = token_count - 1
|
||||
|
||||
for i in range(run_count):
|
||||
token: int = tokens[i]
|
||||
target: int = tokens[i + 1]
|
||||
|
||||
logits, state = model.eval(token, state, state, logits)
|
||||
|
||||
if args.ignore_first_n_tokens == 0 or i + 1 >= args.ignore_first_n_tokens:
|
||||
losses = torch.tensor([
|
||||
torch.nn.functional.cross_entropy(logits, torch.tensor(target, dtype=torch.long), reduction='none').item()
|
||||
])
|
||||
|
||||
loss_sum += losses
|
||||
loss_count += 1
|
||||
|
||||
if i % 10 == 0:
|
||||
avg_loss_so_far = loss_sum / loss_count
|
||||
|
||||
duration: float = time.time() - start
|
||||
duration_per_token: float = duration / (i + 1)
|
||||
runs_remaining: int = run_count - i - 1
|
||||
duration_remaining: int = int(runs_remaining * duration_per_token)
|
||||
|
||||
print(f'Token #{i}/{token_count}, '
|
||||
f'{int(100.0 * i / token_count)}%, '
|
||||
f'ETA {duration_remaining // 60} m {duration_remaining % 60} s', end='')
|
||||
|
||||
if loss_count > 0:
|
||||
print(f', averages so far: {format_loss_with_perplexity(avg_loss_so_far)}')
|
||||
else:
|
||||
print()
|
||||
|
||||
print()
|
||||
print(f'Average latency: {int((time.time() - start) * 1000 / run_count)} ms per token')
|
||||
|
||||
print()
|
||||
print(f'Model: {os.path.basename(args.model_path)}, '
|
||||
f'data: {os.path.basename(args.text_path)} with {token_count} tokens, '
|
||||
f'skipped {args.ignore_first_n_tokens} tokens, '
|
||||
f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}')
|
|
@ -1,5 +1,5 @@
|
|||
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1.
|
||||
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3
|
||||
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0, Q4_1 or Q4_1_O (recommended).
|
||||
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1_o.bin 4
|
||||
|
||||
import argparse
|
||||
import rwkv_cpp_shared_library
|
||||
|
@ -8,12 +8,20 @@ def parse_args():
|
|||
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1')
|
||||
parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
|
||||
parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
|
||||
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3)
|
||||
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0), 3 (GGML_TYPE_Q4_1) or 4 (GGML_TYPE_Q4_1_O)', type=int, choices=[2, 3, 4], default=4)
|
||||
return parser.parse_args()
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if args.data_type == 2 or args.data_type == 3:
|
||||
print()
|
||||
print('WARNING!')
|
||||
print('You are using Q4_0 or Q4_1 quantization; it will heavily degrade RWKV quality.')
|
||||
print('For best quality preservation, it is recommended to use Q4_1_O.')
|
||||
print('More info at https://github.com/saharNooby/rwkv.cpp/issues/12')
|
||||
print()
|
||||
|
||||
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||
|
||||
library.rwkv_quantize_model_file(
|
||||
|
|
|
@ -32,14 +32,14 @@ class RWKVModel:
|
|||
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
||||
assert thread_count > 0, 'Thread count must be positive'
|
||||
|
||||
self.library = shared_library
|
||||
self._library = shared_library
|
||||
|
||||
self.ctx = self.library.rwkv_init_from_file(model_path, thread_count)
|
||||
self._ctx = self._library.rwkv_init_from_file(model_path, thread_count)
|
||||
|
||||
self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx)
|
||||
self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx)
|
||||
self._state_buffer_element_count = self._library.rwkv_get_state_buffer_element_count(self._ctx)
|
||||
self._logits_buffer_element_count = self._library.rwkv_get_logits_buffer_element_count(self._ctx)
|
||||
|
||||
self.valid = True
|
||||
self._valid = True
|
||||
|
||||
def eval(
|
||||
self,
|
||||
|
@ -69,7 +69,7 @@ class RWKVModel:
|
|||
Logits vector of shape (n_vocab); state for the next step.
|
||||
"""
|
||||
|
||||
assert self.valid, 'Model was freed'
|
||||
assert self._valid, 'Model was freed'
|
||||
|
||||
def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None:
|
||||
assert buf.dtype == torch.float32, f'{name} is not of type float32'
|
||||
|
@ -77,24 +77,24 @@ class RWKVModel:
|
|||
assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})'
|
||||
|
||||
if state_in is not None:
|
||||
validate_buffer(state_in, 'state_in', self.state_buffer_element_count)
|
||||
validate_buffer(state_in, 'state_in', self._state_buffer_element_count)
|
||||
|
||||
state_in_ptr = state_in.storage().data_ptr()
|
||||
else:
|
||||
state_in_ptr = 0
|
||||
|
||||
if state_out is not None:
|
||||
validate_buffer(state_out, 'state_out', self.state_buffer_element_count)
|
||||
validate_buffer(state_out, 'state_out', self._state_buffer_element_count)
|
||||
else:
|
||||
state_out = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||
state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||
|
||||
if logits_out is not None:
|
||||
validate_buffer(logits_out, 'logits_out', self.logits_buffer_element_count)
|
||||
validate_buffer(logits_out, 'logits_out', self._logits_buffer_element_count)
|
||||
else:
|
||||
logits_out = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||
logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||
|
||||
self.library.rwkv_eval(
|
||||
self.ctx,
|
||||
self._library.rwkv_eval(
|
||||
self._ctx,
|
||||
token,
|
||||
state_in_ptr,
|
||||
state_out.storage().data_ptr(),
|
||||
|
@ -110,8 +110,13 @@ class RWKVModel:
|
|||
The object must not be used anymore after calling this method.
|
||||
"""
|
||||
|
||||
assert self.valid, 'Already freed'
|
||||
assert self._valid, 'Already freed'
|
||||
|
||||
self.valid = False
|
||||
self._valid = False
|
||||
|
||||
self.library.rwkv_free(self.ctx)
|
||||
self._library.rwkv_free(self._ctx)
|
||||
|
||||
def __del__(self):
|
||||
# Free the context on GC in case user forgot to call free() explicitly.
|
||||
if hasattr(self, '_valid') and self._valid:
|
||||
self.free()
|
||||
|
|
|
@ -192,13 +192,17 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
|||
else:
|
||||
file_name = 'librwkv.so'
|
||||
|
||||
repo_root_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent.parent
|
||||
|
||||
paths = [
|
||||
# If we are in "rwkv" directory
|
||||
f'../bin/Release/{file_name}',
|
||||
# If we are in repo root directory
|
||||
f'bin/Release/{file_name}',
|
||||
# Search relative to this file
|
||||
str(repo_root_dir / 'bin' / 'Release' / file_name),
|
||||
# Fallback
|
||||
pathlib.Path(os.path.abspath(__file__)).parent.parent / file_name
|
||||
str(repo_root_dir / file_name)
|
||||
]
|
||||
|
||||
for path in paths:
|
||||
|
|
Loading…
Reference in New Issue