flask server added
This commit is contained in:
parent
70e5f07d5f
commit
e9ccfc44fd
20
README.md
20
README.md
|
@ -7,8 +7,28 @@ This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [gger
|
||||||
|
|
||||||
As the AI research is evolving at a rapid pace, I, edbrz9, made a "snapshot" of [saharNooby](https://github.com/saharNooby/rwkv.cpp/tree/master)'s original repo. This snapshot was made on June 12th 2023.
|
As the AI research is evolving at a rapid pace, I, edbrz9, made a "snapshot" of [saharNooby](https://github.com/saharNooby/rwkv.cpp/tree/master)'s original repo. This snapshot was made on June 12th 2023.
|
||||||
|
|
||||||
|
I also added a flask server with some hardcoded values. This is still a work in progress.
|
||||||
|
|
||||||
|
So far I've tried it with the 1b5 model quantized in Q8_0.
|
||||||
|
|
||||||
|
https://file.brz9.dev/model/rwkv.cpp-eng98o2-1B5-v12-Q8_0.bin
|
||||||
|
|
||||||
|
You can run the flask server with the following command:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ python rwkv/flask_server.py --model ../model/rwkv.cpp-eng98o2-1B5-v12-Q8_0.bin --port 5349
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, you can test the API with the following command:
|
||||||
|
|
||||||
|
```
|
||||||
|
curl -X POST -H "Content-Type: application/json" -d '{"prompt":"Write a hello world program in python", "temperature":0.8, "top_p":0.2, "max_length":250}' http://127.0.0.1:5349/chat
|
||||||
|
```
|
||||||
|
|
||||||
Below is the rest of the original README file.
|
Below is the rest of the original README file.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT8** inference. This project is **CPU only**.
|
Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT8** inference. This project is **CPU only**.
|
||||||
|
|
||||||
This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](rwkv%2Frwkv_cpp_model.py) for it.
|
This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](rwkv%2Frwkv_cpp_model.py) for it.
|
||||||
|
|
|
@ -0,0 +1,106 @@
|
||||||
|
from flask import Flask, request, jsonify
|
||||||
|
import argparse
|
||||||
|
import sampling
|
||||||
|
import rwkv_cpp_model
|
||||||
|
import rwkv_cpp_shared_library
|
||||||
|
from rwkv_tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model", required=True, help="Path to the model file.")
|
||||||
|
parser.add_argument("--port", type=int, default=5000, help="Port to run server on.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer('20B')[0]
|
||||||
|
|
||||||
|
print("Loading model")
|
||||||
|
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||||
|
model = rwkv_cpp_model.RWKVModel(library, args.model)
|
||||||
|
|
||||||
|
SYSTEM = "\nThe following is a verbose and detailed conversation between an AI assistant called Bot, and a human user called User. Bot is intelligent, knowledgeable, wise and polite.\n\nUser: french revolution what year\n\nBot: The French Revolution started in 1789, and lasted 10 years until 1799.\n\nUser: 3+5=?\n\nBot: The answer is 8.\n\nUser: guess i marry who ?\n\nBot: Only if you tell me more about yourself - what are your interests?\n\nUser: solve for a: 9-a=2\n\nBot: The answer is a = 7, because 9 - 7 = 2.\n\nUser: wat is lhc\n\nBot: LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.\n\n"
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
def complete(prompt, temperature=0.8, top_p=0.5, max_length=100):
|
||||||
|
assert prompt != '', 'Prompt must not be empty'
|
||||||
|
|
||||||
|
prompt_tokens = tokenizer.encode(prompt).ids
|
||||||
|
|
||||||
|
prompt_token_count = len(prompt_tokens)
|
||||||
|
init_logits, init_state = None, None
|
||||||
|
|
||||||
|
for token in prompt_tokens:
|
||||||
|
init_logits, init_state = model.eval(token, init_state, init_state, init_logits)
|
||||||
|
|
||||||
|
logits, state = init_logits.clone(), init_state.clone()
|
||||||
|
completion = []
|
||||||
|
for i in range(max_length):
|
||||||
|
token = sampling.sample_logits(logits, temperature, top_p)
|
||||||
|
completion.append(tokenizer.decode([token]))
|
||||||
|
logits, state = model.eval(token, state, state, logits)
|
||||||
|
|
||||||
|
return ''.join(completion)
|
||||||
|
|
||||||
|
def complete_chat(prompt, temperature=0.8, top_p=0.5, max_length=100):
|
||||||
|
assert prompt != '', 'Prompt must not be empty'
|
||||||
|
|
||||||
|
prompt_tokens = tokenizer.encode(prompt).ids
|
||||||
|
init_logits, init_state = None, None
|
||||||
|
|
||||||
|
for token in prompt_tokens:
|
||||||
|
init_logits, init_state = model.eval(token, init_state, init_state, init_logits)
|
||||||
|
|
||||||
|
logits, state = init_logits.clone(), init_state.clone()
|
||||||
|
completion = []
|
||||||
|
|
||||||
|
user_next = tokenizer.encode("User").ids
|
||||||
|
user_sequence = []
|
||||||
|
|
||||||
|
for i in range(max_length):
|
||||||
|
token = sampling.sample_logits(logits, temperature, top_p)
|
||||||
|
|
||||||
|
user_sequence.append(token)
|
||||||
|
if len(user_sequence) > len(user_next):
|
||||||
|
user_sequence.pop(0)
|
||||||
|
|
||||||
|
if user_sequence == user_next:
|
||||||
|
break
|
||||||
|
|
||||||
|
completion.append(tokenizer.decode([token]))
|
||||||
|
logits, state = model.eval(token, state, state, logits)
|
||||||
|
|
||||||
|
response = ''.join(completion).rstrip()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@app.route("/generate", methods=["POST"])
|
||||||
|
def generate():
|
||||||
|
print("request received")
|
||||||
|
data = request.json
|
||||||
|
prompt = data.get('prompt')
|
||||||
|
temperature = data.get('temperature', 0.8)
|
||||||
|
top_p = data.get('top_p', 0.5)
|
||||||
|
max_length = data.get('max_length', 100)
|
||||||
|
|
||||||
|
output = complete(prompt, temperature, top_p, max_length)
|
||||||
|
print("output:", output)
|
||||||
|
return jsonify({'generated_text': output})
|
||||||
|
|
||||||
|
@app.route("/chat", methods=["POST"])
|
||||||
|
def chat():
|
||||||
|
data = request.json
|
||||||
|
system = data.get('system')
|
||||||
|
if system:
|
||||||
|
system_prompt = system
|
||||||
|
else:
|
||||||
|
system_prompt = SYSTEM
|
||||||
|
prompt = system_prompt + "User: " + data.get('prompt') + '\n\n' + "Bot:"
|
||||||
|
temperature = data.get('temperature', 0.8)
|
||||||
|
top_p = data.get('top_p', 0.5)
|
||||||
|
max_length = data.get('max_length', 100)
|
||||||
|
output = complete_chat(prompt, temperature, top_p, max_length)
|
||||||
|
output = output.lstrip(" ")
|
||||||
|
print("output:", output)
|
||||||
|
return jsonify({'bot_response': output})
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(host="0.0.0.0", port=args.port)
|
Loading…
Reference in New Issue