91 lines
2.8 KiB
Python
91 lines
2.8 KiB
Python
# Provides terminal-based chat interface for RWKV model.
|
|
|
|
import os
|
|
import sys
|
|
import argparse
|
|
import sampling
|
|
import tokenizers
|
|
import rwkv_cpp_model
|
|
import rwkv_cpp_shared_library
|
|
from pathlib import Path
|
|
|
|
# ======================================== Script settings ========================================
|
|
|
|
# Copied from https://github.com/ggerganov/llama.cpp/blob/6e7801d08d81c931a5427bae46f00763e993f54a/prompts/chat-with-bob.txt
|
|
prompt: str = """Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
|
|
|
|
User: Hello, Bob.
|
|
Bob: Hello. How may I help you today?
|
|
User: Please tell me the largest city in Europe.
|
|
Bob: Sure. The largest city in Europe is Moscow, the capital of Russia."""
|
|
|
|
# No trailing space here!
|
|
bot_message_prefix: str = 'Bob:'
|
|
user_message_prefix: str = 'User:'
|
|
|
|
max_tokens_per_generation: int = 100
|
|
|
|
# Sampling settings.
|
|
temperature: float = 0.8
|
|
top_p: float = 0.5
|
|
|
|
# =================================================================================================
|
|
|
|
parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model')
|
|
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
|
|
args = parser.parse_args()
|
|
|
|
assert prompt != '', 'Prompt must not be empty'
|
|
|
|
print('Loading 20B tokenizer')
|
|
tokenizer_path = Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
|
|
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
|
|
|
|
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
|
print(f'System info: {library.rwkv_get_system_info_string()}')
|
|
|
|
print('Loading RWKV model')
|
|
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
|
|
|
|
prompt_tokens = tokenizer.encode(prompt).ids
|
|
prompt_token_count = len(prompt_tokens)
|
|
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
|
|
|
|
logits, state = None, None
|
|
|
|
for token in prompt_tokens:
|
|
logits, state = model.eval(token, state, state, logits)
|
|
|
|
print('\nChat initialized! Write something and press Enter.')
|
|
|
|
while True:
|
|
# Read user input
|
|
print('> ', end='')
|
|
user_input = sys.stdin.readline()
|
|
|
|
# Process the input
|
|
new_tokens = tokenizer.encode('\n' + user_message_prefix + ' ' + user_input + '\n' + bot_message_prefix).ids
|
|
|
|
for token in new_tokens:
|
|
logits, state = model.eval(token, state, state, logits)
|
|
|
|
# Generate and print bot response
|
|
print(bot_message_prefix, end='')
|
|
|
|
decoded = ''
|
|
|
|
for i in range(max_tokens_per_generation):
|
|
token = sampling.sample_logits(logits, temperature, top_p)
|
|
|
|
decoded = tokenizer.decode([token])
|
|
|
|
print(decoded, end='')
|
|
|
|
if '\n' in decoded:
|
|
break
|
|
|
|
logits, state = model.eval(token, state, state, logits)
|
|
|
|
if '\n' not in decoded:
|
|
print()
|