Add text generation and chat scripts

This commit is contained in:
saharNooby 2023-04-02 15:03:31 +04:00
parent ee46ad208e
commit e0684e8104
7 changed files with 100757 additions and 19 deletions

View File

@ -2,29 +2,23 @@
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml).
Besides usual **FP32**, it supports **FP16** and **quantized INT4** inference on CPU. This project is **CPU only**.
Besides the usual **FP32**, it supports **FP16** and **quantized INT4** inference on CPU. This project is **CPU only**.
**WORK IN PROGRESS!** **Status**: INT4 gives not so good quality, need to properly measure and compare perplexity.
RWKV is a novel large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts.
## Plan
**TODO**:
1. Create Python script with sampling and simple chat interface
2. Measure performance and quality of different model sizes and data types
3. Write a good `README.md` and publish links to this repo
4. Create pull request to main `ggml` repo with all improvements made here
## Structure
- `./rwkv.h`, `./rwkv.cpp`: source code of the shared library.
- `./rwkv`: directory containing Python scripts for conversion, inference and validation.
1. Measure performance and perplexity of different model sizes and data types
2. Write a good `README.md` (motivation, benchmarks, perplexity) and publish links to this repo
3. Create pull request to main `ggml` repo with all improvements made here
## How to use
### 1. Clone the repo and build the library
### Windows
Requirements: [git](https://gitforwindows.org/), [CMake](https://cmake.org/download/), MSVC compiler, Python 3.x with PyTorch.
#### 1. Clone the repo and build it:
**Requirements**: [git](https://gitforwindows.org/), [CMake](https://cmake.org/download/), MSVC compiler.
```commandline
git clone https://github.com/saharNooby/rwkv.cpp.git
@ -35,16 +29,37 @@ cmake --build . --config Release
If everything went OK, `bin\Release\rwkv.dll` file should appear.
#### 2. Download an RWKV model from [Huggingface](https://huggingface.co/BlinkDL) and convert it into `ggml` format:
### 2. Download an RWKV model from [Hugging Face](https://huggingface.co/BlinkDL) and convert it into `ggml` format
**Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/).
```commandline
python rwkv\convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32
```
#### 3. Use the model in Python:
### 3. Run the model
**Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/) and [tokenizers](https://pypi.org/project/tokenizers/).
To generate some text, run:
```commandline
python rwkv\generate_completions.py C:\rwkv.cpp-169M.bin
```
To chat with a bot, run:
```commandline
python rwkv\chat_with_bot.py C:\rwkv.cpp-169M.bin
```
Edit [generate_completions.py](rwkv%2Fgenerate_completions.py) or [chat_with_bot.py](rwkv%2Fchat_with_bot.py) to change prompts and sampling settings.
---
Example of using `rwkv.cpp` in your custom Python script:
```python
# These files are located in rwkv directory
import rwkv_cpp_model
import rwkv_cpp_shared_library

100525
rwkv/20B_tokenizer.json Normal file

File diff suppressed because it is too large Load Diff

87
rwkv/chat_with_bot.py Normal file
View File

@ -0,0 +1,87 @@
# Provides terminal-based chat interface for RWKV model.
import sys
import argparse
import sampling
import tokenizers
import rwkv_cpp_model
import rwkv_cpp_shared_library
# ======================================== Script settings ========================================
# Copied from https://github.com/ggerganov/llama.cpp/blob/6e7801d08d81c931a5427bae46f00763e993f54a/prompts/chat-with-bob.txt
prompt: str = """Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
User: Hello, Bob.
Bob: Hello. How may I help you today?
User: Please tell me the largest city in Europe.
Bob: Sure. The largest city in Europe is Moscow, the capital of Russia."""
# No trailing space here!
bot_message_prefix: str = 'Bob:'
user_message_prefix: str = 'User:'
max_tokens_per_generation: int = 100
# Sampling settings.
temperature: float = 0.8
top_p: float = 0.5
# =================================================================================================
parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
args = parser.parse_args()
assert prompt != '', 'Prompt must not be empty'
print('Loading 20B tokenizer')
tokenizer = tokenizers.Tokenizer.from_file('20B_tokenizer.json')
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')
print('Loading RWKV model')
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
prompt_tokens = tokenizer.encode(prompt).ids
prompt_token_count = len(prompt_tokens)
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
logits, state = None, None
for token in prompt_tokens:
logits, state = model.eval(token, state, state, logits)
print('\nChat initialized! Write something and press Enter.')
while True:
# Read user input
print('> ', end='')
user_input = sys.stdin.readline()
# Process the input
new_tokens = tokenizer.encode('\n' + user_message_prefix + ' ' + user_input + '\n' + bot_message_prefix).ids
for token in new_tokens:
logits, state = model.eval(token, state, state, logits)
# Generate and print bot response
print(bot_message_prefix, end='')
decoded = ''
for i in range(max_tokens_per_generation):
token = sampling.sample_logits(logits, temperature, top_p)
decoded = tokenizer.decode([token])
print(decoded, end='')
if '\n' in decoded:
break
logits, state = model.eval(token, state, state, logits)
if '\n' not in decoded:
print()

View File

@ -0,0 +1,68 @@
# Generates completions from RWKV model based on a prompt.
import argparse
import time
import sampling
import tokenizers
import rwkv_cpp_model
import rwkv_cpp_shared_library
# ======================================== Script settings ========================================
prompt: str = """# rwkv.cpp
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml).
Besides usual **FP32**, it supports **FP16** and **quantized INT4** inference on CPU. This project is **CPU only**."""
# How many completions to generate.
generation_count: int = 3
# Token count per single completion.
tokens_per_generation: int = 100
# Sampling settings.
temperature: float = 0.8
top_p: float = 0.5
# =================================================================================================
parser = argparse.ArgumentParser(description='Generate completions from RWKV model based on a prompt')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
args = parser.parse_args()
assert prompt != '', 'Prompt must not be empty'
print('Loading 20B tokenizer')
tokenizer = tokenizers.Tokenizer.from_file('20B_tokenizer.json')
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')
print('Loading RWKV model')
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
prompt_tokens = tokenizer.encode(prompt).ids
prompt_token_count = len(prompt_tokens)
print(f'{prompt_token_count} tokens in prompt')
init_logits, init_state = None, None
for token in prompt_tokens:
init_logits, init_state = model.eval(token, init_state, init_state, init_logits)
for GENERATION in range(generation_count):
print(f'\n--- Generation {GENERATION} ---\n')
print(prompt, end='[')
start = time.time()
logits, state = init_logits.clone(), init_state.clone()
for i in range(tokens_per_generation):
token = sampling.sample_logits(logits, temperature, top_p)
print(tokenizer.decode([token]), end='')
logits, state = model.eval(token, state, state, logits)
delay = time.time() - start
print(']\n\nTook %.3f sec, %d ms per token' % (delay, delay / tokens_per_generation * 1000))

3
rwkv/requirements.txt Normal file
View File

@ -0,0 +1,3 @@
numpy
torch
tokenizers

View File

@ -173,7 +173,7 @@ class RWKVSharedLibrary:
Returns system information string.
"""
return self.library.rwkv_get_system_info_string()
return self.library.rwkv_get_system_info_string().decode('utf-8')
def load_rwkv_shared_library() -> RWKVSharedLibrary:
"""

40
rwkv/sampling.py Normal file
View File

@ -0,0 +1,40 @@
import numpy as np
import torch
from typing import Dict
from torch.nn import functional as F
def sample_logits(out: torch.Tensor, temperature: float = 1.0, top_p: float = 0.8, logit_bias: Dict[int, float] = None) -> int:
probs = F.softmax(out.cpu(), dim=-1).numpy()
return sample_probs(probs, temperature, top_p, logit_bias)
def sample_probs(probs: np.ndarray, temperature: float = 1.0, top_p: float = 0.8, logit_bias: Dict[int, float] = None) -> int:
assert 0.0 <= temperature, 'temperature'
assert 0.0 <= top_p <= 1.0, 'top_p'
if top_p == 0.0:
top_p = 1.0
if logit_bias is not None:
logits = np.log(probs)
for token in logit_bias.keys():
logits[token] += logit_bias[token]
probs = np.exp(logits) / np.sum(np.exp(logits))
if temperature == 0.0:
return np.argmax(probs).item()
if top_p < 1.0:
sorted_probs = np.sort(probs)[::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if temperature != 1.0:
probs = np.power(probs, 1.0 / temperature)
probs = probs / np.sum(probs)
return np.random.choice(a=len(probs), p=probs)