Support FP16 inference
This commit is contained in:
parent
fe98c94a63
commit
f6d45baec0
14
README.md
14
README.md
|
@ -2,17 +2,15 @@
|
||||||
|
|
||||||
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). The end goal is to allow 4-bit quanized inference on CPU.
|
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). The end goal is to allow 4-bit quanized inference on CPU.
|
||||||
|
|
||||||
**WORK IN PROGRESS!** **Status**: FP32 inference works. For 64 tokens, logits from `rwkv.cpp` almost exactly match those from [reference implementation](https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py) (difference <= 0.00005 per token).
|
**WORK IN PROGRESS!** **Status**: FP32 and FP16 inference work correctly. Currently, I'm working on creating usable C library interface and Python wrapper for it.
|
||||||
|
|
||||||
## Plan
|
## Plan
|
||||||
|
|
||||||
1. Heavily refactor code; optimize where possible
|
1. Create proper interface (probably, C library)
|
||||||
2. Make FP16 inference work
|
2. Create Python wrapper with sampling and simple chat interface
|
||||||
3. Create proper interface (probably, C library)
|
3. Write a good `README.md` and publish links to this repo
|
||||||
4. Create Python wrapper with sampling and simple chat interface
|
4. Make INT4 inference work
|
||||||
5. Write a good `README.md` and publish links to this repo
|
5. Create pull request to main `ggml` repo with all improvements made here
|
||||||
6. Make INT4 inference work
|
|
||||||
7. Create pull request to main `ggml` repo with all improvements made here
|
|
||||||
|
|
||||||
## Structure
|
## Structure
|
||||||
|
|
||||||
|
|
|
@ -210,8 +210,6 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
||||||
read_int32(file, &(model->data_type));
|
read_int32(file, &(model->data_type));
|
||||||
RWKV_ASSERT(model->data_type == 0 || model->data_type == 1, "Unsupported model data type %d", model->data_type);
|
RWKV_ASSERT(model->data_type == 0 || model->data_type == 1, "Unsupported model data type %d", model->data_type);
|
||||||
|
|
||||||
RWKV_ASSERT(model->data_type == 0, "Data types other than float32 are not yet supported"); // TODO
|
|
||||||
|
|
||||||
RWKV_LOG("n_vocab = %d", model->n_vocab);
|
RWKV_LOG("n_vocab = %d", model->n_vocab);
|
||||||
RWKV_LOG("n_embed = %d", model->n_embed);
|
RWKV_LOG("n_embed = %d", model->n_embed);
|
||||||
RWKV_LOG("n_layer = %d", model->n_layer);
|
RWKV_LOG("n_layer = %d", model->n_layer);
|
||||||
|
@ -236,7 +234,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
||||||
read_int32(file, &data_type);
|
read_int32(file, &data_type);
|
||||||
RWKV_ASSERT(data_type == 0 || data_type == 1, "Unsupported parameter data type %d", data_type);
|
RWKV_ASSERT(data_type == 0 || data_type == 1, "Unsupported parameter data type %d", data_type);
|
||||||
|
|
||||||
RWKV_ASSERT(data_type == 0, "Data types other than float32 are not yet supported"); // TODO
|
ggml_type ggml_data_type = data_type == 0 ? GGML_TYPE_F32 : GGML_TYPE_F16;
|
||||||
|
|
||||||
struct ggml_tensor * tensor;
|
struct ggml_tensor * tensor;
|
||||||
|
|
||||||
|
@ -248,7 +246,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
||||||
if (dim_count == 1) {
|
if (dim_count == 1) {
|
||||||
read_int32(file, &x);
|
read_int32(file, &x);
|
||||||
element_count = x;
|
element_count = x;
|
||||||
tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, x);
|
tensor = ggml_new_tensor_1d(ctx, ggml_data_type, x);
|
||||||
} else if (dim_count == 2) {
|
} else if (dim_count == 2) {
|
||||||
read_int32(file, &x);
|
read_int32(file, &x);
|
||||||
read_int32(file, &y);
|
read_int32(file, &y);
|
||||||
|
@ -257,7 +255,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
||||||
// * PyTorch shape is (x rows, y columns)
|
// * PyTorch shape is (x rows, y columns)
|
||||||
// * ggml shape is (y elements in a row, x elements in a column)
|
// * ggml shape is (y elements in a row, x elements in a column)
|
||||||
// Both shapes represent the same tensor.
|
// Both shapes represent the same tensor.
|
||||||
tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, y, x);
|
tensor = ggml_new_tensor_2d(ctx, ggml_data_type, y, x);
|
||||||
} else {
|
} else {
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
@ -265,11 +263,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
||||||
std::string key(key_length, 0);
|
std::string key(key_length, 0);
|
||||||
RWKV_ASSERT(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key");
|
RWKV_ASSERT(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key");
|
||||||
|
|
||||||
size_t element_size = data_type == 0 ?
|
size_t byte_count = element_count * ggml_type_size(ggml_data_type);
|
||||||
ggml_type_size(GGML_TYPE_F32) :
|
|
||||||
ggml_type_size(GGML_TYPE_F16);
|
|
||||||
size_t byte_count = element_count * element_size;
|
|
||||||
|
|
||||||
RWKV_ASSERT(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data");
|
RWKV_ASSERT(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data");
|
||||||
|
|
||||||
parameters[key] = tensor;
|
parameters[key] = tensor;
|
||||||
|
@ -322,8 +316,8 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
||||||
// Verify order of dimensions
|
// Verify order of dimensions
|
||||||
struct ggml_tensor * emb = model->emb;
|
struct ggml_tensor * emb = model->emb;
|
||||||
RWKV_ASSERT(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims);
|
RWKV_ASSERT(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims);
|
||||||
RWKV_ASSERT(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[1]);
|
RWKV_ASSERT(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[0]);
|
||||||
RWKV_ASSERT(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[0]);
|
RWKV_ASSERT(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Operators ---
|
// --- Operators ---
|
||||||
|
|
|
@ -4,11 +4,12 @@
|
||||||
# Usage: python compare_cpp_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin
|
# Usage: python compare_cpp_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import struct
|
||||||
import argparse
|
import argparse
|
||||||
import subprocess
|
import subprocess
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import List
|
from typing import List, Tuple, Any
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
|
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
|
||||||
|
@ -32,6 +33,21 @@ def main() -> None:
|
||||||
627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272,
|
627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272,
|
||||||
253, 1708, 432, 253, 13862, 15]
|
253, 1708, 432, 253, 13862, 15]
|
||||||
|
|
||||||
|
threshold: float
|
||||||
|
|
||||||
|
with open(args.ggml_model_path, 'rb') as model_file:
|
||||||
|
header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4))
|
||||||
|
data_type: int = header[5]
|
||||||
|
|
||||||
|
assert data_type == 0 or data_type == 1, f'Unsupported model data type {data_type}'
|
||||||
|
|
||||||
|
if data_type == 0:
|
||||||
|
# FP32, high precision
|
||||||
|
threshold = 0.000005
|
||||||
|
elif data_type == 1:
|
||||||
|
# FP16, lower precision, so higher threshold
|
||||||
|
threshold = 0.003
|
||||||
|
|
||||||
def compare_logits(tokens_subset: List[int]) -> None:
|
def compare_logits(tokens_subset: List[int]) -> None:
|
||||||
token_count: int = len(tokens_subset)
|
token_count: int = len(tokens_subset)
|
||||||
state_path: str = './state.bin'
|
state_path: str = './state.bin'
|
||||||
|
@ -72,7 +88,7 @@ def main() -> None:
|
||||||
print(f'Actual logits: {actual_logits}')
|
print(f'Actual logits: {actual_logits}')
|
||||||
print('Difference per token: %.8f' % (difference,))
|
print('Difference per token: %.8f' % (difference,))
|
||||||
|
|
||||||
assert abs(difference) <= 0.000005, 'Difference is too big'
|
assert abs(difference) <= threshold, 'Difference is too big'
|
||||||
|
|
||||||
# Check small token amount first to avoid waiting too long before seeing that model is broken
|
# Check small token amount first to avoid waiting too long before seeing that model is broken
|
||||||
compare_logits(tokens[:4])
|
compare_logits(tokens[:4])
|
||||||
|
|
Loading…
Reference in New Issue