Add script that measures perplexity
This commit is contained in:
parent
e04baa032c
commit
85db23c7de
|
@ -0,0 +1,100 @@
|
|||
# Measures perplexity and per-token latency of an RWKV model on a given text file.
|
||||
# Perplexity is defined here as exp() of average cross-entropy loss.
|
||||
# Usage: python measure_pexplexity.py C:\rwkv.cpp-169M.bin C:\text.txt 1024
|
||||
|
||||
import os
|
||||
import time
|
||||
import pathlib
|
||||
import argparse
|
||||
import tokenizers
|
||||
import torch
|
||||
import rwkv_cpp_model
|
||||
import rwkv_cpp_shared_library
|
||||
from typing import List
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file')
|
||||
parser.add_argument('model_path', help='Path to model checkpoint file')
|
||||
parser.add_argument('text_path', help='Path to text file in UTF-8 encoding')
|
||||
parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int, default=1024)
|
||||
return parser.parse_args()
|
||||
|
||||
args = parse_args()
|
||||
|
||||
# ---
|
||||
|
||||
print('Loading 20B tokenizer')
|
||||
tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
|
||||
tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
|
||||
|
||||
print('Loading text')
|
||||
text: str = open(args.text_path, encoding='utf-8').read()
|
||||
tokens: List[int] = tokenizer.encode(text).ids
|
||||
token_count: int = len(tokens)
|
||||
print(f'{token_count} tokens in the text')
|
||||
|
||||
assert token_count - args.ignore_first_n_tokens > 1, 'Need at least 2 tokens for evaluation'
|
||||
|
||||
# ---
|
||||
|
||||
def format_loss(loss: torch.Tensor) -> str:
|
||||
return str(['%.3f' % (loss[i].item(),) for i in range(len(loss))]).replace('\'', '')[1:-1]
|
||||
|
||||
def format_loss_with_perplexity(loss: torch.Tensor) -> str:
|
||||
return f'loss [{format_loss(loss)}], perplexity {"%.3f" % (torch.exp(loss[0]).item(),)}'
|
||||
|
||||
# ---
|
||||
|
||||
model: rwkv_cpp_model.RWKVModel = rwkv_cpp_model.RWKVModel(
|
||||
rwkv_cpp_shared_library.load_rwkv_shared_library(),
|
||||
args.model_path
|
||||
)
|
||||
|
||||
logits, state = None, None
|
||||
|
||||
loss_sum: torch.Tensor = torch.tensor([0.0])
|
||||
loss_count: int = 0
|
||||
|
||||
start: float = time.time()
|
||||
|
||||
run_count: int = token_count - 1
|
||||
|
||||
for i in range(run_count):
|
||||
token: int = tokens[i]
|
||||
target: int = tokens[i + 1]
|
||||
|
||||
logits, state = model.eval(token, state, state, logits)
|
||||
|
||||
if args.ignore_first_n_tokens == 0 or i + 1 >= args.ignore_first_n_tokens:
|
||||
losses = torch.tensor([
|
||||
torch.nn.functional.cross_entropy(logits, torch.tensor(target, dtype=torch.long), reduction='none').item()
|
||||
])
|
||||
|
||||
loss_sum += losses
|
||||
loss_count += 1
|
||||
|
||||
if i % 10 == 0:
|
||||
avg_loss_so_far = loss_sum / loss_count
|
||||
|
||||
duration: float = time.time() - start
|
||||
duration_per_token: float = duration / (i + 1)
|
||||
runs_remaining: int = run_count - i - 1
|
||||
duration_remaining: int = int(runs_remaining * duration_per_token)
|
||||
|
||||
print(f'Token #{i}/{token_count}, '
|
||||
f'{int(100.0 * i / token_count)}%, '
|
||||
f'ETA {duration_remaining // 60} m {duration_remaining % 60} s', end='')
|
||||
|
||||
if loss_count > 0:
|
||||
print(f', averages so far: {format_loss_with_perplexity(avg_loss_so_far)}')
|
||||
else:
|
||||
print()
|
||||
|
||||
print()
|
||||
print(f'Average latency: {int((time.time() - start) * 1000 / run_count)} ms per token')
|
||||
|
||||
print()
|
||||
print(f'Model: {os.path.basename(args.model_path)}, '
|
||||
f'data: {os.path.basename(args.text_path)} with {token_count} tokens, '
|
||||
f'skipped {args.ignore_first_n_tokens} tokens, '
|
||||
f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}')
|
Loading…
Reference in New Issue