diff --git a/rwkv.cpp b/rwkv.cpp index 208502c..08b4ad3 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -650,6 +650,7 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode }; printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]); + // TODO Should not be hardcoded here, but read from ggml static const float parameter_data_type_size[] = { 4.0F, 2.0F, @@ -659,8 +660,12 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode total_size_orig += (size_t) (nelements * parameter_data_type_size[parameter_data_type]); } - // Quantize only 2D tensors - bool quantize = n_dims == 2; + // Quantize only 2D tensors, except embedding and head matrices. + // Embedding and head take not too much space, especially in bigger models; + // but they significantly increase perplexity when quantized. + bool quantize = n_dims == 2 && + name != std::string("emb.weight") && + name != std::string("head.weight"); if (quantize) { if (parameter_data_type != 0 && parameter_data_type != 1) {