Improve chat_with_bot.py script (#39)

This commit is contained in:
Alex 2023-04-22 20:33:58 +05:00 committed by GitHub
parent 3587ff9e58
commit c736ef5411
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 151 additions and 129 deletions

View File

@ -1,11 +1,12 @@
# Provides terminal-based chat interface for RWKV model. # Provides terminal-based chat interface for RWKV model.
# Usage: python chat_with_bot.py C:\rwkv.cpp-169M.bin
# Prompts and code adapted from https://github.com/BlinkDL/ChatRWKV/blob/9ca4cdba90efaee25cfec21a0bae72cbd48d8acd/chat.py
import os import os
import sys
import argparse import argparse
import pathlib import pathlib
import copy import copy
from typing import List import torch
import sampling import sampling
import tokenizers import tokenizers
import rwkv_cpp_model import rwkv_cpp_model
@ -13,89 +14,97 @@ import rwkv_cpp_shared_library
# ======================================== Script settings ======================================== # ======================================== Script settings ========================================
# Copied from https://github.com/BlinkDL/ChatRWKV/blob/9ca4cdba90efaee25cfec21a0bae72cbd48d8acd/chat.py#L92-L178 # English, Chinese
CHAT_LANG = 'English' # English // Chinese LANGUAGE: str = 'English'
QA_PROMPT = False # True: Q & A prompt // False: chat prompt (need large model) # True: Q&A prompt
# False: chat prompt (you need a large model for adequate quality, 7B+)
QA_PROMPT: bool = False
if CHAT_LANG == 'English': MAX_GENERATION_LENGTH: int = 250
interface = ':'
# Sampling temperature. It could be a good idea to increase temperature when top_p is low.
TEMPERATURE: float = 0.8
# For better Q&A accuracy and less diversity, reduce top_p (to 0.5, 0.2, 0.1 etc.)
TOP_P: float = 0.5
if LANGUAGE == 'English':
separator: str = ':'
if QA_PROMPT: if QA_PROMPT:
user = "User" user: str = 'User'
bot = "Bot" # Or: 'The following is a verbose and detailed Q & A conversation of factual information.' bot: str = 'Bot'
init_prompt = f''' init_prompt: str = f'''
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and \
polite.
{user}{interface} french revolution what year {user}{separator} french revolution what year
{bot}{interface} The French Revolution started in 1789, and lasted 10 years until 1799. {bot}{separator} The French Revolution started in 1789, and lasted 10 years until 1799.
{user}{interface} 3+5=? {user}{separator} 3+5=?
{bot}{interface} The answer is 8. {bot}{separator} The answer is 8.
{user}{interface} guess i marry who ? {user}{separator} guess i marry who ?
{bot}{interface} Only if you tell me more about yourself - what are your interests? {bot}{separator} Only if you tell me more about yourself - what are your interests?
{user}{interface} solve for a: 9-a=2 {user}{separator} solve for a: 9-a=2
{bot}{interface} The answer is a = 7, because 9 - 7 = 2. {bot}{separator} The answer is a = 7, because 9 - 7 = 2.
{user}{interface} what is lhc {user}{separator} what is lhc
{bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. {bot}{separator} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
''' '''
else: else:
user = "Bob" user: str = 'Bob'
bot = "Alice" bot: str = 'Alice'
init_prompt = f''' init_prompt: str = f'''
The following is a verbose detailed conversation between {user} and a young girl {bot}. {bot} is intelligent, friendly and cute. {bot} is likely to agree with {user}. The following is a verbose detailed conversation between {user} and a young girl {bot}. {bot} is intelligent, friendly and cute. {bot} is likely to agree with {user}.
{user}{interface} Hello {bot}, how are you doing? {user}{separator} Hello {bot}, how are you doing?
{bot}{interface} Hi {user}! Thanks, I'm fine. What about you? {bot}{separator} Hi {user}! Thanks, I'm fine. What about you?
{user}{interface} I am very good! It's nice to see you. Would you mind me chatting with you for a while? {user}{separator} I am very good! It's nice to see you. Would you mind me chatting with you for a while?
{bot}{interface} Not at all! I'm listening. {bot}{separator} Not at all! I'm listening.
''' '''
elif CHAT_LANG == 'Chinese': elif LANGUAGE == 'Chinese':
interface = ":" separator: str = ':'
if QA_PROMPT: if QA_PROMPT:
user = "Q" user: str = 'Q'
bot = "A" bot: str = 'A'
init_prompt = f''' init_prompt: str = f'''
Expert Questions & Helpful Answers Expert Questions & Helpful Answers
Ask Research Experts Ask Research Experts
''' '''
else: else:
user = "Bob" user: str = 'Bob'
bot = "Alice" bot: str = 'Alice'
init_prompt = f''' init_prompt: str = f'''
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and \
polite.
{user}{interface} what is lhc {user}{separator} what is lhc
{bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. {bot}{separator} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
{user}{interface} 企鹅会飞吗 {user}{separator} 企鹅会飞吗
{bot}{interface} 企鹅是不会飞的它们的翅膀主要用于游泳和平衡而不是飞行 {bot}{separator} 企鹅是不会飞的它们的翅膀主要用于游泳和平衡而不是飞行
''' '''
else:
FREE_GEN_LEN: int = 100 assert False, f'Invalid language {LANGUAGE}'
# Sampling settings.
GEN_TEMP: float = 0.8 # It could be a good idea to increase temp when top_p is low
GEN_TOP_P: float = 0.5 # Reduce top_p (to 0.5, 0.2, 0.1 etc.) for better Q&A accuracy (and less diversity)
# ================================================================================================= # =================================================================================================
@ -117,79 +126,87 @@ model = rwkv_cpp_model.RWKVModel(library, args.model_path)
prompt_tokens = tokenizer.encode(init_prompt).ids prompt_tokens = tokenizer.encode(init_prompt).ids
prompt_token_count = len(prompt_tokens) prompt_token_count = len(prompt_tokens)
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
######################################################################################################## ########################################################################################################
def run_rnn(tokens: List[int]): model_tokens: list[int] = []
logits, model_state = None, None
def process_tokens(_tokens: list[int]) -> torch.Tensor:
global model_tokens, model_state, logits global model_tokens, model_state, logits
tokens = [int(x) for x in tokens] _tokens = [int(x) for x in _tokens]
model_tokens += tokens
model_tokens += _tokens
for _token in _tokens:
logits, model_state = model.eval(_token, model_state, model_state, logits)
for token in tokens:
logits, model_state = model.eval(token, model_state, model_state, logits)
return logits return logits
all_state = {} state_by_thread: dict[str, dict] = {}
def save_all_stat(thread: str, last_out): def save_thread_state(_thread: str, _logits: torch.Tensor) -> None:
n = f'{thread}' state_by_thread[_thread] = {}
all_state[n] = {} state_by_thread[_thread]['logits'] = copy.deepcopy(_logits)
all_state[n]['logits'] = copy.deepcopy(last_out) state_by_thread[_thread]['rnn'] = copy.deepcopy(model_state)
all_state[n]['rnn'] = copy.deepcopy(model_state) state_by_thread[_thread]['token'] = copy.deepcopy(model_tokens)
all_state[n]['token'] = copy.deepcopy(model_tokens)
def load_all_stat(thread: str): def load_thread_state(_thread: str) -> torch.Tensor:
global model_tokens, model_state global model_tokens, model_state
n = f'{thread}' model_state = copy.deepcopy(state_by_thread[_thread]['rnn'])
model_state = copy.deepcopy(all_state[n]['rnn']) model_tokens = copy.deepcopy(state_by_thread[_thread]['token'])
model_tokens = copy.deepcopy(all_state[n]['token']) return copy.deepcopy(state_by_thread[_thread]['logits'])
return copy.deepcopy(all_state[n]['logits'])
######################################################################################################## ########################################################################################################
model_tokens = [] print(f'Processing {prompt_token_count} prompt tokens, may take a while')
logits, model_state = None, None
for token in prompt_tokens: for token in prompt_tokens:
logits, model_state = model.eval(token, model_state, model_state, logits) logits, model_state = model.eval(token, model_state, model_state, logits)
model_tokens.append(token) model_tokens.append(token)
save_all_stat('chat_init', logits) save_thread_state('chat_init', logits)
print('\nChat initialized! Write something and press Enter.') save_thread_state('chat', logits)
save_all_stat('chat', logits)
print(f'\nChat initialized! Your name is {user}. Write something and press Enter. Use \\n to add line breaks to your message.')
while True: while True:
# Read user input # Read user input
user_input = input(f'> {user}{interface} ') user_input = input(f'> {user}{separator} ')
msg = user_input.replace('\\n','\n').strip() msg = user_input.replace('\\n', '\n').strip()
temperature = TEMPERATURE
top_p = TOP_P
if "-temp=" in msg:
temperature = float(msg.split('-temp=')[1].split(' ')[0])
msg = msg.replace('-temp='+f'{temperature:g}', '')
if temperature <= 0.2:
temperature = 0.2
if temperature >= 5:
temperature = 5
if "-top_p=" in msg:
top_p = float(msg.split('-top_p=')[1].split(' ')[0])
msg = msg.replace('-top_p='+f'{top_p:g}', '')
if top_p <= 0:
top_p = 0
temperature = GEN_TEMP
top_p = GEN_TOP_P
if ("-temp=" in msg):
temperature = float(msg.split("-temp=")[1].split(" ")[0])
msg = msg.replace("-temp="+f'{temperature:g}', "")
# print(f"temp: {temperature}")
if ("-top_p=" in msg):
top_p = float(msg.split("-top_p=")[1].split(" ")[0])
msg = msg.replace("-top_p="+f'{top_p:g}', "")
# print(f"top_p: {top_p}")
if temperature <= 0.2:
temperature = 0.2
if temperature >= 5:
temperature = 5
if top_p <= 0:
top_p = 0
msg = msg.strip() msg = msg.strip()
# + reset --> reset chat # + reset --> reset chat
if msg == '+reset': if msg == '+reset':
logits = load_all_stat('chat_init') logits = load_thread_state('chat_init')
save_all_stat('chat', logits) save_thread_state('chat', logits)
print(f'{bot}{interface} "Chat reset."\n') print(f'{bot}{separator} Chat reset.\n')
continue continue
elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++': elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++':
@ -199,8 +216,8 @@ while True:
# print(f'### prompt ###\n[{new}]') # print(f'### prompt ###\n[{new}]')
model_state = None model_state = None
model_tokens = [] model_tokens = []
logits = run_rnn(tokenizer.encode(new).ids) logits = process_tokens(tokenizer.encode(new).ids)
save_all_stat('gen_0', logits) save_thread_state('gen_0', logits)
# +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model. # +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model.
elif msg[:3].lower() == '+i ': elif msg[:3].lower() == '+i ':
@ -215,8 +232,8 @@ Below is an instruction that describes a task. Write a response that appropriate
# print(f'### prompt ###\n[{new}]') # print(f'### prompt ###\n[{new}]')
model_state = None model_state = None
model_tokens = [] model_tokens = []
logits = run_rnn(tokenizer.encode(new).ids) logits = process_tokens(tokenizer.encode(new).ids)
save_all_stat('gen_0', logits) save_thread_state('gen_0', logits)
# +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context). # +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context).
elif msg[:4].lower() == '+qq ': elif msg[:4].lower() == '+qq ':
@ -224,25 +241,25 @@ Below is an instruction that describes a task. Write a response that appropriate
# print(f'### prompt ###\n[{new}]') # print(f'### prompt ###\n[{new}]')
model_state = None model_state = None
model_tokens = [] model_tokens = []
logits = run_rnn(tokenizer.encode(new).ids) logits = process_tokens(tokenizer.encode(new).ids)
save_all_stat('gen_0', logits) save_thread_state('gen_0', logits)
# +qa YOUR QUESTION --> answer an independent question (regardless of context). # +qa YOUR QUESTION --> answer an independent question (regardless of context).
elif msg[:4].lower() == '+qa ': elif msg[:4].lower() == '+qa ':
logits = load_all_stat('chat_init') logits = load_thread_state('chat_init')
real_msg = msg[4:].strip() real_msg = msg[4:].strip()
new = f"{user}{interface} {real_msg}\n\n{bot}{interface}" new = f"{user}{separator} {real_msg}\n\n{bot}{separator}"
# print(f'### qa ###\n[{new}]') # print(f'### qa ###\n[{new}]')
logits = run_rnn(tokenizer.encode(new).ids) logits = process_tokens(tokenizer.encode(new).ids)
save_all_stat('gen_0', logits) save_thread_state('gen_0', logits)
# +++ --> continue last free generation (only for +gen / +i) # +++ --> continue last free generation (only for +gen / +i)
elif msg.lower() == '+++': elif msg.lower() == '+++':
try: try:
logits = load_all_stat('gen_1') logits = load_thread_state('gen_1')
save_all_stat('gen_0', logits) save_thread_state('gen_0', logits)
except Exception as e: except Exception as e:
print(e) print(e)
continue continue
@ -250,7 +267,7 @@ Below is an instruction that describes a task. Write a response that appropriate
# ++ --> retry last free generation (only for +gen / +i) # ++ --> retry last free generation (only for +gen / +i)
elif msg.lower() == '++': elif msg.lower() == '++':
try: try:
logits = load_all_stat('gen_0') logits = load_thread_state('gen_0')
except Exception as e: except Exception as e:
print(e) print(e)
continue continue
@ -260,41 +277,46 @@ Below is an instruction that describes a task. Write a response that appropriate
# + --> alternate chat reply # + --> alternate chat reply
if msg.lower() == '+': if msg.lower() == '+':
try: try:
logits = load_all_stat('chat_pre') logits = load_thread_state('chat_pre')
except Exception as e: except Exception as e:
print(e) print(e)
continue continue
# chat with bot # chat with bot
else: else:
logits = load_all_stat('chat') logits = load_thread_state('chat')
new = f"{user}{interface} {msg}\n\n{bot}{interface}" new = f"{user}{separator} {msg}\n\n{bot}{separator}"
# print(f'### add ###\n[{new}]') # print(f'### add ###\n[{new}]')
logits = run_rnn(tokenizer.encode(new).ids) logits = process_tokens(tokenizer.encode(new).ids)
save_all_stat('chat_pre', logits) save_thread_state('chat_pre', logits)
thread = 'chat' thread = 'chat'
# Print bot response # Print bot response
print(f"> {bot}{interface}", end='') print(f"> {bot}{separator}", end='')
decoded = '' start_index: int = len(model_tokens)
begin = len(model_tokens) accumulated_tokens: list[int] = []
out_last = begin
for i in range(FREE_GEN_LEN): for i in range(MAX_GENERATION_LENGTH):
token = sampling.sample_logits(logits, temperature, top_p) token: int = sampling.sample_logits(logits, temperature, top_p)
logits = run_rnn([token])
decoded = tokenizer.decode(model_tokens[out_last:]) logits: torch.Tensor = process_tokens([token])
if '\ufffd' not in decoded: # avoid utf-8 display issues
# Avoid UTF-8 display issues
accumulated_tokens += [token]
decoded: str = tokenizer.decode(accumulated_tokens)
if '\uFFFD' not in decoded:
print(decoded, end='', flush=True) print(decoded, end='', flush=True)
out_last = begin + i + 1
accumulated_tokens = []
if thread == 'chat': if thread == 'chat':
send_msg = tokenizer.decode(model_tokens[begin:]) if '\n\n' in tokenizer.decode(model_tokens[start_index:]):
if '\n\n' in send_msg:
send_msg = send_msg.strip()
break break
if i == FREE_GEN_LEN - 1:
if i == MAX_GENERATION_LENGTH - 1:
print() print()
save_all_stat(thread, logits) save_thread_state(thread, logits)