Implement INT4 conversion and inference
This commit is contained in:
		
							parent
							
								
									b164bf4e27
								
							
						
					
					
						commit
						972e28d48d
					
				
							
								
								
									
										266
									
								
								rwkv.cpp
								
								
								
								
							
							
						
						
									
										266
									
								
								rwkv.cpp
								
								
								
								
							| 
						 | 
				
			
			@ -39,7 +39,6 @@
 | 
			
		|||
 | 
			
		||||
// Reads single int32 value from a file.
 | 
			
		||||
bool read_int32(FILE * file, int32_t * dest) {
 | 
			
		||||
    // TODO Will not read correct values on machine with different endianness
 | 
			
		||||
    RWKV_ASSERT_FALSE(fread(dest, 4, 1, file) == 1, "Failed to read an int32 value from a file");
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -111,13 +110,24 @@ bool set_block_parameter(std::unordered_map<std::string, struct ggml_tensor *> *
 | 
			
		|||
 | 
			
		||||
size_t get_memory_required_mb(int32_t n_vocab, int32_t n_layer, int32_t n_embed, int32_t data_type) {
 | 
			
		||||
    if (n_vocab == 50277) {
 | 
			
		||||
        // 169M to 1.5B are exact, others are extrapolated (slightly bigger than needed).
 | 
			
		||||
        /* 169M */ if (n_layer == 12 && n_embed ==  768) return size_t(data_type == 0 ?   650 :   327);
 | 
			
		||||
        /* 430M */ if (n_layer == 24 && n_embed == 1024) return size_t(data_type == 0 ?  1650 :   830);
 | 
			
		||||
        /* 1.5B */ if (n_layer == 24 && n_embed == 2048) return size_t(data_type == 0 ?  5795 :  2907);
 | 
			
		||||
        /*   3B */ if (n_layer == 32 && n_embed == 2560) return size_t(data_type == 0 ? 11590 :  5720); // TODO Measure exactly (FP32 only)
 | 
			
		||||
        /*   7B */ if (n_layer == 32 && n_embed == 4096) return size_t(data_type == 0 ? 27043 : 13566); // TODO Measure exactly
 | 
			
		||||
        /*  14B */ if (n_layer == 40 && n_embed == 5120) return size_t(data_type == 0 ? 54086 : 27132); // TODO Measure exactly
 | 
			
		||||
        // Non-exact values are extrapolated (slightly bigger than needed).
 | 
			
		||||
        // TODO Measure values exactly
 | 
			
		||||
        static const size_t memory_required_mb[6][4] = {
 | 
			
		||||
            /*            FP32   FP16  Q4_0   Q4_1
 | 
			
		||||
               169M  */ {  650,   327,  105,   165}, // All measured exactly
 | 
			
		||||
            /* 430M  */ { 1650,   830,  263,   415}, // FP32, FP16 are exact
 | 
			
		||||
            /* 1.5B  */ { 5795,  2907,  923,  1454}, // FP32, FP16 are exact
 | 
			
		||||
            /*   3B  */ {11610,  5720, 1816,  2860}, // FP16 is exact
 | 
			
		||||
            /*   7B  */ {27090, 13634, 4328,  6817},
 | 
			
		||||
            /*  14B  */ {54180, 27267, 8656, 13634}
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        /* 169M */ if (n_layer == 12 && n_embed ==  768) return memory_required_mb[0][data_type];
 | 
			
		||||
        /* 430M */ if (n_layer == 24 && n_embed == 1024) return memory_required_mb[1][data_type];
 | 
			
		||||
        /* 1.5B */ if (n_layer == 24 && n_embed == 2048) return memory_required_mb[2][data_type];
 | 
			
		||||
        /*   3B */ if (n_layer == 32 && n_embed == 2560) return memory_required_mb[3][data_type];
 | 
			
		||||
        /*   7B */ if (n_layer == 32 && n_embed == 4096) return memory_required_mb[4][data_type];
 | 
			
		||||
        /*  14B */ if (n_layer == 40 && n_embed == 5120) return memory_required_mb[5][data_type];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fprintf(
 | 
			
		||||
| 
						 | 
				
			
			@ -177,7 +187,14 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads)
 | 
			
		|||
    RWKV_ASSERT_NULL(model->n_layer > 0, "Non-positive n_layer %d", model->n_layer);
 | 
			
		||||
 | 
			
		||||
    read_int32(file, &(model->data_type));
 | 
			
		||||
    RWKV_ASSERT_NULL(model->data_type == 0 || model->data_type == 1, "Unsupported model data type %d", model->data_type);
 | 
			
		||||
    RWKV_ASSERT_NULL(
 | 
			
		||||
        model->data_type == 0 ||
 | 
			
		||||
            model->data_type == 1 ||
 | 
			
		||||
            model->data_type == 2 ||
 | 
			
		||||
            model->data_type == 3,
 | 
			
		||||
        "Unsupported model data type %d",
 | 
			
		||||
        model->data_type
 | 
			
		||||
    );
 | 
			
		||||
 | 
			
		||||
    // Initialize ggml
 | 
			
		||||
    struct ggml_init_params params;
 | 
			
		||||
| 
						 | 
				
			
			@ -203,9 +220,24 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads)
 | 
			
		|||
 | 
			
		||||
        int32_t data_type;
 | 
			
		||||
        read_int32(file, &data_type);
 | 
			
		||||
        RWKV_ASSERT_NULL(data_type == 0 || data_type == 1, "Unsupported parameter data type %d", data_type);
 | 
			
		||||
        RWKV_ASSERT_NULL(
 | 
			
		||||
            data_type == 0 ||
 | 
			
		||||
                data_type == 1 ||
 | 
			
		||||
                data_type == 2 ||
 | 
			
		||||
                data_type == 3,
 | 
			
		||||
            "Unsupported parameter data type %d",
 | 
			
		||||
            data_type
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
        ggml_type ggml_data_type = data_type == 0 ? GGML_TYPE_F32 : GGML_TYPE_F16;
 | 
			
		||||
        ggml_type ggml_data_type;
 | 
			
		||||
 | 
			
		||||
        switch (data_type) {
 | 
			
		||||
            case 0: ggml_data_type = GGML_TYPE_F32; break;
 | 
			
		||||
            case 1: ggml_data_type = GGML_TYPE_F16; break;
 | 
			
		||||
            case 2: ggml_data_type = GGML_TYPE_Q4_0; break;
 | 
			
		||||
            case 3: ggml_data_type = GGML_TYPE_Q4_1; break;
 | 
			
		||||
            default: return NULL;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        struct ggml_tensor * tensor;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -229,8 +261,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads)
 | 
			
		|||
        std::string key(key_length, 0);
 | 
			
		||||
        RWKV_ASSERT_NULL(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key");
 | 
			
		||||
 | 
			
		||||
        size_t byte_count = element_count * ggml_type_size(ggml_data_type);
 | 
			
		||||
        RWKV_ASSERT_NULL(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data");
 | 
			
		||||
        RWKV_ASSERT_NULL(fread(tensor->data, 1, ggml_nbytes(tensor), file) == ggml_nbytes(tensor), "Failed to read parameter data");
 | 
			
		||||
 | 
			
		||||
        parameters[key] = tensor;
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			@ -533,6 +564,215 @@ void rwkv_free(struct rwkv_context * ctx) {
 | 
			
		|||
    delete ctx;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type) {
 | 
			
		||||
    RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type);
 | 
			
		||||
 | 
			
		||||
    ggml_type type;
 | 
			
		||||
 | 
			
		||||
    switch (q_type) {
 | 
			
		||||
        case 2: type = GGML_TYPE_Q4_0; break;
 | 
			
		||||
        case 3: type = GGML_TYPE_Q4_1; break;
 | 
			
		||||
        default: return false;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    RWKV_ASSERT_FALSE(type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1, "Unsupported data type %d", type);
 | 
			
		||||
 | 
			
		||||
    printf("Loading model from '%s'\n", model_file_path_in);
 | 
			
		||||
 | 
			
		||||
    auto finp = std::ifstream(model_file_path_in, std::ios::binary);
 | 
			
		||||
    RWKV_ASSERT_FALSE(finp, "Failed to open %s for reading", model_file_path_in);
 | 
			
		||||
 | 
			
		||||
    auto fout = std::ofstream(model_file_path_out, std::ios::binary);
 | 
			
		||||
    RWKV_ASSERT_FALSE(fout, "Failed to open %s for writing", model_file_path_out);
 | 
			
		||||
 | 
			
		||||
    // Process header
 | 
			
		||||
    {
 | 
			
		||||
        uint32_t magic;
 | 
			
		||||
        finp.read((char *) &magic, sizeof(magic));
 | 
			
		||||
        RWKV_ASSERT_FALSE(magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic);
 | 
			
		||||
        fout.write((char *) &magic, sizeof(magic));
 | 
			
		||||
 | 
			
		||||
        uint32_t format_version;
 | 
			
		||||
        finp.read((char *) &format_version, sizeof(format_version));
 | 
			
		||||
        RWKV_ASSERT_FALSE(format_version == RWKV_FILE_VERSION, "Unsupported file version %d", format_version);
 | 
			
		||||
        fout.write((char *) &format_version, sizeof(format_version));
 | 
			
		||||
 | 
			
		||||
        int32_t n_vocab;
 | 
			
		||||
        int32_t n_embed;
 | 
			
		||||
        int32_t n_layer;
 | 
			
		||||
        int32_t data_type;
 | 
			
		||||
 | 
			
		||||
        finp.read((char *) &n_vocab, sizeof(n_vocab));
 | 
			
		||||
        finp.read((char *) &n_embed, sizeof(n_embed));
 | 
			
		||||
        finp.read((char *) &n_layer, sizeof(n_layer));
 | 
			
		||||
        finp.read((char *) &data_type, sizeof(data_type));
 | 
			
		||||
 | 
			
		||||
        RWKV_ASSERT_FALSE(data_type == 0 || data_type == 1, "Unsupported data type %d, only FP32 and FP16 can be quantized", data_type);
 | 
			
		||||
 | 
			
		||||
        data_type = q_type;
 | 
			
		||||
 | 
			
		||||
        fout.write((char *) &n_vocab, sizeof(n_vocab));
 | 
			
		||||
        fout.write((char *) &n_embed, sizeof(n_embed));
 | 
			
		||||
        fout.write((char *) &n_layer, sizeof(n_layer));
 | 
			
		||||
        fout.write((char *) &data_type, sizeof(data_type));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Process parameters
 | 
			
		||||
    {
 | 
			
		||||
        size_t total_size_orig = 0;
 | 
			
		||||
        size_t total_size_new = 0;
 | 
			
		||||
 | 
			
		||||
        std::vector<float> work;
 | 
			
		||||
 | 
			
		||||
        std::vector<uint8_t>     data_u8;
 | 
			
		||||
        std::vector<ggml_fp16_t> data_f16;
 | 
			
		||||
        std::vector<float>       data_f32;
 | 
			
		||||
 | 
			
		||||
        std::vector<int64_t> hist_all(1 << 4, 0);
 | 
			
		||||
 | 
			
		||||
        while (true) {
 | 
			
		||||
            int32_t n_dims;
 | 
			
		||||
            int32_t key_length;
 | 
			
		||||
            int32_t parameter_data_type;
 | 
			
		||||
 | 
			
		||||
            finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
 | 
			
		||||
            finp.read(reinterpret_cast<char *>(&key_length), sizeof(key_length));
 | 
			
		||||
            finp.read(reinterpret_cast<char *>(¶meter_data_type),  sizeof(parameter_data_type));
 | 
			
		||||
 | 
			
		||||
            if (finp.eof()) {
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            int32_t nelements = 1;
 | 
			
		||||
            int32_t ne[2] = { 1, 1 };
 | 
			
		||||
            for (int i = 0; i < n_dims; ++i) {
 | 
			
		||||
                finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
 | 
			
		||||
                nelements *= ne[i];
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            std::string name(key_length, 0);
 | 
			
		||||
            finp.read(&name[0], key_length);
 | 
			
		||||
 | 
			
		||||
            {
 | 
			
		||||
                static const char * parameter_data_type_str[] = {
 | 
			
		||||
                    "f32",
 | 
			
		||||
                    "f16",
 | 
			
		||||
                    "q4_0",
 | 
			
		||||
                    "q4_1"
 | 
			
		||||
                };
 | 
			
		||||
                printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // Quantize only 2D tensors
 | 
			
		||||
            bool quantize = n_dims == 2;
 | 
			
		||||
 | 
			
		||||
            if (quantize) {
 | 
			
		||||
                if (parameter_data_type != 0 && parameter_data_type != 1) {
 | 
			
		||||
                    fprintf(stderr, "unsupported data type %d for integer quantization\n", parameter_data_type);
 | 
			
		||||
                    return false;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                if (parameter_data_type == 1) {
 | 
			
		||||
                    data_f16.resize(nelements);
 | 
			
		||||
                    finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
 | 
			
		||||
                    data_f32.resize(nelements);
 | 
			
		||||
                    for (int i = 0; i < nelements; ++i) {
 | 
			
		||||
                        data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
 | 
			
		||||
                    }
 | 
			
		||||
                } else {
 | 
			
		||||
                    data_f32.resize(nelements);
 | 
			
		||||
                    finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                parameter_data_type = q_type;
 | 
			
		||||
            } else {
 | 
			
		||||
                const int bytes_per_element = (parameter_data_type == 0) ? sizeof(float) : sizeof(uint16_t);
 | 
			
		||||
                data_u8.resize(nelements * bytes_per_element);
 | 
			
		||||
                finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bytes_per_element);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
 | 
			
		||||
            fout.write(reinterpret_cast<char *>(&key_length), sizeof(key_length));
 | 
			
		||||
            fout.write(reinterpret_cast<char *>(¶meter_data_type),  sizeof(parameter_data_type));
 | 
			
		||||
 | 
			
		||||
            for (int i = 0; i < n_dims; ++i) {
 | 
			
		||||
                fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            fout.write(&name[0], key_length);
 | 
			
		||||
 | 
			
		||||
            if (quantize) {
 | 
			
		||||
                printf("quantizing... ");
 | 
			
		||||
                work.resize(nelements); // for quantization
 | 
			
		||||
 | 
			
		||||
                size_t cur_size = 0;
 | 
			
		||||
                std::vector<int64_t> hist_cur(1 << 4, 0);
 | 
			
		||||
 | 
			
		||||
                switch (type) {
 | 
			
		||||
                    case GGML_TYPE_Q4_0:
 | 
			
		||||
                        {
 | 
			
		||||
                            cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
 | 
			
		||||
                        } break;
 | 
			
		||||
                    case GGML_TYPE_Q4_1:
 | 
			
		||||
                        {
 | 
			
		||||
                            cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
 | 
			
		||||
                        } break;
 | 
			
		||||
                    default:
 | 
			
		||||
                        {
 | 
			
		||||
                            fprintf(stderr, "unsupported quantization type %d\n", type);
 | 
			
		||||
                            return false;
 | 
			
		||||
                        }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                fout.write(reinterpret_cast<char *>(work.data()), cur_size);
 | 
			
		||||
                total_size_new += cur_size;
 | 
			
		||||
 | 
			
		||||
                printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float) / 1024.0 / 1024.0, cur_size / 1024.0 / 1024.0);
 | 
			
		||||
 | 
			
		||||
                for (int i = 0; i < (int) hist_cur.size(); ++i) {
 | 
			
		||||
                    hist_all[i] += hist_cur[i];
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                for (int i = 0; i < (int) hist_cur.size(); ++i) {
 | 
			
		||||
                    printf("%5.3f ", hist_cur[i] / float(nelements));
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                printf("\n");
 | 
			
		||||
            } else {
 | 
			
		||||
                printf("size = %8.3f MB\n", data_u8.size() / 1024.0 / 1024.0);
 | 
			
		||||
                fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
 | 
			
		||||
                total_size_new += data_u8.size();
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            total_size_orig += nelements * sizeof(float);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        printf("model size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0);
 | 
			
		||||
        printf("quant size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0);
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
            int64_t sum_all = 0;
 | 
			
		||||
 | 
			
		||||
            for (int i = 0; i < (int) hist_all.size(); ++i) {
 | 
			
		||||
                sum_all += hist_all[i];
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            printf("hist: ");
 | 
			
		||||
 | 
			
		||||
            for (int i = 0; i < (int) hist_all.size(); ++i) {
 | 
			
		||||
                printf("%5.3f ", hist_all[i] / float(sum_all));
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            printf("\n");
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    finp.close();
 | 
			
		||||
    fout.close();
 | 
			
		||||
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const char * rwkv_get_system_info_string(void) {
 | 
			
		||||
    static std::string s;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										9
									
								
								rwkv.h
								
								
								
								
							
							
						
						
									
										9
									
								
								rwkv.h
								
								
								
								
							| 
						 | 
				
			
			@ -31,6 +31,8 @@ extern "C" {
 | 
			
		|||
 | 
			
		||||
    // Loads the model from a file and prepares it for inference.
 | 
			
		||||
    // Returns NULL on any error. Error messages would be printed to stderr.
 | 
			
		||||
    // - model_file_path: path to model file in ggml format.
 | 
			
		||||
    // - n_threads: count of threads to use, must be positive.
 | 
			
		||||
    RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads);
 | 
			
		||||
 | 
			
		||||
    // Evaluates the model for a single token.
 | 
			
		||||
| 
						 | 
				
			
			@ -50,6 +52,13 @@ extern "C" {
 | 
			
		|||
    // Frees all allocated memory and the context.
 | 
			
		||||
    RWKV_API void rwkv_free(struct rwkv_context * ctx);
 | 
			
		||||
 | 
			
		||||
    // Quantizes FP32 or FP16 model to one of INT4 formats.
 | 
			
		||||
    // Returns false on any error. Error messages would be printed to stderr.
 | 
			
		||||
    // - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
 | 
			
		||||
    // - model_file_path_out: quantized model will be written here.
 | 
			
		||||
    // - q_type: set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1.
 | 
			
		||||
    RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type);
 | 
			
		||||
 | 
			
		||||
    // Returns system information string.
 | 
			
		||||
    RWKV_API const char * rwkv_get_system_info_string(void);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,7 +1,7 @@
 | 
			
		|||
# Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV.
 | 
			
		||||
# Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch.
 | 
			
		||||
# Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py
 | 
			
		||||
# Usage: python compare_cpp_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin
 | 
			
		||||
# Usage: python compare_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import struct
 | 
			
		||||
| 
						 | 
				
			
			@ -40,7 +40,10 @@ def main() -> None:
 | 
			
		|||
        header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4))
 | 
			
		||||
        data_type: int = header[5]
 | 
			
		||||
 | 
			
		||||
        assert data_type == 0 or data_type == 1, f'Unsupported model data type {data_type}'
 | 
			
		||||
        assert data_type == 0 or\
 | 
			
		||||
               data_type == 1 or\
 | 
			
		||||
               data_type == 2 or\
 | 
			
		||||
               data_type == 3, f'Unsupported model data type {data_type}'
 | 
			
		||||
 | 
			
		||||
        if data_type == 0:
 | 
			
		||||
            # FP32, high precision
 | 
			
		||||
| 
						 | 
				
			
			@ -48,6 +51,13 @@ def main() -> None:
 | 
			
		|||
        elif data_type == 1:
 | 
			
		||||
            # FP16, lower precision, so higher threshold
 | 
			
		||||
            threshold = 0.003
 | 
			
		||||
        elif data_type == 2:
 | 
			
		||||
            # INT4 quantized, even lower precision, so even higher threshold
 | 
			
		||||
            # This threshold will let some bugs pass
 | 
			
		||||
            threshold = 4.0
 | 
			
		||||
        elif data_type == 3:
 | 
			
		||||
            # This format stores more data, so error would be lower
 | 
			
		||||
            threshold = 1.2
 | 
			
		||||
 | 
			
		||||
    model = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1,5 +1,5 @@
 | 
			
		|||
# Converts an RWKV model checkpoint to an rwkv.cpp compatible file.
 | 
			
		||||
# Usage: python convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32
 | 
			
		||||
# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32
 | 
			
		||||
# Get model checkpoints from https://huggingface.co/BlinkDL
 | 
			
		||||
 | 
			
		||||
# File format:
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,34 @@
 | 
			
		|||
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1.
 | 
			
		||||
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3
 | 
			
		||||
 | 
			
		||||
import ctypes
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
def parse_args():
 | 
			
		||||
    parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1')
 | 
			
		||||
    parser.add_argument('shared_library_path', help='Path to rwkv.cpp shared library')
 | 
			
		||||
    parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
 | 
			
		||||
    parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
 | 
			
		||||
    parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3)
 | 
			
		||||
    return parser.parse_args()
 | 
			
		||||
 | 
			
		||||
def main() -> None:
 | 
			
		||||
    args = parse_args()
 | 
			
		||||
 | 
			
		||||
    library = ctypes.cdll.LoadLibrary(args.shared_library_path)
 | 
			
		||||
 | 
			
		||||
    library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_int]
 | 
			
		||||
    library.rwkv_quantize_model_file.restype = ctypes.c_bool
 | 
			
		||||
 | 
			
		||||
    result: bool = library.rwkv_quantize_model_file(
 | 
			
		||||
        args.src_path.encode('utf-8'),
 | 
			
		||||
        args.dest_path.encode('utf-8'),
 | 
			
		||||
        ctypes.c_int(args.data_type)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert result, 'Failed to quantize, check stderr'
 | 
			
		||||
 | 
			
		||||
    print('Done')
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
		Loading…
	
		Reference in New Issue