Implement INT4 conversion and inference

This commit is contained in:
saharNooby 2023-04-01 19:22:01 +04:00
parent b164bf4e27
commit 972e28d48d
5 changed files with 309 additions and 16 deletions

266
rwkv.cpp
View File

@ -39,7 +39,6 @@
// Reads single int32 value from a file. // Reads single int32 value from a file.
bool read_int32(FILE * file, int32_t * dest) { bool read_int32(FILE * file, int32_t * dest) {
// TODO Will not read correct values on machine with different endianness
RWKV_ASSERT_FALSE(fread(dest, 4, 1, file) == 1, "Failed to read an int32 value from a file"); RWKV_ASSERT_FALSE(fread(dest, 4, 1, file) == 1, "Failed to read an int32 value from a file");
return true; return true;
} }
@ -111,13 +110,24 @@ bool set_block_parameter(std::unordered_map<std::string, struct ggml_tensor *> *
size_t get_memory_required_mb(int32_t n_vocab, int32_t n_layer, int32_t n_embed, int32_t data_type) { size_t get_memory_required_mb(int32_t n_vocab, int32_t n_layer, int32_t n_embed, int32_t data_type) {
if (n_vocab == 50277) { if (n_vocab == 50277) {
// 169M to 1.5B are exact, others are extrapolated (slightly bigger than needed). // Non-exact values are extrapolated (slightly bigger than needed).
/* 169M */ if (n_layer == 12 && n_embed == 768) return size_t(data_type == 0 ? 650 : 327); // TODO Measure values exactly
/* 430M */ if (n_layer == 24 && n_embed == 1024) return size_t(data_type == 0 ? 1650 : 830); static const size_t memory_required_mb[6][4] = {
/* 1.5B */ if (n_layer == 24 && n_embed == 2048) return size_t(data_type == 0 ? 5795 : 2907); /* FP32 FP16 Q4_0 Q4_1
/* 3B */ if (n_layer == 32 && n_embed == 2560) return size_t(data_type == 0 ? 11590 : 5720); // TODO Measure exactly (FP32 only) 169M */ { 650, 327, 105, 165}, // All measured exactly
/* 7B */ if (n_layer == 32 && n_embed == 4096) return size_t(data_type == 0 ? 27043 : 13566); // TODO Measure exactly /* 430M */ { 1650, 830, 263, 415}, // FP32, FP16 are exact
/* 14B */ if (n_layer == 40 && n_embed == 5120) return size_t(data_type == 0 ? 54086 : 27132); // TODO Measure exactly /* 1.5B */ { 5795, 2907, 923, 1454}, // FP32, FP16 are exact
/* 3B */ {11610, 5720, 1816, 2860}, // FP16 is exact
/* 7B */ {27090, 13634, 4328, 6817},
/* 14B */ {54180, 27267, 8656, 13634}
};
/* 169M */ if (n_layer == 12 && n_embed == 768) return memory_required_mb[0][data_type];
/* 430M */ if (n_layer == 24 && n_embed == 1024) return memory_required_mb[1][data_type];
/* 1.5B */ if (n_layer == 24 && n_embed == 2048) return memory_required_mb[2][data_type];
/* 3B */ if (n_layer == 32 && n_embed == 2560) return memory_required_mb[3][data_type];
/* 7B */ if (n_layer == 32 && n_embed == 4096) return memory_required_mb[4][data_type];
/* 14B */ if (n_layer == 40 && n_embed == 5120) return memory_required_mb[5][data_type];
} }
fprintf( fprintf(
@ -177,7 +187,14 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads)
RWKV_ASSERT_NULL(model->n_layer > 0, "Non-positive n_layer %d", model->n_layer); RWKV_ASSERT_NULL(model->n_layer > 0, "Non-positive n_layer %d", model->n_layer);
read_int32(file, &(model->data_type)); read_int32(file, &(model->data_type));
RWKV_ASSERT_NULL(model->data_type == 0 || model->data_type == 1, "Unsupported model data type %d", model->data_type); RWKV_ASSERT_NULL(
model->data_type == 0 ||
model->data_type == 1 ||
model->data_type == 2 ||
model->data_type == 3,
"Unsupported model data type %d",
model->data_type
);
// Initialize ggml // Initialize ggml
struct ggml_init_params params; struct ggml_init_params params;
@ -203,9 +220,24 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads)
int32_t data_type; int32_t data_type;
read_int32(file, &data_type); read_int32(file, &data_type);
RWKV_ASSERT_NULL(data_type == 0 || data_type == 1, "Unsupported parameter data type %d", data_type); RWKV_ASSERT_NULL(
data_type == 0 ||
data_type == 1 ||
data_type == 2 ||
data_type == 3,
"Unsupported parameter data type %d",
data_type
);
ggml_type ggml_data_type = data_type == 0 ? GGML_TYPE_F32 : GGML_TYPE_F16; ggml_type ggml_data_type;
switch (data_type) {
case 0: ggml_data_type = GGML_TYPE_F32; break;
case 1: ggml_data_type = GGML_TYPE_F16; break;
case 2: ggml_data_type = GGML_TYPE_Q4_0; break;
case 3: ggml_data_type = GGML_TYPE_Q4_1; break;
default: return NULL;
}
struct ggml_tensor * tensor; struct ggml_tensor * tensor;
@ -229,8 +261,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, int n_threads)
std::string key(key_length, 0); std::string key(key_length, 0);
RWKV_ASSERT_NULL(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key"); RWKV_ASSERT_NULL(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key");
size_t byte_count = element_count * ggml_type_size(ggml_data_type); RWKV_ASSERT_NULL(fread(tensor->data, 1, ggml_nbytes(tensor), file) == ggml_nbytes(tensor), "Failed to read parameter data");
RWKV_ASSERT_NULL(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data");
parameters[key] = tensor; parameters[key] = tensor;
} }
@ -533,6 +564,215 @@ void rwkv_free(struct rwkv_context * ctx) {
delete ctx; delete ctx;
} }
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type) {
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type);
ggml_type type;
switch (q_type) {
case 2: type = GGML_TYPE_Q4_0; break;
case 3: type = GGML_TYPE_Q4_1; break;
default: return false;
};
RWKV_ASSERT_FALSE(type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1, "Unsupported data type %d", type);
printf("Loading model from '%s'\n", model_file_path_in);
auto finp = std::ifstream(model_file_path_in, std::ios::binary);
RWKV_ASSERT_FALSE(finp, "Failed to open %s for reading", model_file_path_in);
auto fout = std::ofstream(model_file_path_out, std::ios::binary);
RWKV_ASSERT_FALSE(fout, "Failed to open %s for writing", model_file_path_out);
// Process header
{
uint32_t magic;
finp.read((char *) &magic, sizeof(magic));
RWKV_ASSERT_FALSE(magic == RWKV_FILE_MAGIC, "Unexpected magic value %d", magic);
fout.write((char *) &magic, sizeof(magic));
uint32_t format_version;
finp.read((char *) &format_version, sizeof(format_version));
RWKV_ASSERT_FALSE(format_version == RWKV_FILE_VERSION, "Unsupported file version %d", format_version);
fout.write((char *) &format_version, sizeof(format_version));
int32_t n_vocab;
int32_t n_embed;
int32_t n_layer;
int32_t data_type;
finp.read((char *) &n_vocab, sizeof(n_vocab));
finp.read((char *) &n_embed, sizeof(n_embed));
finp.read((char *) &n_layer, sizeof(n_layer));
finp.read((char *) &data_type, sizeof(data_type));
RWKV_ASSERT_FALSE(data_type == 0 || data_type == 1, "Unsupported data type %d, only FP32 and FP16 can be quantized", data_type);
data_type = q_type;
fout.write((char *) &n_vocab, sizeof(n_vocab));
fout.write((char *) &n_embed, sizeof(n_embed));
fout.write((char *) &n_layer, sizeof(n_layer));
fout.write((char *) &data_type, sizeof(data_type));
}
// Process parameters
{
size_t total_size_orig = 0;
size_t total_size_new = 0;
std::vector<float> work;
std::vector<uint8_t> data_u8;
std::vector<ggml_fp16_t> data_f16;
std::vector<float> data_f32;
std::vector<int64_t> hist_all(1 << 4, 0);
while (true) {
int32_t n_dims;
int32_t key_length;
int32_t parameter_data_type;
finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
finp.read(reinterpret_cast<char *>(&key_length), sizeof(key_length));
finp.read(reinterpret_cast<char *>(&parameter_data_type), sizeof(parameter_data_type));
if (finp.eof()) {
break;
}
int32_t nelements = 1;
int32_t ne[2] = { 1, 1 };
for (int i = 0; i < n_dims; ++i) {
finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
nelements *= ne[i];
}
std::string name(key_length, 0);
finp.read(&name[0], key_length);
{
static const char * parameter_data_type_str[] = {
"f32",
"f16",
"q4_0",
"q4_1"
};
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]);
}
// Quantize only 2D tensors
bool quantize = n_dims == 2;
if (quantize) {
if (parameter_data_type != 0 && parameter_data_type != 1) {
fprintf(stderr, "unsupported data type %d for integer quantization\n", parameter_data_type);
return false;
}
if (parameter_data_type == 1) {
data_f16.resize(nelements);
finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
data_f32.resize(nelements);
for (int i = 0; i < nelements; ++i) {
data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
}
} else {
data_f32.resize(nelements);
finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
}
parameter_data_type = q_type;
} else {
const int bytes_per_element = (parameter_data_type == 0) ? sizeof(float) : sizeof(uint16_t);
data_u8.resize(nelements * bytes_per_element);
finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bytes_per_element);
}
fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fout.write(reinterpret_cast<char *>(&key_length), sizeof(key_length));
fout.write(reinterpret_cast<char *>(&parameter_data_type), sizeof(parameter_data_type));
for (int i = 0; i < n_dims; ++i) {
fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
}
fout.write(&name[0], key_length);
if (quantize) {
printf("quantizing... ");
work.resize(nelements); // for quantization
size_t cur_size = 0;
std::vector<int64_t> hist_cur(1 << 4, 0);
switch (type) {
case GGML_TYPE_Q4_0:
{
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q4_1:
{
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
default:
{
fprintf(stderr, "unsupported quantization type %d\n", type);
return false;
}
}
fout.write(reinterpret_cast<char *>(work.data()), cur_size);
total_size_new += cur_size;
printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float) / 1024.0 / 1024.0, cur_size / 1024.0 / 1024.0);
for (int i = 0; i < (int) hist_cur.size(); ++i) {
hist_all[i] += hist_cur[i];
}
for (int i = 0; i < (int) hist_cur.size(); ++i) {
printf("%5.3f ", hist_cur[i] / float(nelements));
}
printf("\n");
} else {
printf("size = %8.3f MB\n", data_u8.size() / 1024.0 / 1024.0);
fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
total_size_new += data_u8.size();
}
total_size_orig += nelements * sizeof(float);
}
printf("model size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0);
printf("quant size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0);
{
int64_t sum_all = 0;
for (int i = 0; i < (int) hist_all.size(); ++i) {
sum_all += hist_all[i];
}
printf("hist: ");
for (int i = 0; i < (int) hist_all.size(); ++i) {
printf("%5.3f ", hist_all[i] / float(sum_all));
}
printf("\n");
}
}
finp.close();
fout.close();
return true;
}
const char * rwkv_get_system_info_string(void) { const char * rwkv_get_system_info_string(void) {
static std::string s; static std::string s;

9
rwkv.h
View File

@ -31,6 +31,8 @@ extern "C" {
// Loads the model from a file and prepares it for inference. // Loads the model from a file and prepares it for inference.
// Returns NULL on any error. Error messages would be printed to stderr. // Returns NULL on any error. Error messages would be printed to stderr.
// - model_file_path: path to model file in ggml format.
// - n_threads: count of threads to use, must be positive.
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads); RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, int n_threads);
// Evaluates the model for a single token. // Evaluates the model for a single token.
@ -50,6 +52,13 @@ extern "C" {
// Frees all allocated memory and the context. // Frees all allocated memory and the context.
RWKV_API void rwkv_free(struct rwkv_context * ctx); RWKV_API void rwkv_free(struct rwkv_context * ctx);
// Quantizes FP32 or FP16 model to one of INT4 formats.
// Returns false on any error. Error messages would be printed to stderr.
// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
// - model_file_path_out: quantized model will be written here.
// - q_type: set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1.
RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, int q_type);
// Returns system information string. // Returns system information string.
RWKV_API const char * rwkv_get_system_info_string(void); RWKV_API const char * rwkv_get_system_info_string(void);

View File

@ -1,7 +1,7 @@
# Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV. # Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV.
# Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch. # Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch.
# Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py # Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py
# Usage: python compare_cpp_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin # Usage: python compare_with_reference_implementation.py bin\Release\main_rwkv.exe C:\rwkv.cpp-169M.bin
import os import os
import struct import struct
@ -40,7 +40,10 @@ def main() -> None:
header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4)) header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4))
data_type: int = header[5] data_type: int = header[5]
assert data_type == 0 or data_type == 1, f'Unsupported model data type {data_type}' assert data_type == 0 or\
data_type == 1 or\
data_type == 2 or\
data_type == 3, f'Unsupported model data type {data_type}'
if data_type == 0: if data_type == 0:
# FP32, high precision # FP32, high precision
@ -48,6 +51,13 @@ def main() -> None:
elif data_type == 1: elif data_type == 1:
# FP16, lower precision, so higher threshold # FP16, lower precision, so higher threshold
threshold = 0.003 threshold = 0.003
elif data_type == 2:
# INT4 quantized, even lower precision, so even higher threshold
# This threshold will let some bugs pass
threshold = 4.0
elif data_type == 3:
# This format stores more data, so error would be lower
threshold = 1.2
model = None model = None

View File

@ -1,5 +1,5 @@
# Converts an RWKV model checkpoint to an rwkv.cpp compatible file. # Converts an RWKV model checkpoint to an rwkv.cpp compatible file.
# Usage: python convert_pytorch_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32 # Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32
# Get model checkpoints from https://huggingface.co/BlinkDL # Get model checkpoints from https://huggingface.co/BlinkDL
# File format: # File format:

34
rwkv/quantize.py Normal file
View File

@ -0,0 +1,34 @@
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1.
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3
import ctypes
import argparse
def parse_args():
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1')
parser.add_argument('shared_library_path', help='Path to rwkv.cpp shared library')
parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3)
return parser.parse_args()
def main() -> None:
args = parse_args()
library = ctypes.cdll.LoadLibrary(args.shared_library_path)
library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_int]
library.rwkv_quantize_model_file.restype = ctypes.c_bool
result: bool = library.rwkv_quantize_model_file(
args.src_path.encode('utf-8'),
args.dest_path.encode('utf-8'),
ctypes.c_int(args.data_type)
)
assert result, 'Failed to quantize, check stderr'
print('Done')
if __name__ == "__main__":
main()