72 lines
2.3 KiB
Python
72 lines
2.3 KiB
Python
# Generates completions from RWKV model based on a prompt.
|
|
|
|
import argparse
|
|
import os
|
|
import time
|
|
import sampling
|
|
import rwkv_cpp_model
|
|
import rwkv_cpp_shared_library
|
|
from rwkv_tokenizer import get_tokenizer
|
|
from typing import List
|
|
|
|
# ======================================== Script settings ========================================
|
|
|
|
prompt: str = """# rwkv.cpp
|
|
|
|
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml).
|
|
|
|
Besides usual **FP32**, it supports **FP16** and **quantized INT4** inference on CPU. This project is **CPU only**."""
|
|
|
|
# How many completions to generate.
|
|
generation_count: int = 3
|
|
# Token count per single completion.
|
|
tokens_per_generation: int = 100
|
|
|
|
# Sampling settings.
|
|
temperature: float = 0.8
|
|
top_p: float = 0.5
|
|
|
|
# =================================================================================================
|
|
|
|
parser = argparse.ArgumentParser(description='Generate completions from RWKV model based on a prompt')
|
|
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
|
|
parser.add_argument('tokenizer', help='Which tokenizer to use', nargs='?', type=str, default="20B")
|
|
args = parser.parse_args()
|
|
|
|
assert prompt != '', 'Prompt must not be empty'
|
|
|
|
tokenizer, tokenizer_encode = get_tokenizer(args.tokenizer)
|
|
|
|
prompt_tokens = tokenizer_encode(prompt)
|
|
|
|
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
|
print(f'System info: {library.rwkv_get_system_info_string()}')
|
|
|
|
print('Loading RWKV model')
|
|
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
|
|
|
|
prompt_token_count = len(prompt_tokens)
|
|
print(f'{prompt_token_count} tokens in prompt')
|
|
|
|
init_logits, init_state = None, None
|
|
|
|
for token in prompt_tokens:
|
|
init_logits, init_state = model.eval(token, init_state, init_state, init_logits)
|
|
|
|
for GENERATION in range(generation_count):
|
|
print(f'\n--- Generation {GENERATION} ---\n')
|
|
print(prompt, end='[')
|
|
start = time.time()
|
|
|
|
logits, state = init_logits.clone(), init_state.clone()
|
|
|
|
for i in range(tokens_per_generation):
|
|
token = sampling.sample_logits(logits, temperature, top_p)
|
|
|
|
print(tokenizer.decode([token]), end='', flush=True)
|
|
|
|
logits, state = model.eval(token, state, state, logits)
|
|
|
|
delay = time.time() - start
|
|
print(']\n\nTook %.3f sec, %d ms per token' % (delay, delay / tokens_per_generation * 1000))
|