diff --git a/README.md b/README.md index d6249ed..aae62f3 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,12 @@ This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [gger **WORK IN PROGRESS: NOTHING WORKS YET!** If you know C/C++/ggml, please help! -Inference code runs and outputs some correctly-looking numbers in logits. Values are checked to be correct at least up to `ln0`, they match with reference implementation. +**Status**: The model outputs correct logits for the first token (logits match reference implementation). But state saving is broken, so for every subsequent token logits are invalid. ## Plan 1. Make FP32 inference work - 1. Compare vectors step-by-step with reference implementation + 1. Fix state saving 2. Validate states and logits against [reference implementation](https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py) by creating a testing script 3. Heavily refactor code; optimize where possible 4. Make FP16 inference work @@ -23,7 +23,7 @@ Inference code runs and outputs some correctly-looking numbers in logits. Values This repo is based on the [llama.cpp repo](https://github.com/ggerganov/llama.cpp). RWKV-related code is in these directories: -- `./rwkv`: directory containing Python scripts +- `./rwkv`: directory containing Python scripts for conversion and validation - `./examples/main_rwkw`: directory containing script that loads and infers RWKV model Please do not change files in other directories — this will make pulling recent changes easier. diff --git a/examples/main_rwkv/main_rwkv.cpp b/examples/main_rwkv/main_rwkv.cpp index d392292..be88c65 100644 --- a/examples/main_rwkv/main_rwkv.cpp +++ b/examples/main_rwkv/main_rwkv.cpp @@ -364,8 +364,6 @@ int main(int argc, char ** argv) { struct rwkv_model model; load_rwkv_model(ctx, model_path, &model); - PRINT_TENSOR(model.emb); - int32_t n_vocab = model.n_vocab; int32_t n_embed = model.n_embed; int32_t n_layer = model.n_layer; @@ -393,7 +391,7 @@ int main(int argc, char ** argv) { RWKV_ASSERT(state_in_file != NULL, "Failed to open file %s", state_in_path); // TODO Saving/loading raw data makes state cache machine-dependent - RWKV_ASSERT(fread(state->data, 1, state_file_size, state_in_file) == state_file_size, "Failed to read tensor data from a file"); + RWKV_ASSERT(fread(state->data, 1, state_file_size, state_in_file) == state_file_size, "Failed to read state from a file"); fclose(state_in_file); } @@ -409,10 +407,6 @@ int main(int argc, char ** argv) { // x = self.layer_norm(x, self.w.blocks[0].ln0) x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias); - // For token 123 after ln0, should be [-0.4194, 1.1698, 0.7798 ... -1.1838, -0.8716, -0.2765] - // Prints (768, 1), [[-0.419416 1.169782 0.779827 ... -1.183806 -0.871573 -0.276483]] - COMPUTE_AND_PRINT_TENSOR(ctx, x); - for (int i = 0; i < n_layer; i++) { auto layer = model.layers[i]; @@ -422,7 +416,6 @@ int main(int argc, char ** argv) { struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); // state[5 * i + 1] struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * 4); - COMPUTE_AND_PRINT_TENSOR(ctx, x_prev); // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) @@ -444,11 +437,6 @@ int main(int argc, char ** argv) { // state[5 * i + 1] = x ggml_cpy(ctx, x0, x_prev); - COMPUTE_AND_PRINT_TENSOR(ctx, xk); - COMPUTE_AND_PRINT_TENSOR(ctx, xv); - COMPUTE_AND_PRINT_TENSOR(ctx, xr); - COMPUTE_AND_PRINT_TENSOR(ctx, x_prev); - // r = torch.sigmoid(rw @ xr) struct ggml_tensor * r = ggml_sigmoid( ctx, @@ -497,21 +485,21 @@ int main(int argc, char ** argv) { // e2 = torch.exp(k - qq) e2 = ggml_exp(ctx, ggml_sub(ctx, k, qq)); // state[5 * i + 2] = e1 * aa + e2 * v - // todo must save result + // TODO Must save result ggml_cpy(ctx, ggml_add( ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v) ), aa); // state[5 * i + 3] = e1 * bb + e2 - // todo must save result + // TODO Must save result ggml_cpy(ctx, ggml_add( ctx, ggml_mul(ctx, e1, bb), e2 ), bb); // state[5 * i + 4] = qq - // todo must save result + // TODO Must save result ggml_cpy(ctx, qq, pp); // ow @ (r * wkv) x = ggml_add( @@ -523,8 +511,6 @@ int main(int argc, char ** argv) { ggml_mul(ctx, r, wkv) ) ); - RWKV_LOG("RWKV %d completed", i); - COMPUTE_AND_PRINT_TENSOR(ctx, x); } // FFN/channel mixing @@ -546,7 +532,7 @@ int main(int argc, char ** argv) { ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.ffn_time_mix_r)) ); // state[5 * i + 0] = x - // todo must save result + // TODO Must save result ggml_cpy(ctx, x0, x_prev); // r = torch.sigmoid(rw @ xr) @@ -569,8 +555,6 @@ int main(int argc, char ** argv) { ggml_mul_mat(ctx, layer.ffn_value, k) ) ); - RWKV_LOG("FFN %d completed", i); - COMPUTE_AND_PRINT_TENSOR(ctx, x); } } @@ -582,10 +566,31 @@ int main(int argc, char ** argv) { compute_graph(ctx, logits); - // TODO -nan(ind) -nan(ind) ... (maybe implement exp/max first?) PRINT_TENSOR(logits); - // TODO Save new state and logits + { + RWKV_LOG("Saving state to %s", state_out_path); + int32_t state_file_size = state_element_count * 4; + + FILE * state_out_file = fopen(state_out_path, "wb"); + RWKV_ASSERT(state_out_file != NULL, "Failed to open file %s", state_out_path); + + RWKV_ASSERT(fwrite(state->data, 1, state_file_size, state_out_file) == state_file_size, "Failed to write state to a file"); + + fclose(state_out_file); + } + + { + RWKV_LOG("Saving logits to %s", logits_out_path); + int32_t logits_file_size = n_vocab * 4; + + FILE * logits_out_file = fopen(logits_out_path, "wb"); + RWKV_ASSERT(logits_out_file != NULL, "Failed to open file %s", logits_out_path); + + RWKV_ASSERT(fwrite(logits->data, 1, logits_file_size, logits_out_file) == logits_file_size, "Failed to write logits to a file"); + + fclose(logits_out_file); + } ggml_free(ctx); diff --git a/rwkv/compare_cpp_with_reference_implementation.py b/rwkv/compare_cpp_with_reference_implementation.py new file mode 100644 index 0000000..7edc07a --- /dev/null +++ b/rwkv/compare_cpp_with_reference_implementation.py @@ -0,0 +1,63 @@ +# Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV. +# Usage: python compare_cpp_with_reference_implementation.py C:\RWKV-4-Pile-169M-20220807-8023.pth bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin + +import argparse +import subprocess +import rwkv_model +import torch +import numpy as np +from typing import List + +def parse_args(): + parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV') + parser.add_argument('torch_model_path', help='Path to PyTorch checkpoint file') + parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file') + parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file') + return parser.parse_args() + +def main() -> None: + args = parse_args() + + # It's not important what exactly these tokens are; just that output of both model matches. + tokens: List[int] = [(i + 1) for i in range(32)] + state_path: str = './state.bin' + logits_path: str = './logits.bin' + + reference_model: rwkv_model.RWKV_RNN = rwkv_model.RWKV_RNN(args.torch_model_path) + + ref_logits, ref_state = None, None + + for token in tokens: + print() + print(f'--- Token {token} ---') + + subprocess.run( + [ + args.main_executable_path, + args.ggml_model_path, + str(token), + # If this is the first token, let the script create a new state. + '' if ref_state is None else state_path, + state_path, + logits_path + ], + check=True + ) + + with open(logits_path, 'rb') as logits_file: + actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) + + ref_logits, ref_state = reference_model.forward(token, ref_state) + + difference: float = (torch.sum(ref_logits - actual_logits) / len(ref_logits)).item() + + print(f'Reference logits: {ref_logits}') + print(f'Actual logits: {actual_logits}') + print('Difference per token: %.8f' % (difference,)) + + assert abs(difference) <= 0.000001, 'Difference is too big' + + print('Test passes') + +if __name__ == "__main__": + main()