flask server added
This commit is contained in:
parent
70e5f07d5f
commit
e9ccfc44fd
20
README.md
20
README.md
|
@ -7,8 +7,28 @@ This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [gger
|
|||
|
||||
As the AI research is evolving at a rapid pace, I, edbrz9, made a "snapshot" of [saharNooby](https://github.com/saharNooby/rwkv.cpp/tree/master)'s original repo. This snapshot was made on June 12th 2023.
|
||||
|
||||
I also added a flask server with some hardcoded values. This is still a work in progress.
|
||||
|
||||
So far I've tried it with the 1b5 model quantized in Q8_0.
|
||||
|
||||
https://file.brz9.dev/model/rwkv.cpp-eng98o2-1B5-v12-Q8_0.bin
|
||||
|
||||
You can run the flask server with the following command:
|
||||
|
||||
```
|
||||
$ python rwkv/flask_server.py --model ../model/rwkv.cpp-eng98o2-1B5-v12-Q8_0.bin --port 5349
|
||||
```
|
||||
|
||||
Then, you can test the API with the following command:
|
||||
|
||||
```
|
||||
curl -X POST -H "Content-Type: application/json" -d '{"prompt":"Write a hello world program in python", "temperature":0.8, "top_p":0.2, "max_length":250}' http://127.0.0.1:5349/chat
|
||||
```
|
||||
|
||||
Below is the rest of the original README file.
|
||||
|
||||
---
|
||||
|
||||
Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT8** inference. This project is **CPU only**.
|
||||
|
||||
This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](rwkv%2Frwkv_cpp_model.py) for it.
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
from flask import Flask, request, jsonify
|
||||
import argparse
|
||||
import sampling
|
||||
import rwkv_cpp_model
|
||||
import rwkv_cpp_shared_library
|
||||
from rwkv_tokenizer import get_tokenizer
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", required=True, help="Path to the model file.")
|
||||
parser.add_argument("--port", type=int, default=5000, help="Port to run server on.")
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = get_tokenizer('20B')[0]
|
||||
|
||||
print("Loading model")
|
||||
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||
model = rwkv_cpp_model.RWKVModel(library, args.model)
|
||||
|
||||
SYSTEM = "\nThe following is a verbose and detailed conversation between an AI assistant called Bot, and a human user called User. Bot is intelligent, knowledgeable, wise and polite.\n\nUser: french revolution what year\n\nBot: The French Revolution started in 1789, and lasted 10 years until 1799.\n\nUser: 3+5=?\n\nBot: The answer is 8.\n\nUser: guess i marry who ?\n\nBot: Only if you tell me more about yourself - what are your interests?\n\nUser: solve for a: 9-a=2\n\nBot: The answer is a = 7, because 9 - 7 = 2.\n\nUser: wat is lhc\n\nBot: LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.\n\n"
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
def complete(prompt, temperature=0.8, top_p=0.5, max_length=100):
|
||||
assert prompt != '', 'Prompt must not be empty'
|
||||
|
||||
prompt_tokens = tokenizer.encode(prompt).ids
|
||||
|
||||
prompt_token_count = len(prompt_tokens)
|
||||
init_logits, init_state = None, None
|
||||
|
||||
for token in prompt_tokens:
|
||||
init_logits, init_state = model.eval(token, init_state, init_state, init_logits)
|
||||
|
||||
logits, state = init_logits.clone(), init_state.clone()
|
||||
completion = []
|
||||
for i in range(max_length):
|
||||
token = sampling.sample_logits(logits, temperature, top_p)
|
||||
completion.append(tokenizer.decode([token]))
|
||||
logits, state = model.eval(token, state, state, logits)
|
||||
|
||||
return ''.join(completion)
|
||||
|
||||
def complete_chat(prompt, temperature=0.8, top_p=0.5, max_length=100):
|
||||
assert prompt != '', 'Prompt must not be empty'
|
||||
|
||||
prompt_tokens = tokenizer.encode(prompt).ids
|
||||
init_logits, init_state = None, None
|
||||
|
||||
for token in prompt_tokens:
|
||||
init_logits, init_state = model.eval(token, init_state, init_state, init_logits)
|
||||
|
||||
logits, state = init_logits.clone(), init_state.clone()
|
||||
completion = []
|
||||
|
||||
user_next = tokenizer.encode("User").ids
|
||||
user_sequence = []
|
||||
|
||||
for i in range(max_length):
|
||||
token = sampling.sample_logits(logits, temperature, top_p)
|
||||
|
||||
user_sequence.append(token)
|
||||
if len(user_sequence) > len(user_next):
|
||||
user_sequence.pop(0)
|
||||
|
||||
if user_sequence == user_next:
|
||||
break
|
||||
|
||||
completion.append(tokenizer.decode([token]))
|
||||
logits, state = model.eval(token, state, state, logits)
|
||||
|
||||
response = ''.join(completion).rstrip()
|
||||
|
||||
return response
|
||||
|
||||
@app.route("/generate", methods=["POST"])
|
||||
def generate():
|
||||
print("request received")
|
||||
data = request.json
|
||||
prompt = data.get('prompt')
|
||||
temperature = data.get('temperature', 0.8)
|
||||
top_p = data.get('top_p', 0.5)
|
||||
max_length = data.get('max_length', 100)
|
||||
|
||||
output = complete(prompt, temperature, top_p, max_length)
|
||||
print("output:", output)
|
||||
return jsonify({'generated_text': output})
|
||||
|
||||
@app.route("/chat", methods=["POST"])
|
||||
def chat():
|
||||
data = request.json
|
||||
system = data.get('system')
|
||||
if system:
|
||||
system_prompt = system
|
||||
else:
|
||||
system_prompt = SYSTEM
|
||||
prompt = system_prompt + "User: " + data.get('prompt') + '\n\n' + "Bot:"
|
||||
temperature = data.get('temperature', 0.8)
|
||||
top_p = data.get('top_p', 0.5)
|
||||
max_length = data.get('max_length', 100)
|
||||
output = complete_chat(prompt, temperature, top_p, max_length)
|
||||
output = output.lstrip(" ")
|
||||
print("output:", output)
|
||||
return jsonify({'bot_response': output})
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=args.port)
|
Loading…
Reference in New Issue