Add Q4_1_O format

This commit is contained in:
saharNooby 2023-04-07 09:55:39 +04:00
parent ec99bc1765
commit c40941d9d0
6 changed files with 659 additions and 54 deletions

648
ggml.c
View File

@ -25,6 +25,35 @@
#define static_assert(cond, msg) struct global_scope_noop_trick
#endif
// https://gist.github.com/rygorous/2144712
// Public domain, by Fabian "ryg" Giesen
inline static float ggml_half_to_float_reference(uint16_t value) {
union FP32 {
uint32_t u;
float f;
};
const union FP32 magic = { (254UL - 15UL) << 23 };
const union FP32 was_inf_nan = { (127UL + 16UL) << 23 };
union FP32 out;
// Exponent/mantissa bits
out.u = (value & 0x7FFFU) << 13;
// Exponent adjust
out.f *= magic.f;
// Make sure Inf/NaN survive
if (out.f >= was_inf_nan.f) {
out.u |= 255UL << 23;
}
// Sign bit
out.u |= (value & 0x8000UL) << 16;
return out.f;
}
#if defined _MSC_VER || defined(__MINGW32__)
#if !defined(__MINGW32__)
@ -326,42 +355,13 @@ static float table_f32_f16[1 << 16];
// This is also true for POWER9.
#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
// https://gist.github.com/rygorous/2144712
// Public domain, by Fabian "ryg" Giesen
inline static float ggml_half_to_float(uint16_t value) {
union FP32 {
uint32_t u;
float f;
};
const union FP32 magic = { (254UL - 15UL) << 23 };
const union FP32 was_inf_nan = { (127UL + 16UL) << 23 };
union FP32 out;
// Exponent/mantissa bits
out.u = (value & 0x7FFFU) << 13;
// Exponent adjust
out.f *= magic.f;
// Make sure Inf/NaN survive
if (out.f >= was_inf_nan.f) {
out.u |= 255UL << 23;
}
// Sign bit
out.u |= (value & 0x8000UL) << 16;
return out.f;
}
inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
// For some reason, lookup table does not work on my machine:
// - Windows SDK version 10.0.19041.0
// - CMAKE_SYSTEM_PROCESSOR: AMD64
// Replaced lookup with some conversion code found online.
// For some reason, lookup table does not work on my machine.
// Replaced lookup with working reference code.
// TODO This must be properly debugged and fixed
return ggml_half_to_float(f);
return ggml_half_to_float_reference(f);
}
#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
@ -514,6 +514,19 @@ typedef struct {
} block_q4_1;
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding");
// Method 4 with better outlier handling.
typedef struct {
ggml_fp16_t d;
ggml_fp16_t m;
// We need only 5 bits for the in-block index, so 16 bits is overkill.
// TODO Optimize if possible
uint16_t outlier_index;
ggml_fp16_t outlier_value;
// Nibbles / quants.
uint8_t qs[QK / 2];
} block_q4_1_o;
static_assert(sizeof(block_q4_1_o) == 8 + QK / 2, "wrong q4_1_o block size/padding");
// reference implementation for deterministic creation of model files
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
assert(k % QK == 0);
@ -1118,6 +1131,208 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
#endif
}
// Q4_1_O
static inline void quantize_row_q4_1_o_reference_single_block(const float * restrict x, block_q4_1_o * restrict block) {
// An outlier is just the absmax element in the block.
// We store it separately and do not quantize it.
int outlier_index = -1;
float outlier_value = 0.0F;
for (int l = 0; l < QK; l++) {
const float v = x[l];
if (fabsf(v) > fabsf(outlier_value)) {
outlier_index = l;
outlier_value = v;
}
}
block->outlier_index = outlier_index;
block->outlier_value = GGML_COMPUTE_FP32_TO_FP16(outlier_value);
float min = FLT_MAX;
float max = -FLT_MAX;
for (int l = 0; l < QK; l++) {
if (l == outlier_index) {
// Ignore outlier when computing range.
continue;
}
const float v = x[l];
if (v < min) min = v;
if (v > max) max = v;
}
const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0F / d : 0.0F;
block->d = GGML_COMPUTE_FP32_TO_FP16(d);
block->m = GGML_COMPUTE_FP32_TO_FP16(min);
uint8_t pp[QK / 2];
for (int l = 0; l < QK; l += 2) {
float v0 = (x[l + 0] - min) * id;
float v1 = (x[l + 1] - min) * id;
// Write some garbage but valid index for the outlier.
if (l + 0 == outlier_index) v0 = 0.0;
if (l + 1 == outlier_index) v1 = 0.0;
const uint8_t vi0 = roundf(v0);
const uint8_t vi1 = roundf(v1);
assert(vi0 >= 0 && vi0 < 16);
assert(vi1 >= 0 && vi1 < 16);
pp[l/2] = vi0 | (vi1 << 4);
}
memcpy(block->qs, pp, sizeof(pp));
}
static inline void dequantize_row_q4_1_o_reference_single_block(block_q4_1_o * restrict block, float * restrict y) {
const float d = ggml_half_to_float_reference(block->d);
const float m = ggml_half_to_float_reference(block->m);
const uint8_t * restrict pp = block->qs;
for (int l = 0; l < QK; l += 2) {
const uint8_t vi = pp[l / 2];
const int8_t vi0 = vi & 0xF;
const int8_t vi1 = vi >> 4;
const float v0 = vi0 * d + m;
const float v1 = vi1 * d + m;
y[l + 0] = v0;
y[l + 1] = v1;
assert(!isnan(y[l + 0]));
assert(!isnan(y[l + 1]));
}
// Restore the outlier
y[block->outlier_index] = ggml_half_to_float_reference(block->outlier_value);
}
static void quantize_row_q4_1_o_reference(const float * restrict x, void * restrict vy, int k) {
assert(k % QK == 0);
const int nb = k / QK;
block_q4_1_o * restrict y = vy;
for (int i = 0; i < nb; i++) {
quantize_row_q4_1_o_reference_single_block(x + i * QK, y + i);
}
}
static void quantize_row_q4_1_o(const float * restrict x, void * restrict vy, int k) {
quantize_row_q4_1_o_reference(x, vy, k);
}
static void dequantize_row_q4_1_o(const void * restrict vx, float * restrict y, int k) {
assert(k % QK == 0);
const int nb = k / QK;
const block_q4_1_o * restrict x = vx;
#if defined(__AVX2__)
for (int i = 0; i < nb; i++) {
const float x_d = ggml_half_to_float_reference(x[i].d);
const float x_m = ggml_half_to_float_reference(x[i].m);
const __m256 d_v = _mm256_broadcast_ss(&x_d);
const __m256 d_m = _mm256_broadcast_ss(&x_m);
const uint8_t * restrict pp = x[i].qs;
for (int l = 0; l < QK; l += 32) {
// Load 32x4-bit integers into 32x8-bit integers
__m256i vx8 = bytesFromNibbles(pp+l/2);
// Convert to 16-bit int
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
// Convert to 32-bit int -> float 32
const __m256 vf[4] = {
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
};
// Scale, add m and store
for (int j = 0; j < 4; j++) {
const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
_mm256_storeu_ps(y + i * QK + l + j*8, result);
}
}
// Restore the outlier
y[i * QK + x[i].outlier_index] = ggml_half_to_float_reference(x[i].outlier_value);
}
#elif defined(__ARM_NEON)
for (int i = 0; i < nb; i++) {
const float x_d = ggml_half_to_float_reference(x[i].d);
const float x_m = ggml_half_to_float_reference(x[i].m);
const float32x4_t vd = vdupq_n_f32(x_d);
const float32x4_t vm = vdupq_n_f32(x_m);
const uint8_t * restrict pp = x[i].qs;
for (int l = 0; l < QK; l += 16) {
// Load 16x4-bit integers into 8x8-bit integers
const uint8x8_t v8 = vld1_u8(pp + l/2);
// Expand 4-bit qs to 8-bit bytes
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
const uint8x8_t v1 = vshr_n_u8(v8, 4);
// Interleave and combine
const uint8x8_t vx_0 = vzip1_u8(v0, v1);
const uint8x8_t vx_1 = vzip2_u8(v0, v1);
const uint8x16_t vq = vcombine_u8(vx_0, vx_1);
// convert to 2x uint16x8_t
const uint16x8_t vi_0 = vmovl_s8(vget_low_u8 (vq));
const uint16x8_t vi_1 = vmovl_s8(vget_high_u8(vq));
// convert to 4x float32x4_t
const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0)));
const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0)));
const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1)));
const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1)));
// multiply by d and add m
const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd);
const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd);
const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd);
const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
// Store
vst1q_f32(y + i*QK + l + 0, r0);
vst1q_f32(y + i*QK + l + 4, r1);
vst1q_f32(y + i*QK + l + 8, r2);
vst1q_f32(y + i*QK + l + 12, r3);
}
// Restore the outlier
y[i * QK + x[i].outlier_index] = ggml_half_to_float_reference(x[i].outlier_value);
}
#else
for (int i = 0; i < nb; i++) {
dequantize_row_q4_1_o_reference_single_block(x + i, y + i * QK);
}
#endif
}
//
// simd mappings
//
@ -2437,6 +2652,7 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
//
static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
QK,
QK,
QK,
1,
@ -2446,11 +2662,12 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
1,
};
static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_COUNT != 8");
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
sizeof(block_q4_0),
sizeof(block_q4_1),
sizeof(block_q4_1_o),
sizeof(int8_t ),
sizeof(int16_t),
sizeof(int32_t),
@ -2459,7 +2676,7 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
};
// don't forget to update the array above when adding new types
static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_COUNT != 8");
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"NONE",
@ -3196,6 +3413,10 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_1_O:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
assert(tensor->nb[0] == sizeof(int8_t));
@ -3256,6 +3477,10 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_1_O:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
assert(tensor->nb[0] == sizeof(int8_t));
@ -3310,6 +3535,10 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_1_O:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@ -3354,6 +3583,10 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_1_O:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@ -3396,6 +3629,10 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_1_O:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@ -3440,6 +3677,10 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_1_O:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@ -4990,6 +5231,7 @@ static void ggml_compute_forward_dup(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5067,6 +5309,7 @@ static void ggml_compute_forward_add(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5119,6 +5362,7 @@ static void ggml_compute_forward_sub(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5171,6 +5415,7 @@ static void ggml_compute_forward_mul(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5223,6 +5468,7 @@ static void ggml_compute_forward_div(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5271,6 +5517,7 @@ static void ggml_compute_forward_sqr(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5319,6 +5566,7 @@ static void ggml_compute_forward_sqrt(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5377,6 +5625,7 @@ static void ggml_compute_forward_sum(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5454,6 +5703,7 @@ static void ggml_compute_forward_mean(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5518,6 +5768,7 @@ static void ggml_compute_forward_repeat(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5566,6 +5817,7 @@ static void ggml_compute_forward_abs(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5614,6 +5866,7 @@ static void ggml_compute_forward_sgn(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5662,6 +5915,7 @@ static void ggml_compute_forward_neg(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5710,6 +5964,7 @@ static void ggml_compute_forward_exp(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5758,6 +6013,7 @@ static void ggml_compute_forward_1_minus_x(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5810,6 +6066,7 @@ static void ggml_compute_forward_max(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5858,6 +6115,7 @@ static void ggml_compute_forward_step(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5906,6 +6164,7 @@ static void ggml_compute_forward_relu(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5971,6 +6230,7 @@ static void ggml_compute_forward_gelu(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -6021,6 +6281,7 @@ static void ggml_compute_forward_sigmoid(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -6086,6 +6347,7 @@ static void ggml_compute_forward_silu(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -6172,6 +6434,7 @@ static void ggml_compute_forward_norm(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -6252,6 +6515,7 @@ static void ggml_compute_forward_rms_norm(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -6669,6 +6933,11 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
.quantize_row_q = quantize_row_q4_1,
.vec_dot_q = ggml_vec_dot_q4_1,
},
[GGML_TYPE_Q4_1_O] = {
.dequantize_row_q = dequantize_row_q4_1_o,
.quantize_row_q = quantize_row_q4_1_o,
.vec_dot_q = NULL,
},
};
static void ggml_compute_forward_mul_mat_q_f32(
@ -6859,6 +7128,273 @@ static void ggml_compute_forward_mul_mat_q_f32(
//}
}
static void ggml_compute_forward_mul_mat_q4_1_o_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
int64_t t0 = ggml_perf_time_us();
UNUSED(t0);
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const int ne10 = src1->ne[0];
const int ne11 = src1->ne[1];
const int ne12 = src1->ne[2];
const int ne13 = src1->ne[3];
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
const int ne2 = dst->ne[2];
const int ne3 = dst->ne[3];
const int nb00 = src0->nb[0];
const int nb01 = src0->nb[1];
const int nb02 = src0->nb[2];
const int nb03 = src0->nb[3];
const int nb10 = src1->nb[0];
const int nb11 = src1->nb[1];
const int nb12 = src1->nb[2];
const int nb13 = src1->nb[3];
const int nb0 = dst->nb[0];
const int nb1 = dst->nb[1];
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
const int ith = params->ith;
const int nth = params->nth;
GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
const enum ggml_type type = src0->type;
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
GGML_ASSERT(nb10 == sizeof(float));
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne03);
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
if (params->ith != 0) {
return;
}
if (params->type == GGML_TASK_INIT) {
return;
}
if (params->type == GGML_TASK_FINALIZE) {
return;
}
float * const wdata = params->wdata;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
{
size_t id = 0;
for (int i01 = 0; i01 < ne01; ++i01) {
dequantize_row_q4_1_o((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
id += ne00;
}
}
const float * x = wdata;
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
// zT = y * xT
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
ne11, ne01, ne10,
1.0f, y, ne10,
x, ne10,
0.0f, d, ne01);
}
}
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
return;
}
#endif
if (params->type == GGML_TASK_INIT) {
return;
}
if (params->type == GGML_TASK_FINALIZE) {
return;
}
// parallelize by src0 rows using ggml_vec_dot_f32
// total rows in src0
const int nr = ne01*ne02*ne03;
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
#if defined(__AVX2__)
float outlier_mask[QK];
memset(outlier_mask, 0, QK * sizeof(float));
#endif
for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices
const int i03 = ir/(ne02*ne01);
const int i02 = (ir - i03*ne02*ne01)/ne01;
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
#if defined(__AVX2__)
for (int ic = 0; ic < ne11; ++ic) {
// src1 indices
const int i13 = i03;
const int i12 = i02;
const int i11 = ic;
// dst indices
const int i0 = i01;
const int i1 = i11;
const int i2 = i02;
const int i3 = i03;
const int block_count = ne00 / QK;
const block_q4_1_o * row_blocks = (block_q4_1_o *) ((char *) src0->data + (i01 * nb01 + i02 * nb02 + i03 * nb03));
__m256 accum = _mm256_setzero_ps();
// Here we do fused dequantization and dot product.
for (int block_index = 0; block_index < block_count; block_index++) {
const float block_d = ggml_half_to_float_reference(row_blocks[block_index].d);
const float block_m = ggml_half_to_float_reference(row_blocks[block_index].m);
// 0 .. 31
const uint16_t outlier_index = row_blocks[block_index].outlier_index;
const float outlier_value = ggml_half_to_float_reference(row_blocks[block_index].outlier_value);
const uint8_t * restrict quant_nibbles = row_blocks[block_index].qs;
// ---
// Broadcast values to 8x element float32 vectors
const __m256 broadcasted_d = _mm256_broadcast_ss(&block_d);
const __m256 broadcasted_m = _mm256_broadcast_ss(&block_m);
const __m256 broadcasted_outlier_value = _mm256_broadcast_ss(&outlier_value);
// Load 32x4-bit integers into 32x8-bit integers
const __m256i quant_bytes = bytesFromNibbles(quant_nibbles);
// Convert to 16-bit int
const __m256i quant_shorts_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(quant_bytes, 0));
const __m256i quant_shorts_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(quant_bytes, 1));
// Convert to 32-bit int and then to 32-bit float
const __m256 quant_floats_0 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_lo, 0)));
const __m256 quant_floats_1 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_lo, 1)));
const __m256 quant_floats_2 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_hi, 0)));
const __m256 quant_floats_3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_hi, 1)));
// Dequantize to ~original weights
const __m256 weight_0 = _mm256_fmadd_ps(quant_floats_0, broadcasted_d, broadcasted_m);
const __m256 weight_1 = _mm256_fmadd_ps(quant_floats_1, broadcasted_d, broadcasted_m);
const __m256 weight_2 = _mm256_fmadd_ps(quant_floats_2, broadcasted_d, broadcasted_m);
const __m256 weight_3 = _mm256_fmadd_ps(quant_floats_3, broadcasted_d, broadcasted_m);
// TODO This outlier handling is VERY slow
// Set outlier mask -- this should give 1 in the most significant bit
outlier_mask[outlier_index] = -1.0F;
// Load mask into vectors
const __m256 outlier_mask_0 = _mm256_load_ps(outlier_mask);
const __m256 outlier_mask_1 = _mm256_load_ps(outlier_mask + 8);
const __m256 outlier_mask_2 = _mm256_load_ps(outlier_mask + 16);
const __m256 outlier_mask_3 = _mm256_load_ps(outlier_mask + 24);
// Reset mask array to all zeroes for the next iteration
outlier_mask[outlier_index] = 0.0F;
// Replace the weight at the index of the outlier
const __m256 weight_0_with_outlier = _mm256_blendv_ps(weight_0, broadcasted_outlier_value, outlier_mask_0);
const __m256 weight_1_with_outlier = _mm256_blendv_ps(weight_1, broadcasted_outlier_value, outlier_mask_1);
const __m256 weight_2_with_outlier = _mm256_blendv_ps(weight_2, broadcasted_outlier_value, outlier_mask_2);
const __m256 weight_3_with_outlier = _mm256_blendv_ps(weight_3, broadcasted_outlier_value, outlier_mask_3);
// Load 32 floats of data of the second argument
const float * src1_data = (float *) ((char *) src1->data + (block_index * QK * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13));
const __m256 src1_0 = _mm256_load_ps(src1_data);
const __m256 src1_1 = _mm256_load_ps(src1_data + 8);
const __m256 src1_2 = _mm256_load_ps(src1_data + 16);
const __m256 src1_3 = _mm256_load_ps(src1_data + 24);
// Multiply weights and values of the second argument element-wise; add to accumulator
accum = _mm256_fmadd_ps(src1_0, weight_0_with_outlier, accum);
accum = _mm256_fmadd_ps(src1_1, weight_1_with_outlier, accum);
accum = _mm256_fmadd_ps(src1_2, weight_2_with_outlier, accum);
accum = _mm256_fmadd_ps(src1_3, weight_3_with_outlier, accum);
}
// Add elements of accumulator
__m128 res = _mm256_extractf128_ps(accum, 1);
res = _mm_add_ps(res, _mm256_castps256_ps128(accum));
res = _mm_add_ps(res, _mm_movehl_ps(res, res ));
res = _mm_add_ss(res, _mm_movehdup_ps(res));
*((float *) ((char *) dst->data + (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3))) = _mm_cvtss_f32(res);
}
#else
float * const wdata = (float *) ((char *) params->wdata + (i01 * nb01 + i02 * nb02 + i03 * nb03));
dequantize_row_q4_1_o((char *) src0->data + (i01 * nb01 + i02 * nb02 + i03 * nb03), wdata, ne00);
for (int ic = 0; ic < ne11; ++ic) {
// src1 indices
const int i13 = i03;
const int i12 = i02;
const int i11 = ic;
// dst indices
const int i0 = i01;
const int i1 = i11;
const int i2 = i02;
const int i3 = i03;
ggml_vec_dot_f32(
ne00,
(float *) ((char *) dst->data + (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3)),
wdata,
(float *) ((char *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13))
);
}
#endif
}
}
static void ggml_compute_forward_mul_mat(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
@ -6870,6 +7406,10 @@ static void ggml_compute_forward_mul_mat(
{
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_1_O:
{
ggml_compute_forward_mul_mat_q4_1_o_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst);
@ -6965,6 +7505,7 @@ static void ggml_compute_forward_scale(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -7121,6 +7662,7 @@ static void ggml_compute_forward_get_rows(
switch (src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
{
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
} break;
@ -7210,6 +7752,7 @@ static void ggml_compute_forward_diag_mask_inf(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -7304,6 +7847,7 @@ static void ggml_compute_forward_soft_max(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -7446,6 +7990,7 @@ static void ggml_compute_forward_rope(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -7714,6 +8259,7 @@ static void ggml_compute_forward_conv_1d_1s(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -7982,6 +8528,7 @@ static void ggml_compute_forward_conv_1d_2s(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -8467,6 +9014,7 @@ static void ggml_compute_forward_flash_attn(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -8678,6 +9226,7 @@ static void ggml_compute_forward_flash_ff(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1_O:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -9508,6 +10057,14 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
#endif
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
cur = 0;
} else if (node->src0->type == GGML_TYPE_Q4_1_O && node->src1->type == GGML_TYPE_F32) {
#if defined(__AVX2__)
cur = 0;
#else
// Assuming that src1 is a vector
// TODO Not sure whether this is correct
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * ggml_nelements(node->src1);
#endif
} else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
@ -10729,6 +11286,29 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
return (n/QK*sizeof(block_q4_1));
}
size_t ggml_quantize_q4_1_o(const float * src, void * dst, int n, int k, int64_t * hist) {
assert(k % QK == 0);
const int nb = k / QK;
for (int j = 0; j < n; j += k) {
block_q4_1_o * restrict y = (block_q4_1_o *) dst + j / QK;
quantize_row_q4_1_o_reference(src + j, y, k);
for (int i = 0; i < nb; i++) {
for (int l = 0; l < QK; l += 2) {
const uint8_t vi0 = y[i].qs[l / 2] & 0xF;
const uint8_t vi1 = y[i].qs[l / 2] >> 4;
hist[vi0]++;
hist[vi1]++;
}
}
}
return (n / QK * sizeof(block_q4_1_o));
}
////////////////////////////////////////////////////////////////////////////////
int ggml_cpu_has_avx(void) {

5
ggml.h
View File

@ -226,7 +226,11 @@ struct ggml_context;
enum ggml_type {
GGML_TYPE_Q4_0,
// Stores min and delta per block, does quantized matmul.
GGML_TYPE_Q4_1,
// Same as Q4_1, but stores outliers separately, and matmul is done in FP32.
// An outlier is the single absmax element in the quantized block.
GGML_TYPE_Q4_1_O,
GGML_TYPE_I8,
GGML_TYPE_I16,
GGML_TYPE_I32,
@ -807,6 +811,7 @@ enum ggml_opt_result ggml_opt(
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
size_t ggml_quantize_q4_1_o(const float * src, void * dst, int n, int k, int64_t * hist);
//
// system info

View File

@ -160,7 +160,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
model->data_type == 0 ||
model->data_type == 1 ||
model->data_type == 2 ||
model->data_type == 3,
model->data_type == 3 ||
model->data_type == 4,
"Unsupported model data type %d",
model->data_type
);
@ -216,7 +217,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
data_type == 0 ||
data_type == 1 ||
data_type == 2 ||
data_type == 3,
data_type == 3 ||
data_type == 4,
"Unsupported parameter data type %d",
data_type
);
@ -228,6 +230,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
case 1: ggml_data_type = GGML_TYPE_F16; break;
case 2: ggml_data_type = GGML_TYPE_Q4_0; break;
case 3: ggml_data_type = GGML_TYPE_Q4_1; break;
case 4: ggml_data_type = GGML_TYPE_Q4_1_O; break;
default: return NULL;
}
@ -553,18 +556,17 @@ void rwkv_free(struct rwkv_context * ctx) {
}
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) {
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type);
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4, "Unsupported quantization type %d", q_type);
ggml_type type;
switch (q_type) {
case 2: type = GGML_TYPE_Q4_0; break;
case 3: type = GGML_TYPE_Q4_1; break;
case 4: type = GGML_TYPE_Q4_1_O; break;
default: return false;
};
RWKV_ASSERT_FALSE(type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1, "Unsupported data type %d", type);
printf("Loading model from '%s'\n", model_file_path_in);
auto finp = std::ifstream(model_file_path_in, std::ios::binary);
@ -646,7 +648,8 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
"f32",
"f16",
"q4_0",
"q4_1"
"q4_1",
"q4_1_o"
};
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]);
@ -655,6 +658,7 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
4.0F,
2.0F,
20.0F / 32.0F,
24.0F / 32.0F,
24.0F / 32.0F
};
total_size_orig += (size_t) (nelements * parameter_data_type_size[parameter_data_type]);
@ -668,10 +672,11 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
name != std::string("head.weight");
if (quantize) {
if (parameter_data_type != 0 && parameter_data_type != 1) {
fprintf(stderr, "unsupported data type %d for integer quantization\n", parameter_data_type);
return false;
}
RWKV_ASSERT_FALSE(
parameter_data_type == 0 || parameter_data_type == 1,
"Unsupported parameter data type %d, only FP32 and FP16 can be quantized",
parameter_data_type
);
if (parameter_data_type == 1) {
data_f16.resize(nelements);
@ -719,6 +724,10 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
{
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q4_1_O:
{
cur_size = ggml_quantize_q4_1_o(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
default:
{
fprintf(stderr, "unsupported quantization type %d\n", type);

View File

@ -37,7 +37,8 @@ def main() -> None:
assert data_type == 0 or\
data_type == 1 or\
data_type == 2 or\
data_type == 3, f'Unsupported model data type {data_type}'
data_type == 3 or\
data_type == 4, f'Unsupported model data type {data_type}'
if data_type == 0:
# FP32, high precision
@ -46,12 +47,14 @@ def main() -> None:
# FP16, lower precision, so higher threshold
threshold = 0.0032
elif data_type == 2:
# INT4 quantized, even lower precision, so even higher threshold
# This threshold will let some bugs pass
threshold = 4.0
# Q4_0 quantized, even lower precision, so even higher threshold
threshold = 0.4
elif data_type == 3:
# This format stores more data, so error would be lower
threshold = 1.2
# Q4_1
threshold = 1.21
elif data_type == 4:
# Q4_1_O
threshold = 0.2
model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path)

View File

@ -1,5 +1,5 @@
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1.
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0, Q4_1 or Q4_1_O (recommended).
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1_o.bin 4
import argparse
import rwkv_cpp_shared_library
@ -8,12 +8,20 @@ def parse_args():
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1')
parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3)
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0), 3 (GGML_TYPE_Q4_1) or 4 (GGML_TYPE_Q4_1_O)', type=int, choices=[2, 3, 4], default=4)
return parser.parse_args()
def main() -> None:
args = parse_args()
if args.data_type == 2 or args.data_type == 3:
print()
print('WARNING!')
print('You are using Q4_0 or Q4_1 quantization; it will heavily degrade RWKV quality.')
print('For best quality preservation, it is recommended to use Q4_1_O.')
print('More info at https://github.com/saharNooby/rwkv.cpp/issues/12')
print()
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
library.rwkv_quantize_model_file(

View File

@ -118,5 +118,5 @@ class RWKVModel:
def __del__(self):
# Free the context on GC in case user forgot to call free() explicitly.
if self._valid:
if hasattr(self, '_valid') and self._valid:
self.free()