From c40941d9d010a3e0cc3748705eac5d747e72451a Mon Sep 17 00:00:00 2001 From: saharNooby Date: Fri, 7 Apr 2023 09:55:39 +0400 Subject: [PATCH] Add Q4_1_O format --- ggml.c | 648 +++++++++++++++++- ggml.h | 5 + rwkv.cpp | 29 +- rwkv/compare_with_reference_implementation.py | 15 +- rwkv/quantize.py | 14 +- rwkv/rwkv_cpp_model.py | 2 +- 6 files changed, 659 insertions(+), 54 deletions(-) diff --git a/ggml.c b/ggml.c index 6efd47e..7d75e2e 100644 --- a/ggml.c +++ b/ggml.c @@ -25,6 +25,35 @@ #define static_assert(cond, msg) struct global_scope_noop_trick #endif +// https://gist.github.com/rygorous/2144712 +// Public domain, by Fabian "ryg" Giesen +inline static float ggml_half_to_float_reference(uint16_t value) { + union FP32 { + uint32_t u; + float f; + }; + + const union FP32 magic = { (254UL - 15UL) << 23 }; + const union FP32 was_inf_nan = { (127UL + 16UL) << 23 }; + + union FP32 out; + + // Exponent/mantissa bits + out.u = (value & 0x7FFFU) << 13; + // Exponent adjust + out.f *= magic.f; + + // Make sure Inf/NaN survive + if (out.f >= was_inf_nan.f) { + out.u |= 255UL << 23; + } + + // Sign bit + out.u |= (value & 0x8000UL) << 16; + + return out.f; +} + #if defined _MSC_VER || defined(__MINGW32__) #if !defined(__MINGW32__) @@ -326,42 +355,13 @@ static float table_f32_f16[1 << 16]; // This is also true for POWER9. #if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) -// https://gist.github.com/rygorous/2144712 -// Public domain, by Fabian "ryg" Giesen -inline static float ggml_half_to_float(uint16_t value) { - union FP32 { - uint32_t u; - float f; - }; - const union FP32 magic = { (254UL - 15UL) << 23 }; - const union FP32 was_inf_nan = { (127UL + 16UL) << 23 }; - - union FP32 out; - - // Exponent/mantissa bits - out.u = (value & 0x7FFFU) << 13; - // Exponent adjust - out.f *= magic.f; - - // Make sure Inf/NaN survive - if (out.f >= was_inf_nan.f) { - out.u |= 255UL << 23; - } - - // Sign bit - out.u |= (value & 0x8000UL) << 16; - - return out.f; -} inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { - // For some reason, lookup table does not work on my machine: - // - Windows SDK version 10.0.19041.0 - // - CMAKE_SYSTEM_PROCESSOR: AMD64 - // Replaced lookup with some conversion code found online. + // For some reason, lookup table does not work on my machine. + // Replaced lookup with working reference code. // TODO This must be properly debugged and fixed - return ggml_half_to_float(f); + return ggml_half_to_float_reference(f); } #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) @@ -514,6 +514,19 @@ typedef struct { } block_q4_1; static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding"); +// Method 4 with better outlier handling. +typedef struct { + ggml_fp16_t d; + ggml_fp16_t m; + // We need only 5 bits for the in-block index, so 16 bits is overkill. + // TODO Optimize if possible + uint16_t outlier_index; + ggml_fp16_t outlier_value; + // Nibbles / quants. + uint8_t qs[QK / 2]; +} block_q4_1_o; +static_assert(sizeof(block_q4_1_o) == 8 + QK / 2, "wrong q4_1_o block size/padding"); + // reference implementation for deterministic creation of model files static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { assert(k % QK == 0); @@ -1118,6 +1131,208 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in #endif } +// Q4_1_O + +static inline void quantize_row_q4_1_o_reference_single_block(const float * restrict x, block_q4_1_o * restrict block) { + // An outlier is just the absmax element in the block. + // We store it separately and do not quantize it. + int outlier_index = -1; + float outlier_value = 0.0F; + + for (int l = 0; l < QK; l++) { + const float v = x[l]; + + if (fabsf(v) > fabsf(outlier_value)) { + outlier_index = l; + outlier_value = v; + } + } + + block->outlier_index = outlier_index; + block->outlier_value = GGML_COMPUTE_FP32_TO_FP16(outlier_value); + + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int l = 0; l < QK; l++) { + if (l == outlier_index) { + // Ignore outlier when computing range. + continue; + } + + const float v = x[l]; + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0F / d : 0.0F; + + block->d = GGML_COMPUTE_FP32_TO_FP16(d); + block->m = GGML_COMPUTE_FP32_TO_FP16(min); + + uint8_t pp[QK / 2]; + + for (int l = 0; l < QK; l += 2) { + float v0 = (x[l + 0] - min) * id; + float v1 = (x[l + 1] - min) * id; + + // Write some garbage but valid index for the outlier. + if (l + 0 == outlier_index) v0 = 0.0; + if (l + 1 == outlier_index) v1 = 0.0; + + const uint8_t vi0 = roundf(v0); + const uint8_t vi1 = roundf(v1); + + assert(vi0 >= 0 && vi0 < 16); + assert(vi1 >= 0 && vi1 < 16); + + pp[l/2] = vi0 | (vi1 << 4); + } + + memcpy(block->qs, pp, sizeof(pp)); +} + +static inline void dequantize_row_q4_1_o_reference_single_block(block_q4_1_o * restrict block, float * restrict y) { + const float d = ggml_half_to_float_reference(block->d); + const float m = ggml_half_to_float_reference(block->m); + + const uint8_t * restrict pp = block->qs; + + for (int l = 0; l < QK; l += 2) { + const uint8_t vi = pp[l / 2]; + + const int8_t vi0 = vi & 0xF; + const int8_t vi1 = vi >> 4; + + const float v0 = vi0 * d + m; + const float v1 = vi1 * d + m; + + y[l + 0] = v0; + y[l + 1] = v1; + + assert(!isnan(y[l + 0])); + assert(!isnan(y[l + 1])); + } + + // Restore the outlier + y[block->outlier_index] = ggml_half_to_float_reference(block->outlier_value); +} + +static void quantize_row_q4_1_o_reference(const float * restrict x, void * restrict vy, int k) { + assert(k % QK == 0); + const int nb = k / QK; + + block_q4_1_o * restrict y = vy; + + for (int i = 0; i < nb; i++) { + quantize_row_q4_1_o_reference_single_block(x + i * QK, y + i); + } +} + +static void quantize_row_q4_1_o(const float * restrict x, void * restrict vy, int k) { + quantize_row_q4_1_o_reference(x, vy, k); +} + +static void dequantize_row_q4_1_o(const void * restrict vx, float * restrict y, int k) { + assert(k % QK == 0); + const int nb = k / QK; + + const block_q4_1_o * restrict x = vx; + +#if defined(__AVX2__) + for (int i = 0; i < nb; i++) { + const float x_d = ggml_half_to_float_reference(x[i].d); + const float x_m = ggml_half_to_float_reference(x[i].m); + + const __m256 d_v = _mm256_broadcast_ss(&x_d); + const __m256 d_m = _mm256_broadcast_ss(&x_m); + + const uint8_t * restrict pp = x[i].qs; + + for (int l = 0; l < QK; l += 32) { + // Load 32x4-bit integers into 32x8-bit integers + __m256i vx8 = bytesFromNibbles(pp+l/2); + + // Convert to 16-bit int + const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0)); + const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1)); + + // Convert to 32-bit int -> float 32 + const __m256 vf[4] = { + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1))) + }; + + // Scale, add m and store + for (int j = 0; j < 4; j++) { + const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m); + _mm256_storeu_ps(y + i * QK + l + j*8, result); + } + } + + // Restore the outlier + y[i * QK + x[i].outlier_index] = ggml_half_to_float_reference(x[i].outlier_value); + } +#elif defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + const float x_d = ggml_half_to_float_reference(x[i].d); + const float x_m = ggml_half_to_float_reference(x[i].m); + + const float32x4_t vd = vdupq_n_f32(x_d); + const float32x4_t vm = vdupq_n_f32(x_m); + + const uint8_t * restrict pp = x[i].qs; + + for (int l = 0; l < QK; l += 16) { + // Load 16x4-bit integers into 8x8-bit integers + const uint8x8_t v8 = vld1_u8(pp + l/2); + + // Expand 4-bit qs to 8-bit bytes + const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); + const uint8x8_t v1 = vshr_n_u8(v8, 4); + + // Interleave and combine + const uint8x8_t vx_0 = vzip1_u8(v0, v1); + const uint8x8_t vx_1 = vzip2_u8(v0, v1); + + const uint8x16_t vq = vcombine_u8(vx_0, vx_1); + + // convert to 2x uint16x8_t + const uint16x8_t vi_0 = vmovl_s8(vget_low_u8 (vq)); + const uint16x8_t vi_1 = vmovl_s8(vget_high_u8(vq)); + + // convert to 4x float32x4_t + const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0))); + const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0))); + const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1))); + const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1))); + + // multiply by d and add m + const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd); + const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd); + const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd); + const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd); + + // Store + vst1q_f32(y + i*QK + l + 0, r0); + vst1q_f32(y + i*QK + l + 4, r1); + vst1q_f32(y + i*QK + l + 8, r2); + vst1q_f32(y + i*QK + l + 12, r3); + } + + // Restore the outlier + y[i * QK + x[i].outlier_index] = ggml_half_to_float_reference(x[i].outlier_value); + } +#else + for (int i = 0; i < nb; i++) { + dequantize_row_q4_1_o_reference_single_block(x + i, y + i * QK); + } +#endif +} + // // simd mappings // @@ -2437,6 +2652,7 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x // static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { + QK, QK, QK, 1, @@ -2446,11 +2662,12 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { 1, }; -static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5"); +static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_COUNT != 8"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { sizeof(block_q4_0), sizeof(block_q4_1), + sizeof(block_q4_1_o), sizeof(int8_t ), sizeof(int16_t), sizeof(int32_t), @@ -2459,7 +2676,7 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { }; // don't forget to update the array above when adding new types -static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5"); +static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_COUNT != 8"); static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "NONE", @@ -3196,6 +3413,10 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_1_O: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { assert(tensor->nb[0] == sizeof(int8_t)); @@ -3256,6 +3477,10 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_1_O: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { assert(tensor->nb[0] == sizeof(int8_t)); @@ -3310,6 +3535,10 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_1_O: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -3354,6 +3583,10 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_1_O: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -3396,6 +3629,10 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_1_O: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -3440,6 +3677,10 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_1_O: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -4990,6 +5231,7 @@ static void ggml_compute_forward_dup( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5067,6 +5309,7 @@ static void ggml_compute_forward_add( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5119,6 +5362,7 @@ static void ggml_compute_forward_sub( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5171,6 +5415,7 @@ static void ggml_compute_forward_mul( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5223,6 +5468,7 @@ static void ggml_compute_forward_div( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5271,6 +5517,7 @@ static void ggml_compute_forward_sqr( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5319,6 +5566,7 @@ static void ggml_compute_forward_sqrt( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5377,6 +5625,7 @@ static void ggml_compute_forward_sum( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5454,6 +5703,7 @@ static void ggml_compute_forward_mean( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5518,6 +5768,7 @@ static void ggml_compute_forward_repeat( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5566,6 +5817,7 @@ static void ggml_compute_forward_abs( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5614,6 +5866,7 @@ static void ggml_compute_forward_sgn( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5662,6 +5915,7 @@ static void ggml_compute_forward_neg( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5710,6 +5964,7 @@ static void ggml_compute_forward_exp( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5758,6 +6013,7 @@ static void ggml_compute_forward_1_minus_x( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5810,6 +6066,7 @@ static void ggml_compute_forward_max( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5858,6 +6115,7 @@ static void ggml_compute_forward_step( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5906,6 +6164,7 @@ static void ggml_compute_forward_relu( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5971,6 +6230,7 @@ static void ggml_compute_forward_gelu( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6021,6 +6281,7 @@ static void ggml_compute_forward_sigmoid( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6086,6 +6347,7 @@ static void ggml_compute_forward_silu( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6172,6 +6434,7 @@ static void ggml_compute_forward_norm( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6252,6 +6515,7 @@ static void ggml_compute_forward_rms_norm( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6669,6 +6933,11 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .quantize_row_q = quantize_row_q4_1, .vec_dot_q = ggml_vec_dot_q4_1, }, + [GGML_TYPE_Q4_1_O] = { + .dequantize_row_q = dequantize_row_q4_1_o, + .quantize_row_q = quantize_row_q4_1_o, + .vec_dot_q = NULL, + }, }; static void ggml_compute_forward_mul_mat_q_f32( @@ -6859,6 +7128,273 @@ static void ggml_compute_forward_mul_mat_q_f32( //} } +static void ggml_compute_forward_mul_mat_q4_1_o_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + const enum ggml_type type = src0->type; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]); + GGML_ASSERT(nb10 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + if (params->ith != 0) { + return; + } + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + float * const wdata = params->wdata; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + { + size_t id = 0; + for (int i01 = 0; i01 < ne01; ++i01) { + dequantize_row_q4_1_o((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); + id += ne00; + } + } + + const float * x = wdata; + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // zT = y * xT + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01); + } + } + + //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + + return; + } +#endif + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by src0 rows using ggml_vec_dot_f32 + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + +#if defined(__AVX2__) + float outlier_mask[QK]; + memset(outlier_mask, 0, QK * sizeof(float)); +#endif + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + +#if defined(__AVX2__) + for (int ic = 0; ic < ne11; ++ic) { + // src1 indices + const int i13 = i03; + const int i12 = i02; + const int i11 = ic; + + // dst indices + const int i0 = i01; + const int i1 = i11; + const int i2 = i02; + const int i3 = i03; + + const int block_count = ne00 / QK; + + const block_q4_1_o * row_blocks = (block_q4_1_o *) ((char *) src0->data + (i01 * nb01 + i02 * nb02 + i03 * nb03)); + + __m256 accum = _mm256_setzero_ps(); + + // Here we do fused dequantization and dot product. + for (int block_index = 0; block_index < block_count; block_index++) { + const float block_d = ggml_half_to_float_reference(row_blocks[block_index].d); + const float block_m = ggml_half_to_float_reference(row_blocks[block_index].m); + + // 0 .. 31 + const uint16_t outlier_index = row_blocks[block_index].outlier_index; + const float outlier_value = ggml_half_to_float_reference(row_blocks[block_index].outlier_value); + + const uint8_t * restrict quant_nibbles = row_blocks[block_index].qs; + + // --- + + // Broadcast values to 8x element float32 vectors + const __m256 broadcasted_d = _mm256_broadcast_ss(&block_d); + const __m256 broadcasted_m = _mm256_broadcast_ss(&block_m); + const __m256 broadcasted_outlier_value = _mm256_broadcast_ss(&outlier_value); + + // Load 32x4-bit integers into 32x8-bit integers + const __m256i quant_bytes = bytesFromNibbles(quant_nibbles); + + // Convert to 16-bit int + const __m256i quant_shorts_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(quant_bytes, 0)); + const __m256i quant_shorts_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(quant_bytes, 1)); + + // Convert to 32-bit int and then to 32-bit float + const __m256 quant_floats_0 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_lo, 0))); + const __m256 quant_floats_1 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_lo, 1))); + const __m256 quant_floats_2 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_hi, 0))); + const __m256 quant_floats_3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_hi, 1))); + + // Dequantize to ~original weights + const __m256 weight_0 = _mm256_fmadd_ps(quant_floats_0, broadcasted_d, broadcasted_m); + const __m256 weight_1 = _mm256_fmadd_ps(quant_floats_1, broadcasted_d, broadcasted_m); + const __m256 weight_2 = _mm256_fmadd_ps(quant_floats_2, broadcasted_d, broadcasted_m); + const __m256 weight_3 = _mm256_fmadd_ps(quant_floats_3, broadcasted_d, broadcasted_m); + + // TODO This outlier handling is VERY slow + // Set outlier mask -- this should give 1 in the most significant bit + outlier_mask[outlier_index] = -1.0F; + // Load mask into vectors + const __m256 outlier_mask_0 = _mm256_load_ps(outlier_mask); + const __m256 outlier_mask_1 = _mm256_load_ps(outlier_mask + 8); + const __m256 outlier_mask_2 = _mm256_load_ps(outlier_mask + 16); + const __m256 outlier_mask_3 = _mm256_load_ps(outlier_mask + 24); + // Reset mask array to all zeroes for the next iteration + outlier_mask[outlier_index] = 0.0F; + + // Replace the weight at the index of the outlier + const __m256 weight_0_with_outlier = _mm256_blendv_ps(weight_0, broadcasted_outlier_value, outlier_mask_0); + const __m256 weight_1_with_outlier = _mm256_blendv_ps(weight_1, broadcasted_outlier_value, outlier_mask_1); + const __m256 weight_2_with_outlier = _mm256_blendv_ps(weight_2, broadcasted_outlier_value, outlier_mask_2); + const __m256 weight_3_with_outlier = _mm256_blendv_ps(weight_3, broadcasted_outlier_value, outlier_mask_3); + + // Load 32 floats of data of the second argument + const float * src1_data = (float *) ((char *) src1->data + (block_index * QK * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13)); + + const __m256 src1_0 = _mm256_load_ps(src1_data); + const __m256 src1_1 = _mm256_load_ps(src1_data + 8); + const __m256 src1_2 = _mm256_load_ps(src1_data + 16); + const __m256 src1_3 = _mm256_load_ps(src1_data + 24); + + // Multiply weights and values of the second argument element-wise; add to accumulator + accum = _mm256_fmadd_ps(src1_0, weight_0_with_outlier, accum); + accum = _mm256_fmadd_ps(src1_1, weight_1_with_outlier, accum); + accum = _mm256_fmadd_ps(src1_2, weight_2_with_outlier, accum); + accum = _mm256_fmadd_ps(src1_3, weight_3_with_outlier, accum); + } + + // Add elements of accumulator + __m128 res = _mm256_extractf128_ps(accum, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(accum)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res )); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + + *((float *) ((char *) dst->data + (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3))) = _mm_cvtss_f32(res); + } +#else + float * const wdata = (float *) ((char *) params->wdata + (i01 * nb01 + i02 * nb02 + i03 * nb03)); + + dequantize_row_q4_1_o((char *) src0->data + (i01 * nb01 + i02 * nb02 + i03 * nb03), wdata, ne00); + + for (int ic = 0; ic < ne11; ++ic) { + // src1 indices + const int i13 = i03; + const int i12 = i02; + const int i11 = ic; + + // dst indices + const int i0 = i01; + const int i1 = i11; + const int i2 = i02; + const int i3 = i03; + + ggml_vec_dot_f32( + ne00, + (float *) ((char *) dst->data + (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3)), + wdata, + (float *) ((char *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13)) + ); + } +#endif + } +} + static void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -6870,6 +7406,10 @@ static void ggml_compute_forward_mul_mat( { ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_1_O: + { + ggml_compute_forward_mul_mat_q4_1_o_f32(params, src0, src1, dst); + } break; case GGML_TYPE_F16: { ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst); @@ -6965,6 +7505,7 @@ static void ggml_compute_forward_scale( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7121,6 +7662,7 @@ static void ggml_compute_forward_get_rows( switch (src0->type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: { ggml_compute_forward_get_rows_q(params, src0, src1, dst); } break; @@ -7210,6 +7752,7 @@ static void ggml_compute_forward_diag_mask_inf( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7304,6 +7847,7 @@ static void ggml_compute_forward_soft_max( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7446,6 +7990,7 @@ static void ggml_compute_forward_rope( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7714,6 +8259,7 @@ static void ggml_compute_forward_conv_1d_1s( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7982,6 +8528,7 @@ static void ggml_compute_forward_conv_1d_2s( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -8467,6 +9014,7 @@ static void ggml_compute_forward_flash_attn( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -8678,6 +9226,7 @@ static void ggml_compute_forward_flash_ff( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -9508,6 +10057,14 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) #endif } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { cur = 0; + } else if (node->src0->type == GGML_TYPE_Q4_1_O && node->src1->type == GGML_TYPE_F32) { +#if defined(__AVX2__) + cur = 0; +#else + // Assuming that src1 is a vector + // TODO Not sure whether this is correct + cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * ggml_nelements(node->src1); +#endif } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) { #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { @@ -10729,6 +11286,29 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * return (n/QK*sizeof(block_q4_1)); } +size_t ggml_quantize_q4_1_o(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK == 0); + const int nb = k / QK; + + for (int j = 0; j < n; j += k) { + block_q4_1_o * restrict y = (block_q4_1_o *) dst + j / QK; + + quantize_row_q4_1_o_reference(src + j, y, k); + + for (int i = 0; i < nb; i++) { + for (int l = 0; l < QK; l += 2) { + const uint8_t vi0 = y[i].qs[l / 2] & 0xF; + const uint8_t vi1 = y[i].qs[l / 2] >> 4; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n / QK * sizeof(block_q4_1_o)); +} + //////////////////////////////////////////////////////////////////////////////// int ggml_cpu_has_avx(void) { diff --git a/ggml.h b/ggml.h index 9ca2fe8..03b3369 100644 --- a/ggml.h +++ b/ggml.h @@ -226,7 +226,11 @@ struct ggml_context; enum ggml_type { GGML_TYPE_Q4_0, + // Stores min and delta per block, does quantized matmul. GGML_TYPE_Q4_1, + // Same as Q4_1, but stores outliers separately, and matmul is done in FP32. + // An outlier is the single absmax element in the quantized block. + GGML_TYPE_Q4_1_O, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -807,6 +811,7 @@ enum ggml_opt_result ggml_opt( size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); +size_t ggml_quantize_q4_1_o(const float * src, void * dst, int n, int k, int64_t * hist); // // system info diff --git a/rwkv.cpp b/rwkv.cpp index 08b4ad3..c7fd571 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -160,7 +160,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr model->data_type == 0 || model->data_type == 1 || model->data_type == 2 || - model->data_type == 3, + model->data_type == 3 || + model->data_type == 4, "Unsupported model data type %d", model->data_type ); @@ -216,7 +217,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr data_type == 0 || data_type == 1 || data_type == 2 || - data_type == 3, + data_type == 3 || + data_type == 4, "Unsupported parameter data type %d", data_type ); @@ -228,6 +230,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr case 1: ggml_data_type = GGML_TYPE_F16; break; case 2: ggml_data_type = GGML_TYPE_Q4_0; break; case 3: ggml_data_type = GGML_TYPE_Q4_1; break; + case 4: ggml_data_type = GGML_TYPE_Q4_1_O; break; default: return NULL; } @@ -553,18 +556,17 @@ void rwkv_free(struct rwkv_context * ctx) { } bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) { - RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type); + RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4, "Unsupported quantization type %d", q_type); ggml_type type; switch (q_type) { case 2: type = GGML_TYPE_Q4_0; break; case 3: type = GGML_TYPE_Q4_1; break; + case 4: type = GGML_TYPE_Q4_1_O; break; default: return false; }; - RWKV_ASSERT_FALSE(type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1, "Unsupported data type %d", type); - printf("Loading model from '%s'\n", model_file_path_in); auto finp = std::ifstream(model_file_path_in, std::ios::binary); @@ -646,7 +648,8 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode "f32", "f16", "q4_0", - "q4_1" + "q4_1", + "q4_1_o" }; printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]); @@ -655,6 +658,7 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode 4.0F, 2.0F, 20.0F / 32.0F, + 24.0F / 32.0F, 24.0F / 32.0F }; total_size_orig += (size_t) (nelements * parameter_data_type_size[parameter_data_type]); @@ -668,10 +672,11 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode name != std::string("head.weight"); if (quantize) { - if (parameter_data_type != 0 && parameter_data_type != 1) { - fprintf(stderr, "unsupported data type %d for integer quantization\n", parameter_data_type); - return false; - } + RWKV_ASSERT_FALSE( + parameter_data_type == 0 || parameter_data_type == 1, + "Unsupported parameter data type %d, only FP32 and FP16 can be quantized", + parameter_data_type + ); if (parameter_data_type == 1) { data_f16.resize(nelements); @@ -719,6 +724,10 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode { cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); } break; + case GGML_TYPE_Q4_1_O: + { + cur_size = ggml_quantize_q4_1_o(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + } break; default: { fprintf(stderr, "unsupported quantization type %d\n", type); diff --git a/rwkv/compare_with_reference_implementation.py b/rwkv/compare_with_reference_implementation.py index 69a5828..827dc06 100644 --- a/rwkv/compare_with_reference_implementation.py +++ b/rwkv/compare_with_reference_implementation.py @@ -37,7 +37,8 @@ def main() -> None: assert data_type == 0 or\ data_type == 1 or\ data_type == 2 or\ - data_type == 3, f'Unsupported model data type {data_type}' + data_type == 3 or\ + data_type == 4, f'Unsupported model data type {data_type}' if data_type == 0: # FP32, high precision @@ -46,12 +47,14 @@ def main() -> None: # FP16, lower precision, so higher threshold threshold = 0.0032 elif data_type == 2: - # INT4 quantized, even lower precision, so even higher threshold - # This threshold will let some bugs pass - threshold = 4.0 + # Q4_0 quantized, even lower precision, so even higher threshold + threshold = 0.4 elif data_type == 3: - # This format stores more data, so error would be lower - threshold = 1.2 + # Q4_1 + threshold = 1.21 + elif data_type == 4: + # Q4_1_O + threshold = 0.2 model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path) diff --git a/rwkv/quantize.py b/rwkv/quantize.py index e798855..243dc92 100644 --- a/rwkv/quantize.py +++ b/rwkv/quantize.py @@ -1,5 +1,5 @@ -# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1. -# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3 +# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0, Q4_1 or Q4_1_O (recommended). +# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1_o.bin 4 import argparse import rwkv_cpp_shared_library @@ -8,12 +8,20 @@ def parse_args(): parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1') parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file') parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten') - parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3) + parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0), 3 (GGML_TYPE_Q4_1) or 4 (GGML_TYPE_Q4_1_O)', type=int, choices=[2, 3, 4], default=4) return parser.parse_args() def main() -> None: args = parse_args() + if args.data_type == 2 or args.data_type == 3: + print() + print('WARNING!') + print('You are using Q4_0 or Q4_1 quantization; it will heavily degrade RWKV quality.') + print('For best quality preservation, it is recommended to use Q4_1_O.') + print('More info at https://github.com/saharNooby/rwkv.cpp/issues/12') + print() + library = rwkv_cpp_shared_library.load_rwkv_shared_library() library.rwkv_quantize_model_file( diff --git a/rwkv/rwkv_cpp_model.py b/rwkv/rwkv_cpp_model.py index 70c4258..f7bb32b 100644 --- a/rwkv/rwkv_cpp_model.py +++ b/rwkv/rwkv_cpp_model.py @@ -118,5 +118,5 @@ class RWKVModel: def __del__(self): # Free the context on GC in case user forgot to call free() explicitly. - if self._valid: + if hasattr(self, '_valid') and self._valid: self.free()