From e9ccfc44fdf504a25f1b92333925b547dd6086fb Mon Sep 17 00:00:00 2001 From: ed Date: Tue, 18 Jul 2023 22:27:09 +0200 Subject: [PATCH] flask server added --- README.md | 20 ++++++++ rwkv/flask_server.py | 106 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 rwkv/flask_server.py diff --git a/README.md b/README.md index e44f3d4..bc9b3fc 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,28 @@ This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [gger As the AI research is evolving at a rapid pace, I, edbrz9, made a "snapshot" of [saharNooby](https://github.com/saharNooby/rwkv.cpp/tree/master)'s original repo. This snapshot was made on June 12th 2023. +I also added a flask server with some hardcoded values. This is still a work in progress. + +So far I've tried it with the 1b5 model quantized in Q8_0. + +https://file.brz9.dev/model/rwkv.cpp-eng98o2-1B5-v12-Q8_0.bin + +You can run the flask server with the following command: + +``` +$ python rwkv/flask_server.py --model ../model/rwkv.cpp-eng98o2-1B5-v12-Q8_0.bin --port 5349 +``` + +Then, you can test the API with the following command: + +``` +curl -X POST -H "Content-Type: application/json" -d '{"prompt":"Write a hello world program in python", "temperature":0.8, "top_p":0.2, "max_length":250}' http://127.0.0.1:5349/chat +``` + Below is the rest of the original README file. +--- + Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT8** inference. This project is **CPU only**. This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](rwkv%2Frwkv_cpp_model.py) for it. diff --git a/rwkv/flask_server.py b/rwkv/flask_server.py new file mode 100644 index 0000000..5eb53d8 --- /dev/null +++ b/rwkv/flask_server.py @@ -0,0 +1,106 @@ +from flask import Flask, request, jsonify +import argparse +import sampling +import rwkv_cpp_model +import rwkv_cpp_shared_library +from rwkv_tokenizer import get_tokenizer + +parser = argparse.ArgumentParser() +parser.add_argument("--model", required=True, help="Path to the model file.") +parser.add_argument("--port", type=int, default=5000, help="Port to run server on.") +args = parser.parse_args() + +tokenizer = get_tokenizer('20B')[0] + +print("Loading model") +library = rwkv_cpp_shared_library.load_rwkv_shared_library() +model = rwkv_cpp_model.RWKVModel(library, args.model) + +SYSTEM = "\nThe following is a verbose and detailed conversation between an AI assistant called Bot, and a human user called User. Bot is intelligent, knowledgeable, wise and polite.\n\nUser: french revolution what year\n\nBot: The French Revolution started in 1789, and lasted 10 years until 1799.\n\nUser: 3+5=?\n\nBot: The answer is 8.\n\nUser: guess i marry who ?\n\nBot: Only if you tell me more about yourself - what are your interests?\n\nUser: solve for a: 9-a=2\n\nBot: The answer is a = 7, because 9 - 7 = 2.\n\nUser: wat is lhc\n\nBot: LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.\n\n" + +app = Flask(__name__) + +def complete(prompt, temperature=0.8, top_p=0.5, max_length=100): + assert prompt != '', 'Prompt must not be empty' + + prompt_tokens = tokenizer.encode(prompt).ids + + prompt_token_count = len(prompt_tokens) + init_logits, init_state = None, None + + for token in prompt_tokens: + init_logits, init_state = model.eval(token, init_state, init_state, init_logits) + + logits, state = init_logits.clone(), init_state.clone() + completion = [] + for i in range(max_length): + token = sampling.sample_logits(logits, temperature, top_p) + completion.append(tokenizer.decode([token])) + logits, state = model.eval(token, state, state, logits) + + return ''.join(completion) + +def complete_chat(prompt, temperature=0.8, top_p=0.5, max_length=100): + assert prompt != '', 'Prompt must not be empty' + + prompt_tokens = tokenizer.encode(prompt).ids + init_logits, init_state = None, None + + for token in prompt_tokens: + init_logits, init_state = model.eval(token, init_state, init_state, init_logits) + + logits, state = init_logits.clone(), init_state.clone() + completion = [] + + user_next = tokenizer.encode("User").ids + user_sequence = [] + + for i in range(max_length): + token = sampling.sample_logits(logits, temperature, top_p) + + user_sequence.append(token) + if len(user_sequence) > len(user_next): + user_sequence.pop(0) + + if user_sequence == user_next: + break + + completion.append(tokenizer.decode([token])) + logits, state = model.eval(token, state, state, logits) + + response = ''.join(completion).rstrip() + + return response + +@app.route("/generate", methods=["POST"]) +def generate(): + print("request received") + data = request.json + prompt = data.get('prompt') + temperature = data.get('temperature', 0.8) + top_p = data.get('top_p', 0.5) + max_length = data.get('max_length', 100) + + output = complete(prompt, temperature, top_p, max_length) + print("output:", output) + return jsonify({'generated_text': output}) + +@app.route("/chat", methods=["POST"]) +def chat(): + data = request.json + system = data.get('system') + if system: + system_prompt = system + else: + system_prompt = SYSTEM + prompt = system_prompt + "User: " + data.get('prompt') + '\n\n' + "Bot:" + temperature = data.get('temperature', 0.8) + top_p = data.get('top_p', 0.5) + max_length = data.get('max_length', 100) + output = complete_chat(prompt, temperature, top_p, max_length) + output = output.lstrip(" ") + print("output:", output) + return jsonify({'bot_response': output}) + +if __name__ == "__main__": + app.run(host="0.0.0.0", port=args.port)