Remove reference implementation code and test against pre-created logits
This commit is contained in:
		
							parent
							
								
									bf88e8a246
								
							
						
					
					
						commit
						0fcb7c64c6
					
				
							
								
								
									
										15
									
								
								README.md
								
								
								
								
							
							
						
						
									
										15
									
								
								README.md
								
								
								
								
							|  | @ -6,14 +6,13 @@ This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [gger | |||
| 
 | ||||
| ## Plan | ||||
| 
 | ||||
| 1. Remove reference implementation code from this repo | ||||
| 2. Heavily refactor code; optimize where possible | ||||
| 3. Make FP16 inference work | ||||
| 4. Create proper interface (probably, C library) | ||||
| 5. Create Python wrapper with sampling and simple chat interface | ||||
| 6. Write a good `README.md` and publish links to this repo | ||||
| 7. Make INT4 inference work | ||||
| 8. Create pull request to main `ggml` repo with all improvements made here | ||||
| 1. Heavily refactor code; optimize where possible | ||||
| 2. Make FP16 inference work | ||||
| 3. Create proper interface (probably, C library) | ||||
| 4. Create Python wrapper with sampling and simple chat interface | ||||
| 5. Write a good `README.md` and publish links to this repo | ||||
| 6. Make INT4 inference work | ||||
| 7. Create pull request to main `ggml` repo with all improvements made here | ||||
| 
 | ||||
| ## Structure | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,16 +1,17 @@ | |||
| # Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV. | ||||
| # Usage: python compare_cpp_with_reference_implementation.py C:\RWKV-4-Pile-169M-20220807-8023.pth bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin | ||||
| # Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch. | ||||
| # Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py | ||||
| # Usage: python compare_cpp_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin | ||||
| 
 | ||||
| import os | ||||
| import argparse | ||||
| import subprocess | ||||
| import rwkv_model | ||||
| import torch | ||||
| import numpy as np | ||||
| from typing import List | ||||
| 
 | ||||
| def parse_args(): | ||||
|     parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV') | ||||
|     parser.add_argument('torch_model_path', help='Path to PyTorch checkpoint file') | ||||
|     parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file') | ||||
|     parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file') | ||||
|     return parser.parse_args() | ||||
|  | @ -18,21 +19,27 @@ def parse_args(): | |||
| def main() -> None: | ||||
|     args = parse_args() | ||||
| 
 | ||||
|     token_count: int = 64 | ||||
|     # It's not important what exactly these tokens are; just that output of both model matches. | ||||
|     tokens: List[int] = [(i + 1) for i in range(token_count)] | ||||
|     # Don't want to depend on tokenizer here. | ||||
|     # Exact string is: | ||||
|     # context = "1 In the beginning God created the heaven and the earth. " \ | ||||
|     #           "2 And the earth was without form, and void; and darkness was upon the face of the deep. And the Spirit of God moved upon the face of the waters. " \ | ||||
|     #           "3 And God said, Let there be light: and there was light. " \ | ||||
|     #           "4 And God saw the light, that it was good: and God divided the light from the darkness." | ||||
|     # The Bible was the first non-copyrighted public domain text that came to my mind. | ||||
|     tokens: List[int] = [18, 496, 253, 5068, 2656, 3562, 253, 13926, 285, 253, 6149, 15, 374, 1244, 253, 6149, 369, 1293, 830, | ||||
|                          13, 285, 2991, 28, 285, 13862, 369, 2220, 253, 2454, 273, 253, 3676, 15, 1244, 253, 14559, 273, 2656, | ||||
|                          4395, 2220, 253, 2454, 273, 253, 12685, 15, 495, 1244, 2656, 753, 13, 1281, 627, 320, 1708, 27, 285, | ||||
|                          627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272, | ||||
|                          253, 1708, 432, 253, 13862, 15] | ||||
| 
 | ||||
|     token_count: int = len(tokens) | ||||
|     state_path: str = './state.bin' | ||||
|     logits_path: str = './logits.bin' | ||||
| 
 | ||||
|     reference_model: rwkv_model.RWKV_RNN = rwkv_model.RWKV_RNN(args.torch_model_path) | ||||
| 
 | ||||
|     ref_logits, ref_state = None, None | ||||
| 
 | ||||
|     for i in range(token_count): | ||||
|         token: int = tokens[i] | ||||
| 
 | ||||
|         print() | ||||
|         print(f'--- {i + 1}/{token_count} ---') | ||||
|         print(f'{i + 1}/{token_count}') | ||||
| 
 | ||||
|         subprocess.run( | ||||
|             [ | ||||
|  | @ -40,21 +47,27 @@ def main() -> None: | |||
|                 args.ggml_model_path, | ||||
|                 str(token), | ||||
|                 # If this is the first token, let the script create a new state. | ||||
|                 '' if ref_state is None else state_path, | ||||
|                 '' if i == 0 else state_path, | ||||
|                 state_path, | ||||
|                 logits_path | ||||
|             ], | ||||
|             check=True | ||||
|         ) | ||||
| 
 | ||||
|     expected_logits_path: str = 'expected_logits_169M_20220807_8023.bin' | ||||
| 
 | ||||
|     if not os.path.isfile(expected_logits_path): | ||||
|         expected_logits_path = 'rwkv/' + expected_logits_path | ||||
| 
 | ||||
|     with open(expected_logits_path, 'rb') as logits_file: | ||||
|         expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) | ||||
| 
 | ||||
