From 85db23c7dec5ae4f9f8af53fa0ffb1f10db480b7 Mon Sep 17 00:00:00 2001 From: saharNooby Date: Sat, 8 Apr 2023 10:41:16 +0400 Subject: [PATCH] Add script that measures perplexity --- rwkv/measure_pexplexity.py | 100 +++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 rwkv/measure_pexplexity.py diff --git a/rwkv/measure_pexplexity.py b/rwkv/measure_pexplexity.py new file mode 100644 index 0000000..a2a0e2c --- /dev/null +++ b/rwkv/measure_pexplexity.py @@ -0,0 +1,100 @@ +# Measures perplexity and per-token latency of an RWKV model on a given text file. +# Perplexity is defined here as exp() of average cross-entropy loss. +# Usage: python measure_pexplexity.py C:\rwkv.cpp-169M.bin C:\text.txt 1024 + +import os +import time +import pathlib +import argparse +import tokenizers +import torch +import rwkv_cpp_model +import rwkv_cpp_shared_library +from typing import List + +def parse_args(): + parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file') + parser.add_argument('model_path', help='Path to model checkpoint file') + parser.add_argument('text_path', help='Path to text file in UTF-8 encoding') + parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int, default=1024) + return parser.parse_args() + +args = parse_args() + +# --- + +print('Loading 20B tokenizer') +tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json' +tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path)) + +print('Loading text') +text: str = open(args.text_path, encoding='utf-8').read() +tokens: List[int] = tokenizer.encode(text).ids +token_count: int = len(tokens) +print(f'{token_count} tokens in the text') + +assert token_count - args.ignore_first_n_tokens > 1, 'Need at least 2 tokens for evaluation' + +# --- + +def format_loss(loss: torch.Tensor) -> str: + return str(['%.3f' % (loss[i].item(),) for i in range(len(loss))]).replace('\'', '')[1:-1] + +def format_loss_with_perplexity(loss: torch.Tensor) -> str: + return f'loss [{format_loss(loss)}], perplexity {"%.3f" % (torch.exp(loss[0]).item(),)}' + +# --- + +model: rwkv_cpp_model.RWKVModel = rwkv_cpp_model.RWKVModel( + rwkv_cpp_shared_library.load_rwkv_shared_library(), + args.model_path +) + +logits, state = None, None + +loss_sum: torch.Tensor = torch.tensor([0.0]) +loss_count: int = 0 + +start: float = time.time() + +run_count: int = token_count - 1 + +for i in range(run_count): + token: int = tokens[i] + target: int = tokens[i + 1] + + logits, state = model.eval(token, state, state, logits) + + if args.ignore_first_n_tokens == 0 or i + 1 >= args.ignore_first_n_tokens: + losses = torch.tensor([ + torch.nn.functional.cross_entropy(logits, torch.tensor(target, dtype=torch.long), reduction='none').item() + ]) + + loss_sum += losses + loss_count += 1 + + if i % 10 == 0: + avg_loss_so_far = loss_sum / loss_count + + duration: float = time.time() - start + duration_per_token: float = duration / (i + 1) + runs_remaining: int = run_count - i - 1 + duration_remaining: int = int(runs_remaining * duration_per_token) + + print(f'Token #{i}/{token_count}, ' + f'{int(100.0 * i / token_count)}%, ' + f'ETA {duration_remaining // 60} m {duration_remaining % 60} s', end='') + + if loss_count > 0: + print(f', averages so far: {format_loss_with_perplexity(avg_loss_so_far)}') + else: + print() + +print() +print(f'Average latency: {int((time.time() - start) * 1000 / run_count)} ms per token') + +print() +print(f'Model: {os.path.basename(args.model_path)}, ' + f'data: {os.path.basename(args.text_path)} with {token_count} tokens, ' + f'skipped {args.ignore_first_n_tokens} tokens, ' + f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}')