Implement INT4 conversion and inference
This commit is contained in:
parent
b164bf4e27
commit
972e28d48d
266
rwkv.cpp
266
rwkv.cpp
|
@ -39,7 +39,6 @@
|
||||||
|
|
||||||
// Reads single int32 value from a file.
|
// Reads single int32 value from a file.
|
||||||
bool read_int32(FILE * file, int32_t * dest) {
|
bool read_int32(FILE * file, int32_t * dest) {
|
||||||
// TODO Will not read correct values on machine with different endianness
|
|
||||||
RWKV_ASSERT_FALSE(fread(dest, 4, 1, file) == 1, "Failed to read an int32 value from a file");
|
RWKV_ASSERT_FALSE(fread(dest, 4, 1, file) == 1, "Failed to read an int32 value from a file");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -111,13 +110,24 @@ bool set_block_parameter(std::unordered_map<std::string, struct ggml_tensor *> *
|
||||||
|
|
||||||
size_t get_memory_required_mb(int32_t n_vocab, int32_t n_layer, int32_t n_embed, int32_t data_type) {
|
size_t get_memory_required_mb(int32_t n_vocab, int32_t n_layer, int32_t n_embed, int32_t data_type) {
|
||||||
if (n_vocab == 50277) {
|
if (n_vocab == 50277) {
|
||||||
// 169M to 1.5B are exact, others are extrapolated (slightly bigger than needed).
|
// Non-exact values are extrapolated (slightly bigger than needed).
|
||||||
/* 169M */ if (n_layer == 12 && n_embed == 768) return size_t(data_type == 0 ? 650 : 327);
|
// TODO Measure values exactly
|
||||||
/* 430M */ if (n_layer == 24 && n_embed == 1024) return size_t(data_type == 0 ? 1650 : 830);
|
static const size_t memory_required_mb[6][4] = {
|
||||||
/* 1.5B */ if (n_layer == 24 && n_embed == 2048) return size_t(data_type == 0 ? 5795 : 2907);
|
/* FP32 FP16 Q4_0 Q4_1
|
||||||
/* 3B */ if (n_layer == 32 && n_embed == 2560) return size_t(data_type == 0 ? 11590 : 5720); // TODO Measure exactly (FP32 only)
|
169M */ { 650, 327, 105, 165}, // All measured exactly
|
||||||
/* 7B */ if (n_layer == 32 && n_embed == 4096) return size_t(data_type == 0 ? 27043 : 13566); // TODO Measure exactly
|
/* 430M */ { 1650, 830, 263, 415}, // FP32, FP16 are exact
|
||||||
/* 14B */ if (n_layer == 40 && n_embed == 5120) return size_t(data_type == 0 ? 54086 : 27132); // TODO Measure exactly
|
/* 1.5B */ { 5795, 2907, 923, 1454}, // FP32, FP16 are exact
|
||||||
|
/* 3B */ {11610, 5720, 1816, 2860}, // FP16 is exact
|
||||||
|
/* 7B */ {27090, 13634, 4328, 6817},
|
||||||
|
/* 14B */ {54180, 27267, 8656, 13634}
|
||||||
|
};
|
||||||
|
|
||||||
|
/* 169M */ if (n_layer == 12 && n_embed == 768) return memory_required_mb[0][data_type];
|
||||||
|
/* 430M */ if (n_layer == 24 && n_embed == 1024) return memory_required_mb[1][data_type];
|
||||||
|
/* 1.5B */ if (n_layer == 24 && n_embed == 2048) return memory_required_mb[2][data_type];
|
||||||
|
/* 3B */ if (n_layer == 32 && n_embed == 2560) return memory_required_mb[3][data_type];
|
||||||
|
/* 7B */ if (n_layer == 32 && n_embed == 4096) return memory_required_mb[4][data_type];
|
||||||
|
/* 14B */ if (n_layer == 40 && n_embed == 5120) return memory_required_mb[5][data_type];
|
||||||
}
|
}
|
||||||
|
|
||||||
fprintf(
|
fprintf(
|
||||||
|
@ -177,7 +187,14 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads)
|
||||||
RWKV_ASSERT_NULL(model->n_layer > 0, "Non-positive n_layer %d", model->n_layer);
|
RWKV_ASSERT_NULL(model->n_layer > 0, "Non-positive n_layer %d", model->n_layer);
|
||||||
|
|
||||||
read_int32(file, &(model->data_type));
|
read_int32(file, &(model->data_type));
|
||||||
RWKV_ASSERT_NULL(model->data_type == 0 || model->data_type == 1, "Unsupported model data type %d", model->data_type);
|
RWKV_ASSERT_NULL(
|
||||||
|
model->data_type == 0 ||
|
||||||
|
model->data_type == 1 ||
|
||||||
|
model->data_type == 2 ||
|
||||||
|
model->data_type == 3,
|
||||||
|
"Unsupported model data type %d",
|
||||||
|
model->data_type
|
||||||
|
);
|
||||||
|
|
||||||
// Initialize ggml
|
// Initialize ggml
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
|
@ -203,9 +220,24 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads)
|
||||||
|
|
||||||
int32_t data_type;
|
int32_t data_type;
|
||||||
read_int32(file, &data_type);
|
read_int32(file, &data_type);
|
||||||
RWKV_ASSERT_NULL(data_type == 0 || data_type == 1, "Unsupported parameter data type %d", data_type);
|
RWKV_ASSERT_NULL(
|
||||||
|
data_type == 0 ||
|
||||||
|
data_type == 1 ||
|
||||||
|
data_type == 2 ||
|
||||||
|
data_type == 3,
|
||||||
|
"Unsupported parameter data type %d",
|
||||||
|
data_type
|
||||||
|
);
|
||||||
|
|
||||||
ggml_type ggml_data_type = data_type == 0 ? GGML_TYPE_F32 : GGML_TYPE_F16;
|
ggml_type ggml_data_type;
|
||||||
|
|
||||||
|
switch (data_type) {
|
||||||
|
case 0: ggml_data_type = GGML_TYPE_F32; break;
|
||||||
|
case 1: ggml_data_type = GGML_TYPE_F16; break;
|
||||||
|
case 2: ggml_data_type = GGML_TYPE_Q4_0; break;
|
||||||
|
case 3: ggml_data_type = GGML_TYPE_Q4_1; break;
|
||||||
|
default: return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_tensor * tensor;
|
struct ggml_tensor * tensor;
|
||||||
|
|
||||||
|
@ -229,8 +261,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads)
|
||||||
std::string key(key_length, 0);
|
std::string key(key_length, 0);
|
||||||
RWKV_ASSERT_NULL(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key");
|
RWKV_ASSERT_NULL(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key");
|
||||||
|
|
||||||
size_t byte_count = element_count * ggml_type_size(ggml_data_type);
|
RWKV_ASSERT_NULL(fread(tensor->data, 1, ggml_nbytes(tensor), file) == ggml_nbytes(tensor), "Failed to read parameter data");
|
||||||
RWKV_ASSERT_NULL(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data");
|
|
||||||
|
|
||||||
parameters[key] = tensor;
|
parameters[key] = tensor;
|
||||||
}
|
}
|
||||||
|
@ -533,6 +564,215 @@ void rwkv_free(struct rwkv_context * ctx) {
|
||||||
delete ctx;
|
delete ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type) {
|
||||||
|
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type);
|
||||||
|
|
||||||
|
ggml_type type;
|
||||||
|
|
||||||
|
switch (q_type) {
|
||||||
|
case 2: type = GGML_TYPE_Q4_0; break;
|
||||||
|
case 3: type = GGML_TYPE_Q4_1; break;
|
||||||
|
default: return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
RWKV_ASSERT_FALSE(type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1, "Unsupported data type %d", type);
|
||||||
|
|
||||||
|
printf("Loading model from '%s'\n", model_file_path_in);
|
||||||
|
|
||||||
|
auto finp = std::ifstream(model_file_path_in, std::ios::binary);
|
||||||
|
RWKV_ASSERT_FALSE(finp, "Failed to open %s for reading", model_file_path_in);
|
||||||
|
|
||||||
|
auto fout = std::ofstream(model_file_path_out, std::ios::binary);
|
||||||
|
RWKV_ASSERT_FALSE(fout, "Failed to open %s for writing", model_file_path_out);
|
||||||
|
|
||||||
|
// Process header
|
||||||
|
{
|
||||||
|
uint32_t magic;
|
||||||
|
finp.read((char *) &magic, sizeof(magic));
|
||||||
|
RWKV_ASSERT_FALSE(magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic);
|
||||||
|
fout.write((char *) &magic, sizeof(magic));
|
||||||
|
|
||||||
|
uint32_t format_version;
|
||||||
|
finp.read((char *) &format_version, sizeof(format_version));
|
||||||
|
RWKV_ASSERT_FALSE(format_version == RWKV_FILE_VERSION, "Unsupported file version %d", format_version);
|
||||||
|
fout.write((char *) &format_version, sizeof(format_version));
|
||||||
|
|
||||||
|
int32_t n_vocab;
|
||||||
|
int32_t n_embed;
|
||||||
|
int32_t n_layer;
|
||||||
|
int32_t data_type;
|
||||||
|
|
||||||
|
finp.read((char *) &n_vocab, sizeof(n_vocab));
|
||||||
|
finp.read((char *) &n_embed, sizeof(n_embed));
|
||||||
|
finp.read((char *) &n_layer, sizeof(n_layer));
|
||||||
|
finp.read((char *) &data_type, sizeof(data_type));
|
||||||
|
|
||||||
|
RWKV_ASSERT_FALSE(data_type == 0 || data_type == 1, "Unsupported data type %d, only FP32 and FP16 can be quantized", data_type);
|
||||||
|
|
||||||
|
data_type = q_type;
|
||||||
|
|
||||||
|
fout.write((char *) &n_vocab, sizeof(n_vocab));
|
||||||
|
fout.write((char *) &n_embed, sizeof(n_embed));
|
||||||
|
fout.write((char *) &n_layer, sizeof(n_layer));
|
||||||
|
fout.write((char *) &data_type, sizeof(data_type));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process parameters
|
||||||
|
{
|
||||||
|
size_t total_size_orig = 0;
|
||||||
|
size_t total_size_new = 0;
|
||||||
|
|
||||||
|
std::vector<float> work;
|
||||||
|
|
||||||
|
std::vector<uint8_t> data_u8;
|
||||||
|
std::vector<ggml_fp16_t> data_f16;
|
||||||
|
std::vector<float> data_f32;
|
||||||
|
|
||||||
|
std::vector<int64_t> hist_all(1 << 4, 0);
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
int32_t n_dims;
|
||||||
|
int32_t key_length;
|
||||||
|
int32_t parameter_data_type;
|
||||||
|
|
||||||
|
finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||||
|
finp.read(reinterpret_cast<char *>(&key_length), sizeof(key_length));
|
||||||
|
finp.read(reinterpret_cast<char *>(¶meter_data_type), sizeof(parameter_data_type));
|
||||||
|
|
||||||
|
if (finp.eof()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t nelements = 1;
|
||||||
|
int32_t ne[2] = { 1, 1 };
|
||||||
|
for (int i = 0; i < n_dims; ++i) {
|
||||||
|
finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||||
|
nelements *= ne[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string name(key_length, 0);
|
||||||
|
finp.read(&name[0], key_length);
|
||||||
|
|
||||||
|
{
|
||||||
|
static const char * parameter_data_type_str[] = {
|
||||||
|
"f32",
|
||||||
|
"f16",
|
||||||
|
"q4_0",
|
||||||
|
"q4_1"
|
||||||
|
};
|
||||||
|
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quantize only 2D tensors
|
||||||
|
bool quantize = n_dims == 2;
|
||||||
|
|
||||||
|
if (quantize) {
|
||||||
|
if (parameter_data_type != 0 && parameter_data_type != 1) {
|
||||||
|
fprintf(stderr, "unsupported data type %d for integer quantization\n", parameter_data_type);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parameter_data_type == 1) {
|
||||||
|
data_f16.resize(nelements);
|
||||||
|
finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
|
||||||
|
data_f32.resize(nelements);
|
||||||
|
for (int i = 0; i < nelements; ++i) {
|
||||||
|
data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
data_f32.resize(nelements);
|
||||||
|
finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
|
||||||
|
}
|
||||||
|
|
||||||
|
parameter_data_type = q_type;
|
||||||
|
} else {
|
||||||
|
const int bytes_per_element = (parameter_data_type == 0) ? sizeof(float) : sizeof(uint16_t);
|
||||||
|
data_u8.resize(nelements * bytes_per_element);
|
||||||
|
finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bytes_per_element);
|
||||||
|
}
|
||||||
|
|
||||||
|
fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||||
|
fout.write(reinterpret_cast<char *>(&key_length), sizeof(key_length));
|
||||||
|
fout.write(reinterpret_cast<char *>(¶meter_data_type), sizeof(parameter_data_type));
|
||||||
|
|
||||||
|
for (int i = 0; i < n_dims; ++i) {
|
||||||
|
fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
fout.write(&name[0], key_length);
|
||||||
|
|
||||||
|
if (quantize) {
|
||||||
|
printf("quantizing... ");
|
||||||
|
work.resize(nelements); // for quantization
|
||||||
|
|
||||||
|
size_t cur_size = 0;
|
||||||
|
std::vector<int64_t> hist_cur(1 << 4, 0);
|
||||||
|
|
||||||
|
switch (type) {
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
{
|
||||||
|
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
{
|
||||||
|
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
fprintf(stderr, "unsupported quantization type %d\n", type);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fout.write(reinterpret_cast<char *>(work.data()), cur_size);
|
||||||
|
total_size_new += cur_size;
|
||||||
|
|
||||||
|
printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float) / 1024.0 / 1024.0, cur_size / 1024.0 / 1024.0);
|
||||||
|
|
||||||
|
for (int i = 0; i < (int) hist_cur.size(); ++i) {
|
||||||
|
hist_all[i] += hist_cur[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < (int) hist_cur.size(); ++i) {
|
||||||
|
printf("%5.3f ", hist_cur[i] / float(nelements));
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("\n");
|
||||||
|
} else {
|
||||||
|
printf("size = %8.3f MB\n", data_u8.size() / 1024.0 / 1024.0);
|
||||||
|
fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
|
||||||
|
total_size_new += data_u8.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
total_size_orig += nelements * sizeof(float);
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("model size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0);
|
||||||
|
printf("quant size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0);
|
||||||
|
|
||||||
|
{
|
||||||
|
int64_t sum_all = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < (int) hist_all.size(); ++i) {
|
||||||
|
sum_all += hist_all[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("hist: ");
|
||||||
|
|
||||||
|
for (int i = 0; i < (int) hist_all.size(); ++i) {
|
||||||
|
printf("%5.3f ", hist_all[i] / float(sum_all));
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
finp.close();
|
||||||
|
fout.close();
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
const char * rwkv_get_system_info_string(void) {
|
const char * rwkv_get_system_info_string(void) {
|
||||||
static std::string s;
|
static std::string s;
|
||||||
|
|
||||||
|
|
9
rwkv.h
9
rwkv.h
|
@ -31,6 +31,8 @@ extern "C" {
|
||||||
|
|
||||||
// Loads the model from a file and prepares it for inference.
|
// Loads the model from a file and prepares it for inference.
|
||||||
// Returns NULL on any error. Error messages would be printed to stderr.
|
// Returns NULL on any error. Error messages would be printed to stderr.
|
||||||
|
// - model_file_path: path to model file in ggml format.
|
||||||
|
// - n_threads: count of threads to use, must be positive.
|
||||||
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads);
|
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads);
|
||||||
|
|
||||||
// Evaluates the model for a single token.
|
// Evaluates the model for a single token.
|
||||||
|
@ -50,6 +52,13 @@ extern "C" {
|
||||||
// Frees all allocated memory and the context.
|
// Frees all allocated memory and the context.
|
||||||
RWKV_API void rwkv_free(struct rwkv_context * ctx);
|
RWKV_API void rwkv_free(struct rwkv_context * ctx);
|
||||||
|
|
||||||
|
// Quantizes FP32 or FP16 model to one of INT4 formats.
|
||||||
|
// Returns false on any error. Error messages would be printed to stderr.
|
||||||
|
// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
|
||||||
|
// - model_file_path_out: quantized model will be written here.
|
||||||
|
// - q_type: set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1.
|
||||||
|
RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type);
|
||||||
|
|
||||||
// Returns system information string.
|
// Returns system information string.
|
||||||
RWKV_API const char * rwkv_get_system_info_string(void);
|
RWKV_API const char * rwkv_get_system_info_string(void);
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV.
|
# Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV.
|
||||||
# Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch.
|
# Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch.
|
||||||
# Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py
|
# Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py
|
||||||
# Usage: python compare_cpp_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin
|
# Usage: python compare_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
|
@ -40,7 +40,10 @@ def main() -> None:
|
||||||
header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4))
|
header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4))
|
||||||
data_type: int = header[5]
|
data_type: int = header[5]
|
||||||
|
|
||||||
assert data_type == 0 or data_type == 1, f'Unsupported model data type {data_type}'
|
assert data_type == 0 or\
|
||||||
|
data_type == 1 or\
|
||||||
|
data_type == 2 or\
|
||||||
|
data_type == 3, f'Unsupported model data type {data_type}'
|
||||||
|
|
||||||
if data_type == 0:
|
if data_type == 0:
|
||||||
# FP32, high precision
|
# FP32, high precision
|
||||||
|
@ -48,6 +51,13 @@ def main() -> None:
|
||||||
elif data_type == 1:
|
elif data_type == 1:
|
||||||
# FP16, lower precision, so higher threshold
|
# FP16, lower precision, so higher threshold
|
||||||
threshold = 0.003
|
threshold = 0.003
|
||||||
|
elif data_type == 2:
|
||||||
|
# INT4 quantized, even lower precision, so even higher threshold
|
||||||
|
# This threshold will let some bugs pass
|
||||||
|
threshold = 4.0
|
||||||
|
elif data_type == 3:
|
||||||
|
# This format stores more data, so error would be lower
|
||||||
|
threshold = 1.2
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# Converts an RWKV model checkpoint to an rwkv.cpp compatible file.
|
# Converts an RWKV model checkpoint to an rwkv.cpp compatible file.
|
||||||
# Usage: python convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32
|
# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32
|
||||||
# Get model checkpoints from https://huggingface.co/BlinkDL
|
# Get model checkpoints from https://huggingface.co/BlinkDL
|
||||||
|
|
||||||
# File format:
|
# File format:
|
|
@ -0,0 +1,34 @@
|
||||||
|
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1.
|
||||||
|
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3
|
||||||
|
|
||||||
|
import ctypes
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1')
|
||||||
|
parser.add_argument('shared_library_path', help='Path to rwkv.cpp shared library')
|
||||||
|
parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
|
||||||
|
parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
|
||||||
|
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
library = ctypes.cdll.LoadLibrary(args.shared_library_path)
|
||||||
|
|
||||||
|
library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_int]
|
||||||
|
library.rwkv_quantize_model_file.restype = ctypes.c_bool
|
||||||
|
|
||||||
|
result: bool = library.rwkv_quantize_model_file(
|
||||||
|
args.src_path.encode('utf-8'),
|
||||||
|
args.dest_path.encode('utf-8'),
|
||||||
|
ctypes.c_int(args.data_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result, 'Failed to quantize, check stderr'
|
||||||
|
|
||||||
|
print('Done')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue