from flask import Flask, request, jsonify import argparse import sampling import rwkv_cpp_model import rwkv_cpp_shared_library from rwkv_tokenizer import get_tokenizer parser = argparse.ArgumentParser() parser.add_argument("--model", required=True, help="Path to the model file.") parser.add_argument("--port", type=int, default=5000, help="Port to run server on.") args = parser.parse_args() tokenizer = get_tokenizer('20B')[0] print("Loading model") library = rwkv_cpp_shared_library.load_rwkv_shared_library() model = rwkv_cpp_model.RWKVModel(library, args.model) SYSTEM = "\nThe following is a verbose and detailed conversation between an AI assistant called Bot, and a human user called User. Bot is intelligent, knowledgeable, wise and polite.\n\nUser: french revolution what year\n\nBot: The French Revolution started in 1789, and lasted 10 years until 1799.\n\nUser: 3+5=?\n\nBot: The answer is 8.\n\nUser: guess i marry who ?\n\nBot: Only if you tell me more about yourself - what are your interests?\n\nUser: solve for a: 9-a=2\n\nBot: The answer is a = 7, because 9 - 7 = 2.\n\nUser: wat is lhc\n\nBot: LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.\n\n" app = Flask(__name__) def complete(prompt, temperature=0.8, top_p=0.5, max_length=100): assert prompt != '', 'Prompt must not be empty' prompt_tokens = tokenizer.encode(prompt).ids prompt_token_count = len(prompt_tokens) init_logits, init_state = None, None for token in prompt_tokens: init_logits, init_state = model.eval(token, init_state, init_state, init_logits) logits, state = init_logits.clone(), init_state.clone() completion = [] for i in range(max_length): token = sampling.sample_logits(logits, temperature, top_p) completion.append(tokenizer.decode([token])) logits, state = model.eval(token, state, state, logits) return ''.join(completion) def complete_chat(prompt, temperature=0.8, top_p=0.5, max_length=100): assert prompt != '', 'Prompt must not be empty' prompt_tokens = tokenizer.encode(prompt).ids init_logits, init_state = None, None for token in prompt_tokens: init_logits, init_state = model.eval(token, init_state, init_state, init_logits) logits, state = init_logits.clone(), init_state.clone() completion = [] user_next = tokenizer.encode("User").ids user_sequence = [] for i in range(max_length): token = sampling.sample_logits(logits, temperature, top_p) user_sequence.append(token) if len(user_sequence) > len(user_next): user_sequence.pop(0) if user_sequence == user_next: break completion.append(tokenizer.decode([token])) logits, state = model.eval(token, state, state, logits) response = ''.join(completion).rstrip() return response @app.route("/generate", methods=["POST"]) def generate(): print("request received") data = request.json prompt = data.get('prompt') temperature = data.get('temperature', 0.8) top_p = data.get('top_p', 0.5) max_length = data.get('max_length', 100) output = complete(prompt, temperature, top_p, max_length) print("output:", output) return jsonify({'generated_text': output}) @app.route("/chat", methods=["POST"]) def chat(): data = request.json system = data.get('system') if system: system_prompt = system else: system_prompt = SYSTEM prompt = system_prompt + "User: " + data.get('prompt') + '\n\n' + "Bot:" temperature = data.get('temperature', 0.8) top_p = data.get('top_p', 0.5) max_length = data.get('max_length', 100) output = complete_chat(prompt, temperature, top_p, max_length) output = output.lstrip(" ") print("output:", output) return jsonify({'bot_response': output}) if __name__ == "__main__": app.run(host="0.0.0.0", port=args.port)