Implement INT4 conversion and inference
This commit is contained in:
parent
b164bf4e27
commit
972e28d48d
266
rwkv.cpp
266
rwkv.cpp
|
@ -39,7 +39,6 @@
|
|||
|
||||
// Reads single int32 value from a file.
|
||||
bool read_int32(FILE * file, int32_t * dest) {
|
||||
// TODO Will not read correct values on machine with different endianness
|
||||
RWKV_ASSERT_FALSE(fread(dest, 4, 1, file) == 1, "Failed to read an int32 value from a file");
|
||||
return true;
|
||||
}
|
||||
|
@ -111,13 +110,24 @@ bool set_block_parameter(std::unordered_map<std::string, struct ggml_tensor *> *
|
|||
|
||||
size_t get_memory_required_mb(int32_t n_vocab, int32_t n_layer, int32_t n_embed, int32_t data_type) {
|
||||
if (n_vocab == 50277) {
|
||||
// 169M to 1.5B are exact, others are extrapolated (slightly bigger than needed).
|
||||
/* 169M */ if (n_layer == 12 && n_embed == 768) return size_t(data_type == 0 ? 650 : 327);
|
||||
/* 430M */ if (n_layer == 24 && n_embed == 1024) return size_t(data_type == 0 ? 1650 : 830);
|
||||
/* 1.5B */ if (n_layer == 24 && n_embed == 2048) return size_t(data_type == 0 ? 5795 : 2907);
|
||||
/* 3B */ if (n_layer == 32 && n_embed == 2560) return size_t(data_type == 0 ? 11590 : 5720); // TODO Measure exactly (FP32 only)
|
||||
/* 7B */ if (n_layer == 32 && n_embed == 4096) return size_t(data_type == 0 ? 27043 : 13566); // TODO Measure exactly
|
||||
/* 14B */ if (n_layer == 40 && n_embed == 5120) return size_t(data_type == 0 ? 54086 : 27132); // TODO Measure exactly
|
||||
// Non-exact values are extrapolated (slightly bigger than needed).
|
||||
// TODO Measure values exactly
|
||||
static const size_t memory_required_mb[6][4] = {
|
||||
/* FP32 FP16 Q4_0 Q4_1
|
||||
169M */ { 650, 327, 105, 165}, // All measured exactly
|
||||
/* 430M */ { 1650, 830, 263, 415}, // FP32, FP16 are exact
|
||||
/* 1.5B */ { 5795, 2907, 923, 1454}, // FP32, FP16 are exact
|
||||
/* 3B */ {11610, 5720, 1816, 2860}, // FP16 is exact
|
||||
/* 7B */ {27090, 13634, 4328, 6817},
|
||||
/* 14B */ {54180, 27267, 8656, 13634}
|
||||
};
|
||||
|
||||
/* 169M */ if (n_layer == 12 && n_embed == 768) return memory_required_mb[0][data_type];
|
||||
/* 430M */ if (n_layer == 24 && n_embed == 1024) return memory_required_mb[1][data_type];
|
||||
/* 1.5B */ if (n_layer == 24 && n_embed == 2048) return memory_required_mb[2][data_type];
|
||||
/* 3B */ if (n_layer == 32 && n_embed == 2560) return memory_required_mb[3][data_type];
|
||||
/* 7B */ if (n_layer == 32 && n_embed == 4096) return memory_required_mb[4][data_type];
|
||||
/* 14B */ if (n_layer == 40 && n_embed == 5120) return memory_required_mb[5][data_type];
|
||||
}
|
||||
|
||||
fprintf(
|
||||
|
@ -177,7 +187,14 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads)
|
|||
RWKV_ASSERT_NULL(model->n_layer > 0, "Non-positive n_layer %d", model->n_layer);
|
||||
|
||||
read_int32(file, &(model->data_type));
|
||||
RWKV_ASSERT_NULL(model->data_type == 0 || model->data_type == 1, "Unsupported model data type %d", model->data_type);
|
||||
RWKV_ASSERT_NULL(
|
||||
model->data_type == 0 ||
|
||||
model->data_type == 1 ||
|
||||
model->data_type == 2 ||
|
||||
model->data_type == 3,
|
||||
"Unsupported model data type %d",
|
||||
model->data_type
|
||||
);
|
||||
|
||||
// Initialize ggml
|
||||
struct ggml_init_params params;
|
||||
|
@ -203,9 +220,24 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads)
|
|||
|
||||
int32_t data_type;
|
||||
read_int32(file, &data_type);
|
||||
RWKV_ASSERT_NULL(data_type == 0 || data_type == 1, "Unsupported parameter data type %d", data_type);
|
||||
RWKV_ASSERT_NULL(
|
||||
data_type == 0 ||
|
||||
data_type == 1 ||
|
||||
data_type == 2 ||
|
||||
data_type == 3,
|
||||
"Unsupported parameter data type %d",
|
||||
data_type
|
||||
);
|
||||
|
||||
ggml_type ggml_data_type = data_type == 0 ? GGML_TYPE_F32 : GGML_TYPE_F16;
|
||||
ggml_type ggml_data_type;
|
||||
|
||||
switch (data_type) {
|
||||
case 0: ggml_data_type = GGML_TYPE_F32; break;
|
||||
case 1: ggml_data_type = GGML_TYPE_F16; break;
|
||||
case 2: ggml_data_type = GGML_TYPE_Q4_0; break;
|
||||
case 3: ggml_data_type = GGML_TYPE_Q4_1; break;
|
||||
default: return NULL;
|
||||
}
|
||||
|
||||
struct ggml_tensor * tensor;
|
||||
|
||||
|
@ -229,8 +261,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads)
|
|||
std::string key(key_length, 0);
|
||||
RWKV_ASSERT_NULL(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key");
|
||||
|
||||
size_t byte_count = element_count * ggml_type_size(ggml_data_type);
|
||||
RWKV_ASSERT_NULL(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data");
|
||||
RWKV_ASSERT_NULL(fread(tensor->data, 1, ggml_nbytes(tensor), file) == ggml_nbytes(tensor), "Failed to read parameter data");
|
||||
|
||||
parameters[key] = tensor;
|
||||
}
|
||||
|
@ -533,6 +564,215 @@ void rwkv_free(struct rwkv_context * ctx) {
|
|||
delete ctx;
|
||||
}
|
||||
|
||||
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type) {
|
||||
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type);
|
||||
|
||||
ggml_type type;
|
||||
|
||||
switch (q_type) {
|
||||
case 2: type = GGML_TYPE_Q4_0; break;
|
||||
case 3: type = GGML_TYPE_Q4_1; break;
|
||||
default: return false;
|
||||
};
|
||||
|
||||
RWKV_ASSERT_FALSE(type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1, "Unsupported data type %d", type);
|
||||
|
||||
printf("Loading model from '%s'\n", model_file_path_in);
|
||||
|
||||
auto finp = std::ifstream(model_file_path_in, std::ios::binary);
|
||||
RWKV_ASSERT_FALSE(finp, "Failed to open %s for reading", model_file_path_in);
|
||||
|
||||
auto fout = std::ofstream(model_file_path_out, std::ios::binary);
|
||||
RWKV_ASSERT_FALSE(fout, "Failed to open %s for writing", model_file_path_out);
|
||||
|
||||
// Process header
|
||||
{
|
||||
uint32_t magic;
|
||||
finp.read((char *) &magic, sizeof(magic));
|
||||
RWKV_ASSERT_FALSE(magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic);
|
||||
fout.write((char *) &magic, sizeof(magic));
|
||||
|
||||
uint32_t format_version;
|
||||
finp.read((char *) &format_version, sizeof(format_version));
|
||||
RWKV_ASSERT_FALSE(format_version == RWKV_FILE_VERSION, "Unsupported file version %d", format_version);
|
||||
fout.write((char *) &format_version, sizeof(format_version));
|
||||
|
||||
int32_t n_vocab;
|
||||
int32_t n_embed;
|
||||
int32_t n_layer;
|
||||
int32_t data_type;
|
||||
|
||||
finp.read((char *) &n_vocab, sizeof(n_vocab));
|
||||
finp.read((char *) &n_embed, sizeof(n_embed));
|
||||
finp.read((char *) &n_layer, sizeof(n_layer));
|
||||
finp.read((char *) &data_type, sizeof(data_type));
|
||||
|
||||
RWKV_ASSERT_FALSE(data_type == 0 || data_type == 1, "Unsupported data type %d, only FP32 and FP16 can be quantized", data_type);
|
||||
|
||||
data_type = q_type;
|
||||
|
||||
fout.write((char *) &n_vocab, sizeof(n_vocab));
|
||||
fout.write((char *) &n_embed, sizeof(n_embed));
|
||||
fout.write((char *) &n_layer, sizeof(n_layer));
|
||||
fout.write((char *) &data_type, sizeof(data_type));
|
||||
}
|
||||
|
||||
// Process parameters
|
||||
{
|
||||
size_t total_size_orig = 0;
|
||||
size_t total_size_new = 0;
|
||||
|
||||
std::vector<float> work;
|
||||
|
||||
std::vector<uint8_t> data_u8;
|
||||
std::vector<ggml_fp16_t> data_f16;
|
||||
std::vector<float> data_f32;
|
||||
|
||||
std::vector<int64_t> hist_all(1 << 4, 0);
|
||||
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t key_length;
|
||||
int32_t parameter_data_type;
|
||||
|
||||
finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
finp.read(reinterpret_cast<char *>(&key_length), sizeof(key_length));
|
||||
finp.read(reinterpret_cast<char *>(¶meter_data_type), sizeof(parameter_data_type));
|
||||
|
||||
if (finp.eof()) {
|
||||
break;
|
||||
}
|
||||
|
||||
int32_t nelements = 1;
|
||||
int32_t ne[2] = { 1, 1 };
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||
nelements *= ne[i];
|
||||
}
|
||||
|
||||
std::string name(key_length, 0);
|
||||
finp.read(&name[0], key_length);
|
||||
|
||||
{
|
||||
static const char * parameter_data_type_str[] = {
|
||||
"f32",
|
||||
"f16",
|
||||
"q4_0",
|
||||
"q4_1"
|
||||
};
|
||||
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]);
|
||||
}
|
||||
|
||||
// Quantize only 2D tensors
|
||||
bool quantize = n_dims == 2;
|
||||
|
||||
if (quantize) {
|
||||
if (parameter_data_type != 0 && parameter_data_type != 1) {
|
||||
fprintf(stderr, "unsupported data type %d for integer quantization\n", parameter_data_type);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (parameter_data_type == 1) {
|
||||
data_f16.resize(nelements);
|
||||
finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
|
||||
data_f32.resize(nelements);
|
||||
for (int i = 0; i < nelements; ++i) {
|
||||
data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
|
||||
}
|
||||
} else {
|
||||
data_f32.resize(nelements);
|
||||
finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
|
||||
}
|
||||
|
||||
parameter_data_type = q_type;
|
||||
} else {
|
||||
const int bytes_per_element = (parameter_data_type == 0) ? sizeof(float) : sizeof(uint16_t);
|
||||
data_u8.resize(nelements * bytes_per_element);
|
||||
finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bytes_per_element);
|
||||
}
|
||||
|
||||
fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
fout.write(reinterpret_cast<char *>(&key_length), sizeof(key_length));
|
||||
fout.write(reinterpret_cast<char *>(¶meter_data_type), sizeof(parameter_data_type));
|
||||
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||
}
|
||||
|
||||
fout.write(&name[0], key_length);
|
||||
|
||||
if (quantize) {
|
||||
printf("quantizing... ");
|
||||
work.resize(nelements); // for quantization
|
||||
|
||||
size_t cur_size = 0;
|
||||
std::vector<int64_t> hist_cur(1 << 4, 0);
|
||||
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
{
|
||||
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "unsupported quantization type %d\n", type);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
fout.write(reinterpret_cast<char *>(work.data()), cur_size);
|
||||
total_size_new += cur_size;
|
||||
|
||||
printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float) / 1024.0 / 1024.0, cur_size / 1024.0 / 1024.0);
|
||||
|
||||
for (int i = 0; i < (int) hist_cur.size(); ++i) {
|
||||
hist_all[i] += hist_cur[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int) hist_cur.size(); ++i) {
|
||||
printf("%5.3f ", hist_cur[i] / float(nelements));
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
} else {
|
||||
printf("size = %8.3f MB\n", data_u8.size() / 1024.0 / 1024.0);
|
||||
fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
|
||||
total_size_new += data_u8.size();
|
||||
}
|
||||
|
||||
total_size_orig += nelements * sizeof(float);
|
||||
}
|
||||
|
||||
printf("model size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0);
|
||||
printf("quant size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0);
|
||||
|
||||
{
|
||||
int64_t sum_all = 0;
|
||||
|
||||
for (int i = 0; i < (int) hist_all.size(); ++i) {
|
||||
sum_all += hist_all[i];
|
||||
}
|
||||
|
||||
printf("hist: ");
|
||||
|
||||
for (int i = 0; i < (int) hist_all.size(); ++i) {
|
||||
printf("%5.3f ", hist_all[i] / float(sum_all));
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
finp.close();
|
||||
fout.close();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
const char * rwkv_get_system_info_string(void) {
|
||||
static std::string s;
|
||||
|
||||
|
|
9
rwkv.h
9
rwkv.h
|
@ -31,6 +31,8 @@ extern "C" {
|
|||
|
||||
// Loads the model from a file and prepares it for inference.
|
||||
// Returns NULL on any error. Error messages would be printed to stderr.
|
||||
// - model_file_path: path to model file in ggml format.
|
||||
// - n_threads: count of threads to use, must be positive.
|
||||
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads);
|
||||
|
||||
// Evaluates the model for a single token.
|
||||
|
@ -50,6 +52,13 @@ extern "C" {
|
|||
// Frees all allocated memory and the context.
|
||||
RWKV_API void rwkv_free(struct rwkv_context * ctx);
|
||||
|
||||
// Quantizes FP32 or FP16 model to one of INT4 formats.
|
||||
// Returns false on any error. Error messages would be printed to stderr.
|
||||
// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
|
||||
// - model_file_path_out: quantized model will be written here.
|
||||
// - q_type: set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1.
|
||||
RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type);
|
||||
|
||||
// Returns system information string.
|
||||
RWKV_API const char * rwkv_get_system_info_string(void);
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV.
|
||||
# Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch.
|
||||
# Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py
|
||||
# Usage: python compare_cpp_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin
|
||||
# Usage: python compare_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin
|
||||
|
||||
import os
|
||||
import struct
|
||||
|
@ -40,7 +40,10 @@ def main() -> None:
|
|||
header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4))
|
||||
data_type: int = header[5]
|
||||
|
||||
assert data_type == 0 or data_type == 1, f'Unsupported model data type {data_type}'
|
||||
assert data_type == 0 or\
|
||||
data_type == 1 or\
|
||||
data_type == 2 or\
|
||||
data_type == 3, f'Unsupported model data type {data_type}'
|
||||
|
||||
if data_type == 0:
|
||||
# FP32, high precision
|
||||
|
@ -48,6 +51,13 @@ def main() -> None:
|
|||
elif data_type == 1:
|
||||
# FP16, lower precision, so higher threshold
|
||||
threshold = 0.003
|
||||
elif data_type == 2:
|
||||
# INT4 quantized, even lower precision, so even higher threshold
|
||||
# This threshold will let some bugs pass
|
||||
threshold = 4.0
|
||||
elif data_type == 3:
|
||||
# This format stores more data, so error would be lower
|
||||
threshold = 1.2
|
||||
|
||||
model = None
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
# Converts an RWKV model checkpoint to an rwkv.cpp compatible file.
|
||||
# Usage: python convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32
|
||||
# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32
|
||||
# Get model checkpoints from https://huggingface.co/BlinkDL
|
||||
|
||||
# File format:
|
|
@ -0,0 +1,34 @@
|
|||
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1.
|
||||
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3
|
||||
|
||||
import ctypes
|
||||
import argparse
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1')
|
||||
parser.add_argument('shared_library_path', help='Path to rwkv.cpp shared library')
|
||||
parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
|
||||
parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
|
||||
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3)
|
||||
return parser.parse_args()
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
library = ctypes.cdll.LoadLibrary(args.shared_library_path)
|
||||
|
||||
library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_int]
|
||||
library.rwkv_quantize_model_file.restype = ctypes.c_bool
|
||||
|
||||
result: bool = library.rwkv_quantize_model_file(
|
||||
args.src_path.encode('utf-8'),
|
||||
args.dest_path.encode('utf-8'),
|
||||
ctypes.c_int(args.data_type)
|
||||
)
|
||||
|
||||
assert result, 'Failed to quantize, check stderr'
|
||||
|
||||
print('Done')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue