Remove reference implementation code and test against pre-created logits
This commit is contained in:
		
							parent
							
								
									bf88e8a246
								
							
						
					
					
						commit
						0fcb7c64c6
					
				
							
								
								
									
										15
									
								
								README.md
								
								
								
								
							
							
						
						
									
										15
									
								
								README.md
								
								
								
								
							|  | @ -6,14 +6,13 @@ This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [gger | ||||||
| 
 | 
 | ||||||
| ## Plan | ## Plan | ||||||
| 
 | 
 | ||||||
| 1. Remove reference implementation code from this repo | 1. Heavily refactor code; optimize where possible | ||||||
| 2. Heavily refactor code; optimize where possible | 2. Make FP16 inference work | ||||||
| 3. Make FP16 inference work | 3. Create proper interface (probably, C library) | ||||||
| 4. Create proper interface (probably, C library) | 4. Create Python wrapper with sampling and simple chat interface | ||||||
| 5. Create Python wrapper with sampling and simple chat interface | 5. Write a good `README.md` and publish links to this repo | ||||||
| 6. Write a good `README.md` and publish links to this repo | 6. Make INT4 inference work | ||||||
| 7. Make INT4 inference work | 7. Create pull request to main `ggml` repo with all improvements made here | ||||||
| 8. Create pull request to main `ggml` repo with all improvements made here |  | ||||||
| 
 | 
 | ||||||
| ## Structure | ## Structure | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,16 +1,17 @@ | ||||||
| # Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV. | # Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV. | ||||||
| # Usage: python compare_cpp_with_reference_implementation.py C:\RWKV-4-Pile-169M-20220807-8023.pth bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin | # Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch. | ||||||
|  | # Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py | ||||||
|  | # Usage: python compare_cpp_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin | ||||||
| 
 | 
 | ||||||
|  | import os | ||||||
| import argparse | import argparse | ||||||
| import subprocess | import subprocess | ||||||
| import rwkv_model |  | ||||||
| import torch | import torch | ||||||
| import numpy as np | import numpy as np | ||||||
| from typing import List | from typing import List | ||||||
| 
 | 
 | ||||||
| def parse_args(): | def parse_args(): | ||||||
|     parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV') |     parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV') | ||||||
|     parser.add_argument('torch_model_path', help='Path to PyTorch checkpoint file') |  | ||||||
|     parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file') |     parser.add_argument('main_executable_path', help='Path to main rwkv.cpp executable file') | ||||||
|     parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file') |     parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file') | ||||||
|     return parser.parse_args() |     return parser.parse_args() | ||||||
|  | @ -18,21 +19,27 @@ def parse_args(): | ||||||
| def main() -> None: | def main() -> None: | ||||||
|     args = parse_args() |     args = parse_args() | ||||||
| 
 | 
 | ||||||
|     token_count: int = 64 |     # Don't want to depend on tokenizer here. | ||||||
|     # It's not important what exactly these tokens are; just that output of both model matches. |     # Exact string is: | ||||||
|     tokens: List[int] = [(i + 1) for i in range(token_count)] |     # context = "1 In the beginning God created the heaven and the earth. " \ | ||||||
|  |     #           "2 And the earth was without form, and void; and darkness was upon the face of the deep. And the Spirit of God moved upon the face of the waters. " \ | ||||||
|  |     #           "3 And God said, Let there be light: and there was light. " \ | ||||||
|  |     #           "4 And God saw the light, that it was good: and God divided the light from the darkness." | ||||||
|  |     # The Bible was the first non-copyrighted public domain text that came to my mind. | ||||||
|  |     tokens: List[int] = [18, 496, 253, 5068, 2656, 3562, 253, 13926, 285, 253, 6149, 15, 374, 1244, 253, 6149, 369, 1293, 830, | ||||||
|  |                          13, 285, 2991, 28, 285, 13862, 369, 2220, 253, 2454, 273, 253, 3676, 15, 1244, 253, 14559, 273, 2656, | ||||||
|  |                          4395, 2220, 253, 2454, 273, 253, 12685, 15, 495, 1244, 2656, 753, 13, 1281, 627, 320, 1708, 27, 285, | ||||||
|  |                          627, 369, 1708, 15, 577, 1244, 2656, 3047, 253, 1708, 13, 326, 352, 369, 1175, 27, 285, 2656, 4272, | ||||||
|  |                          253, 1708, 432, 253, 13862, 15] | ||||||
|  | 
 | ||||||
|  |     token_count: int = len(tokens) | ||||||
|     state_path: str = './state.bin' |     state_path: str = './state.bin' | ||||||
|     logits_path: str = './logits.bin' |     logits_path: str = './logits.bin' | ||||||
| 
 | 
 | ||||||
|     reference_model: rwkv_model.RWKV_RNN = rwkv_model.RWKV_RNN(args.torch_model_path) |  | ||||||
| 
 |  | ||||||
|     ref_logits, ref_state = None, None |  | ||||||
| 
 |  | ||||||
|     for i in range(token_count): |     for i in range(token_count): | ||||||
|         token: int = tokens[i] |         token: int = tokens[i] | ||||||
| 
 | 
 | ||||||
|         print() |         print(f'{i + 1}/{token_count}') | ||||||
|         print(f'--- {i + 1}/{token_count} ---') |  | ||||||
| 
 | 
 | ||||||
|         subprocess.run( |         subprocess.run( | ||||||
|             [ |             [ | ||||||
|  | @ -40,25 +47,31 @@ def main() -> None: | ||||||
|                 args.ggml_model_path, |                 args.ggml_model_path, | ||||||
|                 str(token), |                 str(token), | ||||||
|                 # If this is the first token, let the script create a new state. |                 # If this is the first token, let the script create a new state. | ||||||
|                 '' if ref_state is None else state_path, |                 '' if i == 0 else state_path, | ||||||
|                 state_path, |                 state_path, | ||||||
|                 logits_path |                 logits_path | ||||||
|             ], |             ], | ||||||
|             check=True |             check=True | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         with open(logits_path, 'rb') as logits_file: |     expected_logits_path: str = 'expected_logits_169M_20220807_8023.bin' | ||||||
|             actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) |  | ||||||
| 
 | 
 | ||||||
|         ref_logits, ref_state = reference_model.forward(token, ref_state) |     if not os.path.isfile(expected_logits_path): | ||||||
|  |         expected_logits_path = 'rwkv/' + expected_logits_path | ||||||
| 
 | 
 | ||||||
|         difference: float = (torch.sum(ref_logits - actual_logits) / len(ref_logits)).item() |     with open(expected_logits_path, 'rb') as logits_file: | ||||||
|  |         expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) | ||||||
| 
 | 
 | ||||||
|         print(f'Reference logits: {ref_logits}') |     with open(logits_path, 'rb') as logits_file: | ||||||
|         print(f'Actual logits: {actual_logits}') |         actual_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) | ||||||
|         print('Difference per token: %.8f' % (difference,)) |  | ||||||
| 
 | 
 | ||||||
|         assert abs(difference) <= 0.00005, 'Difference is too big' |     difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item() | ||||||
|  | 
 | ||||||
|  |     print(f'Reference logits: {expected_logits}') | ||||||
|  |     print(f'Actual logits: {actual_logits}') | ||||||
|  |     print('Difference per token: %.8f' % (difference,)) | ||||||
|  | 
 | ||||||
|  |     assert abs(difference) <= 0.00005, 'Difference is too big' | ||||||
| 
 | 
 | ||||||
|     print() |     print() | ||||||
|     print('Test passes') |     print('Test passes') | ||||||
|  |  | ||||||
										
											Binary file not shown.
										
									
								
							|  | @ -1,239 +0,0 @@ | ||||||
| # Reference implementation of RWKV in PyTorch. |  | ||||||
| 
 |  | ||||||
| # Original code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py |  | ||||||
| # Original code by https://github.com/BlinkDL, licensed under Apache License 2.0 |  | ||||||
| 
 |  | ||||||
| # Improvements made to the original code: |  | ||||||
| # - safetensors loading support |  | ||||||
| # - LoRA loading support |  | ||||||
| # - ln0 absortion support |  | ||||||
| # - general code style improvements |  | ||||||
| 
 |  | ||||||
| import time |  | ||||||
| import torch |  | ||||||
| import types |  | ||||||
| from typing import Union, Tuple, Dict, Optional |  | ||||||
| from torch.nn import functional as F |  | ||||||
| 
 |  | ||||||
| LORA_R: int = 4 |  | ||||||
| LORA_ALPHA: int = 32 |  | ||||||
| 
 |  | ||||||
| def load_state_dict(file_path: str, device: str) -> Dict[str, torch.Tensor]: |  | ||||||
|     print(f'Loading {file_path}') |  | ||||||
| 
 |  | ||||||
|     if file_path.endswith('.safetensors'): |  | ||||||
|         from safetensors import safe_open |  | ||||||
| 
 |  | ||||||
|         w = {} |  | ||||||
| 
 |  | ||||||
|         with safe_open(file_path, framework='pt', device=device) as state_dict: |  | ||||||
|             for key in state_dict.keys(): |  | ||||||
|                 w[key] = state_dict.get_tensor(key) |  | ||||||
| 
 |  | ||||||
|         return w |  | ||||||
|     else: |  | ||||||
|         return torch.load(file_path, map_location=device) |  | ||||||
| 
 |  | ||||||
| def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int: |  | ||||||
|     n_layer = 0 |  | ||||||
| 
 |  | ||||||
|     while f'blocks.{n_layer}.ln1.weight' in state_dict: |  | ||||||
|         n_layer += 1 |  | ||||||
| 
 |  | ||||||
|     assert n_layer > 0 |  | ||||||
| 
 |  | ||||||
|     return n_layer |  | ||||||
| 
 |  | ||||||
| class RWKV_RNN(torch.jit.ScriptModule): |  | ||||||
| 
 |  | ||||||
|     def __init__( |  | ||||||
|             self, |  | ||||||
|             model_path: str, |  | ||||||
|             additional_model_path: Optional[str] = None, |  | ||||||
|             device: str = 'cpu', |  | ||||||
|             absorb_layer_norm_0: bool = False |  | ||||||
|     ): |  | ||||||
|         super().__init__() |  | ||||||
| 
 |  | ||||||
|         self.representation: torch.Tensor = torch.tensor([0], dtype=torch.float32, device=device) |  | ||||||
|         self.eval() |  | ||||||
| 
 |  | ||||||
|         print(f'Loading RWKV model from {model_path}') |  | ||||||
| 
 |  | ||||||
|         w = load_state_dict(model_path, device) |  | ||||||
| 
 |  | ||||||
|         if additional_model_path is not None: |  | ||||||
|             additional_w = load_state_dict(additional_model_path, device) |  | ||||||
| 
 |  | ||||||
|             for k in additional_w: |  | ||||||
|                 if k != '_training_state': |  | ||||||
|                     w[k] = additional_w[k] |  | ||||||
| 
 |  | ||||||
|         print('Merging LoRA into weights') |  | ||||||
| 
 |  | ||||||
|         start = time.time() |  | ||||||
| 
 |  | ||||||
|         for k in list(w.keys()): |  | ||||||
|             module_k = k.replace('.weight', '') |  | ||||||
| 
 |  | ||||||
|             if module_k + '.lora_A.weight' in w: |  | ||||||
|                 lora_A = w[module_k + '.lora_A.weight'] |  | ||||||
|                 lora_B = w[module_k + '.lora_B.weight'] |  | ||||||
|                 assert lora_B.shape[1] == lora_A.shape[0] == LORA_R |  | ||||||
|                 w[module_k + '.weight'] = w[module_k + '.weight'] + lora_B @ lora_A * (LORA_ALPHA / LORA_R) |  | ||||||
|                 del w[module_k + '.lora_A.weight'] |  | ||||||
|                 del w[module_k + '.lora_B.weight'] |  | ||||||
|                 del lora_A |  | ||||||
|                 del lora_B |  | ||||||
| 
 |  | ||||||
|         print('Took %.3f sec' % ((time.time() - start),)) |  | ||||||
| 
 |  | ||||||
|         for k in w.keys(): |  | ||||||
|             if '.time_' in k: |  | ||||||
|                 # (1, 1, n_embed) -> (n_embed) |  | ||||||
|                 w[k] = w[k].squeeze() |  | ||||||
| 
 |  | ||||||
|             if '.time_decay' in k: |  | ||||||
|                 # The real time decay is like e^{-e^x} |  | ||||||
|                 w[k] = -torch.exp(w[k].float()) |  | ||||||
|             elif w[k].dtype != torch.float32: |  | ||||||
|                 w[k] = w[k].float() |  | ||||||
| 
 |  | ||||||
|         self.w = types.SimpleNamespace() |  | ||||||
|         self.w.blocks = {} |  | ||||||
| 
 |  | ||||||
|         # Example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first |  | ||||||
|         for k in w.keys(): |  | ||||||
|             parts = k.split('.') |  | ||||||
|             last = parts.pop() |  | ||||||
|             here = self.w |  | ||||||
| 
 |  | ||||||
|             for p in parts: |  | ||||||
|                 if p.isdigit(): |  | ||||||
|                     p = int(p) |  | ||||||
| 
 |  | ||||||
|                     if p not in here: |  | ||||||
|                         here[p] = types.SimpleNamespace() |  | ||||||
| 
 |  | ||||||
|                     here = here[p] |  | ||||||
|                 else: |  | ||||||
|                     if not hasattr(here, p): |  | ||||||
|                         setattr(here, p, types.SimpleNamespace()) |  | ||||||
| 
 |  | ||||||
|                     here = getattr(here, p) |  | ||||||
| 
 |  | ||||||
|             setattr(here, last, w[k]) |  | ||||||
| 
 |  | ||||||
|         self.absorb_layer_norm_0 = absorb_layer_norm_0 |  | ||||||
| 
 |  | ||||||
|         if absorb_layer_norm_0: |  | ||||||
|             print('Absorbing first LayerNorm into embedding matrix') |  | ||||||
| 
 |  | ||||||
|             start = time.time() |  | ||||||
| 
 |  | ||||||
|             for i in range(len(self.w.emb.weight)): |  | ||||||
|                 self.w.emb.weight[i] = self.layer_norm(self.w.emb.weight[i], self.w.blocks[0].ln0) |  | ||||||
| 
 |  | ||||||
|             print('Took %.3f sec' % ((time.time() - start),)) |  | ||||||
| 
 |  | ||||||
|         self.n_layer = get_layer_count(w) |  | ||||||
|         self.n_embed = self.w.emb.weight.shape[1] |  | ||||||
| 
 |  | ||||||
|     def layer_norm(self, x, w): |  | ||||||
|         return F.layer_norm(x, (self.n_embed,), weight=w.weight, bias=w.bias) |  | ||||||
| 
 |  | ||||||
|     @torch.jit.script_method |  | ||||||
|     def channel_mixing(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw): |  | ||||||
|         xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) |  | ||||||
|         xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) |  | ||||||
|         state[5 * i + 0] = x |  | ||||||
|         r = torch.sigmoid(rw @ xr) |  | ||||||
|         k = torch.square(torch.relu(kw @ xk))  # square relu, primer paper |  | ||||||
|         return r * (vw @ k) |  | ||||||
| 
 |  | ||||||
|     @torch.jit.script_method |  | ||||||
|     def time_mixing(self, x, state, i: int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): |  | ||||||
|         xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) |  | ||||||
|         xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) |  | ||||||
|         xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) |  | ||||||
|         state[5 * i + 1] = x |  | ||||||
|         r = torch.sigmoid(rw @ xr) |  | ||||||
|         k = kw @ xk |  | ||||||
|         v = vw @ xv |  | ||||||
| 
 |  | ||||||
