Use ggml function for parameter size calculation

This commit is contained in:
saharNooby 2023-04-07 10:01:04 +04:00
parent c40941d9d0
commit 18bf02fea4
1 changed files with 16 additions and 32 deletions

View File

@ -43,6 +43,14 @@ bool read_int32(FILE * file, int32_t * dest) {
return true;
}
static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[5] = {
GGML_TYPE_F32,
GGML_TYPE_F16,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1,
GGML_TYPE_Q4_1_O
};
// --- Model definition and loading utilities ---
struct rwkv_layer {
@ -223,16 +231,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
data_type
);
ggml_type ggml_data_type;
switch (data_type) {
case 0: ggml_data_type = GGML_TYPE_F32; break;
case 1: ggml_data_type = GGML_TYPE_F16; break;
case 2: ggml_data_type = GGML_TYPE_Q4_0; break;
case 3: ggml_data_type = GGML_TYPE_Q4_1; break;
case 4: ggml_data_type = GGML_TYPE_Q4_1_O; break;
default: return NULL;
}
ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type];
struct ggml_tensor * tensor;
@ -558,14 +557,7 @@ void rwkv_free(struct rwkv_context * ctx) {
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) {
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4, "Unsupported quantization type %d", q_type);
ggml_type type;
switch (q_type) {
case 2: type = GGML_TYPE_Q4_0; break;
case 3: type = GGML_TYPE_Q4_1; break;
case 4: type = GGML_TYPE_Q4_1_O; break;
default: return false;
};
ggml_type type = FORMAT_TYPE_TO_GGML_TYPE[q_type];
printf("Loading model from '%s'\n", model_file_path_in);
@ -645,23 +637,15 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
{
static const char * parameter_data_type_str[] = {
"f32",
"f16",
"q4_0",
"q4_1",
"q4_1_o"
"F32",
"F16",
"Q4_0",
"Q4_1",
"Q4_1_O"
};
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]);
// TODO Should not be hardcoded here, but read from ggml
static const float parameter_data_type_size[] = {
4.0F,
2.0F,
20.0F / 32.0F,
24.0F / 32.0F,
24.0F / 32.0F
};
total_size_orig += (size_t) (nelements * parameter_data_type_size[parameter_data_type]);
total_size_orig += (size_t) (nelements * ggml_type_sizef(FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type]));
}
// Quantize only 2D tensors, except embedding and head matrices.