diff --git a/README.md b/README.md index 1f409a7..f06fbf3 100644 --- a/README.md +++ b/README.md @@ -6,14 +6,13 @@ This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [gger ## Plan -1. Remove reference implementation code from this repo -2. Heavily refactor code; optimize where possible -3. Make FP16 inference work -4. Create proper interface (probably, C library) -5. Create Python wrapper with sampling and simple chat interface -6. Write a good `README.md` and publish links to this repo -7. Make INT4 inference work -8. Create pull request to main `ggml` repo with all improvements made here +1. Heavily refactor code; optimize where possible +2. Make FP16 inference work +3. Create proper interface (probably, C library) +4. Create Python wrapper with sampling and simple chat interface +5. Write a good `README.md` and publish links to this repo +6. Make INT4 inference work +7. Create pull request to main `ggml` repo with all improvements made here ## Structure diff --git a/rwkv/compare_cpp_with_reference_implementation.py b/rwkv/compare_cpp_with_reference_implementation.py index 8c7f080..8123900 100644 --- a/rwkv/compare_cpp_with_reference_implementation.py +++ b/rwkv/compare_cpp_with_reference_implementation.py @@ -1,16 +1,17 @@ # Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV. -# Usage: python compare_cpp_with_reference_implementation.py C:\RWKV-4-Pile-169M-20220807-8023.pth bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin +# Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch. +# Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py +# Usage: python compare_cpp_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin +import os import argparse import subprocess -import rwkv_model import torch import numpy as np from typing import List def parse_args(): parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV') - parser.add_argument('torch_model_path', help='Path to PyTorch checkpoint file') parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file') parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file') return parser.parse_args() @@ -18,21 +19,27 @@ def parse_args(): def main() -> None: args = parse_args() - token_count: int = 64 - # It's not important what exactly these tokens are; just that output of both model matches. - tokens: List[int] = [(i + 1) for i in range(token_count)] + # Don't want to depend on tokenizer here. + # Exact string is: + # context = "1 In the beginning God created the heaven and the earth. " \ + # "2 And the earth was without form, and void; and darkness was upon the face of the deep. And the Spirit of God moved upon the face of the waters. " \ + # "3 And God said, Let there be light: and there was light. " \ + # "4 And God saw the light, that it was good: and God divided the light from the darkness." + # The Bible was the first non-copyrighted public domain text that came to my mind. + tokens: List[int] = [18, 496, 253, 5068, 2656, 3562, 253, 13926, 285, 253, 6149, 15, 374, 1244, 253, 6149, 369, 1293, 830, + 13, 285, 2991, 28, 285, 13862, 369, 2220, 253, 2454, 273, 253, 3676, 15, 1244, 253, 14559, 273, 2656, + 4395, 2220, 253, 2454, 273, 253, 12685, 15, 495, 1244, 2656, 753, 13, 1281, 627, 320, 1708, 27, 285, + 627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272, + 253, 1708, 432, 253, 13862, 15] + + token_count: int = len(tokens) state_path: str = './state.bin' logits_path: str = './logits.bin' - reference_model: rwkv_model.RWKV_RNN = rwkv_model.RWKV_RNN(args.torch_model_path) - - ref_logits, ref_state = None, None - for i in range(token_count): token: int = tokens[i] - print() - print(f'--- {i + 1}/{token_count} ---') + print(f'{i + 1}/{token_count}') subprocess.run( [ @@ -40,25 +47,31 @@ def main() -> None: args.ggml_model_path, str(token), # If this is the first token, let the script create a new state. - '' if ref_state is None else state_path, + '' if i == 0 else state_path, state_path, logits_path ], check=True ) - with open(logits_path, 'rb') as logits_file: - actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) + expected_logits_path: str = 'expected_logits_169M_20220807_8023.bin' - ref_logits, ref_state = reference_model.forward(token, ref_state) + if not os.path.isfile(expected_logits_path): + expected_logits_path = 'rwkv/' + expected_logits_path - difference: float = (torch.sum(ref_logits - actual_logits) / len(ref_logits)).item() + with open(expected_logits_path, 'rb') as logits_file: + expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) - print(f'Reference logits: {ref_logits}') - print(f'Actual logits: {actual_logits}') - print('Difference per token: %.8f' % (difference,)) + with open(logits_path, 'rb') as logits_file: + actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) - assert abs(difference) <= 0.00005, 'Difference is too big' + difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item() + + print(f'Reference logits: {expected_logits}') + print(f'Actual logits: {actual_logits}') + print('Difference per token: %.8f' % (difference,)) + + assert abs(difference) <= 0.00005, 'Difference is too big' print() print('Test passes') diff --git a/rwkv/expected_logits_169M_20220807_8023.bin b/rwkv/expected_logits_169M_20220807_8023.bin new file mode 100644 index 0000000..9ce6ca1 Binary files /dev/null and b/rwkv/expected_logits_169M_20220807_8023.bin differ diff --git a/rwkv/rwkv_model.py b/rwkv/rwkv_model.py deleted file mode 100644 index f0aa0eb..0000000 --- a/rwkv/rwkv_model.py +++ /dev/null @@ -1,239 +0,0 @@ -# Reference implementation of RWKV in PyTorch. - -# Original code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py -# Original code by https://github.com/BlinkDL, licensed under Apache License 2.0 - -# Improvements made to the original code: -# - safetensors loading support -# - LoRA loading support -# - ln0 absortion support -# - general code style improvements - -import time -import torch -import types -from typing import Union, Tuple, Dict, Optional -from torch.nn import functional as F - -LORA_R: int = 4 -LORA_ALPHA: int = 32 - -def load_state_dict(file_path: str, device: str) -> Dict[str, torch.Tensor]: - print(f'Loading {file_path}') - - if file_path.endswith('.safetensors'): - from safetensors import safe_open - - w = {} - - with safe_open(file_path, framework='pt', device=device) as state_dict: - for key in state_dict.keys(): - w[key] = state_dict.get_tensor(key) - - return w - else: - return torch.load(file_path, map_location=device) - -def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int: - n_layer = 0 - - while f'blocks.{n_layer}.ln1.weight' in state_dict: - n_layer += 1 - - assert n_layer > 0 - - return n_layer - -class RWKV_RNN(torch.jit.ScriptModule): - - def __init__( - self, - model_path: str, - additional_model_path: Optional[str] = None, - device: str = 'cpu', - absorb_layer_norm_0: bool = False - ): - super().__init__() - - self.representation: torch.Tensor = torch.tensor([0], dtype=torch.float32, device=device) - self.eval() - - print(f'Loading RWKV model from {model_path}') - - w = load_state_dict(model_path, device) - - if additional_model_path is not None: - additional_w = load_state_dict(additional_model_path, device) - - for k in additional_w: - if k != '_training_state': - w[k] = additional_w[k] - - print('Merging LoRA into weights') - - start = time.time() - - for k in list(w.keys()): - module_k = k.replace('.weight', '') - - if module_k + '.lora_A.weight' in w: - lora_A = w[module_k + '.lora_A.weight'] - lora_B = w[module_k + '.lora_B.weight'] - assert lora_B.shape[1] == lora_A.shape[0] == LORA_R - w[module_k + '.weight'] = w[module_k + '.weight'] + lora_B @ lora_A * (LORA_ALPHA / LORA_R) - del w[module_k + '.lora_A.weight'] - del w[module_k + '.lora_B.weight'] - del lora_A - del lora_B - - print('Took %.3f sec' % ((time.time() - start),)) - - for k in w.keys(): - if '.time_' in k: - # (1, 1, n_embed) -> (n_embed) - w[k] = w[k].squeeze() - - if '.time_decay' in k: - # The real time decay is like e^{-e^x} - w[k] = -torch.exp(w[k].float()) - elif w[k].dtype != torch.float32: - w[k] = w[k].float() - - self.w = types.SimpleNamespace() - self.w.blocks = {} - - # Example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first - for k in w.keys(): - parts = k.split('.') - last = parts.pop() - here = self.w - - for p in parts: - if p.isdigit(): - p = int(p) - - if p not in here: - here[p] = types.SimpleNamespace() - - here = here[p] - else: - if not hasattr(here, p): - setattr(here, p, types.SimpleNamespace()) - - here = getattr(here, p) - - setattr(here, last, w[k]) - - self.absorb_layer_norm_0 = absorb_layer_norm_0 - - if absorb_layer_norm_0: - print('Absorbing first LayerNorm into embedding matrix') - - start = time.time() - - for i in range(len(self.w.emb.weight)): - self.w.emb.weight[i] = self.layer_norm(self.w.emb.weight[i], self.w.blocks[0].ln0) - - print('Took %.3f sec' % ((time.time() - start),)) - - self.n_layer = get_layer_count(w) - self.n_embed = self.w.emb.weight.shape[1] - - def layer_norm(self, x, w): - return F.layer_norm(x, (self.n_embed,), weight=w.weight, bias=w.bias) - - @torch.jit.script_method - def channel_mixing(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw): - xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) - xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) - state[5 * i + 0] = x - r = torch.sigmoid(rw @ xr) - k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper - return r * (vw @ k) - - @torch.jit.script_method - def time_mixing(self, x, state, i: int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): - xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) - xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) - xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) - state[5 * i + 1] = x - r = torch.sigmoid(rw @ xr) - k = kw @ xk - v = vw @ xv - - aa = state[5 * i + 2] - bb = state[5 * i + 3] - pp = state[5 * i + 4] - ww = time_first + k - qq = torch.maximum(pp, ww) - e1 = torch.exp(pp - qq) - e2 = torch.exp(ww - qq) - a = e1 * aa + e2 * v - b = e1 * bb + e2 - wkv = a / b - ww = pp + time_decay - qq = torch.maximum(ww, k) - e1 = torch.exp(ww - qq) - e2 = torch.exp(k - qq) - state[5 * i + 2] = e1 * aa + e2 * v - state[5 * i + 3] = e1 * bb + e2 - state[5 * i + 4] = qq - return ow @ (r * wkv) - - def warm_up(self): - print('Warming up the model') - start = time.time() - self.forward(0, None) - print('Took %.3f sec' % ((time.time() - start),)) - - def forward(self, token: int, state: Union[torch.Tensor, None], save_representation: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: - with torch.no_grad(): - x: torch.Tensor = self.w.emb.weight[token] - - if state is None: - state = torch.zeros(self.n_layer * 5, self.n_embed, device=x.device) - - for i in range(self.n_layer): - # ~Negative infinity - state[5 * i + 4] = -1e30 - - if not self.absorb_layer_norm_0: - x = self.layer_norm(x, self.w.blocks[0].ln0) - - for i in range(self.n_layer): - att = self.w.blocks[i].att - x = x + self.time_mixing( - self.layer_norm(x, self.w.blocks[i].ln1), - state, - i, - att.time_mix_k, - att.time_mix_v, - att.time_mix_r, - att.time_first, - att.time_decay, - att.key.weight, - att.value.weight, - att.receptance.weight, - att.output.weight - ) - - ffn = self.w.blocks[i].ffn - x = x + self.channel_mixing( - self.layer_norm(x, self.w.blocks[i].ln2), - state, - i, - ffn.time_mix_k, - ffn.time_mix_r, - ffn.key.weight, - ffn.value.weight, - ffn.receptance.weight - ) - - x = self.layer_norm(x, self.w.ln_out) - - if save_representation: - self.representation = x.clone() - - x = (self.w.head.weight @ x).float() - - return x, state