Add fail-fast version of the test
This commit is contained in:
		
							parent
							
								
									0fcb7c64c6
								
							
						
					
					
						commit
						16ec7a5c18
					
				|  | @ -32,12 +32,13 @@ def main() -> None: | ||||||
|                          627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272, |                          627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272, | ||||||
|                          253, 1708, 432, 253, 13862, 15] |                          253, 1708, 432, 253, 13862, 15] | ||||||
| 
 | 
 | ||||||
|     token_count: int = len(tokens) |     def compare_logits(tokens_subset: List[int]) -> None: | ||||||
|  |         token_count: int = len(tokens_subset) | ||||||
|         state_path: str = './state.bin' |         state_path: str = './state.bin' | ||||||
|         logits_path: str = './logits.bin' |         logits_path: str = './logits.bin' | ||||||
| 
 | 
 | ||||||
|         for i in range(token_count): |         for i in range(token_count): | ||||||
|         token: int = tokens[i] |             token: int = tokens_subset[i] | ||||||
| 
 | 
 | ||||||
|             print(f'{i + 1}/{token_count}') |             print(f'{i + 1}/{token_count}') | ||||||
| 
 | 
 | ||||||
|  | @ -54,7 +55,7 @@ def main() -> None: | ||||||
|                 check=True |                 check=True | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|     expected_logits_path: str = 'expected_logits_169M_20220807_8023.bin' |         expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin' | ||||||
| 
 | 
 | ||||||
|         if not os.path.isfile(expected_logits_path): |         if not os.path.isfile(expected_logits_path): | ||||||
|             expected_logits_path = 'rwkv/' + expected_logits_path |             expected_logits_path = 'rwkv/' + expected_logits_path | ||||||
|  | @ -73,6 +74,10 @@ def main() -> None: | ||||||
| 
 | 
 | ||||||
|         assert abs(difference) <= 0.00005, 'Difference is too big' |         assert abs(difference) <= 0.00005, 'Difference is too big' | ||||||
| 
 | 
 | ||||||
|  |     # Check small token amount first to avoid waiting too long before seeing that model is broken | ||||||
|  |     compare_logits(tokens[:4]) | ||||||
|  |     compare_logits(tokens) | ||||||
|  | 
 | ||||||
|     print() |     print() | ||||||
|     print('Test passes') |     print('Test passes') | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
										
											Binary file not shown.
										
									
								
							
		Loading…
	
		Reference in New Issue