|     with open(logits_path, 'rb') as logits_file: | ||||
|         actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) | ||||
| 
 | ||||
|         ref_logits, ref_state = reference_model.forward(token, ref_state) | ||||
|     difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item() | ||||
| 
 | ||||
|         difference: float = (torch.sum(ref_logits - actual_logits) / len(ref_logits)).item() | ||||
| 
 | ||||
|         print(f'Reference logits: {ref_logits}') | ||||
|     print(f'Reference logits: {expected_logits}') | ||||
|     print(f'Actual logits: {actual_logits}') | ||||
|     print('Difference per token: %.8f' % (difference,)) | ||||
| 
 | ||||
|  |  | |||
										
											Binary file not shown.
										
									
								
							|  | @ -1,239 +0,0 @@ | |||
| # Reference implementation of RWKV in PyTorch. | ||||
| 
 | ||||
| # Original code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py | ||||
| # Original code by https://github.com/BlinkDL, licensed under Apache License 2.0 | ||||
| 
 | ||||
| # Improvements made to the original code: | ||||
| # - safetensors loading support | ||||
| # - LoRA loading support | ||||
| # - ln0 absortion support | ||||
| # - general code style improvements | ||||
| 
 | ||||
| import time | ||||
| import torch | ||||
| import types | ||||
| from typing import Union, Tuple, Dict, Optional | ||||
| from torch.nn import functional as F | ||||
| 
 | ||||
| LORA_R: int = 4 | ||||
| LORA_ALPHA: int = 32 | ||||
| 
 | ||||
| def load_state_dict(file_path: str, device: str) -> Dict[str, torch.Tensor]: | ||||
|     print(f'Loading {file_path}') | ||||
| 
 | ||||
|     if file_path.endswith('.safetensors'): | ||||
|         from safetensors import safe_open | ||||
| 
 | ||||
|         w = {} | ||||
| 
 | ||||
|         with safe_open(file_path, framework='pt', device=device) as state_dict: | ||||
|             for key in state_dict.keys(): | ||||
|                 w[key] = state_dict.get_tensor(key) | ||||
| 
 | ||||
|         return w | ||||
|     else: | ||||
|         return torch.load(file_path, map_location=device) | ||||
| 
 | ||||
| def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int: | ||||
|     n_layer = 0 | ||||
| 
 | ||||
|     while f'blocks.{n_layer}.ln1.weight' in state_dict: | ||||
|         n_layer += 1 | ||||
| 
 | ||||
|     assert n_layer > 0 | ||||
| 
 | ||||
|     return n_layer | ||||
| 
 | ||||
| class RWKV_RNN(torch.jit.ScriptModule): | ||||
| 
 | ||||
|     def __init__( | ||||
|             self, | ||||
|             model_path: str, | ||||
|             additional_model_path: Optional[str] = None, | ||||
|             device: str = 'cpu', | ||||
|             absorb_layer_norm_0: bool = False | ||||
|     ): | ||||
|         super().__init__() | ||||
| 
 | ||||
