Use ggml function for parameter size calculation
This commit is contained in:
parent
c40941d9d0
commit
18bf02fea4
48
rwkv.cpp
48
rwkv.cpp
|
@ -43,6 +43,14 @@ bool read_int32(FILE * file, int32_t * dest) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[5] = {
|
||||||
|
GGML_TYPE_F32,
|
||||||
|
GGML_TYPE_F16,
|
||||||
|
GGML_TYPE_Q4_0,
|
||||||
|
GGML_TYPE_Q4_1,
|
||||||
|
GGML_TYPE_Q4_1_O
|
||||||
|
};
|
||||||
|
|
||||||
// --- Model definition and loading utilities ---
|
// --- Model definition and loading utilities ---
|
||||||
|
|
||||||
struct rwkv_layer {
|
struct rwkv_layer {
|
||||||
|
@ -223,16 +231,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
|
||||||
data_type
|
data_type
|
||||||
);
|
);
|
||||||
|
|
||||||
ggml_type ggml_data_type;
|
ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type];
|
||||||
|
|
||||||
switch (data_type) {
|
|
||||||
case 0: ggml_data_type = GGML_TYPE_F32; break;
|
|
||||||
case 1: ggml_data_type = GGML_TYPE_F16; break;
|
|
||||||
case 2: ggml_data_type = GGML_TYPE_Q4_0; break;
|
|
||||||
case 3: ggml_data_type = GGML_TYPE_Q4_1; break;
|
|
||||||
case 4: ggml_data_type = GGML_TYPE_Q4_1_O; break;
|
|
||||||
default: return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor * tensor;
|
struct ggml_tensor * tensor;
|
||||||
|
|
||||||
|
@ -558,14 +557,7 @@ void rwkv_free(struct rwkv_context * ctx) {
|
||||||
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) {
|
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) {
|
||||||
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4, "Unsupported quantization type %d", q_type);
|
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4, "Unsupported quantization type %d", q_type);
|
||||||
|
|
||||||
ggml_type type;
|
ggml_type type = FORMAT_TYPE_TO_GGML_TYPE[q_type];
|
||||||
|
|
||||||
switch (q_type) {
|
|
||||||
case 2: type = GGML_TYPE_Q4_0; break;
|
|
||||||
case 3: type = GGML_TYPE_Q4_1; break;
|
|
||||||
case 4: type = GGML_TYPE_Q4_1_O; break;
|
|
||||||
default: return false;
|
|
||||||
};
|
|
||||||
|
|
||||||
printf("Loading model from '%s'\n", model_file_path_in);
|
printf("Loading model from '%s'\n", model_file_path_in);
|
||||||
|
|
||||||
|
@ -645,23 +637,15 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
|
||||||
|
|
||||||
{
|
{
|
||||||
static const char * parameter_data_type_str[] = {
|
static const char * parameter_data_type_str[] = {
|
||||||
"f32",
|
"F32",
|
||||||
"f16",
|
"F16",
|
||||||
"q4_0",
|
"Q4_0",
|
||||||
"q4_1",
|
"Q4_1",
|
||||||
"q4_1_o"
|
"Q4_1_O"
|
||||||
};
|
};
|
||||||
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]);
|
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]);
|
||||||
|
|
||||||
// TODO Should not be hardcoded here, but read from ggml
|
total_size_orig += (size_t) (nelements * ggml_type_sizef(FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type]));
|
||||||
static const float parameter_data_type_size[] = {
|
|
||||||
4.0F,
|
|
||||||
2.0F,
|
|
||||||
20.0F / 32.0F,
|
|
||||||
24.0F / 32.0F,
|
|
||||||
24.0F / 32.0F
|
|
||||||
};
|
|
||||||
total_size_orig += (size_t) (nelements * parameter_data_type_size[parameter_data_type]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Quantize only 2D tensors, except embedding and head matrices.
|
// Quantize only 2D tensors, except embedding and head matrices.
|
||||||
|
|
Loading…
Reference in New Issue