Use ggml function for parameter size calculation
This commit is contained in:
parent
c40941d9d0
commit
18bf02fea4
48
rwkv.cpp
48
rwkv.cpp
|
@ -43,6 +43,14 @@ bool read_int32(FILE * file, int32_t * dest) {
|
|||
return true;
|
||||
}
|
||||
|
||||
static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[5] = {
|
||||
GGML_TYPE_F32,
|
||||
GGML_TYPE_F16,
|
||||
GGML_TYPE_Q4_0,
|
||||
GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_Q4_1_O
|
||||
};
|
||||
|
||||
// --- Model definition and loading utilities ---
|
||||
|
||||
struct rwkv_layer {
|
||||
|
@ -223,16 +231,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
|
|||
data_type
|
||||
);
|
||||
|
||||
ggml_type ggml_data_type;
|
||||
|
||||
switch (data_type) {
|
||||
case 0: ggml_data_type = GGML_TYPE_F32; break;
|
||||
case 1: ggml_data_type = GGML_TYPE_F16; break;
|
||||
case 2: ggml_data_type = GGML_TYPE_Q4_0; break;
|
||||
case 3: ggml_data_type = GGML_TYPE_Q4_1; break;
|
||||
case 4: ggml_data_type = GGML_TYPE_Q4_1_O; break;
|
||||
default: return NULL;
|
||||
}
|
||||
ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type];
|
||||
|
||||
struct ggml_tensor * tensor;
|
||||
|
||||
|
@ -558,14 +557,7 @@ void rwkv_free(struct rwkv_context * ctx) {
|
|||
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) {
|
||||
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4, "Unsupported quantization type %d", q_type);
|
||||
|
||||
ggml_type type;
|
||||
|
||||
switch (q_type) {
|
||||
case 2: type = GGML_TYPE_Q4_0; break;
|
||||
case 3: type = GGML_TYPE_Q4_1; break;
|
||||
case 4: type = GGML_TYPE_Q4_1_O; break;
|
||||
default: return false;
|
||||
};
|
||||
ggml_type type = FORMAT_TYPE_TO_GGML_TYPE[q_type];
|
||||
|
||||
printf("Loading model from '%s'\n", model_file_path_in);
|
||||
|
||||
|
@ -645,23 +637,15 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
|
|||
|
||||
{
|
||||
static const char * parameter_data_type_str[] = {
|
||||
"f32",
|
||||
"f16",
|
||||
"q4_0",
|
||||
"q4_1",
|
||||
"q4_1_o"
|
||||
"F32",
|
||||
"F16",
|
||||
"Q4_0",
|
||||
"Q4_1",
|
||||
"Q4_1_O"
|
||||
};
|
||||
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]);
|
||||
|
||||
// TODO Should not be hardcoded here, but read from ggml
|
||||
static const float parameter_data_type_size[] = {
|
||||
4.0F,
|
||||
2.0F,
|
||||
20.0F / 32.0F,
|
||||
24.0F / 32.0F,
|
||||
24.0F / 32.0F
|
||||
};
|
||||
total_size_orig += (size_t) (nelements * parameter_data_type_size[parameter_data_type]);
|
||||
total_size_orig += (size_t) (nelements * ggml_type_sizef(FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type]));
|
||||
}
|
||||
|
||||
// Quantize only 2D tensors, except embedding and head matrices.
|
||||
|
|
Loading…
Reference in New Issue