Remove reference impl comparison test
This commit is contained in:
		
							parent
							
								
									edd57a186c
								
							
						
					
					
						commit
						e04baa032c
					
				|  | @ -1,105 +0,0 @@ | |||
| # Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV. | ||||
| # Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch. | ||||
| # Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py | ||||
| # Usage: python compare_with_reference_implementation.py C:\rwkv.cpp-169M.bin | ||||
| 
 | ||||
| import os | ||||
| import struct | ||||
| import argparse | ||||
| import torch | ||||
| import numpy as np | ||||
| import rwkv_cpp_model | ||||
| import rwkv_cpp_shared_library | ||||
| from typing import List, Tuple, Any | ||||
| 
 | ||||
| def parse_args(): | ||||
|     parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV') | ||||
|     parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file') | ||||
|     return parser.parse_args() | ||||
| 
 | ||||
| def main() -> None: | ||||
|     args = parse_args() | ||||
| 
 | ||||
|     # Don't want to depend on tokenizer here. | ||||
|     tokens: List[int] = [4, 3631, 4420, 2412, 953, 432, 391, 30567, 87, 15, 14161, 7092, 273, 416, 27767, 55, 342, | ||||
|                          2412, 953, 432, 3806, 7092, 273, 416, 27767, 55, 15, 187, 4, 19039, 2412, 953, 497, 4561, | ||||
|                          342, 416, 27767, 55, 14, 21, 14, 49, 587, 14, 17809, 46, 14, 938, 14256, 28950, 14, 1438, | ||||
|                          1508, 15, 81, 394, 1566, 275, 8462, 22097, 348, 15, 187, 4, 43825, 27, 15548, 7277, 64, | ||||
|                          3113, 64, 14005, 64, 39595, 15, 4789, 10269, 61, 18992, 61, 7265, 64, 30217, 39297, 15, | ||||
|                          20963, 330, 27, 190, 30567, 87, 15, 14161, 14, 17809, 46, 15, 4805] | ||||
| 
 | ||||
|     threshold: float | ||||
| 
 | ||||
|     with open(args.ggml_model_path, 'rb') as model_file: | ||||
|         header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4)) | ||||
|         data_type: int = header[5] | ||||
| 
 | ||||
|         assert data_type == 0 or\ | ||||
|                data_type == 1 or\ | ||||
|                data_type == 2 or\ | ||||
|                data_type == 3 or\ | ||||
|                data_type == 4, f'Unsupported model data type {data_type}' | ||||
| 
 | ||||
|         if data_type == 0: | ||||
|             # FP32, high precision | ||||
|             threshold = 0.000005 | ||||
|         elif data_type == 1: | ||||
|             # FP16, lower precision, so higher threshold | ||||
|             threshold = 0.0032 | ||||
|         elif data_type == 2: | ||||
|             # Q4_0 quantized, even lower precision, so even higher threshold | ||||
|             threshold = 0.4 | ||||
|         elif data_type == 3: | ||||
|             # Q4_1 | ||||
|             threshold = 1.21 | ||||
|         elif data_type == 4: | ||||
|             # Q4_1_O | ||||
|             threshold = 0.2 | ||||
| 
 | ||||
|     model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path) | ||||
| 
 | ||||
|     def compare_logits(tokens_subset: List[int]) -> None: | ||||
|         token_count: int = len(tokens_subset) | ||||
| 
 | ||||
|         logits, state = None, None | ||||
| 
 | ||||
|         for i in range(token_count): | ||||
|             token: int = tokens_subset[i] | ||||
| 
 | ||||
|             if token_count <= 10 or i % (token_count // 10) == 0: | ||||
|                 print(f'{i + 1}/{token_count}') | ||||
| 
 | ||||
|             logits, state = model.eval(token, state, state, logits) | ||||
| 
 | ||||
|         actual_logits = logits | ||||
| 
 | ||||
|         # --- | ||||
| 
 | ||||
|         expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin' | ||||
| 
 | ||||
|         if not os.path.isfile(expected_logits_path): | ||||
|             expected_logits_path = 'rwkv/' + expected_logits_path | ||||
| 
 | ||||
|         with open(expected_logits_path, 'rb') as logits_file: | ||||
|             expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) | ||||
| 
 | ||||
|         # --- | ||||
| 
 | ||||
|         difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item() | ||||
| 
 | ||||
|         print(f'Reference logits: {expected_logits}') | ||||
|         print(f'Actual logits: {actual_logits}') | ||||
|         print('Difference per token: %.8f' % (difference,)) | ||||
| 
 | ||||
|         assert abs(difference) <= threshold, 'Difference is too big' | ||||
| 
 | ||||
|     compare_logits(tokens) | ||||
| 
 | ||||
|     print() | ||||
|     print('Test passes') | ||||
| 
 | ||||
|     if model is not None: | ||||
|         model.free() | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
										
											Binary file not shown.
										
									
								
							|  | @ -192,13 +192,17 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary: | |||
|     else: | ||||
|         file_name = 'librwkv.so' | ||||
| 
 | ||||
|     repo_root_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent.parent | ||||
| 
 | ||||
|     paths = [ | ||||
|         # If we are in "rwkv" directory | ||||
|         f'../bin/Release/{file_name}', | ||||
|         # If we are in repo root directory | ||||
|         f'bin/Release/{file_name}', | ||||
|         # Search relative to this file | ||||
|         str(repo_root_dir / 'bin' / 'Release' / file_name), | ||||
|         # Fallback | ||||
|         pathlib.Path(os.path.abspath(__file__)).parent.parent / file_name | ||||
|         str(repo_root_dir / file_name) | ||||
|     ] | ||||
| 
 | ||||
|     for path in paths: | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue