From 972e28d48dcabcb7a55393f67cda0d80c53f9214 Mon Sep 17 00:00:00 2001 From: saharNooby Date: Sat, 1 Apr 2023 19:22:01 +0400 Subject: [PATCH] Implement INT4 conversion and inference --- rwkv.cpp | 266 +++++++++++++++++- rwkv.h | 9 + ... compare_with_reference_implementation.py} | 14 +- ..._to_ggml.py => convert_pytorch_to_ggml.py} | 2 +- rwkv/quantize.py | 34 +++ 5 files changed, 309 insertions(+), 16 deletions(-) rename rwkv/{compare_cpp_with_reference_implementation.py => compare_with_reference_implementation.py} (89%) rename rwkv/{convert_pytorch_rwkv_to_ggml.py => convert_pytorch_to_ggml.py} (97%) create mode 100644 rwkv/quantize.py diff --git a/rwkv.cpp b/rwkv.cpp index 9276a95..7110057 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -39,7 +39,6 @@ // Reads single int32 value from a file. bool read_int32(FILE * file, int32_t * dest) { - // TODO Will not read correct values on machine with different endianness RWKV_ASSERT_FALSE(fread(dest, 4, 1, file) == 1, "Failed to read an int32 value from a file"); return true; } @@ -111,13 +110,24 @@ bool set_block_parameter(std::unordered_map * size_t get_memory_required_mb(int32_t n_vocab, int32_t n_layer, int32_t n_embed, int32_t data_type) { if (n_vocab == 50277) { - // 169M to 1.5B are exact, others are extrapolated (slightly bigger than needed). - /* 169M */ if (n_layer == 12 && n_embed == 768) return size_t(data_type == 0 ? 650 : 327); - /* 430M */ if (n_layer == 24 && n_embed == 1024) return size_t(data_type == 0 ? 1650 : 830); - /* 1.5B */ if (n_layer == 24 && n_embed == 2048) return size_t(data_type == 0 ? 5795 : 2907); - /* 3B */ if (n_layer == 32 && n_embed == 2560) return size_t(data_type == 0 ? 11590 : 5720); // TODO Measure exactly (FP32 only) - /* 7B */ if (n_layer == 32 && n_embed == 4096) return size_t(data_type == 0 ? 27043 : 13566); // TODO Measure exactly - /* 14B */ if (n_layer == 40 && n_embed == 5120) return size_t(data_type == 0 ? 54086 : 27132); // TODO Measure exactly + // Non-exact values are extrapolated (slightly bigger than needed). + // TODO Measure values exactly + static const size_t memory_required_mb[6][4] = { + /* FP32 FP16 Q4_0 Q4_1 + 169M */ { 650, 327, 105, 165}, // All measured exactly + /* 430M */ { 1650, 830, 263, 415}, // FP32, FP16 are exact + /* 1.5B */ { 5795, 2907, 923, 1454}, // FP32, FP16 are exact + /* 3B */ {11610, 5720, 1816, 2860}, // FP16 is exact + /* 7B */ {27090, 13634, 4328, 6817}, + /* 14B */ {54180, 27267, 8656, 13634} + }; + + /* 169M */ if (n_layer == 12 && n_embed == 768) return memory_required_mb[0][data_type]; + /* 430M */ if (n_layer == 24 && n_embed == 1024) return memory_required_mb[1][data_type]; + /* 1.5B */ if (n_layer == 24 && n_embed == 2048) return memory_required_mb[2][data_type]; + /* 3B */ if (n_layer == 32 && n_embed == 2560) return memory_required_mb[3][data_type]; + /* 7B */ if (n_layer == 32 && n_embed == 4096) return memory_required_mb[4][data_type]; + /* 14B */ if (n_layer == 40 && n_embed == 5120) return memory_required_mb[5][data_type]; } fprintf( @@ -177,7 +187,14 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads) RWKV_ASSERT_NULL(model->n_layer > 0, "Non-positive n_layer %d", model->n_layer); read_int32(file, &(model->data_type)); - RWKV_ASSERT_NULL(model->data_type == 0 || model->data_type == 1, "Unsupported model data type %d", model->data_type); + RWKV_ASSERT_NULL( + model->data_type == 0 || + model->data_type == 1 || + model->data_type == 2 || + model->data_type == 3, + "Unsupported model data type %d", + model->data_type + ); // Initialize ggml struct ggml_init_params params; @@ -203,9 +220,24 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads) int32_t data_type; read_int32(file, &data_type); - RWKV_ASSERT_NULL(data_type == 0 || data_type == 1, "Unsupported parameter data type %d", data_type); + RWKV_ASSERT_NULL( + data_type == 0 || + data_type == 1 || + data_type == 2 || + data_type == 3, + "Unsupported parameter data type %d", + data_type + ); - ggml_type ggml_data_type = data_type == 0 ? GGML_TYPE_F32 : GGML_TYPE_F16; + ggml_type ggml_data_type; + + switch (data_type) { + case 0: ggml_data_type = GGML_TYPE_F32; break; + case 1: ggml_data_type = GGML_TYPE_F16; break; + case 2: ggml_data_type = GGML_TYPE_Q4_0; break; + case 3: ggml_data_type = GGML_TYPE_Q4_1; break; + default: return NULL; + } struct ggml_tensor * tensor; @@ -229,8 +261,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads) std::string key(key_length, 0); RWKV_ASSERT_NULL(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key"); - size_t byte_count = element_count * ggml_type_size(ggml_data_type); - RWKV_ASSERT_NULL(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data"); + RWKV_ASSERT_NULL(fread(tensor->data, 1, ggml_nbytes(tensor), file) == ggml_nbytes(tensor), "Failed to read parameter data"); parameters[key] = tensor; } @@ -533,6 +564,215 @@ void rwkv_free(struct rwkv_context * ctx) { delete ctx; } +bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type) { + RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type); + + ggml_type type; + + switch (q_type) { + case 2: type = GGML_TYPE_Q4_0; break; + case 3: type = GGML_TYPE_Q4_1; break; + default: return false; + }; + + RWKV_ASSERT_FALSE(type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1, "Unsupported data type %d", type); + + printf("Loading model from '%s'\n", model_file_path_in); + + auto finp = std::ifstream(model_file_path_in, std::ios::binary); + RWKV_ASSERT_FALSE(finp, "Failed to open %s for reading", model_file_path_in); + + auto fout = std::ofstream(model_file_path_out, std::ios::binary); + RWKV_ASSERT_FALSE(fout, "Failed to open %s for writing", model_file_path_out); + + // Process header + { + uint32_t magic; + finp.read((char *) &magic, sizeof(magic)); + RWKV_ASSERT_FALSE(magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic); + fout.write((char *) &magic, sizeof(magic)); + + uint32_t format_version; + finp.read((char *) &format_version, sizeof(format_version)); + RWKV_ASSERT_FALSE(format_version == RWKV_FILE_VERSION, "Unsupported file version %d", format_version); + fout.write((char *) &format_version, sizeof(format_version)); + + int32_t n_vocab; + int32_t n_embed; + int32_t n_layer; + int32_t data_type; + + finp.read((char *) &n_vocab, sizeof(n_vocab)); + finp.read((char *) &n_embed, sizeof(n_embed)); + finp.read((char *) &n_layer, sizeof(n_layer)); + finp.read((char *) &data_type, sizeof(data_type)); + + RWKV_ASSERT_FALSE(data_type == 0 || data_type == 1, "Unsupported data type %d, only FP32 and FP16 can be quantized", data_type); + + data_type = q_type; + + fout.write((char *) &n_vocab, sizeof(n_vocab)); + fout.write((char *) &n_embed, sizeof(n_embed)); + fout.write((char *) &n_layer, sizeof(n_layer)); + fout.write((char *) &data_type, sizeof(data_type)); + } + + // Process parameters + { + size_t total_size_orig = 0; + size_t total_size_new = 0; + + std::vector work; + + std::vector data_u8; + std::vector data_f16; + std::vector data_f32; + + std::vector hist_all(1 << 4, 0); + + while (true) { + int32_t n_dims; + int32_t key_length; + int32_t parameter_data_type; + + finp.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + finp.read(reinterpret_cast(&key_length), sizeof(key_length)); + finp.read(reinterpret_cast(¶meter_data_type), sizeof(parameter_data_type)); + + if (finp.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + finp.read (reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + std::string name(key_length, 0); + finp.read(&name[0], key_length); + + { + static const char * parameter_data_type_str[] = { + "f32", + "f16", + "q4_0", + "q4_1" + }; + printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]); + } + + // Quantize only 2D tensors + bool quantize = n_dims == 2; + + if (quantize) { + if (parameter_data_type != 0 && parameter_data_type != 1) { + fprintf(stderr, "unsupported data type %d for integer quantization\n", parameter_data_type); + return false; + } + + if (parameter_data_type == 1) { + data_f16.resize(nelements); + finp.read(reinterpret_cast(data_f16.data()), nelements * sizeof(ggml_fp16_t)); + data_f32.resize(nelements); + for (int i = 0; i < nelements; ++i) { + data_f32[i] = ggml_fp16_to_fp32(data_f16[i]); + } + } else { + data_f32.resize(nelements); + finp.read(reinterpret_cast(data_f32.data()), nelements * sizeof(float)); + } + + parameter_data_type = q_type; + } else { + const int bytes_per_element = (parameter_data_type == 0) ? sizeof(float) : sizeof(uint16_t); + data_u8.resize(nelements * bytes_per_element); + finp.read(reinterpret_cast(data_u8.data()), nelements * bytes_per_element); + } + + fout.write(reinterpret_cast(&n_dims), sizeof(n_dims)); + fout.write(reinterpret_cast(&key_length), sizeof(key_length)); + fout.write(reinterpret_cast(¶meter_data_type), sizeof(parameter_data_type)); + + for (int i = 0; i < n_dims; ++i) { + fout.write(reinterpret_cast(&ne[i]), sizeof(ne[i])); + } + + fout.write(&name[0], key_length); + + if (quantize) { + printf("quantizing... "); + work.resize(nelements); // for quantization + + size_t cur_size = 0; + std::vector hist_cur(1 << 4, 0); + + switch (type) { + case GGML_TYPE_Q4_0: + { + cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + } break; + case GGML_TYPE_Q4_1: + { + cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + } break; + default: + { + fprintf(stderr, "unsupported quantization type %d\n", type); + return false; + } + } + + fout.write(reinterpret_cast(work.data()), cur_size); + total_size_new += cur_size; + + printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float) / 1024.0 / 1024.0, cur_size / 1024.0 / 1024.0); + + for (int i = 0; i < (int) hist_cur.size(); ++i) { + hist_all[i] += hist_cur[i]; + } + + for (int i = 0; i < (int) hist_cur.size(); ++i) { + printf("%5.3f ", hist_cur[i] / float(nelements)); + } + + printf("\n"); + } else { + printf("size = %8.3f MB\n", data_u8.size() / 1024.0 / 1024.0); + fout.write(reinterpret_cast(data_u8.data()), data_u8.size()); + total_size_new += data_u8.size(); + } + + total_size_orig += nelements * sizeof(float); + } + + printf("model size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0); + printf("quant size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0); + + { + int64_t sum_all = 0; + + for (int i = 0; i < (int) hist_all.size(); ++i) { + sum_all += hist_all[i]; + } + + printf("hist: "); + + for (int i = 0; i < (int) hist_all.size(); ++i) { + printf("%5.3f ", hist_all[i] / float(sum_all)); + } + + printf("\n"); + } + } + + finp.close(); + fout.close(); + + return true; +} + const char * rwkv_get_system_info_string(void) { static std::string s; diff --git a/rwkv.h b/rwkv.h index b44fbb0..f7dbfb4 100644 --- a/rwkv.h +++ b/rwkv.h @@ -31,6 +31,8 @@ extern "C" { // Loads the model from a file and prepares it for inference. // Returns NULL on any error. Error messages would be printed to stderr. + // - model_file_path: path to model file in ggml format. + // - n_threads: count of threads to use, must be positive. RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads); // Evaluates the model for a single token. @@ -50,6 +52,13 @@ extern "C" { // Frees all allocated memory and the context. RWKV_API void rwkv_free(struct rwkv_context * ctx); + // Quantizes FP32 or FP16 model to one of INT4 formats. + // Returns false on any error. Error messages would be printed to stderr. + // - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16. + // - model_file_path_out: quantized model will be written here. + // - q_type: set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1. + RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type); + // Returns system information string. RWKV_API const char * rwkv_get_system_info_string(void); diff --git a/rwkv/compare_cpp_with_reference_implementation.py b/rwkv/compare_with_reference_implementation.py similarity index 89% rename from rwkv/compare_cpp_with_reference_implementation.py rename to rwkv/compare_with_reference_implementation.py index f754fb1..7bd3ee8 100644 --- a/rwkv/compare_cpp_with_reference_implementation.py +++ b/rwkv/compare_with_reference_implementation.py @@ -1,7 +1,7 @@ # Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV. # Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch. # Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py -# Usage: python compare_cpp_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin +# Usage: python compare_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin import os import struct @@ -40,7 +40,10 @@ def main() -> None: header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4)) data_type: int = header[5] - assert data_type == 0 or data_type == 1, f'Unsupported model data type {data_type}' + assert data_type == 0 or\ + data_type == 1 or\ + data_type == 2 or\ + data_type == 3, f'Unsupported model data type {data_type}' if data_type == 0: # FP32, high precision @@ -48,6 +51,13 @@ def main() -> None: elif data_type == 1: # FP16, lower precision, so higher threshold threshold = 0.003 + elif data_type == 2: + # INT4 quantized, even lower precision, so even higher threshold + # This threshold will let some bugs pass + threshold = 4.0 + elif data_type == 3: + # This format stores more data, so error would be lower + threshold = 1.2 model = None diff --git a/rwkv/convert_pytorch_rwkv_to_ggml.py b/rwkv/convert_pytorch_to_ggml.py similarity index 97% rename from rwkv/convert_pytorch_rwkv_to_ggml.py rename to rwkv/convert_pytorch_to_ggml.py index 67532de..f3731b6 100644 --- a/rwkv/convert_pytorch_rwkv_to_ggml.py +++ b/rwkv/convert_pytorch_to_ggml.py @@ -1,5 +1,5 @@ # Converts an RWKV model checkpoint to an rwkv.cpp compatible file. -# Usage: python convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32 +# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32 # Get model checkpoints from https://huggingface.co/BlinkDL # File format: diff --git a/rwkv/quantize.py b/rwkv/quantize.py new file mode 100644 index 0000000..e76359c --- /dev/null +++ b/rwkv/quantize.py @@ -0,0 +1,34 @@ +# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1. +# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3 + +import ctypes +import argparse + +def parse_args(): + parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1') + parser.add_argument('shared_library_path', help='Path to rwkv.cpp shared library') + parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file') + parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten') + parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3) + return parser.parse_args() + +def main() -> None: + args = parse_args() + + library = ctypes.cdll.LoadLibrary(args.shared_library_path) + + library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_int] + library.rwkv_quantize_model_file.restype = ctypes.c_bool + + result: bool = library.rwkv_quantize_model_file( + args.src_path.encode('utf-8'), + args.dest_path.encode('utf-8'), + ctypes.c_int(args.data_type) + ) + + assert result, 'Failed to quantize, check stderr' + + print('Done') + +if __name__ == "__main__": + main()