rwkv.cpp/rwkv/chat_with_bot.py

91 lines
2.8 KiB
Python

# Provides terminal-based chat interface for RWKV model.
import os
import sys
import argparse
import pathlib
import sampling
import tokenizers
import rwkv_cpp_model
import rwkv_cpp_shared_library
# ======================================== Script settings ========================================
# Copied from https://github.com/ggerganov/llama.cpp/blob/6e7801d08d81c931a5427bae46f00763e993f54a/prompts/chat-with-bob.txt
prompt: str = """Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
User: Hello, Bob.
Bob: Hello. How may I help you today?
User: Please tell me the largest city in Europe.
Bob: Sure. The largest city in Europe is Moscow, the capital of Russia."""
# No trailing space here!
bot_message_prefix: str = 'Bob:'
user_message_prefix: str = 'User:'
max_tokens_per_generation: int = 100
# Sampling settings.
temperature: float = 0.8
top_p: float = 0.5
# =================================================================================================
parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
args = parser.parse_args()
assert prompt != '', 'Prompt must not be empty'
print('Loading 20B tokenizer')
tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')
print('Loading RWKV model')
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
prompt_tokens = tokenizer.encode(prompt).ids
prompt_token_count = len(prompt_tokens)
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
logits, state = None, None
for token in prompt_tokens:
logits, state = model.eval(token, state, state, logits)
print('\nChat initialized! Write something and press Enter.')
while True:
# Read user input
print('> ', end='')
user_input = sys.stdin.readline()
# Process the input
new_tokens = tokenizer.encode('\n' + user_message_prefix + ' ' + user_input + '\n' + bot_message_prefix).ids
for token in new_tokens:
logits, state = model.eval(token, state, state, logits)
# Generate and print bot response
print(bot_message_prefix, end='')
decoded = ''
for i in range(max_tokens_per_generation):
token = sampling.sample_logits(logits, temperature, top_p)
decoded = tokenizer.decode([token])
print(decoded, end='', flush=True)
if '\n' in decoded:
break
logits, state = model.eval(token, state, state, logits)
if '\n' not in decoded:
print()