Improve chat_with_bot.py script (#39)
This commit is contained in:
parent
3587ff9e58
commit
c736ef5411
|
@ -1,11 +1,12 @@
|
||||||
# Provides terminal-based chat interface for RWKV model.
|
# Provides terminal-based chat interface for RWKV model.
|
||||||
|
# Usage: python chat_with_bot.py C:\rwkv.cpp-169M.bin
|
||||||
|
# Prompts and code adapted from https://github.com/BlinkDL/ChatRWKV/blob/9ca4cdba90efaee25cfec21a0bae72cbd48d8acd/chat.py
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import argparse
|
import argparse
|
||||||
import pathlib
|
import pathlib
|
||||||
import copy
|
import copy
|
||||||
from typing import List
|
import torch
|
||||||
import sampling
|
import sampling
|
||||||
import tokenizers
|
import tokenizers
|
||||||
import rwkv_cpp_model
|
import rwkv_cpp_model
|
||||||
|
@ -13,89 +14,97 @@ import rwkv_cpp_shared_library
|
||||||
|
|
||||||
# ======================================== Script settings ========================================
|
# ======================================== Script settings ========================================
|
||||||
|
|
||||||
# Copied from https://github.com/BlinkDL/ChatRWKV/blob/9ca4cdba90efaee25cfec21a0bae72cbd48d8acd/chat.py#L92-L178
|
# English, Chinese
|
||||||
CHAT_LANG = 'English' # English // Chinese
|
LANGUAGE: str = 'English'
|
||||||
|
|
||||||
QA_PROMPT = False # True: Q & A prompt // False: chat prompt (need large model)
|
# True: Q&A prompt
|
||||||
|
# False: chat prompt (you need a large model for adequate quality, 7B+)
|
||||||
|
QA_PROMPT: bool = False
|
||||||
|
|
||||||
if CHAT_LANG == 'English':
|
MAX_GENERATION_LENGTH: int = 250
|
||||||
interface = ':'
|
|
||||||
|
# Sampling temperature. It could be a good idea to increase temperature when top_p is low.
|
||||||
|
TEMPERATURE: float = 0.8
|
||||||
|
# For better Q&A accuracy and less diversity, reduce top_p (to 0.5, 0.2, 0.1 etc.)
|
||||||
|
TOP_P: float = 0.5
|
||||||
|
|
||||||
|
if LANGUAGE == 'English':
|
||||||
|
separator: str = ':'
|
||||||
|
|
||||||
if QA_PROMPT:
|
if QA_PROMPT:
|
||||||
user = "User"
|
user: str = 'User'
|
||||||
bot = "Bot" # Or: 'The following is a verbose and detailed Q & A conversation of factual information.'
|
bot: str = 'Bot'
|
||||||
init_prompt = f'''
|
init_prompt: str = f'''
|
||||||
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.
|
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and \
|
||||||
|
polite.
|
||||||
|
|
||||||
{user}{interface} french revolution what year
|
{user}{separator} french revolution what year
|
||||||
|
|
||||||
{bot}{interface} The French Revolution started in 1789, and lasted 10 years until 1799.
|
{bot}{separator} The French Revolution started in 1789, and lasted 10 years until 1799.
|
||||||
|
|
||||||
{user}{interface} 3+5=?
|
{user}{separator} 3+5=?
|
||||||
|
|
||||||
{bot}{interface} The answer is 8.
|
{bot}{separator} The answer is 8.
|
||||||
|
|
||||||
{user}{interface} guess i marry who ?
|
{user}{separator} guess i marry who ?
|
||||||
|
|
||||||
{bot}{interface} Only if you tell me more about yourself - what are your interests?
|
{bot}{separator} Only if you tell me more about yourself - what are your interests?
|
||||||
|
|
||||||
{user}{interface} solve for a: 9-a=2
|
{user}{separator} solve for a: 9-a=2
|
||||||
|
|
||||||
{bot}{interface} The answer is a = 7, because 9 - 7 = 2.
|
{bot}{separator} The answer is a = 7, because 9 - 7 = 2.
|
||||||
|
|
||||||
{user}{interface} what is lhc
|
{user}{separator} what is lhc
|
||||||
|
|
||||||
{bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
|
{bot}{separator} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
else:
|
else:
|
||||||
user = "Bob"
|
user: str = 'Bob'
|
||||||
bot = "Alice"
|
bot: str = 'Alice'
|
||||||
init_prompt = f'''
|
init_prompt: str = f'''
|
||||||
The following is a verbose detailed conversation between {user} and a young girl {bot}. {bot} is intelligent, friendly and cute. {bot} is likely to agree with {user}.
|
The following is a verbose detailed conversation between {user} and a young girl {bot}. {bot} is intelligent, friendly and cute. {bot} is likely to agree with {user}.
|
||||||
|
|
||||||
{user}{interface} Hello {bot}, how are you doing?
|
{user}{separator} Hello {bot}, how are you doing?
|
||||||
|
|
||||||
{bot}{interface} Hi {user}! Thanks, I'm fine. What about you?
|
{bot}{separator} Hi {user}! Thanks, I'm fine. What about you?
|
||||||
|
|
||||||
{user}{interface} I am very good! It's nice to see you. Would you mind me chatting with you for a while?
|
{user}{separator} I am very good! It's nice to see you. Would you mind me chatting with you for a while?
|
||||||
|
|
||||||
{bot}{interface} Not at all! I'm listening.
|
{bot}{separator} Not at all! I'm listening.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
elif CHAT_LANG == 'Chinese':
|
elif LANGUAGE == 'Chinese':
|
||||||
interface = ":"
|
separator: str = ':'
|
||||||
|
|
||||||
if QA_PROMPT:
|
if QA_PROMPT:
|
||||||
user = "Q"
|
user: str = 'Q'
|
||||||
bot = "A"
|
bot: str = 'A'
|
||||||
init_prompt = f'''
|
init_prompt: str = f'''
|
||||||
Expert Questions & Helpful Answers
|
Expert Questions & Helpful Answers
|
||||||
|
|
||||||
Ask Research Experts
|
Ask Research Experts
|
||||||
|
|
||||||
'''
|
'''
|
||||||
else:
|
else:
|
||||||
user = "Bob"
|
user: str = 'Bob'
|
||||||
bot = "Alice"
|
bot: str = 'Alice'
|
||||||
init_prompt = f'''
|
init_prompt: str = f'''
|
||||||
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.
|
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and \
|
||||||
|
polite.
|
||||||
|
|
||||||
{user}{interface} what is lhc
|
{user}{separator} what is lhc
|
||||||
|
|
||||||
{bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
|
{bot}{separator} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
|
||||||
|
|
||||||
{user}{interface} 企鹅会飞吗
|
{user}{separator} 企鹅会飞吗
|
||||||
|
|
||||||
{bot}{interface} 企鹅是不会飞的。它们的翅膀主要用于游泳和平衡,而不是飞行。
|
{bot}{separator} 企鹅是不会飞的。它们的翅膀主要用于游泳和平衡,而不是飞行。
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
else:
|
||||||
FREE_GEN_LEN: int = 100
|
assert False, f'Invalid language {LANGUAGE}'
|
||||||
|
|
||||||
# Sampling settings.
|
|
||||||
GEN_TEMP: float = 0.8 # It could be a good idea to increase temp when top_p is low
|
|
||||||
GEN_TOP_P: float = 0.5 # Reduce top_p (to 0.5, 0.2, 0.1 etc.) for better Q&A accuracy (and less diversity)
|
|
||||||
|
|
||||||
# =================================================================================================
|
# =================================================================================================
|
||||||
|
|
||||||
|
@ -117,79 +126,87 @@ model = rwkv_cpp_model.RWKVModel(library, args.model_path)
|
||||||
|
|
||||||
prompt_tokens = tokenizer.encode(init_prompt).ids
|
prompt_tokens = tokenizer.encode(init_prompt).ids
|
||||||
prompt_token_count = len(prompt_tokens)
|
prompt_token_count = len(prompt_tokens)
|
||||||
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
|
|
||||||
|
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
|
|
||||||
def run_rnn(tokens: List[int]):
|
model_tokens: list[int] = []
|
||||||
|
|
||||||
|
logits, model_state = None, None
|
||||||
|
|
||||||
|
def process_tokens(_tokens: list[int]) -> torch.Tensor:
|
||||||
global model_tokens, model_state, logits
|
global model_tokens, model_state, logits
|
||||||
|
|
||||||
tokens = [int(x) for x in tokens]
|
_tokens = [int(x) for x in _tokens]
|
||||||
model_tokens += tokens
|
|
||||||
|
model_tokens += _tokens
|
||||||
|
|
||||||
|
for _token in _tokens:
|
||||||
|
logits, model_state = model.eval(_token, model_state, model_state, logits)
|
||||||
|
|
||||||
for token in tokens:
|
|
||||||
logits, model_state = model.eval(token, model_state, model_state, logits)
|
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
all_state = {}
|
state_by_thread: dict[str, dict] = {}
|
||||||
|
|
||||||
def save_all_stat(thread: str, last_out):
|
def save_thread_state(_thread: str, _logits: torch.Tensor) -> None:
|
||||||
n = f'{thread}'
|
state_by_thread[_thread] = {}
|
||||||
all_state[n] = {}
|
state_by_thread[_thread]['logits'] = copy.deepcopy(_logits)
|
||||||
all_state[n]['logits'] = copy.deepcopy(last_out)
|
state_by_thread[_thread]['rnn'] = copy.deepcopy(model_state)
|
||||||
all_state[n]['rnn'] = copy.deepcopy(model_state)
|
state_by_thread[_thread]['token'] = copy.deepcopy(model_tokens)
|
||||||
all_state[n]['token'] = copy.deepcopy(model_tokens)
|
|
||||||
|
|
||||||
def load_all_stat(thread: str):
|
def load_thread_state(_thread: str) -> torch.Tensor:
|
||||||
global model_tokens, model_state
|
global model_tokens, model_state
|
||||||
n = f'{thread}'
|
model_state = copy.deepcopy(state_by_thread[_thread]['rnn'])
|
||||||
model_state = copy.deepcopy(all_state[n]['rnn'])
|
model_tokens = copy.deepcopy(state_by_thread[_thread]['token'])
|
||||||
model_tokens = copy.deepcopy(all_state[n]['token'])
|
return copy.deepcopy(state_by_thread[_thread]['logits'])
|
||||||
return copy.deepcopy(all_state[n]['logits'])
|
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
|
|
||||||
model_tokens = []
|
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
|
||||||
logits, model_state = None, None
|
|
||||||
|
|
||||||
for token in prompt_tokens:
|
for token in prompt_tokens:
|
||||||
logits, model_state = model.eval(token, model_state, model_state, logits)
|
logits, model_state = model.eval(token, model_state, model_state, logits)
|
||||||
|
|
||||||
model_tokens.append(token)
|
model_tokens.append(token)
|
||||||
|
|
||||||
save_all_stat('chat_init', logits)
|
save_thread_state('chat_init', logits)
|
||||||
print('\nChat initialized! Write something and press Enter.')
|
save_thread_state('chat', logits)
|
||||||
save_all_stat('chat', logits)
|
|
||||||
|
print(f'\nChat initialized! Your name is {user}. Write something and press Enter. Use \\n to add line breaks to your message.')
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# Read user input
|
# Read user input
|
||||||
user_input = input(f'> {user}{interface} ')
|
user_input = input(f'> {user}{separator} ')
|
||||||
msg = user_input.replace('\\n','\n').strip()
|
msg = user_input.replace('\\n', '\n').strip()
|
||||||
|
|
||||||
|
temperature = TEMPERATURE
|
||||||
|
top_p = TOP_P
|
||||||
|
|
||||||
|
if "-temp=" in msg:
|
||||||
|
temperature = float(msg.split('-temp=')[1].split(' ')[0])
|
||||||
|
|
||||||
|
msg = msg.replace('-temp='+f'{temperature:g}', '')
|
||||||
|
|
||||||
|
if temperature <= 0.2:
|
||||||
|
temperature = 0.2
|
||||||
|
|
||||||
|
if temperature >= 5:
|
||||||
|
temperature = 5
|
||||||
|
|
||||||
|
if "-top_p=" in msg:
|
||||||
|
top_p = float(msg.split('-top_p=')[1].split(' ')[0])
|
||||||
|
|
||||||
|
msg = msg.replace('-top_p='+f'{top_p:g}', '')
|
||||||
|
|
||||||
|
if top_p <= 0:
|
||||||
|
top_p = 0
|
||||||
|
|
||||||
temperature = GEN_TEMP
|
|
||||||
top_p = GEN_TOP_P
|
|
||||||
if ("-temp=" in msg):
|
|
||||||
temperature = float(msg.split("-temp=")[1].split(" ")[0])
|
|
||||||
msg = msg.replace("-temp="+f'{temperature:g}', "")
|
|
||||||
# print(f"temp: {temperature}")
|
|
||||||
if ("-top_p=" in msg):
|
|
||||||
top_p = float(msg.split("-top_p=")[1].split(" ")[0])
|
|
||||||
msg = msg.replace("-top_p="+f'{top_p:g}', "")
|
|
||||||
# print(f"top_p: {top_p}")
|
|
||||||
if temperature <= 0.2:
|
|
||||||
temperature = 0.2
|
|
||||||
if temperature >= 5:
|
|
||||||
temperature = 5
|
|
||||||
if top_p <= 0:
|
|
||||||
top_p = 0
|
|
||||||
msg = msg.strip()
|
msg = msg.strip()
|
||||||
|
|
||||||
# + reset --> reset chat
|
# + reset --> reset chat
|
||||||
if msg == '+reset':
|
if msg == '+reset':
|
||||||
logits = load_all_stat('chat_init')
|
logits = load_thread_state('chat_init')
|
||||||
save_all_stat('chat', logits)
|
save_thread_state('chat', logits)
|
||||||
print(f'{bot}{interface} "Chat reset."\n')
|
print(f'{bot}{separator} Chat reset.\n')
|
||||||
continue
|
continue
|
||||||
elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++':
|
elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++':
|
||||||
|
|
||||||
|
@ -199,8 +216,8 @@ while True:
|
||||||
# print(f'### prompt ###\n[{new}]')
|
# print(f'### prompt ###\n[{new}]')
|
||||||
model_state = None
|
model_state = None
|
||||||
model_tokens = []
|
model_tokens = []
|
||||||
logits = run_rnn(tokenizer.encode(new).ids)
|
logits = process_tokens(tokenizer.encode(new).ids)
|
||||||
save_all_stat('gen_0', logits)
|
save_thread_state('gen_0', logits)
|
||||||
|
|
||||||
# +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model.
|
# +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model.
|
||||||
elif msg[:3].lower() == '+i ':
|
elif msg[:3].lower() == '+i ':
|
||||||
|
@ -215,8 +232,8 @@ Below is an instruction that describes a task. Write a response that appropriate
|
||||||
# print(f'### prompt ###\n[{new}]')
|
# print(f'### prompt ###\n[{new}]')
|
||||||
model_state = None
|
model_state = None
|
||||||
model_tokens = []
|
model_tokens = []
|
||||||
logits = run_rnn(tokenizer.encode(new).ids)
|
logits = process_tokens(tokenizer.encode(new).ids)
|
||||||
save_all_stat('gen_0', logits)
|
save_thread_state('gen_0', logits)
|
||||||
|
|
||||||
# +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context).
|
# +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context).
|
||||||
elif msg[:4].lower() == '+qq ':
|
elif msg[:4].lower() == '+qq ':
|
||||||
|
@ -224,25 +241,25 @@ Below is an instruction that describes a task. Write a response that appropriate
|
||||||
# print(f'### prompt ###\n[{new}]')
|
# print(f'### prompt ###\n[{new}]')
|
||||||
model_state = None
|
model_state = None
|
||||||
model_tokens = []
|
model_tokens = []
|
||||||
logits = run_rnn(tokenizer.encode(new).ids)
|
logits = process_tokens(tokenizer.encode(new).ids)
|
||||||
save_all_stat('gen_0', logits)
|
save_thread_state('gen_0', logits)
|
||||||
|
|
||||||
# +qa YOUR QUESTION --> answer an independent question (regardless of context).
|
# +qa YOUR QUESTION --> answer an independent question (regardless of context).
|
||||||
elif msg[:4].lower() == '+qa ':
|
elif msg[:4].lower() == '+qa ':
|
||||||
logits = load_all_stat('chat_init')
|
logits = load_thread_state('chat_init')
|
||||||
|
|
||||||
real_msg = msg[4:].strip()
|
real_msg = msg[4:].strip()
|
||||||
new = f"{user}{interface} {real_msg}\n\n{bot}{interface}"
|
new = f"{user}{separator} {real_msg}\n\n{bot}{separator}"
|
||||||
# print(f'### qa ###\n[{new}]')
|
# print(f'### qa ###\n[{new}]')
|
||||||
|
|
||||||
logits = run_rnn(tokenizer.encode(new).ids)
|
logits = process_tokens(tokenizer.encode(new).ids)
|
||||||
save_all_stat('gen_0', logits)
|
save_thread_state('gen_0', logits)
|
||||||
|
|
||||||
# +++ --> continue last free generation (only for +gen / +i)
|
# +++ --> continue last free generation (only for +gen / +i)
|
||||||
elif msg.lower() == '+++':
|
elif msg.lower() == '+++':
|
||||||
try:
|
try:
|
||||||
logits = load_all_stat('gen_1')
|
logits = load_thread_state('gen_1')
|
||||||
save_all_stat('gen_0', logits)
|
save_thread_state('gen_0', logits)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
continue
|
continue
|
||||||
|
@ -250,7 +267,7 @@ Below is an instruction that describes a task. Write a response that appropriate
|
||||||
# ++ --> retry last free generation (only for +gen / +i)
|
# ++ --> retry last free generation (only for +gen / +i)
|
||||||
elif msg.lower() == '++':
|
elif msg.lower() == '++':
|
||||||
try:
|
try:
|
||||||
logits = load_all_stat('gen_0')
|
logits = load_thread_state('gen_0')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
continue
|
continue
|
||||||
|
@ -260,41 +277,46 @@ Below is an instruction that describes a task. Write a response that appropriate
|
||||||
# + --> alternate chat reply
|
# + --> alternate chat reply
|
||||||
if msg.lower() == '+':
|
if msg.lower() == '+':
|
||||||
try:
|
try:
|
||||||
logits = load_all_stat('chat_pre')
|
logits = load_thread_state('chat_pre')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
continue
|
continue
|
||||||
# chat with bot
|
# chat with bot
|
||||||
else:
|
else:
|
||||||
logits = load_all_stat('chat')
|
logits = load_thread_state('chat')
|
||||||
new = f"{user}{interface} {msg}\n\n{bot}{interface}"
|
new = f"{user}{separator} {msg}\n\n{bot}{separator}"
|
||||||
# print(f'### add ###\n[{new}]')
|
# print(f'### add ###\n[{new}]')
|
||||||
logits = run_rnn(tokenizer.encode(new).ids)
|
logits = process_tokens(tokenizer.encode(new).ids)
|
||||||
save_all_stat('chat_pre', logits)
|
save_thread_state('chat_pre', logits)
|
||||||
|
|
||||||
thread = 'chat'
|
thread = 'chat'
|
||||||
|
|
||||||
# Print bot response
|
# Print bot response
|
||||||
print(f"> {bot}{interface}", end='')
|
print(f"> {bot}{separator}", end='')
|
||||||
|
|
||||||
decoded = ''
|
start_index: int = len(model_tokens)
|
||||||
begin = len(model_tokens)
|
accumulated_tokens: list[int] = []
|
||||||
out_last = begin
|
|
||||||
|
|
||||||
for i in range(FREE_GEN_LEN):
|
for i in range(MAX_GENERATION_LENGTH):
|
||||||
token = sampling.sample_logits(logits, temperature, top_p)
|
token: int = sampling.sample_logits(logits, temperature, top_p)
|
||||||
logits = run_rnn([token])
|
|
||||||
decoded = tokenizer.decode(model_tokens[out_last:])
|
logits: torch.Tensor = process_tokens([token])
|
||||||
if '\ufffd' not in decoded: # avoid utf-8 display issues
|
|
||||||
|
# Avoid UTF-8 display issues
|
||||||
|
accumulated_tokens += [token]
|
||||||
|
|
||||||
|
decoded: str = tokenizer.decode(accumulated_tokens)
|
||||||
|
|
||||||
|
if '\uFFFD' not in decoded:
|
||||||
print(decoded, end='', flush=True)
|
print(decoded, end='', flush=True)
|
||||||
out_last = begin + i + 1
|
|
||||||
|
accumulated_tokens = []
|
||||||
|
|
||||||
if thread == 'chat':
|
if thread == 'chat':
|
||||||
send_msg = tokenizer.decode(model_tokens[begin:])
|
if '\n\n' in tokenizer.decode(model_tokens[start_index:]):
|
||||||
if '\n\n' in send_msg:
|
|
||||||
send_msg = send_msg.strip()
|
|
||||||
break
|
break
|
||||||
if i == FREE_GEN_LEN - 1:
|
|
||||||
|
if i == MAX_GENERATION_LENGTH - 1:
|
||||||
print()
|
print()
|
||||||
|
|
||||||
save_all_stat(thread, logits)
|
save_thread_state(thread, logits)
|
||||||
|
|
Loading…
Reference in New Issue