flask server added

This commit is contained in:
ed barz 2023-07-18 22:27:09 +02:00
parent 70e5f07d5f
commit e9ccfc44fd
2 changed files with 126 additions and 0 deletions

View File

@ -7,8 +7,28 @@ This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [gger
As the AI research is evolving at a rapid pace, I, edbrz9, made a "snapshot" of [saharNooby](https://github.com/saharNooby/rwkv.cpp/tree/master)'s original repo. This snapshot was made on June 12th 2023.
I also added a flask server with some hardcoded values. This is still a work in progress.
So far I've tried it with the 1b5 model quantized in Q8_0.
https://file.brz9.dev/model/rwkv.cpp-eng98o2-1B5-v12-Q8_0.bin
You can run the flask server with the following command:
```
$ python rwkv/flask_server.py --model ../model/rwkv.cpp-eng98o2-1B5-v12-Q8_0.bin --port 5349
```
Then, you can test the API with the following command:
```
curl -X POST -H "Content-Type: application/json" -d '{"prompt":"Write a hello world program in python", "temperature":0.8, "top_p":0.2, "max_length":250}' http://127.0.0.1:5349/chat
```
Below is the rest of the original README file.
---
Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT8** inference. This project is **CPU only**.
This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](rwkv%2Frwkv_cpp_model.py) for it.

106
rwkv/flask_server.py Normal file
View File

@ -0,0 +1,106 @@
from flask import Flask, request, jsonify
import argparse
import sampling
import rwkv_cpp_model
import rwkv_cpp_shared_library
from rwkv_tokenizer import get_tokenizer
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, help="Path to the model file.")
parser.add_argument("--port", type=int, default=5000, help="Port to run server on.")
args = parser.parse_args()
tokenizer = get_tokenizer('20B')[0]
print("Loading model")
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
model = rwkv_cpp_model.RWKVModel(library, args.model)
SYSTEM = "\nThe following is a verbose and detailed conversation between an AI assistant called Bot, and a human user called User. Bot is intelligent, knowledgeable, wise and polite.\n\nUser: french revolution what year\n\nBot: The French Revolution started in 1789, and lasted 10 years until 1799.\n\nUser: 3+5=?\n\nBot: The answer is 8.\n\nUser: guess i marry who ?\n\nBot: Only if you tell me more about yourself - what are your interests?\n\nUser: solve for a: 9-a=2\n\nBot: The answer is a = 7, because 9 - 7 = 2.\n\nUser: wat is lhc\n\nBot: LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.\n\n"
app = Flask(__name__)
def complete(prompt, temperature=0.8, top_p=0.5, max_length=100):
assert prompt != '', 'Prompt must not be empty'
prompt_tokens = tokenizer.encode(prompt).ids
prompt_token_count = len(prompt_tokens)
init_logits, init_state = None, None
for token in prompt_tokens:
init_logits, init_state = model.eval(token, init_state, init_state, init_logits)
logits, state = init_logits.clone(), init_state.clone()
completion = []
for i in range(max_length):
token = sampling.sample_logits(logits, temperature, top_p)
completion.append(tokenizer.decode([token]))
logits, state = model.eval(token, state, state, logits)
return ''.join(completion)
def complete_chat(prompt, temperature=0.8, top_p=0.5, max_length=100):
assert prompt != '', 'Prompt must not be empty'
prompt_tokens = tokenizer.encode(prompt).ids
init_logits, init_state = None, None
for token in prompt_tokens:
init_logits, init_state = model.eval(token, init_state, init_state, init_logits)
logits, state = init_logits.clone(), init_state.clone()
completion = []
user_next = tokenizer.encode("User").ids
user_sequence = []
for i in range(max_length):
token = sampling.sample_logits(logits, temperature, top_p)
user_sequence.append(token)
if len(user_sequence) > len(user_next):
user_sequence.pop(0)
if user_sequence == user_next:
break
completion.append(tokenizer.decode([token]))
logits, state = model.eval(token, state, state, logits)
response = ''.join(completion).rstrip()
return response
@app.route("/generate", methods=["POST"])
def generate():
print("request received")
data = request.json
prompt = data.get('prompt')
temperature = data.get('temperature', 0.8)
top_p = data.get('top_p', 0.5)
max_length = data.get('max_length', 100)
output = complete(prompt, temperature, top_p, max_length)
print("output:", output)
return jsonify({'generated_text': output})
@app.route("/chat", methods=["POST"])
def chat():
data = request.json
system = data.get('system')
if system:
system_prompt = system
else:
system_prompt = SYSTEM
prompt = system_prompt + "User: " + data.get('prompt') + '\n\n' + "Bot:"
temperature = data.get('temperature', 0.8)
top_p = data.get('top_p', 0.5)
max_length = data.get('max_length', 100)
output = complete_chat(prompt, temperature, top_p, max_length)
output = output.lstrip(" ")
print("output:", output)
return jsonify({'bot_response': output})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=args.port)