Add fail-fast version of the test

This commit is contained in:
saharNooby 2023-04-01 11:15:15 +04:00
parent 0fcb7c64c6
commit 16ec7a5c18
3 changed files with 35 additions and 30 deletions

View File

@ -32,46 +32,51 @@ def main() -> None:
627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272, 627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272,
253, 1708, 432, 253, 13862, 15] 253, 1708, 432, 253, 13862, 15]
token_count: int = len(tokens) def compare_logits(tokens_subset: List[int]) -> None:
state_path: str = './state.bin' token_count: int = len(tokens_subset)
logits_path: str = './logits.bin' state_path: str = './state.bin'
logits_path: str = './logits.bin'
for i in range(token_count): for i in range(token_count):
token: int = tokens[i] token: int = tokens_subset[i]
print(f'{i + 1}/{token_count}') print(f'{i + 1}/{token_count}')
subprocess.run( subprocess.run(
[ [
args.main_executable_path, args.main_executable_path,
args.ggml_model_path, args.ggml_model_path,
str(token), str(token),
# If this is the first token, let the script create a new state. # If this is the first token, let the script create a new state.
'' if i == 0 else state_path, '' if i == 0 else state_path,
state_path, state_path,
logits_path logits_path
], ],
check=True check=True
) )
expected_logits_path: str = 'expected_logits_169M_20220807_8023.bin' expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin'
if not os.path.isfile(expected_logits_path): if not os.path.isfile(expected_logits_path):
expected_logits_path = 'rwkv/' + expected_logits_path expected_logits_path = 'rwkv/' + expected_logits_path
with open(expected_logits_path, 'rb') as logits_file: with open(expected_logits_path, 'rb') as logits_file:
expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
with open(logits_path, 'rb') as logits_file: with open(logits_path, 'rb') as logits_file:
actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item() difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item()
print(f'Reference logits: {expected_logits}') print(f'Reference logits: {expected_logits}')
print(f'Actual logits: {actual_logits}') print(f'Actual logits: {actual_logits}')
print('Difference per token: %.8f' % (difference,)) print('Difference per token: %.8f' % (difference,))
assert abs(difference) <= 0.00005, 'Difference is too big' assert abs(difference) <= 0.00005, 'Difference is too big'
# Check small token amount first to avoid waiting too long before seeing that model is broken
compare_logits(tokens[:4])
compare_logits(tokens)
print() print()
print('Test passes') print('Test passes')

Binary file not shown.