Merge pull request #16 from saharNooby/outliers-preserving-quantization-PR

Add Q4_1_O quantization format that preserves outliers in weights and does dot in FP32
This commit is contained in:
Alex 2023-04-08 16:51:47 +05:00 committed by GitHub
commit 84e0698f2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 873 additions and 206 deletions

View File

@ -10,9 +10,10 @@ This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrap
**TODO (contributions welcome!)**: **TODO (contributions welcome!)**:
1. Measure latency and perplexity of different model sizes (169M to 14B) and data types (FP32, FP16, Q4_0, Q4_1) 1. Optimize AVX2 implementation of `Q4_1_O` matmul — currently, it is as slow as `FP32`
2. Test on Linux (including Colab) and MacOS 2. Measure latency and perplexity of different model sizes (169M to 14B) and data types (`FP32`, `FP16`, `Q4_0`, `Q4_1`, `Q4_1_O`)
3. Make required memory calculation more robust (see #4) 3. Test on Linux (including Colab) and MacOS
4. Make required memory calculation more robust (see [#4](https://github.com/saharNooby/rwkv.cpp/issues/4))
## How to use ## How to use
@ -68,7 +69,7 @@ If everything went OK, `librwkv.so` (Linux) or `rwkv.o` (MacOS) file should appe
```commandline ```commandline
# Windows # Windows
python rwkv\convert_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float16 python rwkv\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float16
# Linux / MacOS # Linux / MacOS
python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin float16 python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin float16
@ -80,13 +81,17 @@ To convert the model into INT4 quantized format, run:
```commandline ```commandline
# Windows # Windows
python rwkv\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q4_1.bin 3 python rwkv\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q4_1_O.bin 4
# Linux / MacOS # Linux / MacOS
python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q4_1.bin 3 python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q4_1_O.bin 4
``` ```
Pass `2` for `Q4_0` format (smaller size, lower quality), `3` for `Q4_1` format (larger size, higher quality). Formats available:
- `4`: `Q4_1_O`, best quality, very slow (as `FP32`).
- `3`: `Q4_1`, poor quality, very fast (as `FP16`).
- `2`: `Q4_0`, worst quality, breaks larger models, moderately fast (between `FP16` and `FP32`).
### 4. Run the model ### 4. Run the model
@ -98,20 +103,20 @@ To generate some text, run:
```commandline ```commandline
# Windows # Windows
python rwkv\generate_completions.py C:\rwkv.cpp-169M-Q4_1.bin python rwkv\generate_completions.py C:\rwkv.cpp-169M-Q4_1_O.bin
# Linux / MacOS # Linux / MacOS
python rwkv/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q4_1.bin python rwkv/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q4_1_O.bin
``` ```
To chat with a bot, run: To chat with a bot, run:
```commandline ```commandline
# Windows # Windows
python rwkv\chat_with_bot.py C:\rwkv.cpp-169M-Q4_1.bin python rwkv\chat_with_bot.py C:\rwkv.cpp-169M-Q4_1_O.bin
# Linux / MacOS # Linux / MacOS
python rwkv/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q4_1.bin python rwkv/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q4_1_O.bin
``` ```
Edit [generate_completions.py](rwkv%2Fgenerate_completions.py) or [chat_with_bot.py](rwkv%2Fchat_with_bot.py) to change prompts and sampling settings. Edit [generate_completions.py](rwkv%2Fgenerate_completions.py) or [chat_with_bot.py](rwkv%2Fchat_with_bot.py) to change prompts and sampling settings.

710
ggml.c

File diff suppressed because it is too large Load Diff

8
ggml.h
View File

@ -186,7 +186,8 @@
// - to `ggml_compute_forward` and call the forward dispatch function here. // - to `ggml_compute_forward` and call the forward dispatch function here.
// - to `ggml_compute_backward` and add `GGML_ASSERT(false)` here. // - to `ggml_compute_backward` and add `GGML_ASSERT(false)` here.
// - to `ggml_graph_compute` and add `node->n_tasks = 1` here. // - to `ggml_graph_compute` and add `node->n_tasks = 1` here.
// 6. Fix all assertions that check value of `GGML_OP_COUNT`: you've added 1 operator, so increment asserted value by one. // 6. Add operator label to `GGML_OP_LABEL` array and operator symbol to `GGML_OP_SYMBOL` array.
// 7. Fix all assertions that check value of `GGML_OP_COUNT`: you've added 1 operator, so increment asserted value by one.
// //
// When in doubt, consult the code of existing operators similar to that you're implementing. // When in doubt, consult the code of existing operators similar to that you're implementing.
// Resulting operator would work for the forward pass, but will lack backward implementation and multi-threading support. // Resulting operator would work for the forward pass, but will lack backward implementation and multi-threading support.
@ -225,7 +226,11 @@ struct ggml_context;
enum ggml_type { enum ggml_type {
GGML_TYPE_Q4_0, GGML_TYPE_Q4_0,
// Stores min and delta per block, does quantized matmul.
GGML_TYPE_Q4_1, GGML_TYPE_Q4_1,
// Same as Q4_1, but stores outliers separately, and matmul is done in FP32.
// An outlier is the single absmax element in the quantized block.
GGML_TYPE_Q4_1_O,
GGML_TYPE_I8, GGML_TYPE_I8,
GGML_TYPE_I16, GGML_TYPE_I16,
GGML_TYPE_I32, GGML_TYPE_I32,
@ -806,6 +811,7 @@ enum ggml_opt_result ggml_opt(
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
size_t ggml_quantize_q4_1_o(const float * src, void * dst, int n, int k, int64_t * hist);
// //
// system info // system info

View File

@ -43,6 +43,14 @@ bool read_int32(FILE * file, int32_t * dest) {
return true; return true;
} }
static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[5] = {
GGML_TYPE_F32,
GGML_TYPE_F16,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1,
GGML_TYPE_Q4_1_O
};
// --- Model definition and loading utilities --- // --- Model definition and loading utilities ---
struct rwkv_layer { struct rwkv_layer {
@ -160,7 +168,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
model->data_type == 0 || model->data_type == 0 ||
model->data_type == 1 || model->data_type == 1 ||
model->data_type == 2 || model->data_type == 2 ||
model->data_type == 3, model->data_type == 3 ||
model->data_type == 4,
"Unsupported model data type %d", "Unsupported model data type %d",
model->data_type model->data_type
); );
@ -216,20 +225,13 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
data_type == 0 || data_type == 0 ||
data_type == 1 || data_type == 1 ||
data_type == 2 || data_type == 2 ||
data_type == 3, data_type == 3 ||
data_type == 4,
"Unsupported parameter data type %d", "Unsupported parameter data type %d",
data_type data_type
); );
ggml_type ggml_data_type; ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type];
switch (data_type) {
case 0: ggml_data_type = GGML_TYPE_F32; break;
case 1: ggml_data_type = GGML_TYPE_F16; break;
case 2: ggml_data_type = GGML_TYPE_Q4_0; break;
case 3: ggml_data_type = GGML_TYPE_Q4_1; break;
default: return NULL;
}
struct ggml_tensor * tensor; struct ggml_tensor * tensor;
@ -553,17 +555,9 @@ void rwkv_free(struct rwkv_context * ctx) {
} }
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) { bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) {
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type); RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4, "Unsupported quantization type %d", q_type);
ggml_type type; ggml_type type = FORMAT_TYPE_TO_GGML_TYPE[q_type];
switch (q_type) {
case 2: type = GGML_TYPE_Q4_0; break;
case 3: type = GGML_TYPE_Q4_1; break;
default: return false;
};
RWKV_ASSERT_FALSE(type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1, "Unsupported data type %d", type);
printf("Loading model from '%s'\n", model_file_path_in); printf("Loading model from '%s'\n", model_file_path_in);
@ -643,22 +637,30 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
{ {
static const char * parameter_data_type_str[] = { static const char * parameter_data_type_str[] = {
"f32", "F32",
"f16", "F16",
"q4_0", "Q4_0",
"q4_1" "Q4_1",
"Q4_1_O"
}; };
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]); printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]);
total_size_orig += (size_t) (nelements * ggml_type_sizef(FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type]));
} }
// Quantize only 2D tensors // Quantize only 2D tensors, except embedding and head matrices.
bool quantize = n_dims == 2; // Embedding and head take not too much space, especially in bigger models;
// but they significantly increase perplexity when quantized.
bool quantize = n_dims == 2 &&
name != std::string("emb.weight") &&
name != std::string("head.weight");
if (quantize) { if (quantize) {
if (parameter_data_type != 0 && parameter_data_type != 1) { RWKV_ASSERT_FALSE(
fprintf(stderr, "unsupported data type %d for integer quantization\n", parameter_data_type); parameter_data_type == 0 || parameter_data_type == 1,
return false; "Unsupported parameter data type %d, only FP32 and FP16 can be quantized",
} parameter_data_type
);
if (parameter_data_type == 1) { if (parameter_data_type == 1) {
data_f16.resize(nelements); data_f16.resize(nelements);
@ -706,6 +708,10 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
{ {
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break; } break;
case GGML_TYPE_Q4_1_O:
{
cur_size = ggml_quantize_q4_1_o(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
default: default:
{ {
fprintf(stderr, "unsupported quantization type %d\n", type); fprintf(stderr, "unsupported quantization type %d\n", type);
@ -732,12 +738,11 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size()); fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
total_size_new += data_u8.size(); total_size_new += data_u8.size();
} }
total_size_orig += nelements * sizeof(float);
} }
printf("model size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0); printf("original size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0);
printf("quant size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0); printf("quantized size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0);
printf("compression ratio = %8.2f%\n", 1.0 * total_size_orig / total_size_new);
{ {
int64_t sum_all = 0; int64_t sum_all = 0;

View File

@ -1,102 +0,0 @@
# Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV.
# Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch.
# Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py
# Usage: python compare_with_reference_implementation.py C:\rwkv.cpp-169M.bin
import os
import struct
import argparse
import torch
import numpy as np
import rwkv_cpp_model
import rwkv_cpp_shared_library
from typing import List, Tuple, Any
def parse_args():
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file')
return parser.parse_args()
def main() -> None:
args = parse_args()
# Don't want to depend on tokenizer here.
tokens: List[int] = [4, 3631, 4420, 2412, 953, 432, 391, 30567, 87, 15, 14161, 7092, 273, 416, 27767, 55, 342,
2412, 953, 432, 3806, 7092, 273, 416, 27767, 55, 15, 187, 4, 19039, 2412, 953, 497, 4561,
342, 416, 27767, 55, 14, 21, 14, 49, 587, 14, 17809, 46, 14, 938, 14256, 28950, 14, 1438,
1508, 15, 81, 394, 1566, 275, 8462, 22097, 348, 15, 187, 4, 43825, 27, 15548, 7277, 64,
3113, 64, 14005, 64, 39595, 15, 4789, 10269, 61, 18992, 61, 7265, 64, 30217, 39297, 15,
20963, 330, 27, 190, 30567, 87, 15, 14161, 14, 17809, 46, 15, 4805]
threshold: float
with open(args.ggml_model_path, 'rb') as model_file:
header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4))
data_type: int = header[5]
assert data_type == 0 or\
data_type == 1 or\
data_type == 2 or\
data_type == 3, f'Unsupported model data type {data_type}'
if data_type == 0:
# FP32, high precision
threshold = 0.000005
elif data_type == 1:
# FP16, lower precision, so higher threshold
threshold = 0.0032
elif data_type == 2:
# INT4 quantized, even lower precision, so even higher threshold
# This threshold will let some bugs pass
threshold = 4.0
elif data_type == 3:
# This format stores more data, so error would be lower
threshold = 1.2
model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path)
def compare_logits(tokens_subset: List[int]) -> None:
token_count: int = len(tokens_subset)
logits, state = None, None
for i in range(token_count):
token: int = tokens_subset[i]
if token_count <= 10 or i % (token_count // 10) == 0:
print(f'{i + 1}/{token_count}')
logits, state = model.eval(token, state, state, logits)
actual_logits = logits
# ---
expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin'
if not os.path.isfile(expected_logits_path):
expected_logits_path = 'rwkv/' + expected_logits_path
with open(expected_logits_path, 'rb') as logits_file:
expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
# ---
difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item()
print(f'Reference logits: {expected_logits}')
print(f'Actual logits: {actual_logits}')
print('Difference per token: %.8f' % (difference,))
assert abs(difference) <= threshold, 'Difference is too big'
compare_logits(tokens)
print()
print('Test passes')
if model is not None:
model.free()
if __name__ == "__main__":
main()

100
rwkv/measure_pexplexity.py Normal file
View File

@ -0,0 +1,100 @@
# Measures perplexity and per-token latency of an RWKV model on a given text file.
# Perplexity is defined here as exp() of average cross-entropy loss.
# Usage: python measure_pexplexity.py C:\rwkv.cpp-169M.bin C:\text.txt 1024
import os
import time
import pathlib
import argparse
import tokenizers
import torch
import rwkv_cpp_model
import rwkv_cpp_shared_library
from typing import List
def parse_args():
parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file')
parser.add_argument('model_path', help='Path to model checkpoint file')
parser.add_argument('text_path', help='Path to text file in UTF-8 encoding')
parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int, default=1024)
return parser.parse_args()
args = parse_args()
# ---
print('Loading 20B tokenizer')
tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
print('Loading text')
text: str = open(args.text_path, encoding='utf-8').read()
tokens: List[int] = tokenizer.encode(text).ids
token_count: int = len(tokens)
print(f'{token_count} tokens in the text')
assert token_count - args.ignore_first_n_tokens > 1, 'Need at least 2 tokens for evaluation'
# ---
def format_loss(loss: torch.Tensor) -> str:
return str(['%.3f' % (loss[i].item(),) for i in range(len(loss))]).replace('\'', '')[1:-1]
def format_loss_with_perplexity(loss: torch.Tensor) -> str:
return f'loss [{format_loss(loss)}], perplexity {"%.3f" % (torch.exp(loss[0]).item(),)}'
# ---
model: rwkv_cpp_model.RWKVModel = rwkv_cpp_model.RWKVModel(
rwkv_cpp_shared_library.load_rwkv_shared_library(),
args.model_path
)
logits, state = None, None
loss_sum: torch.Tensor = torch.tensor([0.0])
loss_count: int = 0
start: float = time.time()
run_count: int = token_count - 1
for i in range(run_count):
token: int = tokens[i]
target: int = tokens[i + 1]
logits, state = model.eval(token, state, state, logits)
if args.ignore_first_n_tokens == 0 or i + 1 >= args.ignore_first_n_tokens:
losses = torch.tensor([
torch.nn.functional.cross_entropy(logits, torch.tensor(target, dtype=torch.long), reduction='none').item()
])
loss_sum += losses
loss_count += 1
if i % 10 == 0:
avg_loss_so_far = loss_sum / loss_count
duration: float = time.time() - start
duration_per_token: float = duration / (i + 1)
runs_remaining: int = run_count - i - 1
duration_remaining: int = int(runs_remaining * duration_per_token)
print(f'Token #{i}/{token_count}, '
f'{int(100.0 * i / token_count)}%, '
f'ETA {duration_remaining // 60} m {duration_remaining % 60} s', end='')
if loss_count > 0:
print(f', averages so far: {format_loss_with_perplexity(avg_loss_so_far)}')
else:
print()
print()
print(f'Average latency: {int((time.time() - start) * 1000 / run_count)} ms per token')
print()
print(f'Model: {os.path.basename(args.model_path)}, '
f'data: {os.path.basename(args.text_path)} with {token_count} tokens, '
f'skipped {args.ignore_first_n_tokens} tokens, '
f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}')

View File

@ -1,5 +1,5 @@
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1. # Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0, Q4_1 or Q4_1_O (recommended).
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3 # Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1_o.bin 4
import argparse import argparse
import rwkv_cpp_shared_library import rwkv_cpp_shared_library
@ -8,12 +8,20 @@ def parse_args():
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1') parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1')
parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file') parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten') parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3) parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0), 3 (GGML_TYPE_Q4_1) or 4 (GGML_TYPE_Q4_1_O)', type=int, choices=[2, 3, 4], default=4)
return parser.parse_args() return parser.parse_args()
def main() -> None: def main() -> None:
args = parse_args() args = parse_args()
if args.data_type == 2 or args.data_type == 3:
print()
print('WARNING!')
print('You are using Q4_0 or Q4_1 quantization; it will heavily degrade RWKV quality.')
print('For best quality preservation, it is recommended to use Q4_1_O.')
print('More info at https://github.com/saharNooby/rwkv.cpp/issues/12')
print()
library = rwkv_cpp_shared_library.load_rwkv_shared_library() library = rwkv_cpp_shared_library.load_rwkv_shared_library()
library.rwkv_quantize_model_file( library.rwkv_quantize_model_file(

View File

@ -32,14 +32,14 @@ class RWKVModel:
assert os.path.isfile(model_path), f'{model_path} is not a file' assert os.path.isfile(model_path), f'{model_path} is not a file'
assert thread_count > 0, 'Thread count must be positive' assert thread_count > 0, 'Thread count must be positive'
self.library = shared_library self._library = shared_library
self.ctx = self.library.rwkv_init_from_file(model_path, thread_count) self._ctx = self._library.rwkv_init_from_file(model_path, thread_count)
self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx) self._state_buffer_element_count = self._library.rwkv_get_state_buffer_element_count(self._ctx)
self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx) self._logits_buffer_element_count = self._library.rwkv_get_logits_buffer_element_count(self._ctx)
self.valid = True self._valid = True
def eval( def eval(
self, self,
@ -69,7 +69,7 @@ class RWKVModel:
Logits vector of shape (n_vocab); state for the next step. Logits vector of shape (n_vocab); state for the next step.
""" """
assert self.valid, 'Model was freed' assert self._valid, 'Model was freed'
def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None: def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None:
assert buf.dtype == torch.float32, f'{name} is not of type float32' assert buf.dtype == torch.float32, f'{name} is not of type float32'
@ -77,24 +77,24 @@ class RWKVModel:
assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})' assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})'
if state_in is not None: if state_in is not None:
validate_buffer(state_in, 'state_in', self.state_buffer_element_count) validate_buffer(state_in, 'state_in', self._state_buffer_element_count)
state_in_ptr = state_in.storage().data_ptr() state_in_ptr = state_in.storage().data_ptr()
else: else:
state_in_ptr = 0 state_in_ptr = 0
if state_out is not None: if state_out is not None:
validate_buffer(state_out, 'state_out', self.state_buffer_element_count) validate_buffer(state_out, 'state_out', self._state_buffer_element_count)
else: else:
state_out = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu') state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu')
if logits_out is not None: if logits_out is not None:
validate_buffer(logits_out, 'logits_out', self.logits_buffer_element_count) validate_buffer(logits_out, 'logits_out', self._logits_buffer_element_count)
else: else:
logits_out = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu') logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu')
self.library.rwkv_eval( self._library.rwkv_eval(
self.ctx, self._ctx,
token, token,
state_in_ptr, state_in_ptr,
state_out.storage().data_ptr(), state_out.storage().data_ptr(),
@ -110,8 +110,13 @@ class RWKVModel:
The object must not be used anymore after calling this method. The object must not be used anymore after calling this method.
""" """
assert self.valid, 'Already freed' assert self._valid, 'Already freed'
self.valid = False self._valid = False
self.library.rwkv_free(self.ctx) self._library.rwkv_free(self._ctx)
def __del__(self):
# Free the context on GC in case user forgot to call free() explicitly.
if hasattr(self, '_valid') and self._valid:
self.free()

View File

@ -192,13 +192,17 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary:
else: else:
file_name = 'librwkv.so' file_name = 'librwkv.so'
repo_root_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent.parent
paths = [ paths = [
# If we are in "rwkv" directory # If we are in "rwkv" directory
f'../bin/Release/{file_name}', f'../bin/Release/{file_name}',
# If we are in repo root directory # If we are in repo root directory
f'bin/Release/{file_name}', f'bin/Release/{file_name}',
# Search relative to this file
str(repo_root_dir / 'bin' / 'Release' / file_name),
# Fallback # Fallback
pathlib.Path(os.path.abspath(__file__)).parent.parent / file_name str(repo_root_dir / file_name)
] ]
for path in paths: for path in paths: