Add text generation and chat scripts

This commit is contained in:
saharNooby 2023-04-02 15:03:31 +04:00
parent ee46ad208e
commit e0684e8104
7 changed files with 100757 additions and 19 deletions

View File

@ -2,29 +2,23 @@
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml).
Besides usual **FP32**, it supports **FP16** and **quantized INT4** inference on CPU. This project is **CPU only**. Besides the usual **FP32**, it supports **FP16** and **quantized INT4** inference on CPU. This project is **CPU only**.
**WORK IN PROGRESS!** **Status**: INT4 gives not so good quality, need to properly measure and compare perplexity. RWKV is a novel large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts.
## Plan **TODO**:
1. Create Python script with sampling and simple chat interface 1. Measure performance and perplexity of different model sizes and data types
2. Measure performance and quality of different model sizes and data types 2. Write a good `README.md` (motivation, benchmarks, perplexity) and publish links to this repo
3. Write a good `README.md` and publish links to this repo 3. Create pull request to main `ggml` repo with all improvements made here
4. Create pull request to main `ggml` repo with all improvements made here
## Structure
- `./rwkv.h`, `./rwkv.cpp`: source code of the shared library.
- `./rwkv`: directory containing Python scripts for conversion, inference and validation.
## How to use ## How to use
### 1. Clone the repo and build the library
### Windows ### Windows
Requirements: [git](https://gitforwindows.org/), [CMake](https://cmake.org/download/), MSVC compiler, Python 3.x with PyTorch. **Requirements**: [git](https://gitforwindows.org/), [CMake](https://cmake.org/download/), MSVC compiler.
#### 1. Clone the repo and build it:
```commandline ```commandline
git clone https://github.com/saharNooby/rwkv.cpp.git git clone https://github.com/saharNooby/rwkv.cpp.git
@ -35,16 +29,37 @@ cmake --build . --config Release
If everything went OK, `bin\Release\rwkv.dll` file should appear. If everything went OK, `bin\Release\rwkv.dll` file should appear.
#### 2. Download an RWKV model from [Huggingface](https://huggingface.co/BlinkDL) and convert it into `ggml` format: ### 2. Download an RWKV model from [Hugging Face](https://huggingface.co/BlinkDL) and convert it into `ggml` format
**Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/).
```commandline ```commandline
python rwkv\convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32 python rwkv\convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32
``` ```
#### 3. Use the model in Python: ### 3. Run the model
**Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/) and [tokenizers](https://pypi.org/project/tokenizers/).
To generate some text, run:
```commandline
python rwkv\generate_completions.py C:\rwkv.cpp-169M.bin
```
To chat with a bot, run:
```commandline
python rwkv\chat_with_bot.py C:\rwkv.cpp-169M.bin
```
Edit [generate_completions.py](rwkv%2Fgenerate_completions.py) or [chat_with_bot.py](rwkv%2Fchat_with_bot.py) to change prompts and sampling settings.
---
Example of using `rwkv.cpp` in your custom Python script:
```python ```python
# These files are located in rwkv directory
import rwkv_cpp_model import rwkv_cpp_model
import rwkv_cpp_shared_library import rwkv_cpp_shared_library

100525
rwkv/20B_tokenizer.json Normal file

File diff suppressed because it is too large Load Diff

87
rwkv/chat_with_bot.py Normal file
View File

@ -0,0 +1,87 @@
# Provides terminal-based chat interface for RWKV model.
import sys
import argparse
import sampling
import tokenizers
import rwkv_cpp_model
import rwkv_cpp_shared_library
# ======================================== Script settings ========================================
# Copied from https://github.com/ggerganov/llama.cpp/blob/6e7801d08d81c931a5427bae46f00763e993f54a/prompts/chat-with-bob.txt
prompt: str = """Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
User: Hello, Bob.
Bob: Hello. How may I help you today?
User: Please tell me the largest city in Europe.
Bob: Sure. The largest city in Europe is Moscow, the capital of Russia."""
# No trailing space here!
bot_message_prefix: str = 'Bob:'
user_message_prefix: str = 'User:'
max_tokens_per_generation: int = 100
# Sampling settings.
temperature: float = 0.8
top_p: float = 0.5
# =================================================================================================
parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
args = parser.parse_args()
assert prompt != '', 'Prompt must not be empty'
print('Loading 20B tokenizer')
tokenizer = tokenizers.Tokenizer.from_file('20B_tokenizer.json')
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')
print('Loading RWKV model')
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
prompt_tokens = tokenizer.encode(prompt).ids
prompt_token_count = len(prompt_tokens)
print(f'Processing {prompt_token_count} prompt tokens, may take a while')
logits, state = None, None
for token in prompt_tokens:
logits, state = model.eval(token, state, state, logits)
print('\nChat initialized! Write something and press Enter.')
while True:
# Read user input
print('> ', end='')
user_input = sys.stdin.readline()
# Process the input
new_tokens = tokenizer.encode('\n' + user_message_prefix + ' ' + user_input + '\n' + bot_message_prefix).ids
for token in new_tokens:
logits, state = model.eval(token, state, state, logits)
# Generate and print bot response
print(bot_message_prefix, end='')
decoded = ''
for i in range(max_tokens_per_generation):
token = sampling.sample_logits(logits, temperature, top_p)
decoded = tokenizer.decode([token])
print(decoded, end='')
if '\n' in decoded:
break
logits, state = model.eval(token, state, state, logits)
if '\n' not in decoded:
print()

View File

@ -0,0 +1,68 @@
# Generates completions from RWKV model based on a prompt.
import argparse
import time
import sampling
import tokenizers
import rwkv_cpp_model
import rwkv_cpp_shared_library
# ======================================== Script settings ========================================
prompt: str = """# rwkv.cpp
This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml).
Besides usual **FP32**, it supports **FP16** and **quantized INT4** inference on CPU. This project is **CPU only**."""
# How many completions to generate.
generation_count: int = 3
# Token count per single completion.
tokens_per_generation: int = 100
# Sampling settings.
temperature: float = 0.8
top_p: float = 0.5
# =================================================================================================
parser = argparse.ArgumentParser(description='Generate completions from RWKV model based on a prompt')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
args = parser.parse_args()
assert prompt != '', 'Prompt must not be empty'
print('Loading 20B tokenizer')
tokenizer = tokenizers.Tokenizer.from_file('20B_tokenizer.json')
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')
print('Loading RWKV model')
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
prompt_tokens = tokenizer.encode(prompt).ids
prompt_token_count = len(prompt_tokens)
print(f'{prompt_token_count} tokens in prompt')
init_logits, init_state = None, None
for token in prompt_tokens:
init_logits, init_state = model.eval(token, init_state, init_state, init_logits)
for GENERATION in range(generation_count):
print(f'\n--- Generation {GENERATION} ---\n')
print(prompt, end='[')
start = time.time()
logits, state = init_logits.clone(), init_state.clone()
for i in range(tokens_per_generation):
token = sampling.sample_logits(logits, temperature, top_p)
print(tokenizer.decode([token]), end='')
logits, state = model.eval(token, state, state, logits)
delay = time.time() - start
print(']\n\nTook %.3f sec, %d ms per token' % (delay, delay / tokens_per_generation * 1000))

3
rwkv/requirements.txt Normal file
View File

@ -0,0 +1,3 @@
numpy
torch
tokenizers

View File

@ -173,7 +173,7 @@ class RWKVSharedLibrary:
Returns system information string. Returns system information string.
""" """
return self.library.rwkv_get_system_info_string() return self.library.rwkv_get_system_info_string().decode('utf-8')
def load_rwkv_shared_library() -> RWKVSharedLibrary: def load_rwkv_shared_library() -> RWKVSharedLibrary:
""" """

40
rwkv/sampling.py Normal file
View File

@ -0,0 +1,40 @@
import numpy as np
import torch
from typing import Dict
from torch.nn import functional as F
def sample_logits(out: torch.Tensor, temperature: float = 1.0, top_p: float = 0.8, logit_bias: Dict[int, float] = None) -> int:
probs = F.softmax(out.cpu(), dim=-1).numpy()
return sample_probs(probs, temperature, top_p, logit_bias)
def sample_probs(probs: np.ndarray, temperature: float = 1.0, top_p: float = 0.8, logit_bias: Dict[int, float] = None) -> int:
assert 0.0 <= temperature, 'temperature'
assert 0.0 <= top_p <= 1.0, 'top_p'
if top_p == 0.0:
top_p = 1.0
if logit_bias is not None:
logits = np.log(probs)
for token in logit_bias.keys():
logits[token] += logit_bias[token]
probs = np.exp(logits) / np.sum(np.exp(logits))
if temperature == 0.0:
return np.argmax(probs).item()
if top_p < 1.0:
sorted_probs = np.sort(probs)[::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if temperature != 1.0:
probs = np.power(probs, 1.0 / temperature)
probs = probs / np.sum(probs)
return np.random.choice(a=len(probs), p=probs)