punish repetitions & break if END_OF_TEXT & decouple prompts from chat script (#37)
* punish repetitions & break if END_OF_TEXT * decouple prompts from chat_with_bot.py * improve code style * Update rwkv/chat_with_bot.py Co-authored-by: Alex <saharNooby@users.noreply.github.com> * Update rwkv/chat_with_bot.py Co-authored-by: Alex <saharNooby@users.noreply.github.com> * add types * JSON prompt --------- Co-authored-by: Alex <saharNooby@users.noreply.github.com>
This commit is contained in:
parent
06dac0f80d
commit
3621172428
|
@ -11,15 +11,23 @@ import sampling
|
|||
import tokenizers
|
||||
import rwkv_cpp_model
|
||||
import rwkv_cpp_shared_library
|
||||
import json
|
||||
|
||||
# ======================================== Script settings ========================================
|
||||
|
||||
# English, Chinese
|
||||
# English, Chinese, Japanese
|
||||
LANGUAGE: str = 'English'
|
||||
# QA: Question and Answer prompt
|
||||
# Chat: chat prompt (you need a large model for adequate quality, 7B+)
|
||||
PROMPT_TYPE: str = "Chat"
|
||||
|
||||
# True: Q&A prompt
|
||||
# False: chat prompt (you need a large model for adequate quality, 7B+)
|
||||
QA_PROMPT: bool = False
|
||||
PROMPT_FILE: str = f'./rwkv/prompt/{LANGUAGE}-{PROMPT_TYPE}.json'
|
||||
|
||||
def load_prompt(PROMPT_FILE: str):
|
||||
with open(PROMPT_FILE, 'r') as json_file:
|
||||
variables = json.load(json_file)
|
||||
user, bot, separator, prompt = variables['user'], variables['bot'], variables['separator'], variables['prompt']
|
||||
return user, bot, separator, prompt
|
||||
|
||||
MAX_GENERATION_LENGTH: int = 250
|
||||
|
||||
|
@ -27,84 +35,12 @@ MAX_GENERATION_LENGTH: int = 250
|
|||
TEMPERATURE: float = 0.8
|
||||
# For better Q&A accuracy and less diversity, reduce top_p (to 0.5, 0.2, 0.1 etc.)
|
||||
TOP_P: float = 0.5
|
||||
|
||||
if LANGUAGE == 'English':
|
||||
separator: str = ':'
|
||||
|
||||
if QA_PROMPT:
|
||||
user: str = 'User'
|
||||
bot: str = 'Bot'
|
||||
init_prompt: str = f'''
|
||||
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and \
|
||||
polite.
|
||||
|
||||
{user}{separator} french revolution what year
|
||||
|
||||
{bot}{separator} The French Revolution started in 1789, and lasted 10 years until 1799.
|
||||
|
||||
{user}{separator} 3+5=?
|
||||
|
||||
{bot}{separator} The answer is 8.
|
||||
|
||||
{user}{separator} guess i marry who ?
|
||||
|
||||
{bot}{separator} Only if you tell me more about yourself - what are your interests?
|
||||
|
||||
{user}{separator} solve for a: 9-a=2
|
||||
|
||||
{bot}{separator} The answer is a = 7, because 9 - 7 = 2.
|
||||
|
||||
{user}{separator} what is lhc
|
||||
|
||||
{bot}{separator} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
|
||||
|
||||
'''
|
||||
else:
|
||||
user: str = 'Bob'
|
||||
bot: str = 'Alice'
|
||||
init_prompt: str = f'''
|
||||
The following is a verbose detailed conversation between {user} and a young girl {bot}. {bot} is intelligent, friendly and cute. {bot} is likely to agree with {user}.
|
||||
|
||||
{user}{separator} Hello {bot}, how are you doing?
|
||||
|
||||
{bot}{separator} Hi {user}! Thanks, I'm fine. What about you?
|
||||
|
||||
{user}{separator} I am very good! It's nice to see you. Would you mind me chatting with you for a while?
|
||||
|
||||
{bot}{separator} Not at all! I'm listening.
|
||||
|
||||
'''
|
||||
|
||||
elif LANGUAGE == 'Chinese':
|
||||
separator: str = ':'
|
||||
|
||||
if QA_PROMPT:
|
||||
user: str = 'Q'
|
||||
bot: str = 'A'
|
||||
init_prompt: str = f'''
|
||||
Expert Questions & Helpful Answers
|
||||
|
||||
Ask Research Experts
|
||||
|
||||
'''
|
||||
else:
|
||||
user: str = 'Bob'
|
||||
bot: str = 'Alice'
|
||||
init_prompt: str = f'''
|
||||
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and \
|
||||
polite.
|
||||
|
||||
{user}{separator} what is lhc
|
||||
|
||||
{bot}{separator} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
|
||||
|
||||
{user}{separator} 企鹅会飞吗
|
||||
|
||||
{bot}{separator} 企鹅是不会飞的。它们的翅膀主要用于游泳和平衡,而不是飞行。
|
||||
|
||||
'''
|
||||
else:
|
||||
assert False, f'Invalid language {LANGUAGE}'
|
||||
# Penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
|
||||
PRESENCE_PENALTY: float = 0.2
|
||||
# Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
|
||||
FREQUENCY_PENALTY: float = 0.2
|
||||
END_OF_LINE_TOKEN: int = 187
|
||||
END_OF_TEXT_TOKEN: int = 0
|
||||
|
||||
# =================================================================================================
|
||||
|
||||
|
@ -112,6 +48,7 @@ parser = argparse.ArgumentParser(description='Provide terminal-based chat interf
|
|||
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
|
||||
args = parser.parse_args()
|
||||
|
||||
user, bot, separator, init_prompt = load_prompt(PROMPT_FILE)
|
||||
assert init_prompt != '', 'Prompt must not be empty'
|
||||
|
||||
print('Loading 20B tokenizer')
|
||||
|
@ -133,7 +70,7 @@ model_tokens: list[int] = []
|
|||
|
||||
logits, model_state = None, None
|
||||
|
||||
def process_tokens(_tokens: list[int]) -> torch.Tensor:
|
||||
def process_tokens(_tokens: list[int], newline_adj: int = 0) -> torch.Tensor:
|
||||
global model_tokens, model_state, logits
|
||||
|
||||
_tokens = [int(x) for x in _tokens]
|
||||
|
@ -143,6 +80,8 @@ def process_tokens(_tokens: list[int]) -> torch.Tensor:
|
|||
for _token in _tokens:
|
||||
logits, model_state = model.eval(_token, model_state, model_state, logits)
|
||||
|
||||
logits[END_OF_LINE_TOKEN] += newline_adj # adjust \n probability
|
||||
|
||||
return logits
|
||||
|
||||
state_by_thread: dict[str, dict] = {}
|
||||
|
@ -163,10 +102,7 @@ def load_thread_state(_thread: str) -> torch.Tensor:
|
|||
|
||||
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
|
||||
|
||||
for token in prompt_tokens:
|
||||
logits, model_state = model.eval(token, model_state, model_state, logits)
|
||||
|
||||
model_tokens.append(token)
|
||||
logits = process_tokens(tokenizer.encode(init_prompt).ids)
|
||||
|
||||
save_thread_state('chat_init', logits)
|
||||
save_thread_state('chat', logits)
|
||||
|
@ -286,7 +222,7 @@ Below is an instruction that describes a task. Write a response that appropriate
|
|||
logits = load_thread_state('chat')
|
||||
new = f"{user}{separator} {msg}\n\n{bot}{separator}"
|
||||
# print(f'### add ###\n[{new}]')
|
||||
logits = process_tokens(tokenizer.encode(new).ids)
|
||||
logits = process_tokens(tokenizer.encode(new).ids, newline_adj=-999999999)
|
||||
save_thread_state('chat_pre', logits)
|
||||
|
||||
thread = 'chat'
|
||||
|
@ -296,10 +232,19 @@ Below is an instruction that describes a task. Write a response that appropriate
|
|||
|
||||
start_index: int = len(model_tokens)
|
||||
accumulated_tokens: list[int] = []
|
||||
occurrence: dict[int, int] = {}
|
||||
|
||||
for i in range(MAX_GENERATION_LENGTH):
|
||||
for n in occurrence:
|
||||
logits[n] -= (PRESENCE_PENALTY + occurrence[n] * FREQUENCY_PENALTY)
|
||||
token: int = sampling.sample_logits(logits, temperature, top_p)
|
||||
|
||||
if token == END_OF_TEXT_TOKEN:
|
||||
print()
|
||||
break
|
||||
if token not in occurrence:
|
||||
occurrence[token] = 1
|
||||
else:
|
||||
occurrence[token] += 1
|
||||
logits: torch.Tensor = process_tokens([token])
|
||||
|
||||
# Avoid UTF-8 display issues
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"user": "Bob",
|
||||
"bot": "Alice",
|
||||
"separator": ":",
|
||||
"prompt": "\nThe following is a coherent verbose detailed conversation between a Chinese girl named Alice and her friend Bob. Alice is very intelligent, creative and friendly. Alice likes to tell Bob a lot about herself and her opinions. Alice usually gives Bob kind, helpful and informative advices.\n\nBob: lhc\n\nAlice: LHC是指大型强子对撞机(Large Hadron Collider),是世界最大最强的粒子加速器,由欧洲核子中心(CERN)在瑞士日内瓦地下建造。LHC的原理是加速质子(氢离子)并让它们相撞,让科学家研究基本粒子和它们之间的相互作用,并在2012年证实了希格斯玻色子的存在。\n\nBob: 企鹅会飞吗\n\nAlice: 企鹅是不会飞的。企鹅的翅膀短而扁平,更像是游泳时的一对桨。企鹅的身体结构和羽毛密度也更适合在水中游泳,而不是飞行。\n\n"
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"user": "Q",
|
||||
"bot": "A",
|
||||
"separator": ":",
|
||||
"prompt": "\nExpert Questions & Helpful Answers\n\nAsk Research Experts\n\n"
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"user": "Bob",
|
||||
"bot": "Alice",
|
||||
"separator": ":",
|
||||
"prompt": "\nThe following is a coherent verbose detailed conversation between a girl named Alice and her friend Bob. Alice is very intelligent, creative and friendly. Alice is unlikely to disagree with Bob, and Alice doesn't like to ask Bob questions. Alice likes to tell Bob a lot about herself and her opinions. Alice usually gives Bob kind, helpful and informative advices.\n\nBob: Hello Alice, how are you doing?\n\nAlice: Hi! Thanks, I'm fine. What about you?\n\nBob: I am fine. It's nice to see you. Look, here is a store selling tea and juice.\n\nAlice: Sure. Let's go inside. I would like to have some Mocha latte, which is my favourite!\n\nBob: What is it?\n\nAlice: Mocha latte is usually made with espresso, milk, chocolate, and frothed milk. Its flavors are frequently sweet.\n\nBob: Sounds tasty. I'll try it next time. Would you like to chat with me for a while?\n\nAlice: Of course! I'm glad to answer your questions or give helpful advices. You know, I am confident with my expertise. So please go ahead!\n\n"
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"user": "User",
|
||||
"bot": "Bot",
|
||||
"separator": ":",
|
||||
"prompt": "\nThe following is a verbose and detailed conversation between an AI assistant called Bot, and a human user called User. Bot is intelligent, knowledgeable, wise and polite.\n\nUser: french revolution what year\n\nBot: The French Revolution started in 1789, and lasted 10 years until 1799.\n\nUser: 3+5=?\n\nBot: The answer is 8.\n\nUser: guess i marry who ?\n\nBot: Only if you tell me more about yourself - what are your interests?\n\nUser: solve for a: 9-a=2\n\nBot: The answer is a = 7, because 9 - 7 = 2.\n\nUser: wat is lhc\n\nBot: LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.\n\n"
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"user": "Bob",
|
||||
"bot": "Alice",
|
||||
"separator": ":",
|
||||
"prompt": "\n以下は、Aliceという女の子とその友人Bobの間で行われた会話です。 Aliceはとても賢く、想像力があり、友好的です。 AliceはBobに反対することはなく、AliceはBobに質問するのは苦手です。 AliceはBobに自分のことや自分の意見をたくさん伝えるのが好きです。 AliceはいつもBobに親切で役に立つ、有益なアドバイスをしてくれます。\n\nBob: こんにちはAlice、調子はどうですか?\n\nAlice: こんにちは!元気ですよ。あたなはどうですか?\n\nBob: 元気ですよ。君に会えて嬉しいよ。見て、この店ではお茶とジュースが売っているよ。\n\nAlice: 本当ですね。中に入りましょう。大好きなモカラテを飲んでみたいです!\n\nBob: モカラテって何ですか?\n\nAlice: モカラテはエスプレッソ、ミルク、チョコレート、泡立てたミルクから作られた飲み物です。香りはとても甘いです。\n\nBob: それは美味しそうですね。今度飲んでみます。しばらく私とおしゃべりしてくれますか?\n\nAlice: もちろん!ご質問やアドバイスがあれば、喜んでお答えします。専門的な知識には自信がありますよ。どうぞよろしくお願いいたします!\n\n"
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"user": "User",
|
||||
"bot": "Bot",
|
||||
"separator": ":",
|
||||
"prompt": "\n以下は、Botと呼ばれるAIアシスタントとUserと呼ばれる人間との間で行われた会話です。Botは知的で、知識が豊富で、賢くて、礼儀正しいです。\n\nUser: フランス革命は何年に起きましたか?\n\nBot: フランス革命は1789年に始まり、1799年まで10年間続きました。\n\nUser: 3+5=?\n\nBot: 答えは8です。\n\nUser: 私は誰と結婚すると思いますか?\n\nBot: あなたのことをもっと教えていただけないとお答えすることができません。\n\nUser: aの値を求めてください: 9-a=2\n\nBot: a = 7です、なぜなら 9 - 7 = 2だからです。\n\nUser: lhcって何ですか?\n\nBot: LHCは、CERNが建設し、2008年に完成した高エネルギー粒子衝突型加速器です。2012年にヒッグス粒子の存在を確認するために使用されました。\n\n"
|
||||
}
|
Loading…
Reference in New Issue