Support FP16 inference

This commit is contained in:
saharNooby 2023-04-01 11:53:49 +04:00
parent fe98c94a63
commit f6d45baec0
3 changed files with 30 additions and 22 deletions

View File

@ -2,17 +2,15 @@
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). The end goal is to allow 4-bit quanized inference on CPU. This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). The end goal is to allow 4-bit quanized inference on CPU.
**WORK IN PROGRESS!** **Status**: FP32 inference works. For 64 tokens, logits from `rwkv.cpp` almost exactly match those from [reference implementation](https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py) (difference <= 0.00005 per token). **WORK IN PROGRESS!** **Status**: FP32 and FP16 inference work correctly. Currently, I'm working on creating usable C library interface and Python wrapper for it.
## Plan ## Plan
1. Heavily refactor code; optimize where possible 1. Create proper interface (probably, C library)
2. Make FP16 inference work 2. Create Python wrapper with sampling and simple chat interface
3. Create proper interface (probably, C library) 3. Write a good `README.md` and publish links to this repo
4. Create Python wrapper with sampling and simple chat interface 4. Make INT4 inference work
5. Write a good `README.md` and publish links to this repo 5. Create pull request to main `ggml` repo with all improvements made here
6. Make INT4 inference work
7. Create pull request to main `ggml` repo with all improvements made here
## Structure ## Structure

View File

@ -210,8 +210,6 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
read_int32(file, &(model->data_type)); read_int32(file, &(model->data_type));
RWKV_ASSERT(model->data_type == 0 || model->data_type == 1, "Unsupported model data type %d", model->data_type); RWKV_ASSERT(model->data_type == 0 || model->data_type == 1, "Unsupported model data type %d", model->data_type);
RWKV_ASSERT(model->data_type == 0, "Data types other than float32 are not yet supported"); // TODO
RWKV_LOG("n_vocab = %d", model->n_vocab); RWKV_LOG("n_vocab = %d", model->n_vocab);
RWKV_LOG("n_embed = %d", model->n_embed); RWKV_LOG("n_embed = %d", model->n_embed);
RWKV_LOG("n_layer = %d", model->n_layer); RWKV_LOG("n_layer = %d", model->n_layer);
@ -236,7 +234,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
read_int32(file, &data_type); read_int32(file, &data_type);
RWKV_ASSERT(data_type == 0 || data_type == 1, "Unsupported parameter data type %d", data_type); RWKV_ASSERT(data_type == 0 || data_type == 1, "Unsupported parameter data type %d", data_type);
RWKV_ASSERT(data_type == 0, "Data types other than float32 are not yet supported"); // TODO ggml_type ggml_data_type = data_type == 0 ? GGML_TYPE_F32 : GGML_TYPE_F16;
struct ggml_tensor * tensor; struct ggml_tensor * tensor;
@ -248,7 +246,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
if (dim_count == 1) { if (dim_count == 1) {
read_int32(file, &x); read_int32(file, &x);
element_count = x; element_count = x;
tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, x); tensor = ggml_new_tensor_1d(ctx, ggml_data_type, x);
} else if (dim_count == 2) { } else if (dim_count == 2) {
read_int32(file, &x); read_int32(file, &x);
read_int32(file, &y); read_int32(file, &y);
@ -257,7 +255,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
// * PyTorch shape is (x rows, y columns) // * PyTorch shape is (x rows, y columns)
// * ggml shape is (y elements in a row, x elements in a column) // * ggml shape is (y elements in a row, x elements in a column)
// Both shapes represent the same tensor. // Both shapes represent the same tensor.
tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, y, x); tensor = ggml_new_tensor_2d(ctx, ggml_data_type, y, x);
} else { } else {
abort(); abort();
} }
@ -265,11 +263,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
std::string key(key_length, 0); std::string key(key_length, 0);
RWKV_ASSERT(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key"); RWKV_ASSERT(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key");
size_t element_size = data_type == 0 ? size_t byte_count = element_count * ggml_type_size(ggml_data_type);
ggml_type_size(GGML_TYPE_F32) :
ggml_type_size(GGML_TYPE_F16);
size_t byte_count = element_count * element_size;
RWKV_ASSERT(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data"); RWKV_ASSERT(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data");
parameters[key] = tensor; parameters[key] = tensor;
@ -322,8 +316,8 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
// Verify order of dimensions // Verify order of dimensions
struct ggml_tensor * emb = model->emb; struct ggml_tensor * emb = model->emb;
RWKV_ASSERT(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); RWKV_ASSERT(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims);
RWKV_ASSERT(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[1]); RWKV_ASSERT(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[0]);
RWKV_ASSERT(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[0]); RWKV_ASSERT(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[1]);
} }
// --- Operators --- // --- Operators ---

View File

@ -4,11 +4,12 @@
# Usage: python compare_cpp_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin # Usage: python compare_cpp_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin
import os import os
import struct
import argparse import argparse
import subprocess import subprocess
import torch import torch
import numpy as np import numpy as np
from typing import List from typing import List, Tuple, Any
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV') parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
@ -32,6 +33,21 @@ def main() -> None:
627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272, 627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272,
253, 1708, 432, 253, 13862, 15] 253, 1708, 432, 253, 13862, 15]
threshold: float
with open(args.ggml_model_path, 'rb') as model_file:
header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4))
data_type: int = header[5]
assert data_type == 0 or data_type == 1, f'Unsupported model data type {data_type}'
if data_type == 0:
# FP32, high precision
threshold = 0.000005
elif data_type == 1:
# FP16, lower precision, so higher threshold
threshold = 0.003
def compare_logits(tokens_subset: List[int]) -> None: def compare_logits(tokens_subset: List[int]) -> None:
token_count: int = len(tokens_subset) token_count: int = len(tokens_subset)
state_path: str = './state.bin' state_path: str = './state.bin'
@ -72,7 +88,7 @@ def main() -> None:
print(f'Actual logits: {actual_logits}') print(f'Actual logits: {actual_logits}')
print('Difference per token: %.8f' % (difference,)) print('Difference per token: %.8f' % (difference,))
assert abs(difference) <= 0.000005, 'Difference is too big' assert abs(difference) <= threshold, 'Difference is too big'
# Check small token amount first to avoid waiting too long before seeing that model is broken # Check small token amount first to avoid waiting too long before seeing that model is broken
compare_logits(tokens[:4]) compare_logits(tokens[:4])