91 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			91 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
# Provides terminal-based chat interface for RWKV model.
 | 
						|
 | 
						|
import os
 | 
						|
import sys
 | 
						|
import argparse
 | 
						|
import sampling
 | 
						|
import tokenizers
 | 
						|
import rwkv_cpp_model
 | 
						|
import rwkv_cpp_shared_library
 | 
						|
from pathlib import Path
 | 
						|
 | 
						|
# ======================================== Script settings ========================================
 | 
						|
 | 
						|
# Copied from https://github.com/ggerganov/llama.cpp/blob/6e7801d08d81c931a5427bae46f00763e993f54a/prompts/chat-with-bob.txt
 | 
						|
prompt: str = """Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
 | 
						|
 | 
						|
User: Hello, Bob.
 | 
						|
Bob: Hello. How may I help you today?
 | 
						|
User: Please tell me the largest city in Europe.
 | 
						|
Bob: Sure. The largest city in Europe is Moscow, the capital of Russia."""
 | 
						|
 | 
						|
# No trailing space here!
 | 
						|
bot_message_prefix: str = 'Bob:'
 | 
						|
user_message_prefix: str = 'User:'
 | 
						|
 | 
						|
max_tokens_per_generation: int = 100
 | 
						|
 | 
						|
# Sampling settings.
 | 
						|
temperature: float = 0.8
 | 
						|
top_p: float = 0.5
 | 
						|
 | 
						|
# =================================================================================================
 | 
						|
 | 
						|
parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model')
 | 
						|
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
 | 
						|
args = parser.parse_args()
 | 
						|
 | 
						|
assert prompt != '', 'Prompt must not be empty'
 | 
						|
 | 
						|
print('Loading 20B tokenizer')
 | 
						|
tokenizer_path = Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
 | 
						|
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
 | 
						|
 | 
						|
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
 | 
						|
print(f'System info: {library.rwkv_get_system_info_string()}')
 | 
						|
 | 
						|
print('Loading RWKV model')
 | 
						|
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
 | 
						|
 | 
						|
prompt_tokens = tokenizer.encode(prompt).ids
 | 
						|
prompt_token_count = len(prompt_tokens)
 | 
						|
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
 | 
						|
 | 
						|
logits, state = None, None
 | 
						|
 | 
						|
for token in prompt_tokens:
 | 
						|
    logits, state = model.eval(token, state, state, logits)
 | 
						|
 | 
						|
print('\nChat initialized! Write something and press Enter.')
 | 
						|
 | 
						|
while True:
 | 
						|
    # Read user input
 | 
						|
    print('> ', end='')
 | 
						|
    user_input = sys.stdin.readline()
 | 
						|
 | 
						|
    # Process the input
 | 
						|
    new_tokens = tokenizer.encode('\n' + user_message_prefix + ' ' + user_input + '\n' + bot_message_prefix).ids
 | 
						|
 | 
						|
    for token in new_tokens:
 | 
						|
        logits, state = model.eval(token, state, state, logits)
 | 
						|
 | 
						|
    # Generate and print bot response
 | 
						|
    print(bot_message_prefix, end='')
 | 
						|
 | 
						|
    decoded = ''
 | 
						|
 | 
						|
    for i in range(max_tokens_per_generation):
 | 
						|
        token = sampling.sample_logits(logits, temperature, top_p)
 | 
						|
 | 
						|
        decoded = tokenizer.decode([token])
 | 
						|
 | 
						|
        print(decoded, end='')
 | 
						|
 | 
						|
        if '\n' in decoded:
 | 
						|
            break
 | 
						|
 | 
						|
        logits, state = model.eval(token, state, state, logits)
 | 
						|
 | 
						|
    if '\n' not in decoded:
 | 
						|
        print()
 |