diff --git a/README.md b/README.md index f06fbf3..8f14573 100644 --- a/README.md +++ b/README.md @@ -2,17 +2,15 @@ This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). The end goal is to allow 4-bit quanized inference on CPU. -**WORK IN PROGRESS!** **Status**: FP32 inference works. For 64 tokens, logits from `rwkv.cpp` almost exactly match those from [reference implementation](https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py) (difference <= 0.00005 per token). +**WORK IN PROGRESS!** **Status**: FP32 and FP16 inference work correctly. Currently, I'm working on creating usable C library interface and Python wrapper for it. ## Plan -1. Heavily refactor code; optimize where possible -2. Make FP16 inference work -3. Create proper interface (probably, C library) -4. Create Python wrapper with sampling and simple chat interface -5. Write a good `README.md` and publish links to this repo -6. Make INT4 inference work -7. Create pull request to main `ggml` repo with all improvements made here +1. Create proper interface (probably, C library) +2. Create Python wrapper with sampling and simple chat interface +3. Write a good `README.md` and publish links to this repo +4. Make INT4 inference work +5. Create pull request to main `ggml` repo with all improvements made here ## Structure diff --git a/examples/main_rwkv/main_rwkv.cpp b/examples/main_rwkv/main_rwkv.cpp index d72cd92..ea30b62 100644 --- a/examples/main_rwkv/main_rwkv.cpp +++ b/examples/main_rwkv/main_rwkv.cpp @@ -210,8 +210,6 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m read_int32(file, &(model->data_type)); RWKV_ASSERT(model->data_type == 0 || model->data_type == 1, "Unsupported model data type %d", model->data_type); - RWKV_ASSERT(model->data_type == 0, "Data types other than float32 are not yet supported"); // TODO - RWKV_LOG("n_vocab = %d", model->n_vocab); RWKV_LOG("n_embed = %d", model->n_embed); RWKV_LOG("n_layer = %d", model->n_layer); @@ -236,7 +234,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m read_int32(file, &data_type); RWKV_ASSERT(data_type == 0 || data_type == 1, "Unsupported parameter data type %d", data_type); - RWKV_ASSERT(data_type == 0, "Data types other than float32 are not yet supported"); // TODO + ggml_type ggml_data_type = data_type == 0 ? GGML_TYPE_F32 : GGML_TYPE_F16; struct ggml_tensor * tensor; @@ -248,7 +246,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m if (dim_count == 1) { read_int32(file, &x); element_count = x; - tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, x); + tensor = ggml_new_tensor_1d(ctx, ggml_data_type, x); } else if (dim_count == 2) { read_int32(file, &x); read_int32(file, &y); @@ -257,7 +255,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m // * PyTorch shape is (x rows, y columns) // * ggml shape is (y elements in a row, x elements in a column) // Both shapes represent the same tensor. - tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, y, x); + tensor = ggml_new_tensor_2d(ctx, ggml_data_type, y, x); } else { abort(); } @@ -265,11 +263,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m std::string key(key_length, 0); RWKV_ASSERT(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key"); - size_t element_size = data_type == 0 ? - ggml_type_size(GGML_TYPE_F32) : - ggml_type_size(GGML_TYPE_F16); - size_t byte_count = element_count * element_size; - + size_t byte_count = element_count * ggml_type_size(ggml_data_type); RWKV_ASSERT(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data"); parameters[key] = tensor; @@ -322,8 +316,8 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m // Verify order of dimensions struct ggml_tensor * emb = model->emb; RWKV_ASSERT(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); - RWKV_ASSERT(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[1]); - RWKV_ASSERT(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[0]); + RWKV_ASSERT(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[0]); + RWKV_ASSERT(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[1]); } // --- Operators --- diff --git a/rwkv/compare_cpp_with_reference_implementation.py b/rwkv/compare_cpp_with_reference_implementation.py index d8252e2..d6b6e6d 100644 --- a/rwkv/compare_cpp_with_reference_implementation.py +++ b/rwkv/compare_cpp_with_reference_implementation.py @@ -4,11 +4,12 @@ # Usage: python compare_cpp_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin import os +import struct import argparse import subprocess import torch import numpy as np -from typing import List +from typing import List, Tuple, Any def parse_args(): parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV') @@ -32,6 +33,21 @@ def main() -> None: 627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272, 253, 1708, 432, 253, 13862, 15] + threshold: float + + with open(args.ggml_model_path, 'rb') as model_file: + header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4)) + data_type: int = header[5] + + assert data_type == 0 or data_type == 1, f'Unsupported model data type {data_type}' + + if data_type == 0: + # FP32, high precision + threshold = 0.000005 + elif data_type == 1: + # FP16, lower precision, so higher threshold + threshold = 0.003 + def compare_logits(tokens_subset: List[int]) -> None: token_count: int = len(tokens_subset) state_path: str = './state.bin' @@ -72,7 +88,7 @@ def main() -> None: print(f'Actual logits: {actual_logits}') print('Difference per token: %.8f' % (difference,)) - assert abs(difference) <= 0.000005, 'Difference is too big' + assert abs(difference) <= threshold, 'Difference is too big' # Check small token amount first to avoid waiting too long before seeing that model is broken compare_logits(tokens[:4])