Support FP16 inference

This commit is contained in:
saharNooby 2023-04-01 11:53:49 +04:00
parent fe98c94a63
commit f6d45baec0
3 changed files with 30 additions and 22 deletions

View File

@ -2,17 +2,15 @@
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). The end goal is to allow 4-bit quanized inference on CPU.
**WORK IN PROGRESS!** **Status**: FP32 inference works. For 64 tokens, logits from `rwkv.cpp` almost exactly match those from [reference implementation](https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py) (difference <= 0.00005 per token).
**WORK IN PROGRESS!** **Status**: FP32 and FP16 inference work correctly. Currently, I'm working on creating usable C library interface and Python wrapper for it.
## Plan
1. Heavily refactor code; optimize where possible
2. Make FP16 inference work
3. Create proper interface (probably, C library)
4. Create Python wrapper with sampling and simple chat interface
5. Write a good `README.md` and publish links to this repo
6. Make INT4 inference work
7. Create pull request to main `ggml` repo with all improvements made here
1. Create proper interface (probably, C library)
2. Create Python wrapper with sampling and simple chat interface
3. Write a good `README.md` and publish links to this repo
4. Make INT4 inference work
5. Create pull request to main `ggml` repo with all improvements made here
## Structure

View File

@ -210,8 +210,6 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
read_int32(file, &(model->data_type));
RWKV_ASSERT(model->data_type == 0 || model->data_type == 1, "Unsupported model data type %d", model->data_type);
RWKV_ASSERT(model->data_type == 0, "Data types other than float32 are not yet supported"); // TODO
RWKV_LOG("n_vocab = %d", model->n_vocab);
RWKV_LOG("n_embed = %d", model->n_embed);
RWKV_LOG("n_layer = %d", model->n_layer);
@ -236,7 +234,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
read_int32(file, &data_type);
RWKV_ASSERT(data_type == 0 || data_type == 1, "Unsupported parameter data type %d", data_type);
RWKV_ASSERT(data_type == 0, "Data types other than float32 are not yet supported"); // TODO
ggml_type ggml_data_type = data_type == 0 ? GGML_TYPE_F32 : GGML_TYPE_F16;
struct ggml_tensor * tensor;
@ -248,7 +246,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
if (dim_count == 1) {
read_int32(file, &x);
element_count = x;
tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, x);
tensor = ggml_new_tensor_1d(ctx, ggml_data_type, x);
} else if (dim_count == 2) {
read_int32(file, &x);
read_int32(file, &y);
@ -257,7 +255,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
// * PyTorch shape is (x rows, y columns)
// * ggml shape is (y elements in a row, x elements in a column)
// Both shapes represent the same tensor.
tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, y, x);
tensor = ggml_new_tensor_2d(ctx, ggml_data_type, y, x);
} else {
abort();
}
@ -265,11 +263,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
std::string key(key_length, 0);
RWKV_ASSERT(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key");
size_t element_size = data_type == 0 ?
ggml_type_size(GGML_TYPE_F32) :
ggml_type_size(GGML_TYPE_F16);
size_t byte_count = element_count * element_size;
size_t byte_count = element_count * ggml_type_size(ggml_data_type);
RWKV_ASSERT(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data");
parameters[key] = tensor;
@ -322,8 +316,8 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
// Verify order of dimensions
struct ggml_tensor * emb = model->emb;
RWKV_ASSERT(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims);
RWKV_ASSERT(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[1]);
RWKV_ASSERT(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[0]);
RWKV_ASSERT(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[0]);
RWKV_ASSERT(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[1]);
}
// --- Operators ---

View File

@ -4,11 +4,12 @@
# Usage: python compare_cpp_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin
import os
import struct
import argparse
import subprocess
import torch
import numpy as np
from typing import List
from typing import List, Tuple, Any
def parse_args():
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
@ -32,6 +33,21 @@ def main() -> None:
627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272,
253, 1708, 432, 253, 13862, 15]
threshold: float
with open(args.ggml_model_path, 'rb') as model_file:
header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4))
data_type: int = header[5]
assert data_type == 0 or data_type == 1, f'Unsupported model data type {data_type}'
if data_type == 0:
# FP32, high precision
threshold = 0.000005
elif data_type == 1:
# FP16, lower precision, so higher threshold
threshold = 0.003
def compare_logits(tokens_subset: List[int]) -> None:
token_count: int = len(tokens_subset)
state_path: str = './state.bin'
@ -72,7 +88,7 @@ def main() -> None:
print(f'Actual logits: {actual_logits}')
print('Difference per token: %.8f' % (difference,))
assert abs(difference) <= 0.000005, 'Difference is too big'
assert abs(difference) <= threshold, 'Difference is too big'
# Check small token amount first to avoid waiting too long before seeing that model is broken
compare_logits(tokens[:4])