Add text generation and chat scripts
This commit is contained in:
parent
ee46ad208e
commit
e0684e8104
51
README.md
51
README.md
|
@ -2,29 +2,23 @@
|
|||
|
||||
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml).
|
||||
|
||||
Besides usual **FP32**, it supports **FP16** and **quantized INT4** inference on CPU. This project is **CPU only**.
|
||||
Besides the usual **FP32**, it supports **FP16** and **quantized INT4** inference on CPU. This project is **CPU only**.
|
||||
|
||||
**WORK IN PROGRESS!** **Status**: INT4 gives not so good quality, need to properly measure and compare perplexity.
|
||||
RWKV is a novel large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts.
|
||||
|
||||
## Plan
|
||||
**TODO**:
|
||||
|
||||
1. Create Python script with sampling and simple chat interface
|
||||
2. Measure performance and quality of different model sizes and data types
|
||||
3. Write a good `README.md` and publish links to this repo
|
||||
4. Create pull request to main `ggml` repo with all improvements made here
|
||||
|
||||
## Structure
|
||||
|
||||
- `./rwkv.h`, `./rwkv.cpp`: source code of the shared library.
|
||||
- `./rwkv`: directory containing Python scripts for conversion, inference and validation.
|
||||
1. Measure performance and perplexity of different model sizes and data types
|
||||
2. Write a good `README.md` (motivation, benchmarks, perplexity) and publish links to this repo
|
||||
3. Create pull request to main `ggml` repo with all improvements made here
|
||||
|
||||
## How to use
|
||||
|
||||
### 1. Clone the repo and build the library
|
||||
|
||||
### Windows
|
||||
|
||||
Requirements: [git](https://gitforwindows.org/), [CMake](https://cmake.org/download/), MSVC compiler, Python 3.x with PyTorch.
|
||||
|
||||
#### 1. Clone the repo and build it:
|
||||
**Requirements**: [git](https://gitforwindows.org/), [CMake](https://cmake.org/download/), MSVC compiler.
|
||||
|
||||
```commandline
|
||||
git clone https://github.com/saharNooby/rwkv.cpp.git
|
||||
|
@ -35,16 +29,37 @@ cmake --build . --config Release
|
|||
|
||||
If everything went OK, `bin\Release\rwkv.dll` file should appear.
|
||||
|
||||
#### 2. Download an RWKV model from [Huggingface](https://huggingface.co/BlinkDL) and convert it into `ggml` format:
|
||||
### 2. Download an RWKV model from [Hugging Face](https://huggingface.co/BlinkDL) and convert it into `ggml` format
|
||||
|
||||
**Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/).
|
||||
|
||||
```commandline
|
||||
python rwkv\convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32
|
||||
```
|
||||
|
||||
#### 3. Use the model in Python:
|
||||
### 3. Run the model
|
||||
|
||||
**Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/) and [tokenizers](https://pypi.org/project/tokenizers/).
|
||||
|
||||
To generate some text, run:
|
||||
|
||||
```commandline
|
||||
python rwkv\generate_completions.py C:\rwkv.cpp-169M.bin
|
||||
```
|
||||
|
||||
To chat with a bot, run:
|
||||
|
||||
```commandline
|
||||
python rwkv\chat_with_bot.py C:\rwkv.cpp-169M.bin
|
||||
```
|
||||
|
||||
Edit [generate_completions.py](rwkv%2Fgenerate_completions.py) or [chat_with_bot.py](rwkv%2Fchat_with_bot.py) to change prompts and sampling settings.
|
||||
|
||||
---
|
||||
|
||||
Example of using `rwkv.cpp` in your custom Python script:
|
||||
|
||||
```python
|
||||
# These files are located in rwkv directory
|
||||
import rwkv_cpp_model
|
||||
import rwkv_cpp_shared_library
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,87 @@
|
|||
# Provides terminal-based chat interface for RWKV model.
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import sampling
|
||||
import tokenizers
|
||||
import rwkv_cpp_model
|
||||
import rwkv_cpp_shared_library
|
||||
|
||||
# ======================================== Script settings ========================================
|
||||
|
||||
# Copied from https://github.com/ggerganov/llama.cpp/blob/6e7801d08d81c931a5427bae46f00763e993f54a/prompts/chat-with-bob.txt
|
||||
prompt: str = """Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
|
||||
|
||||
User: Hello, Bob.
|
||||
Bob: Hello. How may I help you today?
|
||||
User: Please tell me the largest city in Europe.
|
||||
Bob: Sure. The largest city in Europe is Moscow, the capital of Russia."""
|
||||
|
||||
# No trailing space here!
|
||||
bot_message_prefix: str = 'Bob:'
|
||||
user_message_prefix: str = 'User:'
|
||||
|
||||
max_tokens_per_generation: int = 100
|
||||
|
||||
# Sampling settings.
|
||||
temperature: float = 0.8
|
||||
top_p: float = 0.5
|
||||
|
||||
# =================================================================================================
|
||||
|
||||
parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model')
|
||||
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert prompt != '', 'Prompt must not be empty'
|
||||
|
||||
print('Loading 20B tokenizer')
|
||||
tokenizer = tokenizers.Tokenizer.from_file('20B_tokenizer.json')
|
||||
|
||||
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||
print(f'System info: {library.rwkv_get_system_info_string()}')
|
||||
|
||||
print('Loading RWKV model')
|
||||
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
|
||||
|
||||
prompt_tokens = tokenizer.encode(prompt).ids
|
||||
prompt_token_count = len(prompt_tokens)
|
||||
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
|
||||
|
||||
logits, state = None, None
|
||||
|
||||
for token in prompt_tokens:
|
||||
logits, state = model.eval(token, state, state, logits)
|
||||
|
||||
print('\nChat initialized! Write something and press Enter.')
|
||||
|
||||
while True:
|
||||
# Read user input
|
||||
print('> ', end='')
|
||||
user_input = sys.stdin.readline()
|
||||
|
||||
# Process the input
|
||||
new_tokens = tokenizer.encode('\n' + user_message_prefix + ' ' + user_input + '\n' + bot_message_prefix).ids
|
||||
|
||||
for token in new_tokens:
|
||||
logits, state = model.eval(token, state, state, logits)
|
||||
|
||||
# Generate and print bot response
|
||||
print(bot_message_prefix, end='')
|
||||
|
||||
decoded = ''
|
||||
|
||||
for i in range(max_tokens_per_generation):
|
||||
token = sampling.sample_logits(logits, temperature, top_p)
|
||||
|
||||
decoded = tokenizer.decode([token])
|
||||
|
||||
print(decoded, end='')
|
||||
|
||||
if '\n' in decoded:
|
||||
break
|
||||
|
||||
logits, state = model.eval(token, state, state, logits)
|
||||
|
||||
if '\n' not in decoded:
|
||||
print()
|
|
@ -0,0 +1,68 @@
|
|||
# Generates completions from RWKV model based on a prompt.
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import sampling
|
||||
import tokenizers
|
||||
import rwkv_cpp_model
|
||||
import rwkv_cpp_shared_library
|
||||
|
||||
# ======================================== Script settings ========================================
|
||||
|
||||
prompt: str = """# rwkv.cpp
|
||||
|
||||
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml).
|
||||
|
||||
Besides usual **FP32**, it supports **FP16** and **quantized INT4** inference on CPU. This project is **CPU only**."""
|
||||
|
||||
# How many completions to generate.
|
||||
generation_count: int = 3
|
||||
# Token count per single completion.
|
||||
tokens_per_generation: int = 100
|
||||
|
||||
# Sampling settings.
|
||||
temperature: float = 0.8
|
||||
top_p: float = 0.5
|
||||
|
||||
# =================================================================================================
|
||||
|
||||
parser = argparse.ArgumentParser(description='Generate completions from RWKV model based on a prompt')
|
||||
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert prompt != '', 'Prompt must not be empty'
|
||||
|
||||
print('Loading 20B tokenizer')
|
||||
tokenizer = tokenizers.Tokenizer.from_file('20B_tokenizer.json')
|
||||
|
||||
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||
print(f'System info: {library.rwkv_get_system_info_string()}')
|
||||
|
||||
print('Loading RWKV model')
|
||||
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
|
||||
|
||||
prompt_tokens = tokenizer.encode(prompt).ids
|
||||
prompt_token_count = len(prompt_tokens)
|
||||
print(f'{prompt_token_count} tokens in prompt')
|
||||
|
||||
init_logits, init_state = None, None
|
||||
|
||||
for token in prompt_tokens:
|
||||
init_logits, init_state = model.eval(token, init_state, init_state, init_logits)
|
||||
|
||||
for GENERATION in range(generation_count):
|
||||
print(f'\n--- Generation {GENERATION} ---\n')
|
||||
print(prompt, end='[')
|
||||
start = time.time()
|
||||
|
||||
logits, state = init_logits.clone(), init_state.clone()
|
||||
|
||||
for i in range(tokens_per_generation):
|
||||
token = sampling.sample_logits(logits, temperature, top_p)
|
||||
|
||||
print(tokenizer.decode([token]), end='')
|
||||
|
||||
logits, state = model.eval(token, state, state, logits)
|
||||
|
||||
delay = time.time() - start
|
||||
print(']\n\nTook %.3f sec, %d ms per token' % (delay, delay / tokens_per_generation * 1000))
|
|
@ -0,0 +1,3 @@
|
|||
numpy
|
||||
torch
|
||||
tokenizers
|
|
@ -173,7 +173,7 @@ class RWKVSharedLibrary:
|
|||
Returns system information string.
|
||||
"""
|
||||
|
||||
return self.library.rwkv_get_system_info_string()
|
||||
return self.library.rwkv_get_system_info_string().decode('utf-8')
|
||||
|
||||
def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from typing import Dict
|
||||
from torch.nn import functional as F
|
||||
|
||||
def sample_logits(out: torch.Tensor, temperature: float = 1.0, top_p: float = 0.8, logit_bias: Dict[int, float] = None) -> int:
|
||||
probs = F.softmax(out.cpu(), dim=-1).numpy()
|
||||
|
||||
return sample_probs(probs, temperature, top_p, logit_bias)
|
||||
|
||||
def sample_probs(probs: np.ndarray, temperature: float = 1.0, top_p: float = 0.8, logit_bias: Dict[int, float] = None) -> int:
|
||||
assert 0.0 <= temperature, 'temperature'
|
||||
assert 0.0 <= top_p <= 1.0, 'top_p'
|
||||
|
||||
if top_p == 0.0:
|
||||
top_p = 1.0
|
||||
|
||||
if logit_bias is not None:
|
||||
logits = np.log(probs)
|
||||
|
||||
for token in logit_bias.keys():
|
||||
logits[token] += logit_bias[token]
|
||||
|
||||
probs = np.exp(logits) / np.sum(np.exp(logits))
|
||||
|
||||
if temperature == 0.0:
|
||||
return np.argmax(probs).item()
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_probs = np.sort(probs)[::-1]
|
||||
cumulative_probs = np.cumsum(sorted_probs)
|
||||
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||
probs[probs < cutoff] = 0
|
||||
|
||||
if temperature != 1.0:
|
||||
probs = np.power(probs, 1.0 / temperature)
|
||||
|
||||
probs = probs / np.sum(probs)
|
||||
|
||||
return np.random.choice(a=len(probs), p=probs)
|
Loading…
Reference in New Issue