diff --git a/rwkv/measure_pexplexity.py b/rwkv/measure_pexplexity.py
new file mode 100644
index 0000000..a2a0e2c
--- /dev/null
+++ b/rwkv/measure_pexplexity.py
@@ -0,0 +1,100 @@
+# Measures perplexity and per-token latency of an RWKV model on a given text file.
+# Perplexity is defined here as exp() of average cross-entropy loss.
+# Usage: python measure_pexplexity.py C:\rwkv.cpp-169M.bin C:\text.txt 1024
+
+import os
+import time
+import pathlib
+import argparse
+import tokenizers
+import torch
+import rwkv_cpp_model
+import rwkv_cpp_shared_library
+from typing import List
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file')
+    parser.add_argument('model_path', help='Path to model checkpoint file')
+    parser.add_argument('text_path', help='Path to text file in UTF-8 encoding')
+    parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int, default=1024)
+    return parser.parse_args()
+
+args = parse_args()
+
+# ---
+
+print('Loading 20B tokenizer')
+tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
+tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
+
+print('Loading text')
+text: str = open(args.text_path, encoding='utf-8').read()
+tokens: List[int] = tokenizer.encode(text).ids
+token_count: int = len(tokens)
+print(f'{token_count} tokens in the text')
+
+assert token_count - args.ignore_first_n_tokens > 1, 'Need at least 2 tokens for evaluation'
+
+# ---
+
+def format_loss(loss: torch.Tensor) -> str:
+    return str(['%.3f' % (loss[i].item(),) for i in range(len(loss))]).replace('\'', '')[1:-1]
+
+def format_loss_with_perplexity(loss: torch.Tensor) -> str:
+    return f'loss [{format_loss(loss)}], perplexity {"%.3f" % (torch.exp(loss[0]).item(),)}'
+
+# ---
+
+model: rwkv_cpp_model.RWKVModel = rwkv_cpp_model.RWKVModel(
+    rwkv_cpp_shared_library.load_rwkv_shared_library(),
+    args.model_path
+)
+
+logits, state = None, None
+
+loss_sum: torch.Tensor = torch.tensor([0.0])
+loss_count: int = 0
+
+start: float = time.time()
+
+run_count: int = token_count - 1
+
+for i in range(run_count):
+    token: int = tokens[i]
+    target: int = tokens[i + 1]
+
+    logits, state = model.eval(token, state, state, logits)
+
+    if args.ignore_first_n_tokens == 0 or i + 1 >= args.ignore_first_n_tokens:
+        losses = torch.tensor([
+            torch.nn.functional.cross_entropy(logits, torch.tensor(target, dtype=torch.long), reduction='none').item()
+        ])
+
+        loss_sum += losses
+        loss_count += 1
+
+    if i % 10 == 0:
+        avg_loss_so_far = loss_sum / loss_count
+
+        duration: float = time.time() - start
+        duration_per_token: float = duration / (i + 1)
+        runs_remaining: int = run_count - i - 1
+        duration_remaining: int = int(runs_remaining * duration_per_token)
+
+        print(f'Token #{i}/{token_count}, '
+              f'{int(100.0 * i / token_count)}%, '
+              f'ETA {duration_remaining // 60} m {duration_remaining % 60} s', end='')
+
+        if loss_count > 0:
+            print(f', averages so far: {format_loss_with_perplexity(avg_loss_so_far)}')
+        else:
+            print()
+
+print()
+print(f'Average latency: {int((time.time() - start) * 1000 / run_count)} ms per token')
+
+print()
+print(f'Model: {os.path.basename(args.model_path)}, '
+      f'data: {os.path.basename(args.text_path)} with {token_count} tokens, '
+      f'skipped {args.ignore_first_n_tokens} tokens, '
+      f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}')