Remove reference impl comparison test
This commit is contained in:
parent
edd57a186c
commit
e04baa032c
|
@ -1,105 +0,0 @@
|
|||
# Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV.
|
||||
# Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch.
|
||||
# Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py
|
||||
# Usage: python compare_with_reference_implementation.py C:\rwkv.cpp-169M.bin
|
||||
|
||||
import os
|
||||
import struct
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
import rwkv_cpp_model
|
||||
import rwkv_cpp_shared_library
|
||||
from typing import List, Tuple, Any
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV')
|
||||
parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file')
|
||||
return parser.parse_args()
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
# Don't want to depend on tokenizer here.
|
||||
tokens: List[int] = [4, 3631, 4420, 2412, 953, 432, 391, 30567, 87, 15, 14161, 7092, 273, 416, 27767, 55, 342,
|
||||
2412, 953, 432, 3806, 7092, 273, 416, 27767, 55, 15, 187, 4, 19039, 2412, 953, 497, 4561,
|
||||
342, 416, 27767, 55, 14, 21, 14, 49, 587, 14, 17809, 46, 14, 938, 14256, 28950, 14, 1438,
|
||||
1508, 15, 81, 394, 1566, 275, 8462, 22097, 348, 15, 187, 4, 43825, 27, 15548, 7277, 64,
|
||||
3113, 64, 14005, 64, 39595, 15, 4789, 10269, 61, 18992, 61, 7265, 64, 30217, 39297, 15,
|
||||
20963, 330, 27, 190, 30567, 87, 15, 14161, 14, 17809, 46, 15, 4805]
|
||||
|
||||
threshold: float
|
||||
|
||||
with open(args.ggml_model_path, 'rb') as model_file:
|
||||
header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4))
|
||||
data_type: int = header[5]
|
||||
|
||||
assert data_type == 0 or\
|
||||
data_type == 1 or\
|
||||
data_type == 2 or\
|
||||
data_type == 3 or\
|
||||
data_type == 4, f'Unsupported model data type {data_type}'
|
||||
|
||||
if data_type == 0:
|
||||
# FP32, high precision
|
||||
threshold = 0.000005
|
||||
elif data_type == 1:
|
||||
# FP16, lower precision, so higher threshold
|
||||
threshold = 0.0032
|
||||
elif data_type == 2:
|
||||
# Q4_0 quantized, even lower precision, so even higher threshold
|
||||
threshold = 0.4
|
||||
elif data_type == 3:
|
||||
# Q4_1
|
||||
threshold = 1.21
|
||||
elif data_type == 4:
|
||||
# Q4_1_O
|
||||
threshold = 0.2
|
||||
|
||||
model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path)
|
||||
|
||||
def compare_logits(tokens_subset: List[int]) -> None:
|
||||
token_count: int = len(tokens_subset)
|
||||
|
||||
logits, state = None, None
|
||||
|
||||
for i in range(token_count):
|
||||
token: int = tokens_subset[i]
|
||||
|
||||
if token_count <= 10 or i % (token_count // 10) == 0:
|
||||
print(f'{i + 1}/{token_count}')
|
||||
|
||||
logits, state = model.eval(token, state, state, logits)
|
||||
|
||||
actual_logits = logits
|
||||
|
||||
# ---
|
||||
|
||||
expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin'
|
||||
|
||||
if not os.path.isfile(expected_logits_path):
|
||||
expected_logits_path = 'rwkv/' + expected_logits_path
|
||||
|
||||
with open(expected_logits_path, 'rb') as logits_file:
|
||||
expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
||||
|
||||
# ---
|
||||
|
||||
difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item()
|
||||
|
||||
print(f'Reference logits: {expected_logits}')
|
||||
print(f'Actual logits: {actual_logits}')
|
||||
print('Difference per token: %.8f' % (difference,))
|
||||
|
||||
assert abs(difference) <= threshold, 'Difference is too big'
|
||||
|
||||
compare_logits(tokens)
|
||||
|
||||
print()
|
||||
print('Test passes')
|
||||
|
||||
if model is not None:
|
||||
model.free()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Binary file not shown.
|
@ -192,13 +192,17 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
|||
else:
|
||||
file_name = 'librwkv.so'
|
||||
|
||||
repo_root_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent.parent
|
||||
|
||||
paths = [
|
||||
# If we are in "rwkv" directory
|
||||
f'../bin/Release/{file_name}',
|
||||
# If we are in repo root directory
|
||||
f'bin/Release/{file_name}',
|
||||
# Search relative to this file
|
||||
str(repo_root_dir / 'bin' / 'Release' / file_name),
|
||||
# Fallback
|
||||
pathlib.Path(os.path.abspath(__file__)).parent.parent / file_name
|
||||
str(repo_root_dir / file_name)
|
||||
]
|
||||
|
||||
for path in paths:
|
||||
|
|
Loading…
Reference in New Issue