rwkv.cpp/rwkv/rwkv_model.py

240 lines
7.6 KiB
Python

# Reference implementation of RWKV in PyTorch.
# Original code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py
# Original code by https://github.com/BlinkDL, licensed under Apache License 2.0
# Improvements made to the original code:
# - safetensors loading support
# - LoRA loading support
# - ln0 absortion support
# - general code style improvements
import time
import torch
import types
from typing import Union, Tuple, Dict, Optional
from torch.nn import functional as F
LORA_R: int = 4
LORA_ALPHA: int = 32
def load_state_dict(file_path: str, device: str) -> Dict[str, torch.Tensor]:
print(f'Loading {file_path}')
if file_path.endswith('.safetensors'):
from safetensors import safe_open
w = {}
with safe_open(file_path, framework='pt', device=device) as state_dict:
for key in state_dict.keys():
w[key] = state_dict.get_tensor(key)
return w
else:
return torch.load(file_path, map_location=device)
def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int:
n_layer = 0
while f'blocks.{n_layer}.ln1.weight' in state_dict:
n_layer += 1
assert n_layer > 0
return n_layer
class RWKV_RNN(torch.jit.ScriptModule):
def __init__(
self,
model_path: str,
additional_model_path: Optional[str] = None,
device: str = 'cpu',
absorb_layer_norm_0: bool = False
):
super().__init__()
self.representation: torch.Tensor = torch.tensor([0], dtype=torch.float32, device=device)
self.eval()
print(f'Loading RWKV model from {model_path}')
w = load_state_dict(model_path, device)
if additional_model_path is not None:
additional_w = load_state_dict(additional_model_path, device)
for k in additional_w:
if k != '_training_state':
w[k] = additional_w[k]
print('Merging LoRA into weights')
start = time.time()
for k in list(w.keys()):
module_k = k.replace('.weight', '')
if module_k + '.lora_A.weight' in w:
lora_A = w[module_k + '.lora_A.weight']
lora_B = w[module_k + '.lora_B.weight']
assert lora_B.shape[1] == lora_A.shape[0] == LORA_R
w[module_k + '.weight'] = w[module_k + '.weight'] + lora_B @ lora_A * (LORA_ALPHA / LORA_R)
del w[module_k + '.lora_A.weight']
del w[module_k + '.lora_B.weight']
del lora_A
del lora_B
print('Took %.3f sec' % ((time.time() - start),))
for k in w.keys():
if '.time_' in k:
# (1, 1, n_embed) -> (n_embed)
w[k] = w[k].squeeze()
if '.time_decay' in k:
# The real time decay is like e^{-e^x}
w[k] = -torch.exp(w[k].float())
elif w[k].dtype != torch.float32:
w[k] = w[k].float()
self.w = types.SimpleNamespace()
self.w.blocks = {}
# Example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
for k in w.keys():
parts = k.split('.')
last = parts.pop()
here = self.w
for p in parts:
if p.isdigit():
p = int(p)
if p not in here:
here[p] = types.SimpleNamespace()
here = here[p]
else:
if not hasattr(here, p):
setattr(here, p, types.SimpleNamespace())
here = getattr(here, p)
setattr(here, last, w[k])
self.absorb_layer_norm_0 = absorb_layer_norm_0
if absorb_layer_norm_0:
print('Absorbing first LayerNorm into embedding matrix')
start = time.time()
for i in range(len(self.w.emb.weight)):
self.w.emb.weight[i] = self.layer_norm(self.w.emb.weight[i], self.w.blocks[0].ln0)
print('Took %.3f sec' % ((time.time() - start),))
self.n_layer = get_layer_count(w)
self.n_embed = self.w.emb.weight.shape[1]
def layer_norm(self, x, w):
return F.layer_norm(x, (self.n_embed,), weight=w.weight, bias=w.bias)
@torch.jit.script_method
def channel_mixing(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw):
xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
state[5 * i + 0] = x
r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
return r * (vw @ k)
@torch.jit.script_method
def time_mixing(self, x, state, i: int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
state[5 * i + 1] = x
r = torch.sigmoid(rw @ xr)
k = kw @ xk
v = vw @ xv
aa = state[5 * i + 2]
bb = state[5 * i + 3]
pp = state[5 * i + 4]
ww = time_first + k
qq = torch.maximum(pp, ww)
e1 = torch.exp(pp - qq)
e2 = torch.exp(ww - qq)
a = e1 * aa + e2 * v
b = e1 * bb + e2
wkv = a / b
ww = pp + time_decay
qq = torch.maximum(ww, k)
e1 = torch.exp(ww - qq)
e2 = torch.exp(k - qq)
state[5 * i + 2] = e1 * aa + e2 * v
state[5 * i + 3] = e1 * bb + e2
state[5 * i + 4] = qq
return ow @ (r * wkv)
def warm_up(self):
print('Warming up the model')
start = time.time()
self.forward(0, None)
print('Took %.3f sec' % ((time.time() - start),))
def forward(self, token: int, state: Union[torch.Tensor, None], save_representation: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
x: torch.Tensor = self.w.emb.weight[token]
if state is None:
state = torch.zeros(self.n_layer * 5, self.n_embed, device=x.device)
for i in range(self.n_layer):
# ~Negative infinity
state[5 * i + 4] = -1e30
if not self.absorb_layer_norm_0:
x = self.layer_norm(x, self.w.blocks[0].ln0)
for i in range(self.n_layer):
att = self.w.blocks[i].att
x = x + self.time_mixing(
self.layer_norm(x, self.w.blocks[i].ln1),
state,
i,
att.time_mix_k,
att.time_mix_v,
att.time_mix_r,
att.time_first,
att.time_decay,
att.key.weight,
att.value.weight,
att.receptance.weight,
att.output.weight
)
ffn = self.w.blocks[i].ffn
x = x + self.channel_mixing(
self.layer_norm(x, self.w.blocks[i].ln2),
state,
i,
ffn.time_mix_k,
ffn.time_mix_r,
ffn.key.weight,
ffn.value.weight,
ffn.receptance.weight
)
x = self.layer_norm(x, self.w.ln_out)
if save_representation:
self.representation = x.clone()
x = (self.w.head.weight @ x).float()
return x, state