Support FP16 inference
This commit is contained in:
parent
fe98c94a63
commit
f6d45baec0
14
README.md
14
README.md
|
@ -2,17 +2,15 @@
|
|||
|
||||
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). The end goal is to allow 4-bit quanized inference on CPU.
|
||||
|
||||
**WORK IN PROGRESS!** **Status**: FP32 inference works. For 64 tokens, logits from `rwkv.cpp` almost exactly match those from [reference implementation](https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py) (difference <= 0.00005 per token).
|
||||
**WORK IN PROGRESS!** **Status**: FP32 and FP16 inference work correctly. Currently, I'm working on creating usable C library interface and Python wrapper for it.
|
||||
|
||||
## Plan
|
||||
|
||||
1. Heavily refactor code; optimize where possible
|
||||
2. Make FP16 inference work
|
||||
3. Create proper interface (probably, C library)
|
||||
4. Create Python wrapper with sampling and simple chat interface
|
||||
5. Write a good `README.md` and publish links to this repo
|
||||
6. Make INT4 inference work
|
||||
7. Create pull request to main `ggml` repo with all improvements made here
|
||||
1. Create proper interface (probably, C library)
|
||||
2. Create Python wrapper with sampling and simple chat interface
|
||||
3. Write a good `README.md` and publish links to this repo
|
||||
4. Make INT4 inference work
|
||||
5. Create pull request to main `ggml` repo with all improvements made here
|
||||
|
||||
## Structure
|
||||
|
||||
|
|
|
@ -210,8 +210,6 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
|||
read_int32(file, &(model->data_type));
|
||||
RWKV_ASSERT(model->data_type == 0 || model->data_type == 1, "Unsupported model data type %d", model->data_type);
|
||||
|
||||
RWKV_ASSERT(model->data_type == 0, "Data types other than float32 are not yet supported"); // TODO
|
||||
|
||||
RWKV_LOG("n_vocab = %d", model->n_vocab);
|
||||
RWKV_LOG("n_embed = %d", model->n_embed);
|
||||
RWKV_LOG("n_layer = %d", model->n_layer);
|
||||
|
@ -236,7 +234,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
|||
read_int32(file, &data_type);
|
||||
RWKV_ASSERT(data_type == 0 || data_type == 1, "Unsupported parameter data type %d", data_type);
|
||||
|
||||
RWKV_ASSERT(data_type == 0, "Data types other than float32 are not yet supported"); // TODO
|
||||
ggml_type ggml_data_type = data_type == 0 ? GGML_TYPE_F32 : GGML_TYPE_F16;
|
||||
|
||||
struct ggml_tensor * tensor;
|
||||
|
||||
|
@ -248,7 +246,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
|||
if (dim_count == 1) {
|
||||
read_int32(file, &x);
|
||||
element_count = x;
|
||||
tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, x);
|
||||
tensor = ggml_new_tensor_1d(ctx, ggml_data_type, x);
|
||||
} else if (dim_count == 2) {
|
||||
read_int32(file, &x);
|
||||
read_int32(file, &y);
|
||||
|
@ -257,7 +255,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
|||
// * PyTorch shape is (x rows, y columns)
|
||||
// * ggml shape is (y elements in a row, x elements in a column)
|
||||
// Both shapes represent the same tensor.
|
||||
tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, y, x);
|
||||
tensor = ggml_new_tensor_2d(ctx, ggml_data_type, y, x);
|
||||
} else {
|
||||
abort();
|
||||
}
|
||||
|
@ -265,11 +263,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
|||
std::string key(key_length, 0);
|
||||
RWKV_ASSERT(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key");
|
||||
|
||||
size_t element_size = data_type == 0 ?
|
||||
ggml_type_size(GGML_TYPE_F32) :
|
||||
ggml_type_size(GGML_TYPE_F16);
|
||||
size_t byte_count = element_count * element_size;
|
||||
|
||||
size_t byte_count = element_count * ggml_type_size(ggml_data_type);
|
||||
RWKV_ASSERT(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data");
|
||||
|
||||
parameters[key] = tensor;
|
||||
|
@ -322,8 +316,8 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
|||
// Verify order of dimensions
|
||||
struct ggml_tensor * emb = model->emb;
|
||||
RWKV_ASSERT(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims);
|
||||
RWKV_ASSERT(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[1]);
|
||||
RWKV_ASSERT(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[0]);
|
||||
RWKV_ASSERT(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[0]);
|
||||
RWKV_ASSERT(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[1]);
|
||||
}
|
||||
|
||||
// --- Operators ---
|
||||
|
|
|
@ -4,11 +4,12 @@
|
|||
# Usage: python compare_cpp_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin
|
||||
|
||||
import os
|
||||
import struct
|
||||
import argparse
|
||||
import subprocess
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import List
|
||||
from typing import List, Tuple, Any
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
|
||||
|
@ -32,6 +33,21 @@ def main() -> None:
|
|||
627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272,
|
||||
253, 1708, 432, 253, 13862, 15]
|
||||
|
||||
threshold: float
|
||||
|
||||
with open(args.ggml_model_path, 'rb') as model_file:
|
||||
header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4))
|
||||
data_type: int = header[5]
|
||||
|
||||
assert data_type == 0 or data_type == 1, f'Unsupported model data type {data_type}'
|
||||
|
||||
if data_type == 0:
|
||||
# FP32, high precision
|
||||
threshold = 0.000005
|
||||
elif data_type == 1:
|
||||
# FP16, lower precision, so higher threshold
|
||||
threshold = 0.003
|
||||
|
||||
def compare_logits(tokens_subset: List[int]) -> None:
|
||||
token_count: int = len(tokens_subset)
|
||||
state_path: str = './state.bin'
|
||||
|
@ -72,7 +88,7 @@ def main() -> None:
|
|||
print(f'Actual logits: {actual_logits}')
|
||||
print('Difference per token: %.8f' % (difference,))
|
||||
|
||||
assert abs(difference) <= 0.000005, 'Difference is too big'
|
||||
assert abs(difference) <= threshold, 'Difference is too big'
|
||||
|
||||
# Check small token amount first to avoid waiting too long before seeing that model is broken
|
||||
compare_logits(tokens[:4])
|
||||
|
|
Loading…
Reference in New Issue