diff --git a/rwkv/compare_cpp_with_reference_implementation.py b/rwkv/compare_cpp_with_reference_implementation.py index 8123900..0e08e67 100644 --- a/rwkv/compare_cpp_with_reference_implementation.py +++ b/rwkv/compare_cpp_with_reference_implementation.py @@ -32,46 +32,51 @@ def main() -> None: 627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272, 253, 1708, 432, 253, 13862, 15] - token_count: int = len(tokens) - state_path: str = './state.bin' - logits_path: str = './logits.bin' + def compare_logits(tokens_subset: List[int]) -> None: + token_count: int = len(tokens_subset) + state_path: str = './state.bin' + logits_path: str = './logits.bin' - for i in range(token_count): - token: int = tokens[i] + for i in range(token_count): + token: int = tokens_subset[i] - print(f'{i + 1}/{token_count}') + print(f'{i + 1}/{token_count}') - subprocess.run( - [ - args.main_executable_path, - args.ggml_model_path, - str(token), - # If this is the first token, let the script create a new state. - '' if i == 0 else state_path, - state_path, - logits_path - ], - check=True - ) + subprocess.run( + [ + args.main_executable_path, + args.ggml_model_path, + str(token), + # If this is the first token, let the script create a new state. + '' if i == 0 else state_path, + state_path, + logits_path + ], + check=True + ) - expected_logits_path: str = 'expected_logits_169M_20220807_8023.bin' + expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin' - if not os.path.isfile(expected_logits_path): - expected_logits_path = 'rwkv/' + expected_logits_path + if not os.path.isfile(expected_logits_path): + expected_logits_path = 'rwkv/' + expected_logits_path - with open(expected_logits_path, 'rb') as logits_file: - expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) + with open(expected_logits_path, 'rb') as logits_file: + expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) - with open(logits_path, 'rb') as logits_file: - actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) + with open(logits_path, 'rb') as logits_file: + actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) - difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item() + difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item() - print(f'Reference logits: {expected_logits}') - print(f'Actual logits: {actual_logits}') - print('Difference per token: %.8f' % (difference,)) + print(f'Reference logits: {expected_logits}') + print(f'Actual logits: {actual_logits}') + print('Difference per token: %.8f' % (difference,)) - assert abs(difference) <= 0.00005, 'Difference is too big' + assert abs(difference) <= 0.00005, 'Difference is too big' + + # Check small token amount first to avoid waiting too long before seeing that model is broken + compare_logits(tokens[:4]) + compare_logits(tokens) print() print('Test passes') diff --git a/rwkv/expected_logits_169M_20220807_8023_4_tokens.bin b/rwkv/expected_logits_169M_20220807_8023_4_tokens.bin new file mode 100644 index 0000000..e1ddfc0 Binary files /dev/null and b/rwkv/expected_logits_169M_20220807_8023_4_tokens.bin differ diff --git a/rwkv/expected_logits_169M_20220807_8023.bin b/rwkv/expected_logits_169M_20220807_8023_82_tokens.bin similarity index 100% rename from rwkv/expected_logits_169M_20220807_8023.bin rename to rwkv/expected_logits_169M_20220807_8023_82_tokens.bin