Add fail-fast version of the test
This commit is contained in:
parent
0fcb7c64c6
commit
16ec7a5c18
|
@ -32,46 +32,51 @@ def main() -> None:
|
||||||
627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272,
|
627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272,
|
||||||
253, 1708, 432, 253, 13862, 15]
|
253, 1708, 432, 253, 13862, 15]
|
||||||
|
|
||||||
token_count: int = len(tokens)
|
def compare_logits(tokens_subset: List[int]) -> None:
|
||||||
state_path: str = './state.bin'
|
token_count: int = len(tokens_subset)
|
||||||
logits_path: str = './logits.bin'
|
state_path: str = './state.bin'
|
||||||
|
logits_path: str = './logits.bin'
|
||||||
|
|
||||||
for i in range(token_count):
|
for i in range(token_count):
|
||||||
token: int = tokens[i]
|
token: int = tokens_subset[i]
|
||||||
|
|
||||||
print(f'{i + 1}/{token_count}')
|
print(f'{i + 1}/{token_count}')
|
||||||
|
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
[
|
[
|
||||||
args.main_executable_path,
|
args.main_executable_path,
|
||||||
args.ggml_model_path,
|
args.ggml_model_path,
|
||||||
str(token),
|
str(token),
|
||||||
# If this is the first token, let the script create a new state.
|
# If this is the first token, let the script create a new state.
|
||||||
'' if i == 0 else state_path,
|
'' if i == 0 else state_path,
|
||||||
state_path,
|
state_path,
|
||||||
logits_path
|
logits_path
|
||||||
],
|
],
|
||||||
check=True
|
check=True
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_logits_path: str = 'expected_logits_169M_20220807_8023.bin'
|
expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin'
|
||||||
|
|
||||||
if not os.path.isfile(expected_logits_path):
|
if not os.path.isfile(expected_logits_path):
|
||||||
expected_logits_path = 'rwkv/' + expected_logits_path
|
expected_logits_path = 'rwkv/' + expected_logits_path
|
||||||
|
|
||||||
with open(expected_logits_path, 'rb') as logits_file:
|
with open(expected_logits_path, 'rb') as logits_file:
|
||||||
expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
||||||
|
|
||||||
with open(logits_path, 'rb') as logits_file:
|
with open(logits_path, 'rb') as logits_file:
|
||||||
actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single))
|
||||||
|
|
||||||
difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item()
|
difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item()
|
||||||
|
|
||||||
print(f'Reference logits: {expected_logits}')
|
print(f'Reference logits: {expected_logits}')
|
||||||
print(f'Actual logits: {actual_logits}')
|
print(f'Actual logits: {actual_logits}')
|
||||||
print('Difference per token: %.8f' % (difference,))
|
print('Difference per token: %.8f' % (difference,))
|
||||||
|
|
||||||
assert abs(difference) <= 0.00005, 'Difference is too big'
|
assert abs(difference) <= 0.00005, 'Difference is too big'
|
||||||
|
|
||||||
|
# Check small token amount first to avoid waiting too long before seeing that model is broken
|
||||||
|
compare_logits(tokens[:4])
|
||||||
|
compare_logits(tokens)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print('Test passes')
|
print('Test passes')
|
||||||
|
|
Binary file not shown.
Loading…
Reference in New Issue