From c736ef5411606b529d3a74c139ee111ef1a28bb9 Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 22 Apr 2023 20:33:58 +0500 Subject: [PATCH] Improve chat_with_bot.py script (#39) --- rwkv/chat_with_bot.py | 280 +++++++++++++++++++++++------------------- 1 file changed, 151 insertions(+), 129 deletions(-) diff --git a/rwkv/chat_with_bot.py b/rwkv/chat_with_bot.py index 4b4673f..03fb6da 100644 --- a/rwkv/chat_with_bot.py +++ b/rwkv/chat_with_bot.py @@ -1,11 +1,12 @@ # Provides terminal-based chat interface for RWKV model. +# Usage: python chat_with_bot.py C:\rwkv.cpp-169M.bin +# Prompts and code adapted from https://github.com/BlinkDL/ChatRWKV/blob/9ca4cdba90efaee25cfec21a0bae72cbd48d8acd/chat.py import os -import sys import argparse import pathlib import copy -from typing import List +import torch import sampling import tokenizers import rwkv_cpp_model @@ -13,89 +14,97 @@ import rwkv_cpp_shared_library # ======================================== Script settings ======================================== -# Copied from https://github.com/BlinkDL/ChatRWKV/blob/9ca4cdba90efaee25cfec21a0bae72cbd48d8acd/chat.py#L92-L178 -CHAT_LANG = 'English' # English // Chinese +# English, Chinese +LANGUAGE: str = 'English' -QA_PROMPT = False # True: Q & A prompt // False: chat prompt (need large model) +# True: Q&A prompt +# False: chat prompt (you need a large model for adequate quality, 7B+) +QA_PROMPT: bool = False -if CHAT_LANG == 'English': - interface = ':' +MAX_GENERATION_LENGTH: int = 250 + +# Sampling temperature. It could be a good idea to increase temperature when top_p is low. +TEMPERATURE: float = 0.8 +# For better Q&A accuracy and less diversity, reduce top_p (to 0.5, 0.2, 0.1 etc.) +TOP_P: float = 0.5 + +if LANGUAGE == 'English': + separator: str = ':' if QA_PROMPT: - user = "User" - bot = "Bot" # Or: 'The following is a verbose and detailed Q & A conversation of factual information.' - init_prompt = f''' -The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. + user: str = 'User' + bot: str = 'Bot' + init_prompt: str = f''' +The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and \ +polite. -{user}{interface} french revolution what year +{user}{separator} french revolution what year -{bot}{interface} The French Revolution started in 1789, and lasted 10 years until 1799. +{bot}{separator} The French Revolution started in 1789, and lasted 10 years until 1799. -{user}{interface} 3+5=? +{user}{separator} 3+5=? -{bot}{interface} The answer is 8. +{bot}{separator} The answer is 8. -{user}{interface} guess i marry who ? +{user}{separator} guess i marry who ? -{bot}{interface} Only if you tell me more about yourself - what are your interests? +{bot}{separator} Only if you tell me more about yourself - what are your interests? -{user}{interface} solve for a: 9-a=2 +{user}{separator} solve for a: 9-a=2 -{bot}{interface} The answer is a = 7, because 9 - 7 = 2. +{bot}{separator} The answer is a = 7, because 9 - 7 = 2. -{user}{interface} what is lhc +{user}{separator} what is lhc -{bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. +{bot}{separator} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. -''' +''' else: - user = "Bob" - bot = "Alice" - init_prompt = f''' + user: str = 'Bob' + bot: str = 'Alice' + init_prompt: str = f''' The following is a verbose detailed conversation between {user} and a young girl {bot}. {bot} is intelligent, friendly and cute. {bot} is likely to agree with {user}. -{user}{interface} Hello {bot}, how are you doing? +{user}{separator} Hello {bot}, how are you doing? -{bot}{interface} Hi {user}! Thanks, I'm fine. What about you? +{bot}{separator} Hi {user}! Thanks, I'm fine. What about you? -{user}{interface} I am very good! It's nice to see you. Would you mind me chatting with you for a while? +{user}{separator} I am very good! It's nice to see you. Would you mind me chatting with you for a while? -{bot}{interface} Not at all! I'm listening. +{bot}{separator} Not at all! I'm listening. ''' -elif CHAT_LANG == 'Chinese': - interface = ":" +elif LANGUAGE == 'Chinese': + separator: str = ':' + if QA_PROMPT: - user = "Q" - bot = "A" - init_prompt = f''' + user: str = 'Q' + bot: str = 'A' + init_prompt: str = f''' Expert Questions & Helpful Answers Ask Research Experts ''' else: - user = "Bob" - bot = "Alice" - init_prompt = f''' -The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. + user: str = 'Bob' + bot: str = 'Alice' + init_prompt: str = f''' +The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and \ +polite. -{user}{interface} what is lhc +{user}{separator} what is lhc -{bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. +{bot}{separator} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. -{user}{interface} 企鹅会飞吗 +{user}{separator} 企鹅会飞吗 -{bot}{interface} 企鹅是不会飞的。它们的翅膀主要用于游泳和平衡,而不是飞行。 +{bot}{separator} 企鹅是不会飞的。它们的翅膀主要用于游泳和平衡,而不是飞行。 ''' - -FREE_GEN_LEN: int = 100 - -# Sampling settings. -GEN_TEMP: float = 0.8 # It could be a good idea to increase temp when top_p is low -GEN_TOP_P: float = 0.5 # Reduce top_p (to 0.5, 0.2, 0.1 etc.) for better Q&A accuracy (and less diversity) +else: + assert False, f'Invalid language {LANGUAGE}' # ================================================================================================= @@ -117,79 +126,87 @@ model = rwkv_cpp_model.RWKVModel(library, args.model_path) prompt_tokens = tokenizer.encode(init_prompt).ids prompt_token_count = len(prompt_tokens) -print(f'Processing {prompt_token_count} prompt tokens, may take a while') - ######################################################################################################## -def run_rnn(tokens: List[int]): +model_tokens: list[int] = [] + +logits, model_state = None, None + +def process_tokens(_tokens: list[int]) -> torch.Tensor: global model_tokens, model_state, logits - tokens = [int(x) for x in tokens] - model_tokens += tokens + _tokens = [int(x) for x in _tokens] + + model_tokens += _tokens + + for _token in _tokens: + logits, model_state = model.eval(_token, model_state, model_state, logits) - for token in tokens: - logits, model_state = model.eval(token, model_state, model_state, logits) - return logits -all_state = {} +state_by_thread: dict[str, dict] = {} -def save_all_stat(thread: str, last_out): - n = f'{thread}' - all_state[n] = {} - all_state[n]['logits'] = copy.deepcopy(last_out) - all_state[n]['rnn'] = copy.deepcopy(model_state) - all_state[n]['token'] = copy.deepcopy(model_tokens) +def save_thread_state(_thread: str, _logits: torch.Tensor) -> None: + state_by_thread[_thread] = {} + state_by_thread[_thread]['logits'] = copy.deepcopy(_logits) + state_by_thread[_thread]['rnn'] = copy.deepcopy(model_state) + state_by_thread[_thread]['token'] = copy.deepcopy(model_tokens) -def load_all_stat(thread: str): +def load_thread_state(_thread: str) -> torch.Tensor: global model_tokens, model_state - n = f'{thread}' - model_state = copy.deepcopy(all_state[n]['rnn']) - model_tokens = copy.deepcopy(all_state[n]['token']) - return copy.deepcopy(all_state[n]['logits']) + model_state = copy.deepcopy(state_by_thread[_thread]['rnn']) + model_tokens = copy.deepcopy(state_by_thread[_thread]['token']) + return copy.deepcopy(state_by_thread[_thread]['logits']) ######################################################################################################## -model_tokens = [] -logits, model_state = None, None +print(f'Processing {prompt_token_count} prompt tokens, may take a while') for token in prompt_tokens: logits, model_state = model.eval(token, model_state, model_state, logits) + model_tokens.append(token) -save_all_stat('chat_init', logits) -print('\nChat initialized! Write something and press Enter.') -save_all_stat('chat', logits) +save_thread_state('chat_init', logits) +save_thread_state('chat', logits) + +print(f'\nChat initialized! Your name is {user}. Write something and press Enter. Use \\n to add line breaks to your message.') while True: # Read user input - user_input = input(f'> {user}{interface} ') - msg = user_input.replace('\\n','\n').strip() + user_input = input(f'> {user}{separator} ') + msg = user_input.replace('\\n', '\n').strip() + + temperature = TEMPERATURE + top_p = TOP_P + + if "-temp=" in msg: + temperature = float(msg.split('-temp=')[1].split(' ')[0]) + + msg = msg.replace('-temp='+f'{temperature:g}', '') + + if temperature <= 0.2: + temperature = 0.2 + + if temperature >= 5: + temperature = 5 + + if "-top_p=" in msg: + top_p = float(msg.split('-top_p=')[1].split(' ')[0]) + + msg = msg.replace('-top_p='+f'{top_p:g}', '') + + if top_p <= 0: + top_p = 0 - temperature = GEN_TEMP - top_p = GEN_TOP_P - if ("-temp=" in msg): - temperature = float(msg.split("-temp=")[1].split(" ")[0]) - msg = msg.replace("-temp="+f'{temperature:g}', "") - # print(f"temp: {temperature}") - if ("-top_p=" in msg): - top_p = float(msg.split("-top_p=")[1].split(" ")[0]) - msg = msg.replace("-top_p="+f'{top_p:g}', "") - # print(f"top_p: {top_p}") - if temperature <= 0.2: - temperature = 0.2 - if temperature >= 5: - temperature = 5 - if top_p <= 0: - top_p = 0 msg = msg.strip() # + reset --> reset chat if msg == '+reset': - logits = load_all_stat('chat_init') - save_all_stat('chat', logits) - print(f'{bot}{interface} "Chat reset."\n') + logits = load_thread_state('chat_init') + save_thread_state('chat', logits) + print(f'{bot}{separator} Chat reset.\n') continue elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++': @@ -199,8 +216,8 @@ while True: # print(f'### prompt ###\n[{new}]') model_state = None model_tokens = [] - logits = run_rnn(tokenizer.encode(new).ids) - save_all_stat('gen_0', logits) + logits = process_tokens(tokenizer.encode(new).ids) + save_thread_state('gen_0', logits) # +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model. elif msg[:3].lower() == '+i ': @@ -215,8 +232,8 @@ Below is an instruction that describes a task. Write a response that appropriate # print(f'### prompt ###\n[{new}]') model_state = None model_tokens = [] - logits = run_rnn(tokenizer.encode(new).ids) - save_all_stat('gen_0', logits) + logits = process_tokens(tokenizer.encode(new).ids) + save_thread_state('gen_0', logits) # +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context). elif msg[:4].lower() == '+qq ': @@ -224,25 +241,25 @@ Below is an instruction that describes a task. Write a response that appropriate # print(f'### prompt ###\n[{new}]') model_state = None model_tokens = [] - logits = run_rnn(tokenizer.encode(new).ids) - save_all_stat('gen_0', logits) + logits = process_tokens(tokenizer.encode(new).ids) + save_thread_state('gen_0', logits) # +qa YOUR QUESTION --> answer an independent question (regardless of context). elif msg[:4].lower() == '+qa ': - logits = load_all_stat('chat_init') + logits = load_thread_state('chat_init') real_msg = msg[4:].strip() - new = f"{user}{interface} {real_msg}\n\n{bot}{interface}" + new = f"{user}{separator} {real_msg}\n\n{bot}{separator}" # print(f'### qa ###\n[{new}]') - - logits = run_rnn(tokenizer.encode(new).ids) - save_all_stat('gen_0', logits) + + logits = process_tokens(tokenizer.encode(new).ids) + save_thread_state('gen_0', logits) # +++ --> continue last free generation (only for +gen / +i) elif msg.lower() == '+++': try: - logits = load_all_stat('gen_1') - save_all_stat('gen_0', logits) + logits = load_thread_state('gen_1') + save_thread_state('gen_0', logits) except Exception as e: print(e) continue @@ -250,7 +267,7 @@ Below is an instruction that describes a task. Write a response that appropriate # ++ --> retry last free generation (only for +gen / +i) elif msg.lower() == '++': try: - logits = load_all_stat('gen_0') + logits = load_thread_state('gen_0') except Exception as e: print(e) continue @@ -260,41 +277,46 @@ Below is an instruction that describes a task. Write a response that appropriate # + --> alternate chat reply if msg.lower() == '+': try: - logits = load_all_stat('chat_pre') + logits = load_thread_state('chat_pre') except Exception as e: print(e) continue # chat with bot else: - logits = load_all_stat('chat') - new = f"{user}{interface} {msg}\n\n{bot}{interface}" + logits = load_thread_state('chat') + new = f"{user}{separator} {msg}\n\n{bot}{separator}" # print(f'### add ###\n[{new}]') - logits = run_rnn(tokenizer.encode(new).ids) - save_all_stat('chat_pre', logits) - + logits = process_tokens(tokenizer.encode(new).ids) + save_thread_state('chat_pre', logits) + thread = 'chat' # Print bot response - print(f"> {bot}{interface}", end='') + print(f"> {bot}{separator}", end='') - decoded = '' - begin = len(model_tokens) - out_last = begin + start_index: int = len(model_tokens) + accumulated_tokens: list[int] = [] - for i in range(FREE_GEN_LEN): - token = sampling.sample_logits(logits, temperature, top_p) - logits = run_rnn([token]) - decoded = tokenizer.decode(model_tokens[out_last:]) - if '\ufffd' not in decoded: # avoid utf-8 display issues + for i in range(MAX_GENERATION_LENGTH): + token: int = sampling.sample_logits(logits, temperature, top_p) + + logits: torch.Tensor = process_tokens([token]) + + # Avoid UTF-8 display issues + accumulated_tokens += [token] + + decoded: str = tokenizer.decode(accumulated_tokens) + + if '\uFFFD' not in decoded: print(decoded, end='', flush=True) - out_last = begin + i + 1 + + accumulated_tokens = [] if thread == 'chat': - send_msg = tokenizer.decode(model_tokens[begin:]) - if '\n\n' in send_msg: - send_msg = send_msg.strip() + if '\n\n' in tokenizer.decode(model_tokens[start_index:]): break - if i == FREE_GEN_LEN - 1: + + if i == MAX_GENERATION_LENGTH - 1: print() - save_all_stat(thread, logits) + save_thread_state(thread, logits)