rwkv.cpp/rwkv/compare_with_reference_impl...

103 lines
3.9 KiB
Python

# Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV.
# Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch.
# Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py
# Usage: python compare_with_reference_implementation.py C:\rwkv.cpp-169M.bin
import os
import struct
import argparse
import torch
import numpy as np
import rwkv_cpp_model
import rwkv_cpp_shared_library
from typing import List, Tuple, Any
def parse_args():
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file')
return parser.parse_args()
def main() -> None:
args = parse_args()
# Don't want to depend on tokenizer here.
tokens: List[int] = [4, 3631, 4420, 2412, 953, 432, 391, 30567, 87, 15, 14161, 7092, 273, 416, 27767, 55, 342,
2412, 953, 432, 3806, 7092, 273, 416, 27767, 55, 15, 187, 4, 19039, 2412, 953, 497, 4561,
342, 416, 27767, 55, 14, 21, 14, 49, 587, 14, 17809, 46, 14, 938, 14256, 28950, 14, 1438,
1508, 15, 81, 394, 1566, 275, 8462, 22097, 348, 15, 187, 4, 43825, 27, 15548, 7277, 64,
3113, 64, 14005, 64, 39595, 15, 4789, 10269, 61, 18992, 61, 7265, 64, 30217, 39297, 15,
20963, 330, 27, 190, 30567, 87, 15, 14161, 14, 17809, 46, 15, 4805]
threshold: float
with open(args.ggml_model_path, 'rb') as model_file:
header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4))
data_type: int = header[5]
assert data_type == 0 or\
data_type == 1 or\
data_type == 2 or\
data_type == 3, f'Unsupported model data type {data_type}'
if data_type == 0:
# FP32, high precision
threshold = 0.000005
elif data_type == 1:
# FP16, lower precision, so higher threshold
threshold = 0.0032
elif data_type == 2:
# INT4 quantized, even lower precision, so even higher threshold
# This threshold will let some bugs pass
threshold = 4.0
elif data_type == 3:
# This format stores more data, so error would be lower
threshold = 1.2
model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path)
def compare_logits(tokens_subset: List[int]) -> None:
token_count: int = len(tokens_subset)
logits, state = None, None
for i in range(token_count):
token: int = tokens_subset[i]
if token_count <= 10 or i % (token_count // 10) == 0:
print(f'{i + 1}/{token_count}')
logits, state = model.eval(token, state, state, logits)
actual_logits = logits
# ---
expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin'
if not os.path.isfile(expected_logits_path):
expected_logits_path = 'rwkv/' + expected_logits_path
with open(expected_logits_path, 'rb') as logits_file:
expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
# ---
difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item()
print(f'Reference logits: {expected_logits}')
print(f'Actual logits: {actual_logits}')
print('Difference per token: %.8f' % (difference,))
assert abs(difference) <= threshold, 'Difference is too big'
compare_logits(tokens)
print()
print('Test passes')
if model is not None:
model.free()
if __name__ == "__main__":
main()