|         self.representation: torch.Tensor = torch.tensor([0], dtype=torch.float32, device=device) | ||||
|         self.eval() | ||||
| 
 | ||||
|         print(f'Loading RWKV model from {model_path}') | ||||
| 
 | ||||
|         w = load_state_dict(model_path, device) | ||||
| 
 | ||||
|         if additional_model_path is not None: | ||||
|             additional_w = load_state_dict(additional_model_path, device) | ||||
| 
 | ||||
|             for k in additional_w: | ||||
|                 if k != '_training_state': | ||||
|                     w[k] = additional_w[k] | ||||
| 
 | ||||
|         print('Merging LoRA into weights') | ||||
| 
 | ||||
|         start = time.time() | ||||
| 
 | ||||
|         for k in list(w.keys()): | ||||
|             module_k = k.replace('.weight', '') | ||||
| 
 | ||||
|             if module_k + '.lora_A.weight' in w: | ||||
|                 lora_A = w[module_k + '.lora_A.weight'] | ||||
|                 lora_B = w[module_k + '.lora_B.weight'] | ||||
|                 assert lora_B.shape[1] == lora_A.shape[0] == LORA_R | ||||
|                 w[module_k + '.weight'] = w[module_k + '.weight'] + lora_B @ lora_A * (LORA_ALPHA / LORA_R) | ||||
|                 del w[module_k + '.lora_A.weight'] | ||||
|                 del w[module_k + '.lora_B.weight'] | ||||
|                 del lora_A | ||||
|                 del lora_B | ||||
| 
 | ||||
|         print('Took %.3f sec' % ((time.time() - start),)) | ||||
| 
 | ||||
|         for k in w.keys(): | ||||
|             if '.time_' in k: | ||||
|                 # (1, 1, n_embed) -> (n_embed) | ||||
|                 w[k] = w[k].squeeze() | ||||
| 
 | ||||
|             if '.time_decay' in k: | ||||
|                 # The real time decay is like e^{-e^x} | ||||
|                 w[k] = -torch.exp(w[k].float()) | ||||
|             elif w[k].dtype != torch.float32: | ||||
|                 w[k] = w[k].float() | ||||
| 
 | ||||
|         self.w = types.SimpleNamespace() | ||||
|         self.w.blocks = {} | ||||
| 
 | ||||
|         # Example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first | ||||
|         for k in w.keys(): | ||||
|             parts = k.split('.') | ||||
|             last = parts.pop() | ||||
|             here = self.w | ||||
| 
 | ||||
|             for p in parts: | ||||
|                 if p.isdigit(): | ||||
|                     p = int(p) | ||||
| 
 | ||||
|                     if p not in here: | ||||
|                         here[p] = types.SimpleNamespace() | ||||
| 
 | ||||
|                     here = here[p] | ||||
|                 else: | ||||
|                     if not hasattr(here, p): | ||||
|                         setattr(here, p, types.SimpleNamespace()) | ||||
| 
 | ||||
|                     here = getattr(here, p) | ||||
| 
 | ||||
|             setattr(here, last, w[k]) | ||||
| 
 | ||||
|         self.absorb_layer_norm_0 = absorb_layer_norm_0 | ||||
| 
 | ||||
|         if absorb_layer_norm_0: | ||||
|             print('Absorbing first LayerNorm into embedding matrix') | ||||
| 
 | ||||
|             start = time.time() | ||||
| 
 | ||||
|             for i in range(len(self.w.emb.weight)): | ||||
|                 self.w.emb.weight[i] = self.layer_norm(self.w.emb.weight[i], self.w.blocks[0].ln0) | ||||
| 
 | ||||
|             print('Took %.3f sec' % ((time.time() - start),)) | ||||
| 
 | ||||
|         self.n_layer = get_layer_count(w) | ||||
|         self.n_embed = self.w.emb.weight.shape[1] | ||||
| 
 | ||||
|     def layer_norm(self, x, w): | ||||
|         return F.layer_norm(x, (self.n_embed,), weight=w.weight, bias=w.bias) | ||||
| 
 | ||||
