Add comparison against reference implementation script, implement state & logits saving
This commit is contained in:
parent
d00f28581a
commit
61c6b1a4e0
|
@ -4,12 +4,12 @@ This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [gger
|
|||
|
||||
**WORK IN PROGRESS: NOTHING WORKS YET!** If you know C/C++/ggml, please help!
|
||||
|
||||
Inference code runs and outputs some correctly-looking numbers in logits. Values are checked to be correct at least up to `ln0`, they match with reference implementation.
|
||||
**Status**: The model outputs correct logits for the first token (logits match reference implementation). But state saving is broken, so for every subsequent token logits are invalid.
|
||||
|
||||
## Plan
|
||||
|
||||
1. Make FP32 inference work
|
||||
1. Compare vectors step-by-step with reference implementation
|
||||
1. Fix state saving
|
||||
2. Validate states and logits against [reference implementation](https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py) by creating a testing script
|
||||
3. Heavily refactor code; optimize where possible
|
||||
4. Make FP16 inference work
|
||||
|
@ -23,7 +23,7 @@ Inference code runs and outputs some correctly-looking numbers in logits. Values
|
|||
|
||||
This repo is based on the [llama.cpp repo](https://github.com/ggerganov/llama.cpp). RWKV-related code is in these directories:
|
||||
|
||||
- `./rwkv`: directory containing Python scripts
|
||||
- `./rwkv`: directory containing Python scripts for conversion and validation
|
||||
- `./examples/main_rwkw`: directory containing script that loads and infers RWKV model
|
||||
|
||||
Please do not change files in other directories — this will make pulling recent changes easier.
|
||||
|
|
|
@ -364,8 +364,6 @@ int main(int argc, char ** argv) {
|
|||
struct rwkv_model model;
|
||||
load_rwkv_model(ctx, model_path, &model);
|
||||
|
||||
PRINT_TENSOR(model.emb);
|
||||
|
||||
int32_t n_vocab = model.n_vocab;
|
||||
int32_t n_embed = model.n_embed;
|
||||
int32_t n_layer = model.n_layer;
|
||||
|
@ -393,7 +391,7 @@ int main(int argc, char ** argv) {
|
|||
RWKV_ASSERT(state_in_file != NULL, "Failed to open file %s", state_in_path);
|
||||
|
||||
// TODO Saving/loading raw data makes state cache machine-dependent
|
||||
RWKV_ASSERT(fread(state->data, 1, state_file_size, state_in_file) == state_file_size, "Failed to read tensor data from a file");
|
||||
RWKV_ASSERT(fread(state->data, 1, state_file_size, state_in_file) == state_file_size, "Failed to read state from a file");
|
||||
|
||||
fclose(state_in_file);
|
||||
}
|
||||
|
@ -409,10 +407,6 @@ int main(int argc, char ** argv) {
|
|||
// x = self.layer_norm(x, self.w.blocks[0].ln0)
|
||||
x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias);
|
||||
|
||||
// For token 123 after ln0, should be [-0.4194, 1.1698, 0.7798 ... -1.1838, -0.8716, -0.2765]
|
||||
// Prints (768, 1), [[-0.419416 1.169782 0.779827 ... -1.183806 -0.871573 -0.276483]]
|
||||
COMPUTE_AND_PRINT_TENSOR(ctx, x);
|
||||
|
||||
for (int i = 0; i < n_layer; i++) {
|
||||
auto layer = model.layers[i];
|
||||
|
||||
|
@ -422,7 +416,6 @@ int main(int argc, char ** argv) {
|
|||
struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);
|
||||
// state[5 * i + 1]
|
||||
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * 4);
|
||||
COMPUTE_AND_PRINT_TENSOR(ctx, x_prev);
|
||||
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
|
||||
// xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
|
||||
// xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
|
||||
|
@ -444,11 +437,6 @@ int main(int argc, char ** argv) {
|
|||
// state[5 * i + 1] = x
|
||||
ggml_cpy(ctx, x0, x_prev);
|
||||
|
||||
COMPUTE_AND_PRINT_TENSOR(ctx, xk);
|
||||
COMPUTE_AND_PRINT_TENSOR(ctx, xv);
|
||||
COMPUTE_AND_PRINT_TENSOR(ctx, xr);
|
||||
COMPUTE_AND_PRINT_TENSOR(ctx, x_prev);
|
||||
|
||||
// r = torch.sigmoid(rw @ xr)
|
||||
struct ggml_tensor * r = ggml_sigmoid(
|
||||
ctx,
|
||||
|
@ -497,21 +485,21 @@ int main(int argc, char ** argv) {
|
|||
// e2 = torch.exp(k - qq)
|
||||
e2 = ggml_exp(ctx, ggml_sub(ctx, k, qq));
|
||||
// state[5 * i + 2] = e1 * aa + e2 * v
|
||||
// todo must save result
|
||||
// TODO Must save result
|
||||
ggml_cpy(ctx, ggml_add(
|
||||
ctx,
|
||||
ggml_mul(ctx, e1, aa),
|
||||
ggml_mul(ctx, e2, v)
|
||||
), aa);
|
||||
// state[5 * i + 3] = e1 * bb + e2
|
||||
// todo must save result
|
||||
// TODO Must save result
|
||||
ggml_cpy(ctx, ggml_add(
|
||||
ctx,
|
||||
ggml_mul(ctx, e1, bb),
|
||||
e2
|
||||
), bb);
|
||||
// state[5 * i + 4] = qq
|
||||
// todo must save result
|
||||
// TODO Must save result
|
||||
ggml_cpy(ctx, qq, pp);
|
||||
// ow @ (r * wkv)
|
||||
x = ggml_add(
|
||||
|
@ -523,8 +511,6 @@ int main(int argc, char ** argv) {
|
|||
ggml_mul(ctx, r, wkv)
|
||||
)
|
||||
);
|
||||
RWKV_LOG("RWKV %d completed", i);
|
||||
COMPUTE_AND_PRINT_TENSOR(ctx, x);
|
||||
}
|
||||
|
||||
// FFN/channel mixing
|
||||
|
@ -546,7 +532,7 @@ int main(int argc, char ** argv) {
|
|||
ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.ffn_time_mix_r))
|
||||
);
|
||||
// state[5 * i + 0] = x
|
||||
// todo must save result
|
||||
// TODO Must save result
|
||||
ggml_cpy(ctx, x0, x_prev);
|
||||
|
||||
// r = torch.sigmoid(rw @ xr)
|
||||
|
@ -569,8 +555,6 @@ int main(int argc, char ** argv) {
|
|||
ggml_mul_mat(ctx, layer.ffn_value, k)
|
||||
)
|
||||
);
|
||||
RWKV_LOG("FFN %d completed", i);
|
||||
COMPUTE_AND_PRINT_TENSOR(ctx, x);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -582,10 +566,31 @@ int main(int argc, char ** argv) {
|
|||
|
||||
compute_graph(ctx, logits);
|
||||
|
||||
// TODO -nan(ind) -nan(ind) ... (maybe implement exp/max first?)
|
||||
PRINT_TENSOR(logits);
|
||||
|
||||
// TODO Save new state and logits
|
||||
{
|
||||
RWKV_LOG("Saving state to %s", state_out_path);
|
||||
int32_t state_file_size = state_element_count * 4;
|
||||
|
||||
FILE * state_out_file = fopen(state_out_path, "wb");
|
||||
RWKV_ASSERT(state_out_file != NULL, "Failed to open file %s", state_out_path);
|
||||
|
||||
RWKV_ASSERT(fwrite(state->data, 1, state_file_size, state_out_file) == state_file_size, "Failed to write state to a file");
|
||||
|
||||
fclose(state_out_file);
|
||||
}
|
||||
|
||||
{
|
||||
RWKV_LOG("Saving logits to %s", logits_out_path);
|
||||
int32_t logits_file_size = n_vocab * 4;
|
||||
|
||||
FILE * logits_out_file = fopen(logits_out_path, "wb");
|
||||
RWKV_ASSERT(logits_out_file != NULL, "Failed to open file %s", logits_out_path);
|
||||
|
||||
RWKV_ASSERT(fwrite(logits->data, 1, logits_file_size, logits_out_file) == logits_file_size, "Failed to write logits to a file");
|
||||
|
||||
fclose(logits_out_file);
|
||||
}
|
||||
|
||||
ggml_free(ctx);
|
||||
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
# Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV.
|
||||
# Usage: python compare_cpp_with_reference_implementation.py C:\RWKV-4-Pile-169M-20220807-8023.pth bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import rwkv_model
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
|
||||
parser.add_argument('torch_model_path', help='Path to PyTorch checkpoint file')
|
||||
parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file')
|
||||
parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file')
|
||||
return parser.parse_args()
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
# It's not important what exactly these tokens are; just that output of both model matches.
|
||||
tokens: List[int] = [(i + 1) for i in range(32)]
|
||||
state_path: str = './state.bin'
|
||||
logits_path: str = './logits.bin'
|
||||
|
||||
reference_model: rwkv_model.RWKV_RNN = rwkv_model.RWKV_RNN(args.torch_model_path)
|
||||
|
||||
ref_logits, ref_state = None, None
|
||||
|
||||
for token in tokens:
|
||||
print()
|
||||
print(f'--- Token {token} ---')
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
args.main_executable_path,
|
||||
args.ggml_model_path,
|
||||
str(token),
|
||||
# If this is the first token, let the script create a new state.
|
||||
'' if ref_state is None else state_path,
|
||||
state_path,
|
||||
logits_path
|
||||
],
|
||||
check=True
|
||||
)
|
||||
|
||||
with open(logits_path, 'rb') as logits_file:
|
||||
actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
||||
|
||||
ref_logits, ref_state = reference_model.forward(token, ref_state)
|
||||
|
||||
difference: float = (torch.sum(ref_logits - actual_logits) / len(ref_logits)).item()
|
||||
|
||||
print(f'Reference logits: {ref_logits}')
|
||||
print(f'Actual logits: {actual_logits}')
|
||||
print('Difference per token: %.8f' % (difference,))
|
||||
|
||||
assert abs(difference) <= 0.000001, 'Difference is too big'
|
||||
|
||||
print('Test passes')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue