From 678f5233a58c546b9b8869a7a8a1b0423c32f217 Mon Sep 17 00:00:00 2001 From: saharNooby Date: Sat, 15 Apr 2023 20:46:30 +0400 Subject: [PATCH] Add LoRA loading support --- README.md | 2 + rwkv/merge_lora_into_ggml.py | 148 +++++++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+) create mode 100644 rwkv/merge_lora_into_ggml.py diff --git a/README.md b/README.md index 20bf774..d13bb59 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,8 @@ RWKV is a novel large language model architecture, [with the largest model in th This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](rwkv%2Frwkv_cpp_model.py) for it. +Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py). + **TODO (contributions welcome!)**: 1. Optimize AVX2 implementation of `Q4_1_O` matmul — currently, it is as slow as `FP32` diff --git a/rwkv/merge_lora_into_ggml.py b/rwkv/merge_lora_into_ggml.py new file mode 100644 index 0000000..e5c7d3a --- /dev/null +++ b/rwkv/merge_lora_into_ggml.py @@ -0,0 +1,148 @@ +# Merges a LoRA checkpoint in PyTorch format (.pth) into an rwkv.cpp model file. +# Usage: python merge_lora_into_ggml.py C:\rwkv.cpp-169M.bin C:\my-lora.pth 32 C:\rwkv.cpp-169M-with-my-lora.bin +# LoRA format is compatible with https://github.com/Blealtan/RWKV-LM-LoRA +# You need to know lora_alpha value to perform the merge. +# Source model must be in either FP16 or FP32 format. Quantization can be performed after merging. + +import argparse +import struct +import torch +import numpy as np + +def parse_args(): + parser = argparse.ArgumentParser(description='Merge a PyTorch LoRA checkpoint (.pth) into an rwkv.cpp model file') + parser.add_argument('src_path', help='Path to source rwkv.cpp model') + parser.add_argument('lora_path', help='Path to LoRA checkpoint in PyTorch format') + parser.add_argument('lora_alpha', type=int, help='Value of lora_alpha parameter used when training this LoRA checkpoint') + parser.add_argument('dest_path', help='Path to destination rwkv.cpp model, will be overwitten with the merged model') + return parser.parse_args() + +def write_parameter(out_file, key: str, parameter: torch.Tensor) -> None: + assert parameter.dtype == torch.float32 or parameter.dtype == torch.float16 + + key_encoded: bytes = key.encode('utf-8') + + out_file.write(struct.pack( + '=iii', + len(parameter.shape), + len(key_encoded), + 1 if parameter.dtype == torch.float16 else 0 + )) + + # Dimension order is reversed here: + # * PyTorch shape is (x rows, y columns) + # * ggml shape is (y elements in a row, x elements in a column) + # Both shapes represent the same tensor. + for dim in reversed(parameter.shape): + out_file.write(struct.pack('=i', dim)) + + out_file.write(key_encoded) + + parameter.numpy().tofile(out_file) + +def main() -> None: + args = parse_args() + + print(f'Reading {args.lora_path}') + + lora_state_dict: dict[str, torch.Tensor] = torch.load(args.lora_path, map_location='cpu') + + print(f'Merging') + + with open(args.src_path, 'rb') as in_file, open(args.dest_path, 'wb') as out_file: + # noinspection PyTypeChecker + header: tuple[int, int, int, int, int, int] = struct.unpack('=iiiiii', in_file.read(6 * 4)) + + assert header[0] == 0x67676d66, 'Invalid magic value' + assert header[1] == 100, 'Invalid version number' + assert header[5] == 0 or header[5] == 1, 'Only FP32 and FP16 models are supported' + + out_file.write(struct.pack('=iiiiii', *header)) + + while True: + parameter_header_bytes: bytes = in_file.read(3 * 4) + + if len(parameter_header_bytes) == 0: + break + + dim_count, key_length, data_type = struct.unpack('=iii', parameter_header_bytes) + + # noinspection PyTypeChecker + shape: tuple[int] = struct.unpack('=' + 'i' * dim_count, in_file.read(dim_count * 4)) + # ggml order to PyTorch + shape: list[int] = [d for d in reversed(shape)] + + key: str = in_file.read(key_length).decode('utf-8') + + print(f'* {key} {shape}') + + assert data_type == 0 or data_type == 1, 'Only FP32 and FP16 models are supported' + + element_count: int = 1 + + for dim in shape: + element_count *= dim + + parameter_np: np.ndarray = np.frombuffer( + in_file.read((2 if data_type == 1 else 4) * element_count), + dtype=(np.half if data_type == 1 else np.single) + ) + + parameter: torch.Tensor = torch.tensor(parameter_np).view(shape) + + if key in lora_state_dict: + replacement: torch.Tensor = lora_state_dict[key].float() + + # Same processing as in convert_pytorch_to_ggml.py + if '.time_' in key: + # (1, 1, n_embed) -> (n_embed) + replacement = replacement.squeeze() + + if '.time_decay' in key: + replacement = -torch.exp(replacement) + + if parameter.dtype == torch.float16: + replacement = replacement.half() + + assert replacement.shape == parameter.shape, f'Parameter {key} has shape {parameter.shape} in model file ' \ + f'and shape {replacement.shape} in LoRA file' + + parameter = replacement + + print(f'Replaced parameter {key}') + + del lora_state_dict[key] + + lora_A_key: str = key.replace('.weight', '') + '.lora_A.weight' + lora_B_key: str = key.replace('.weight', '') + '.lora_B.weight' + + if lora_A_key in lora_state_dict: + lora_A: torch.Tensor = lora_state_dict[lora_A_key] + lora_B: torch.Tensor = lora_state_dict[lora_B_key] + + assert lora_B.shape[1] == lora_A.shape[0], f'Invalid shape of LoRA matrices for {key}: ' \ + f'{lora_A.shape}, {lora_B.shape}' + + lora_R: int = lora_B.shape[1] + + replacement: torch.Tensor = parameter + lora_B @ lora_A * (args.lora_alpha / lora_R) + + if parameter.dtype == torch.float16: + replacement = replacement.half() + + parameter = replacement + + print(f'Merged LoRA into parameter {key}, lora_r = {lora_R}') + + del lora_state_dict[lora_A_key] + del lora_state_dict[lora_B_key] + + write_parameter(out_file, key, parameter) + + for key in lora_state_dict: + print(f'WARNING: Unused parameter in LoRA state dict {key}') + + print('Done') + +if __name__ == "__main__": + main()