Improve the prompt & fix chinese display issue & support commands (#34)
* update the prompt * Fix/chinese display issue * remove debug code * support commands (#1) +reset +gen +i +qq +qa +++ ++ + * run_rnn before decode * remove debug code * deep copy logits * remove extra print() * print newline if reach max_tokens_per_generation * fix typo in init prompt * Update rwkv/chat_with_bot.py Co-authored-by: Alex <saharNooby@users.noreply.github.com> * Update rwkv/chat_with_bot.py Co-authored-by: Alex <saharNooby@users.noreply.github.com> * Update rwkv/chat_with_bot.py Co-authored-by: Alex <saharNooby@users.noreply.github.com> * Update rwkv/chat_with_bot.py Co-authored-by: Alex <saharNooby@users.noreply.github.com> * refine code & type annotation * add comments for commands * support change temp & top_p during chat. * set default language & prompt --------- Co-authored-by: Alex <saharNooby@users.noreply.github.com>
This commit is contained in:
parent
1be9fda248
commit
ac663631e1
|
@ -4,6 +4,8 @@ import os
|
|||
import sys
|
||||
import argparse
|
||||
import pathlib
|
||||
import copy
|
||||
from typing import List
|
||||
import sampling
|
||||
import tokenizers
|
||||
import rwkv_cpp_model
|
||||
|
@ -11,23 +13,89 @@ import rwkv_cpp_shared_library
|
|||
|
||||
# ======================================== Script settings ========================================
|
||||
|
||||
# Copied from https://github.com/ggerganov/llama.cpp/blob/6e7801d08d81c931a5427bae46f00763e993f54a/prompts/chat-with-bob.txt
|
||||
prompt: str = """Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
|
||||
# Copied from https://github.com/BlinkDL/ChatRWKV/blob/9ca4cdba90efaee25cfec21a0bae72cbd48d8acd/chat.py#L92-L178
|
||||
CHAT_LANG = 'English' # English // Chinese
|
||||
|
||||
User: Hello, Bob.
|
||||
Bob: Hello. How may I help you today?
|
||||
User: Please tell me the largest city in Europe.
|
||||
Bob: Sure. The largest city in Europe is Moscow, the capital of Russia."""
|
||||
QA_PROMPT = False # True: Q & A prompt // False: chat prompt (need large model)
|
||||
|
||||
# No trailing space here!
|
||||
bot_message_prefix: str = 'Bob:'
|
||||
user_message_prefix: str = 'User:'
|
||||
if CHAT_LANG == 'English':
|
||||
interface = ':'
|
||||
|
||||
max_tokens_per_generation: int = 100
|
||||
if QA_PROMPT:
|
||||
user = "User"
|
||||
bot = "Bot" # Or: 'The following is a verbose and detailed Q & A conversation of factual information.'
|
||||
init_prompt = f'''
|
||||
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.
|
||||
|
||||
{user}{interface} french revolution what year
|
||||
|
||||
{bot}{interface} The French Revolution started in 1789, and lasted 10 years until 1799.
|
||||
|
||||
{user}{interface} 3+5=?
|
||||
|
||||
{bot}{interface} The answer is 8.
|
||||
|
||||
{user}{interface} guess i marry who ?
|
||||
|
||||
{bot}{interface} Only if you tell me more about yourself - what are your interests?
|
||||
|
||||
{user}{interface} solve for a: 9-a=2
|
||||
|
||||
{bot}{interface} The answer is a = 7, because 9 - 7 = 2.
|
||||
|
||||
{user}{interface} what is lhc
|
||||
|
||||
{bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
|
||||
|
||||
'''
|
||||
else:
|
||||
user = "Bob"
|
||||
bot = "Alice"
|
||||
init_prompt = f'''
|
||||
The following is a verbose detailed conversation between {user} and a young girl {bot}. {bot} is intelligent, friendly and cute. {bot} is likely to agree with {user}.
|
||||
|
||||
{user}{interface} Hello {bot}, how are you doing?
|
||||
|
||||
{bot}{interface} Hi {user}! Thanks, I'm fine. What about you?
|
||||
|
||||
{user}{interface} I am very good! It's nice to see you. Would you mind me chatting with you for a while?
|
||||
|
||||
{bot}{interface} Not at all! I'm listening.
|
||||
|
||||
'''
|
||||
|
||||
elif CHAT_LANG == 'Chinese':
|
||||
interface = ":"
|
||||
if QA_PROMPT:
|
||||
user = "Q"
|
||||
bot = "A"
|
||||
init_prompt = f'''
|
||||
Expert Questions & Helpful Answers
|
||||
|
||||
Ask Research Experts
|
||||
|
||||
'''
|
||||
else:
|
||||
user = "Bob"
|
||||
bot = "Alice"
|
||||
init_prompt = f'''
|
||||
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.
|
||||
|
||||
{user}{interface} what is lhc
|
||||
|
||||
{bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
|
||||
|
||||
{user}{interface} 企鹅会飞吗
|
||||
|
||||
{bot}{interface} 企鹅是不会飞的。它们的翅膀主要用于游泳和平衡,而不是飞行。
|
||||
|
||||
'''
|
||||
|
||||
FREE_GEN_LEN: int = 100
|
||||
|
||||
# Sampling settings.
|
||||
temperature: float = 0.8
|
||||
top_p: float = 0.5
|
||||
GEN_TEMP: float = 0.8 # It could be a good idea to increase temp when top_p is low
|
||||
GEN_TOP_P: float = 0.5 # Reduce top_p (to 0.5, 0.2, 0.1 etc.) for better Q&A accuracy (and less diversity)
|
||||
|
||||
# =================================================================================================
|
||||
|
||||
|
@ -35,7 +103,7 @@ parser = argparse.ArgumentParser(description='Provide terminal-based chat interf
|
|||
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert prompt != '', 'Prompt must not be empty'
|
||||
assert init_prompt != '', 'Prompt must not be empty'
|
||||
|
||||
print('Loading 20B tokenizer')
|
||||
tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
|
||||
|
@ -47,44 +115,186 @@ print(f'System info: {library.rwkv_get_system_info_string()}')
|
|||
print('Loading RWKV model')
|
||||
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
|
||||
|
||||
prompt_tokens = tokenizer.encode(prompt).ids
|
||||
prompt_tokens = tokenizer.encode(init_prompt).ids
|
||||
prompt_token_count = len(prompt_tokens)
|
||||
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
|
||||
|
||||
logits, state = None, None
|
||||
|
||||
########################################################################################################
|
||||
|
||||
def run_rnn(tokens: List[int]):
|
||||
global model_tokens, model_state, logits
|
||||
|
||||
tokens = [int(x) for x in tokens]
|
||||
model_tokens += tokens
|
||||
|
||||
for token in tokens:
|
||||
logits, model_state = model.eval(token, model_state, model_state, logits)
|
||||
|
||||
return logits
|
||||
|
||||
all_state = {}
|
||||
|
||||
def save_all_stat(thread: str, last_out):
|
||||
n = f'{thread}'
|
||||
all_state[n] = {}
|
||||
all_state[n]['logits'] = copy.deepcopy(last_out)
|
||||
all_state[n]['rnn'] = copy.deepcopy(model_state)
|
||||
all_state[n]['token'] = copy.deepcopy(model_tokens)
|
||||
|
||||
def load_all_stat(thread: str):
|
||||
global model_tokens, model_state
|
||||
n = f'{thread}'
|
||||
model_state = copy.deepcopy(all_state[n]['rnn'])
|
||||
model_tokens = copy.deepcopy(all_state[n]['token'])
|
||||
return copy.deepcopy(all_state[n]['logits'])
|
||||
|
||||
########################################################################################################
|
||||
|
||||
model_tokens = []
|
||||
logits, model_state = None, None
|
||||
|
||||
for token in prompt_tokens:
|
||||
logits, state = model.eval(token, state, state, logits)
|
||||
logits, model_state = model.eval(token, model_state, model_state, logits)
|
||||
model_tokens.append(token)
|
||||
|
||||
save_all_stat('chat_init', logits)
|
||||
print('\nChat initialized! Write something and press Enter.')
|
||||
save_all_stat('chat', logits)
|
||||
|
||||
while True:
|
||||
# Read user input
|
||||
print('> ', end='')
|
||||
user_input = sys.stdin.readline()
|
||||
user_input = input(f'> {user}{interface} ')
|
||||
msg = user_input.replace('\\n','\n').strip()
|
||||
|
||||
# Process the input
|
||||
new_tokens = tokenizer.encode('\n' + user_message_prefix + ' ' + user_input + '\n' + bot_message_prefix).ids
|
||||
temperature = GEN_TEMP
|
||||
top_p = GEN_TOP_P
|
||||
if ("-temp=" in msg):
|
||||
temperature = float(msg.split("-temp=")[1].split(" ")[0])
|
||||
msg = msg.replace("-temp="+f'{temperature:g}', "")
|
||||
# print(f"temp: {temperature}")
|
||||
if ("-top_p=" in msg):
|
||||
top_p = float(msg.split("-top_p=")[1].split(" ")[0])
|
||||
msg = msg.replace("-top_p="+f'{top_p:g}', "")
|
||||
# print(f"top_p: {top_p}")
|
||||
if temperature <= 0.2:
|
||||
temperature = 0.2
|
||||
if temperature >= 5:
|
||||
temperature = 5
|
||||
if top_p <= 0:
|
||||
top_p = 0
|
||||
msg = msg.strip()
|
||||
|
||||
for token in new_tokens:
|
||||
logits, state = model.eval(token, state, state, logits)
|
||||
# + reset --> reset chat
|
||||
if msg == '+reset':
|
||||
logits = load_all_stat('chat_init')
|
||||
save_all_stat('chat', logits)
|
||||
print(f'{bot}{interface} "Chat reset."\n')
|
||||
continue
|
||||
elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++':
|
||||
|
||||
# Generate and print bot response
|
||||
print(bot_message_prefix, end='')
|
||||
# +gen YOUR PROMPT --> free single-round generation with any prompt. Requires Novel model.
|
||||
if msg[:5].lower() == '+gen ':
|
||||
new = '\n' + msg[5:].strip()
|
||||
# print(f'### prompt ###\n[{new}]')
|
||||
model_state = None
|
||||
model_tokens = []
|
||||
logits = run_rnn(tokenizer.encode(new).ids)
|
||||
save_all_stat('gen_0', logits)
|
||||
|
||||
# +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model.
|
||||
elif msg[:3].lower() == '+i ':
|
||||
new = f'''
|
||||
Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||
|
||||
# Instruction:
|
||||
{msg[3:].strip()}
|
||||
|
||||
# Response:
|
||||
'''
|
||||
# print(f'### prompt ###\n[{new}]')
|
||||
model_state = None
|
||||
model_tokens = []
|
||||
logits = run_rnn(tokenizer.encode(new).ids)
|
||||
save_all_stat('gen_0', logits)
|
||||
|
||||
# +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context).
|
||||
elif msg[:4].lower() == '+qq ':
|
||||
new = '\nQ: ' + msg[4:].strip() + '\nA:'
|
||||
# print(f'### prompt ###\n[{new}]')
|
||||
model_state = None
|
||||
model_tokens = []
|
||||
logits = run_rnn(tokenizer.encode(new).ids)
|
||||
save_all_stat('gen_0', logits)
|
||||
|
||||
# +qa YOUR QUESTION --> answer an independent question (regardless of context).
|
||||
elif msg[:4].lower() == '+qa ':
|
||||
logits = load_all_stat('chat_init')
|
||||
|
||||
real_msg = msg[4:].strip()
|
||||
new = f"{user}{interface} {real_msg}\n\n{bot}{interface}"
|
||||
# print(f'### qa ###\n[{new}]')
|
||||
|
||||
logits = run_rnn(tokenizer.encode(new).ids)
|
||||
save_all_stat('gen_0', logits)
|
||||
|
||||
# +++ --> continue last free generation (only for +gen / +i)
|
||||
elif msg.lower() == '+++':
|
||||
try:
|
||||
logits = load_all_stat('gen_1')
|
||||
save_all_stat('gen_0', logits)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
|
||||
# ++ --> retry last free generation (only for +gen / +i)
|
||||
elif msg.lower() == '++':
|
||||
try:
|
||||
logits = load_all_stat('gen_0')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
thread = "gen_1"
|
||||
|
||||
else:
|
||||
# + --> alternate chat reply
|
||||
if msg.lower() == '+':
|
||||
try:
|
||||
logits = load_all_stat('chat_pre')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
# chat with bot
|
||||
else:
|
||||
logits = load_all_stat('chat')
|
||||
new = f"{user}{interface} {msg}\n\n{bot}{interface}"
|
||||
# print(f'### add ###\n[{new}]')
|
||||
logits = run_rnn(tokenizer.encode(new).ids)
|
||||
save_all_stat('chat_pre', logits)
|
||||
|
||||
thread = 'chat'
|
||||
|
||||
# Print bot response
|
||||
print(f"> {bot}{interface}", end='')
|
||||
|
||||
decoded = ''
|
||||
begin = len(model_tokens)
|
||||
out_last = begin
|
||||
|
||||
for i in range(max_tokens_per_generation):
|
||||
for i in range(FREE_GEN_LEN):
|
||||
token = sampling.sample_logits(logits, temperature, top_p)
|
||||
|
||||
decoded = tokenizer.decode([token])
|
||||
|
||||
logits = run_rnn([token])
|
||||
decoded = tokenizer.decode(model_tokens[out_last:])
|
||||
if '\ufffd' not in decoded: # avoid utf-8 display issues
|
||||
print(decoded, end='', flush=True)
|
||||
out_last = begin + i + 1
|
||||
|
||||
if '\n' in decoded:
|
||||
if thread == 'chat':
|
||||
send_msg = tokenizer.decode(model_tokens[begin:])
|
||||
if '\n\n' in send_msg:
|
||||
send_msg = send_msg.strip()
|
||||
break
|
||||
|
||||
logits, state = model.eval(token, state, state, logits)
|
||||
|
||||
if '\n' not in decoded:
|
||||
if i == FREE_GEN_LEN - 1:
|
||||
print()
|
||||
|
||||
save_all_stat(thread, logits)
|
||||
|
|
Loading…
Reference in New Issue