|         aa = state[5 * i + 2] |  | ||||||
|         bb = state[5 * i + 3] |  | ||||||
|         pp = state[5 * i + 4] |  | ||||||
|         ww = time_first + k |  | ||||||
|         qq = torch.maximum(pp, ww) |  | ||||||
|         e1 = torch.exp(pp - qq) |  | ||||||
|         e2 = torch.exp(ww - qq) |  | ||||||
|         a = e1 * aa + e2 * v |  | ||||||
|         b = e1 * bb + e2 |  | ||||||
|         wkv = a / b |  | ||||||
|         ww = pp + time_decay |  | ||||||
|         qq = torch.maximum(ww, k) |  | ||||||
|         e1 = torch.exp(ww - qq) |  | ||||||
|         e2 = torch.exp(k - qq) |  | ||||||
|         state[5 * i + 2] = e1 * aa + e2 * v |  | ||||||
|         state[5 * i + 3] = e1 * bb + e2 |  | ||||||
|         state[5 * i + 4] = qq |  | ||||||
|         return ow @ (r * wkv) |  | ||||||
| 
 |  | ||||||
|     def warm_up(self): |  | ||||||
|         print('Warming up the model') |  | ||||||
|         start = time.time() |  | ||||||
|         self.forward(0, None) |  | ||||||
|         print('Took %.3f sec' % ((time.time() - start),)) |  | ||||||
| 
 |  | ||||||
|     def forward(self, token: int, state: Union[torch.Tensor, None], save_representation: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: |  | ||||||
|         with torch.no_grad(): |  | ||||||
|             x: torch.Tensor = self.w.emb.weight[token] |  | ||||||
| 
 |  | ||||||
|             if state is None: |  | ||||||
|                 state = torch.zeros(self.n_layer * 5, self.n_embed, device=x.device) |  | ||||||
| 
 |  | ||||||
|                 for i in range(self.n_layer): |  | ||||||
|                     # ~Negative infinity |  | ||||||
|                     state[5 * i + 4] = -1e30 |  | ||||||
| 
 |  | ||||||
|             if not self.absorb_layer_norm_0: |  | ||||||
|                 x = self.layer_norm(x, self.w.blocks[0].ln0) |  | ||||||
| 
 |  | ||||||
|             for i in range(self.n_layer): |  | ||||||
|                 att = self.w.blocks[i].att |  | ||||||
|                 x = x + self.time_mixing( |  | ||||||
|                     self.layer_norm(x, self.w.blocks[i].ln1), |  | ||||||
|                     state, |  | ||||||
|                     i, |  | ||||||
|                     att.time_mix_k, |  | ||||||
|                     att.time_mix_v, |  | ||||||
|                     att.time_mix_r, |  | ||||||
|                     att.time_first, |  | ||||||
|                     att.time_decay, |  | ||||||
|                     att.key.weight, |  | ||||||
|                     att.value.weight, |  | ||||||
|                     att.receptance.weight, |  | ||||||
|                     att.output.weight |  | ||||||
|                 ) |  | ||||||
| 
 |  | ||||||
|                 ffn = self.w.blocks[i].ffn |  | ||||||
|                 x = x + self.channel_mixing( |  | ||||||
|                     self.layer_norm(x, self.w.blocks[i].ln2), |  | ||||||
|                     state, |  | ||||||
|                     i, |  | ||||||
|                     ffn.time_mix_k, |  | ||||||
|                     ffn.time_mix_r, |  | ||||||
|                     ffn.key.weight, |  | ||||||
|                     ffn.value.weight, |  | ||||||
|                     ffn.receptance.weight |  | ||||||
|                 ) |  | ||||||
| 
 |  | ||||||
|             x = self.layer_norm(x, self.w.ln_out) |  | ||||||
| 
 |  | ||||||
|             if save_representation: |  | ||||||
|                 self.representation = x.clone() |  | ||||||
| 
 |  | ||||||
|             x = (self.w.head.weight @ x).float() |  | ||||||
| 
 |  | ||||||
|             return x, state |  | ||||||
		Loading…
	
		Reference in New Issue