From 18bf02fea465547c6a83cffaf351fcd927ef0584 Mon Sep 17 00:00:00 2001 From: saharNooby Date: Fri, 7 Apr 2023 10:01:04 +0400 Subject: [PATCH] Use ggml function for parameter size calculation --- rwkv.cpp | 48 ++++++++++++++++-------------------------------- 1 file changed, 16 insertions(+), 32 deletions(-) diff --git a/rwkv.cpp b/rwkv.cpp index c7fd571..0c331c3 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -43,6 +43,14 @@ bool read_int32(FILE * file, int32_t * dest) { return true; } +static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[5] = { + GGML_TYPE_F32, + GGML_TYPE_F16, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, + GGML_TYPE_Q4_1_O +}; + // --- Model definition and loading utilities --- struct rwkv_layer { @@ -223,16 +231,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr data_type ); - ggml_type ggml_data_type; - - switch (data_type) { - case 0: ggml_data_type = GGML_TYPE_F32; break; - case 1: ggml_data_type = GGML_TYPE_F16; break; - case 2: ggml_data_type = GGML_TYPE_Q4_0; break; - case 3: ggml_data_type = GGML_TYPE_Q4_1; break; - case 4: ggml_data_type = GGML_TYPE_Q4_1_O; break; - default: return NULL; - } + ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type]; struct ggml_tensor * tensor; @@ -558,14 +557,7 @@ void rwkv_free(struct rwkv_context * ctx) { bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) { RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4, "Unsupported quantization type %d", q_type); - ggml_type type; - - switch (q_type) { - case 2: type = GGML_TYPE_Q4_0; break; - case 3: type = GGML_TYPE_Q4_1; break; - case 4: type = GGML_TYPE_Q4_1_O; break; - default: return false; - }; + ggml_type type = FORMAT_TYPE_TO_GGML_TYPE[q_type]; printf("Loading model from '%s'\n", model_file_path_in); @@ -645,23 +637,15 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode { static const char * parameter_data_type_str[] = { - "f32", - "f16", - "q4_0", - "q4_1", - "q4_1_o" + "F32", + "F16", + "Q4_0", + "Q4_1", + "Q4_1_O" }; printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]); - // TODO Should not be hardcoded here, but read from ggml - static const float parameter_data_type_size[] = { - 4.0F, - 2.0F, - 20.0F / 32.0F, - 24.0F / 32.0F, - 24.0F / 32.0F - }; - total_size_orig += (size_t) (nelements * parameter_data_type_size[parameter_data_type]); + total_size_orig += (size_t) (nelements * ggml_type_sizef(FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type])); } // Quantize only 2D tensors, except embedding and head matrices.