|     @torch.jit.script_method | ||||
|     def channel_mixing(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw): | ||||
|         xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) | ||||
|         xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) | ||||
|         state[5 * i + 0] = x | ||||
|         r = torch.sigmoid(rw @ xr) | ||||
|         k = torch.square(torch.relu(kw @ xk))  # square relu, primer paper | ||||
|         return r * (vw @ k) | ||||
| 
 | ||||
|     @torch.jit.script_method | ||||
|     def time_mixing(self, x, state, i: int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): | ||||
|         xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) | ||||
|         xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) | ||||
|         xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) | ||||
|         state[5 * i + 1] = x | ||||
|         r = torch.sigmoid(rw @ xr) | ||||
|         k = kw @ xk | ||||
|         v = vw @ xv | ||||
| 
 | ||||
|         aa = state[5 * i + 2] | ||||
|         bb = state[5 * i + 3] | ||||
|         pp = state[5 * i + 4] | ||||
|         ww = time_first + k | ||||
|         qq = torch.maximum(pp, ww) | ||||
|         e1 = torch.exp(pp - qq) | ||||
|         e2 = torch.exp(ww - qq) | ||||
|         a = e1 * aa + e2 * v | ||||
|         b = e1 * bb + e2 | ||||
|         wkv = a / b | ||||
|         ww = pp + time_decay | ||||
|         qq = torch.maximum(ww, k) | ||||
|         e1 = torch.exp(ww - qq) | ||||
|         e2 = torch.exp(k - qq) | ||||
|         state[5 * i + 2] = e1 * aa + e2 * v | ||||
|         state[5 * i + 3] = e1 * bb + e2 | ||||
|         state[5 * i + 4] = qq | ||||
|         return ow @ (r * wkv) | ||||
| 
 | ||||
|     def warm_up(self): | ||||
|         print('Warming up the model') | ||||
|         start = time.time() | ||||
|         self.forward(0, None) | ||||
|         print('Took %.3f sec' % ((time.time() - start),)) | ||||
| 
 | ||||
|     def forward(self, token: int, state: Union[torch.Tensor, None], save_representation: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: | ||||
|         with torch.no_grad(): | ||||
|             x: torch.Tensor = self.w.emb.weight[token] | ||||
| 
 | ||||
|             if state is None: | ||||
|                 state = torch.zeros(self.n_layer * 5, self.n_embed, device=x.device) | ||||
| 
 | ||||
|                 for i in range(self.n_layer): | ||||
|                     # ~Negative infinity | ||||
|                     state[5 * i + 4] = -1e30 | ||||
| 
 | ||||
|             if not self.absorb_layer_norm_0: | ||||
|                 x = self.layer_norm(x, self.w.blocks[0].ln0) | ||||
| 
 | ||||
|             for i in range(self.n_layer): | ||||
|                 att = self.w.blocks[i].att | ||||
|                 x = x + self.time_mixing( | ||||
|                     self.layer_norm(x, self.w.blocks[i].ln1), | ||||
|                     state, | ||||
|                     i, | ||||
|                     att.time_mix_k, | ||||
|                     att.time_mix_v, | ||||
|                     att.time_mix_r, | ||||
|                     att.time_first, | ||||
|                     att.time_decay, | ||||
|                     att.key.weight, | ||||
|                     att.value.weight, | ||||
|                     att.receptance.weight, | ||||
|                     att.output.weight | ||||
|                 ) | ||||
| 
 | ||||
|                 ffn = self.w.blocks[i].ffn | ||||
|                 x = x + self.channel_mixing( | ||||
|                     self.layer_norm(x, self.w.blocks[i].ln2), | ||||
|                     state, | ||||
|                     i, | ||||
|                     ffn.time_mix_k, | ||||
|                     ffn.time_mix_r, | ||||
|                     ffn.key.weight, | ||||
|                     ffn.value.weight, | ||||
|                     ffn.receptance.weight | ||||
|                 ) | ||||
| 
 | ||||
|             x = self.layer_norm(x, self.w.ln_out) | ||||
| 
 | ||||
|             if save_representation: | ||||
|                 self.representation = x.clone() | ||||
| 
 | ||||
|             x = (self.w.head.weight @ x).float() | ||||
| 
 | ||||
|             return x, state | ||||
		Loading…
	
		Reference in New Issue