diff --git a/rwkv/rwkv_model.py b/rwkv/rwkv_model.py new file mode 100644 index 0000000..f0aa0eb --- /dev/null +++ b/rwkv/rwkv_model.py @@ -0,0 +1,239 @@ +# Reference implementation of RWKV in PyTorch. + +# Original code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py +# Original code by https://github.com/BlinkDL, licensed under Apache License 2.0 + +# Improvements made to the original code: +# - safetensors loading support +# - LoRA loading support +# - ln0 absortion support +# - general code style improvements + +import time +import torch +import types +from typing import Union, Tuple, Dict, Optional +from torch.nn import functional as F + +LORA_R: int = 4 +LORA_ALPHA: int = 32 + +def load_state_dict(file_path: str, device: str) -> Dict[str, torch.Tensor]: + print(f'Loading {file_path}') + + if file_path.endswith('.safetensors'): + from safetensors import safe_open + + w = {} + + with safe_open(file_path, framework='pt', device=device) as state_dict: + for key in state_dict.keys(): + w[key] = state_dict.get_tensor(key) + + return w + else: + return torch.load(file_path, map_location=device) + +def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int: + n_layer = 0 + + while f'blocks.{n_layer}.ln1.weight' in state_dict: + n_layer += 1 + + assert n_layer > 0 + + return n_layer + +class RWKV_RNN(torch.jit.ScriptModule): + + def __init__( + self, + model_path: str, + additional_model_path: Optional[str] = None, + device: str = 'cpu', + absorb_layer_norm_0: bool = False + ): + super().__init__() + + self.representation: torch.Tensor = torch.tensor([0], dtype=torch.float32, device=device) + self.eval() + + print(f'Loading RWKV model from {model_path}') + + w = load_state_dict(model_path, device) + + if additional_model_path is not None: + additional_w = load_state_dict(additional_model_path, device) + + for k in additional_w: + if k != '_training_state': + w[k] = additional_w[k] + + print('Merging LoRA into weights') + + start = time.time() + + for k in list(w.keys()): + module_k = k.replace('.weight', '') + + if module_k + '.lora_A.weight' in w: + lora_A = w[module_k + '.lora_A.weight'] + lora_B = w[module_k + '.lora_B.weight'] + assert lora_B.shape[1] == lora_A.shape[0] == LORA_R + w[module_k + '.weight'] = w[module_k + '.weight'] + lora_B @ lora_A * (LORA_ALPHA / LORA_R) + del w[module_k + '.lora_A.weight'] + del w[module_k + '.lora_B.weight'] + del lora_A + del lora_B + + print('Took %.3f sec' % ((time.time() - start),)) + + for k in w.keys(): + if '.time_' in k: + # (1, 1, n_embed) -> (n_embed) + w[k] = w[k].squeeze() + + if '.time_decay' in k: + # The real time decay is like e^{-e^x} + w[k] = -torch.exp(w[k].float()) + elif w[k].dtype != torch.float32: + w[k] = w[k].float() + + self.w = types.SimpleNamespace() + self.w.blocks = {} + + # Example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first + for k in w.keys(): + parts = k.split('.') + last = parts.pop() + here = self.w + + for p in parts: + if p.isdigit(): + p = int(p) + + if p not in here: + here[p] = types.SimpleNamespace() + + here = here[p] + else: + if not hasattr(here, p): + setattr(here, p, types.SimpleNamespace()) + + here = getattr(here, p) + + setattr(here, last, w[k]) + + self.absorb_layer_norm_0 = absorb_layer_norm_0 + + if absorb_layer_norm_0: + print('Absorbing first LayerNorm into embedding matrix') + + start = time.time() + + for i in range(len(self.w.emb.weight)): + self.w.emb.weight[i] = self.layer_norm(self.w.emb.weight[i], self.w.blocks[0].ln0) + + print('Took %.3f sec' % ((time.time() - start),)) + + self.n_layer = get_layer_count(w) + self.n_embed = self.w.emb.weight.shape[1] + + def layer_norm(self, x, w): + return F.layer_norm(x, (self.n_embed,), weight=w.weight, bias=w.bias) + + @torch.jit.script_method + def channel_mixing(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw): + xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) + xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) + state[5 * i + 0] = x + r = torch.sigmoid(rw @ xr) + k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper + return r * (vw @ k) + + @torch.jit.script_method + def time_mixing(self, x, state, i: int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): + xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) + xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) + xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) + state[5 * i + 1] = x + r = torch.sigmoid(rw @ xr) + k = kw @ xk + v = vw @ xv + + aa = state[5 * i + 2] + bb = state[5 * i + 3] + pp = state[5 * i + 4] + ww = time_first + k + qq = torch.maximum(pp, ww) + e1 = torch.exp(pp - qq) + e2 = torch.exp(ww - qq) + a = e1 * aa + e2 * v + b = e1 * bb + e2 + wkv = a / b + ww = pp + time_decay + qq = torch.maximum(ww, k) + e1 = torch.exp(ww - qq) + e2 = torch.exp(k - qq) + state[5 * i + 2] = e1 * aa + e2 * v + state[5 * i + 3] = e1 * bb + e2 + state[5 * i + 4] = qq + return ow @ (r * wkv) + + def warm_up(self): + print('Warming up the model') + start = time.time() + self.forward(0, None) + print('Took %.3f sec' % ((time.time() - start),)) + + def forward(self, token: int, state: Union[torch.Tensor, None], save_representation: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + with torch.no_grad(): + x: torch.Tensor = self.w.emb.weight[token] + + if state is None: + state = torch.zeros(self.n_layer * 5, self.n_embed, device=x.device) + + for i in range(self.n_layer): + # ~Negative infinity + state[5 * i + 4] = -1e30 + + if not self.absorb_layer_norm_0: + x = self.layer_norm(x, self.w.blocks[0].ln0) + + for i in range(self.n_layer): + att = self.w.blocks[i].att + x = x + self.time_mixing( + self.layer_norm(x, self.w.blocks[i].ln1), + state, + i, + att.time_mix_k, + att.time_mix_v, + att.time_mix_r, + att.time_first, + att.time_decay, + att.key.weight, + att.value.weight, + att.receptance.weight, + att.output.weight + ) + + ffn = self.w.blocks[i].ffn + x = x + self.channel_mixing( + self.layer_norm(x, self.w.blocks[i].ln2), + state, + i, + ffn.time_mix_k, + ffn.time_mix_r, + ffn.key.weight, + ffn.value.weight, + ffn.receptance.weight + ) + + x = self.layer_norm(x, self.w.ln_out) + + if save_representation: + self.representation = x.clone() + + x = (self.w.head.weight @ x).float() + + return x, state