Add Q4_1_O format
This commit is contained in:
parent
ec99bc1765
commit
c40941d9d0
648
ggml.c
648
ggml.c
|
@ -25,6 +25,35 @@
|
|||
#define static_assert(cond, msg) struct global_scope_noop_trick
|
||||
#endif
|
||||
|
||||
// https://gist.github.com/rygorous/2144712
|
||||
// Public domain, by Fabian "ryg" Giesen
|
||||
inline static float ggml_half_to_float_reference(uint16_t value) {
|
||||
union FP32 {
|
||||
uint32_t u;
|
||||
float f;
|
||||
};
|
||||
|
||||
const union FP32 magic = { (254UL - 15UL) << 23 };
|
||||
const union FP32 was_inf_nan = { (127UL + 16UL) << 23 };
|
||||
|
||||
union FP32 out;
|
||||
|
||||
// Exponent/mantissa bits
|
||||
out.u = (value & 0x7FFFU) << 13;
|
||||
// Exponent adjust
|
||||
out.f *= magic.f;
|
||||
|
||||
// Make sure Inf/NaN survive
|
||||
if (out.f >= was_inf_nan.f) {
|
||||
out.u |= 255UL << 23;
|
||||
}
|
||||
|
||||
// Sign bit
|
||||
out.u |= (value & 0x8000UL) << 16;
|
||||
|
||||
return out.f;
|
||||
}
|
||||
|
||||
#if defined _MSC_VER || defined(__MINGW32__)
|
||||
|
||||
#if !defined(__MINGW32__)
|
||||
|
@ -326,42 +355,13 @@ static float table_f32_f16[1 << 16];
|
|||
// This is also true for POWER9.
|
||||
#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
|
||||
|
||||
// https://gist.github.com/rygorous/2144712
|
||||
// Public domain, by Fabian "ryg" Giesen
|
||||
inline static float ggml_half_to_float(uint16_t value) {
|
||||
union FP32 {
|
||||
uint32_t u;
|
||||
float f;
|
||||
};
|
||||
|
||||
const union FP32 magic = { (254UL - 15UL) << 23 };
|
||||
const union FP32 was_inf_nan = { (127UL + 16UL) << 23 };
|
||||
|
||||
union FP32 out;
|
||||
|
||||
// Exponent/mantissa bits
|
||||
out.u = (value & 0x7FFFU) << 13;
|
||||
// Exponent adjust
|
||||
out.f *= magic.f;
|
||||
|
||||
// Make sure Inf/NaN survive
|
||||
if (out.f >= was_inf_nan.f) {
|
||||
out.u |= 255UL << 23;
|
||||
}
|
||||
|
||||
// Sign bit
|
||||
out.u |= (value & 0x8000UL) << 16;
|
||||
|
||||
return out.f;
|
||||
}
|
||||
|
||||
inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||
// For some reason, lookup table does not work on my machine:
|
||||
// - Windows SDK version 10.0.19041.0
|
||||
// - CMAKE_SYSTEM_PROCESSOR: AMD64
|
||||
// Replaced lookup with some conversion code found online.
|
||||
// For some reason, lookup table does not work on my machine.
|
||||
// Replaced lookup with working reference code.
|
||||
// TODO This must be properly debugged and fixed
|
||||
return ggml_half_to_float(f);
|
||||
return ggml_half_to_float_reference(f);
|
||||
}
|
||||
|
||||
#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
|
||||
|
@ -514,6 +514,19 @@ typedef struct {
|
|||
} block_q4_1;
|
||||
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding");
|
||||
|
||||
// Method 4 with better outlier handling.
|
||||
typedef struct {
|
||||
ggml_fp16_t d;
|
||||
ggml_fp16_t m;
|
||||
// We need only 5 bits for the in-block index, so 16 bits is overkill.
|
||||
// TODO Optimize if possible
|
||||
uint16_t outlier_index;
|
||||
ggml_fp16_t outlier_value;
|
||||
// Nibbles / quants.
|
||||
uint8_t qs[QK / 2];
|
||||
} block_q4_1_o;
|
||||
static_assert(sizeof(block_q4_1_o) == 8 + QK / 2, "wrong q4_1_o block size/padding");
|
||||
|
||||
// reference implementation for deterministic creation of model files
|
||||
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
|
||||
assert(k % QK == 0);
|
||||
|
@ -1118,6 +1131,208 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|||
#endif
|
||||
}
|
||||
|
||||
// Q4_1_O
|
||||
|
||||
static inline void quantize_row_q4_1_o_reference_single_block(const float * restrict x, block_q4_1_o * restrict block) {
|
||||
// An outlier is just the absmax element in the block.
|
||||
// We store it separately and do not quantize it.
|
||||
int outlier_index = -1;
|
||||
float outlier_value = 0.0F;
|
||||
|
||||
for (int l = 0; l < QK; l++) {
|
||||
const float v = x[l];
|
||||
|
||||
if (fabsf(v) > fabsf(outlier_value)) {
|
||||
outlier_index = l;
|
||||
outlier_value = v;
|
||||
}
|
||||
}
|
||||
|
||||
block->outlier_index = outlier_index;
|
||||
block->outlier_value = GGML_COMPUTE_FP32_TO_FP16(outlier_value);
|
||||
|
||||
float min = FLT_MAX;
|
||||
float max = -FLT_MAX;
|
||||
|
||||
for (int l = 0; l < QK; l++) {
|
||||
if (l == outlier_index) {
|
||||
// Ignore outlier when computing range.
|
||||
continue;
|
||||
}
|
||||
|
||||
const float v = x[l];
|
||||
if (v < min) min = v;
|
||||
if (v > max) max = v;
|
||||
}
|
||||
|
||||
const float d = (max - min) / ((1 << 4) - 1);
|
||||
const float id = d ? 1.0F / d : 0.0F;
|
||||
|
||||
block->d = GGML_COMPUTE_FP32_TO_FP16(d);
|
||||
block->m = GGML_COMPUTE_FP32_TO_FP16(min);
|
||||
|
||||
uint8_t pp[QK / 2];
|
||||
|
||||
for (int l = 0; l < QK; l += 2) {
|
||||
float v0 = (x[l + 0] - min) * id;
|
||||
float v1 = (x[l + 1] - min) * id;
|
||||
|
||||
// Write some garbage but valid index for the outlier.
|
||||
if (l + 0 == outlier_index) v0 = 0.0;
|
||||
if (l + 1 == outlier_index) v1 = 0.0;
|
||||
|
||||
const uint8_t vi0 = roundf(v0);
|
||||
const uint8_t vi1 = roundf(v1);
|
||||
|
||||
assert(vi0 >= 0 && vi0 < 16);
|
||||
assert(vi1 >= 0 && vi1 < 16);
|
||||
|
||||
pp[l/2] = vi0 | (vi1 << 4);
|
||||
}
|
||||
|
||||
memcpy(block->qs, pp, sizeof(pp));
|
||||
}
|
||||
|
||||
static inline void dequantize_row_q4_1_o_reference_single_block(block_q4_1_o * restrict block, float * restrict y) {
|
||||
const float d = ggml_half_to_float_reference(block->d);
|
||||
const float m = ggml_half_to_float_reference(block->m);
|
||||
|
||||
const uint8_t * restrict pp = block->qs;
|
||||
|
||||
for (int l = 0; l < QK; l += 2) {
|
||||
const uint8_t vi = pp[l / 2];
|
||||
|
||||
const int8_t vi0 = vi & 0xF;
|
||||
const int8_t vi1 = vi >> 4;
|
||||
|
||||
const float v0 = vi0 * d + m;
|
||||
const float v1 = vi1 * d + m;
|
||||
|
||||
y[l + 0] = v0;
|
||||
y[l + 1] = v1;
|
||||
|
||||
assert(!isnan(y[l + 0]));
|
||||
assert(!isnan(y[l + 1]));
|
||||
}
|
||||
|
||||
// Restore the outlier
|
||||
y[block->outlier_index] = ggml_half_to_float_reference(block->outlier_value);
|
||||
}
|
||||
|
||||
static void quantize_row_q4_1_o_reference(const float * restrict x, void * restrict vy, int k) {
|
||||
assert(k % QK == 0);
|
||||
const int nb = k / QK;
|
||||
|
||||
block_q4_1_o * restrict y = vy;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
quantize_row_q4_1_o_reference_single_block(x + i * QK, y + i);
|
||||
}
|
||||
}
|
||||
|
||||
static void quantize_row_q4_1_o(const float * restrict x, void * restrict vy, int k) {
|
||||
quantize_row_q4_1_o_reference(x, vy, k);
|
||||
}
|
||||
|
||||
static void dequantize_row_q4_1_o(const void * restrict vx, float * restrict y, int k) {
|
||||
assert(k % QK == 0);
|
||||
const int nb = k / QK;
|
||||
|
||||
const block_q4_1_o * restrict x = vx;
|
||||
|
||||
#if defined(__AVX2__)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float x_d = ggml_half_to_float_reference(x[i].d);
|
||||
const float x_m = ggml_half_to_float_reference(x[i].m);
|
||||
|
||||
const __m256 d_v = _mm256_broadcast_ss(&x_d);
|
||||
const __m256 d_m = _mm256_broadcast_ss(&x_m);
|
||||
|
||||
const uint8_t * restrict pp = x[i].qs;
|
||||
|
||||
for (int l = 0; l < QK; l += 32) {
|
||||
// Load 32x4-bit integers into 32x8-bit integers
|
||||
__m256i vx8 = bytesFromNibbles(pp+l/2);
|
||||
|
||||
// Convert to 16-bit int
|
||||
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
|
||||
const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
|
||||
|
||||
// Convert to 32-bit int -> float 32
|
||||
const __m256 vf[4] = {
|
||||
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
|
||||
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
|
||||
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
|
||||
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
|
||||
};
|
||||
|
||||
// Scale, add m and store
|
||||
for (int j = 0; j < 4; j++) {
|
||||
const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
|
||||
_mm256_storeu_ps(y + i * QK + l + j*8, result);
|
||||
}
|
||||
}
|
||||
|
||||
// Restore the outlier
|
||||
y[i * QK + x[i].outlier_index] = ggml_half_to_float_reference(x[i].outlier_value);
|
||||
}
|
||||
#elif defined(__ARM_NEON)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float x_d = ggml_half_to_float_reference(x[i].d);
|
||||
const float x_m = ggml_half_to_float_reference(x[i].m);
|
||||
|
||||
const float32x4_t vd = vdupq_n_f32(x_d);
|
||||
const float32x4_t vm = vdupq_n_f32(x_m);
|
||||
|
||||
const uint8_t * restrict pp = x[i].qs;
|
||||
|
||||
for (int l = 0; l < QK; l += 16) {
|
||||
// Load 16x4-bit integers into 8x8-bit integers
|
||||
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
||||
|
||||
// Expand 4-bit qs to 8-bit bytes
|
||||
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
|
||||
const uint8x8_t v1 = vshr_n_u8(v8, 4);
|
||||
|
||||
// Interleave and combine
|
||||
const uint8x8_t vx_0 = vzip1_u8(v0, v1);
|
||||
const uint8x8_t vx_1 = vzip2_u8(v0, v1);
|
||||
|
||||
const uint8x16_t vq = vcombine_u8(vx_0, vx_1);
|
||||
|
||||
// convert to 2x uint16x8_t
|
||||
const uint16x8_t vi_0 = vmovl_s8(vget_low_u8 (vq));
|
||||
const uint16x8_t vi_1 = vmovl_s8(vget_high_u8(vq));
|
||||
|
||||
// convert to 4x float32x4_t
|
||||
const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0)));
|
||||
const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0)));
|
||||
const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1)));
|
||||
const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1)));
|
||||
|
||||
// multiply by d and add m
|
||||
const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd);
|
||||
const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd);
|
||||
const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd);
|
||||
const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
|
||||
|
||||
// Store
|
||||
vst1q_f32(y + i*QK + l + 0, r0);
|
||||
vst1q_f32(y + i*QK + l + 4, r1);
|
||||
vst1q_f32(y + i*QK + l + 8, r2);
|
||||
vst1q_f32(y + i*QK + l + 12, r3);
|
||||
}
|
||||
|
||||
// Restore the outlier
|
||||
y[i * QK + x[i].outlier_index] = ggml_half_to_float_reference(x[i].outlier_value);
|
||||
}
|
||||
#else
|
||||
for (int i = 0; i < nb; i++) {
|
||||
dequantize_row_q4_1_o_reference_single_block(x + i, y + i * QK);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
// simd mappings
|
||||
//
|
||||
|
@ -2437,6 +2652,7 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
|
|||
//
|
||||
|
||||
static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
|
||||
QK,
|
||||
QK,
|
||||
QK,
|
||||
1,
|
||||
|
@ -2446,11 +2662,12 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
|
|||
1,
|
||||
};
|
||||
|
||||
static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
|
||||
static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_COUNT != 8");
|
||||
|
||||
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
||||
sizeof(block_q4_0),
|
||||
sizeof(block_q4_1),
|
||||
sizeof(block_q4_1_o),
|
||||
sizeof(int8_t ),
|
||||
sizeof(int16_t),
|
||||
sizeof(int32_t),
|
||||
|
@ -2459,7 +2676,7 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
|||
};
|
||||
|
||||
// don't forget to update the array above when adding new types
|
||||
static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
|
||||
static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_COUNT != 8");
|
||||
|
||||
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
||||
"NONE",
|
||||
|
@ -3196,6 +3413,10 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
|
|||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
case GGML_TYPE_I8:
|
||||
{
|
||||
assert(tensor->nb[0] == sizeof(int8_t));
|
||||
|
@ -3256,6 +3477,10 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
case GGML_TYPE_I8:
|
||||
{
|
||||
assert(tensor->nb[0] == sizeof(int8_t));
|
||||
|
@ -3310,6 +3535,10 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
|||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
case GGML_TYPE_I8:
|
||||
{
|
||||
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
||||
|
@ -3354,6 +3583,10 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
|||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
case GGML_TYPE_I8:
|
||||
{
|
||||
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
||||
|
@ -3396,6 +3629,10 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
|||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
case GGML_TYPE_I8:
|
||||
{
|
||||
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
||||
|
@ -3440,6 +3677,10 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
|
|||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
case GGML_TYPE_I8:
|
||||
{
|
||||
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
||||
|
@ -4990,6 +5231,7 @@ static void ggml_compute_forward_dup(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5067,6 +5309,7 @@ static void ggml_compute_forward_add(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5119,6 +5362,7 @@ static void ggml_compute_forward_sub(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5171,6 +5415,7 @@ static void ggml_compute_forward_mul(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5223,6 +5468,7 @@ static void ggml_compute_forward_div(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5271,6 +5517,7 @@ static void ggml_compute_forward_sqr(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5319,6 +5566,7 @@ static void ggml_compute_forward_sqrt(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5377,6 +5625,7 @@ static void ggml_compute_forward_sum(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5454,6 +5703,7 @@ static void ggml_compute_forward_mean(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5518,6 +5768,7 @@ static void ggml_compute_forward_repeat(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5566,6 +5817,7 @@ static void ggml_compute_forward_abs(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5614,6 +5866,7 @@ static void ggml_compute_forward_sgn(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5662,6 +5915,7 @@ static void ggml_compute_forward_neg(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5710,6 +5964,7 @@ static void ggml_compute_forward_exp(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5758,6 +6013,7 @@ static void ggml_compute_forward_1_minus_x(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5810,6 +6066,7 @@ static void ggml_compute_forward_max(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5858,6 +6115,7 @@ static void ggml_compute_forward_step(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5906,6 +6164,7 @@ static void ggml_compute_forward_relu(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -5971,6 +6230,7 @@ static void ggml_compute_forward_gelu(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -6021,6 +6281,7 @@ static void ggml_compute_forward_sigmoid(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -6086,6 +6347,7 @@ static void ggml_compute_forward_silu(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -6172,6 +6434,7 @@ static void ggml_compute_forward_norm(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -6252,6 +6515,7 @@ static void ggml_compute_forward_rms_norm(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -6669,6 +6933,11 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
|||
.quantize_row_q = quantize_row_q4_1,
|
||||
.vec_dot_q = ggml_vec_dot_q4_1,
|
||||
},
|
||||
[GGML_TYPE_Q4_1_O] = {
|
||||
.dequantize_row_q = dequantize_row_q4_1_o,
|
||||
.quantize_row_q = quantize_row_q4_1_o,
|
||||
.vec_dot_q = NULL,
|
||||
},
|
||||
};
|
||||
|
||||
static void ggml_compute_forward_mul_mat_q_f32(
|
||||
|
@ -6859,6 +7128,273 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||
//}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_mul_mat_q4_1_o_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
int64_t t0 = ggml_perf_time_us();
|
||||
UNUSED(t0);
|
||||
|
||||
const int ne00 = src0->ne[0];
|
||||
const int ne01 = src0->ne[1];
|
||||
const int ne02 = src0->ne[2];
|
||||
const int ne03 = src0->ne[3];
|
||||
|
||||
const int ne10 = src1->ne[0];
|
||||
const int ne11 = src1->ne[1];
|
||||
const int ne12 = src1->ne[2];
|
||||
const int ne13 = src1->ne[3];
|
||||
|
||||
const int ne0 = dst->ne[0];
|
||||
const int ne1 = dst->ne[1];
|
||||
const int ne2 = dst->ne[2];
|
||||
const int ne3 = dst->ne[3];
|
||||
|
||||
const int nb00 = src0->nb[0];
|
||||
const int nb01 = src0->nb[1];
|
||||
const int nb02 = src0->nb[2];
|
||||
const int nb03 = src0->nb[3];
|
||||
|
||||
const int nb10 = src1->nb[0];
|
||||
const int nb11 = src1->nb[1];
|
||||
const int nb12 = src1->nb[2];
|
||||
const int nb13 = src1->nb[3];
|
||||
|
||||
const int nb0 = dst->nb[0];
|
||||
const int nb1 = dst->nb[1];
|
||||
const int nb2 = dst->nb[2];
|
||||
const int nb3 = dst->nb[3];
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_ASSERT(ne02 == ne12);
|
||||
GGML_ASSERT(ne03 == ne13);
|
||||
GGML_ASSERT(ne2 == ne12);
|
||||
GGML_ASSERT(ne3 == ne13);
|
||||
|
||||
const enum ggml_type type = src0->type;
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
GGML_ASSERT(ne0 == ne01);
|
||||
GGML_ASSERT(ne1 == ne11);
|
||||
GGML_ASSERT(ne2 == ne02);
|
||||
GGML_ASSERT(ne3 == ne03);
|
||||
|
||||
// nb01 >= nb00 - src0 is not transposed
|
||||
// compute by src0 rows
|
||||
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
||||
if (params->ith != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (params->type == GGML_TASK_INIT) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
float * const wdata = params->wdata;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
{
|
||||
size_t id = 0;
|
||||
for (int i01 = 0; i01 < ne01; ++i01) {
|
||||
dequantize_row_q4_1_o((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
|
||||
id += ne00;
|
||||
}
|
||||
}
|
||||
|
||||
const float * x = wdata;
|
||||
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
||||
|
||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||
|
||||
// zT = y * xT
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||
ne11, ne01, ne10,
|
||||
1.0f, y, ne10,
|
||||
x, ne10,
|
||||
0.0f, d, ne01);
|
||||
}
|
||||
}
|
||||
|
||||
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
||||
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (params->type == GGML_TASK_INIT) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
// parallelize by src0 rows using ggml_vec_dot_f32
|
||||
|
||||
// total rows in src0
|
||||
const int nr = ne01*ne02*ne03;
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
#if defined(__AVX2__)
|
||||
float outlier_mask[QK];
|
||||
memset(outlier_mask, 0, QK * sizeof(float));
|
||||
#endif
|
||||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// src0 indices
|
||||
const int i03 = ir/(ne02*ne01);
|
||||
const int i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
#if defined(__AVX2__)
|
||||
for (int ic = 0; ic < ne11; ++ic) {
|
||||
// src1 indices
|
||||
const int i13 = i03;
|
||||
const int i12 = i02;
|
||||
const int i11 = ic;
|
||||
|
||||
// dst indices
|
||||
const int i0 = i01;
|
||||
const int i1 = i11;
|
||||
const int i2 = i02;
|
||||
const int i3 = i03;
|
||||
|
||||
const int block_count = ne00 / QK;
|
||||
|
||||
const block_q4_1_o * row_blocks = (block_q4_1_o *) ((char *) src0->data + (i01 * nb01 + i02 * nb02 + i03 * nb03));
|
||||
|
||||
__m256 accum = _mm256_setzero_ps();
|
||||
|
||||
// Here we do fused dequantization and dot product.
|
||||
for (int block_index = 0; block_index < block_count; block_index++) {
|
||||
const float block_d = ggml_half_to_float_reference(row_blocks[block_index].d);
|
||||
const float block_m = ggml_half_to_float_reference(row_blocks[block_index].m);
|
||||
|
||||
// 0 .. 31
|
||||
const uint16_t outlier_index = row_blocks[block_index].outlier_index;
|
||||
const float outlier_value = ggml_half_to_float_reference(row_blocks[block_index].outlier_value);
|
||||
|
||||
const uint8_t * restrict quant_nibbles = row_blocks[block_index].qs;
|
||||
|
||||
// ---
|
||||
|
||||
// Broadcast values to 8x element float32 vectors
|
||||
const __m256 broadcasted_d = _mm256_broadcast_ss(&block_d);
|
||||
const __m256 broadcasted_m = _mm256_broadcast_ss(&block_m);
|
||||
const __m256 broadcasted_outlier_value = _mm256_broadcast_ss(&outlier_value);
|
||||
|
||||
// Load 32x4-bit integers into 32x8-bit integers
|
||||
const __m256i quant_bytes = bytesFromNibbles(quant_nibbles);
|
||||
|
||||
// Convert to 16-bit int
|
||||
const __m256i quant_shorts_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(quant_bytes, 0));
|
||||
const __m256i quant_shorts_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(quant_bytes, 1));
|
||||
|
||||
// Convert to 32-bit int and then to 32-bit float
|
||||
const __m256 quant_floats_0 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_lo, 0)));
|
||||
const __m256 quant_floats_1 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_lo, 1)));
|
||||
const __m256 quant_floats_2 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_hi, 0)));
|
||||
const __m256 quant_floats_3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_hi, 1)));
|
||||
|
||||
// Dequantize to ~original weights
|
||||
const __m256 weight_0 = _mm256_fmadd_ps(quant_floats_0, broadcasted_d, broadcasted_m);
|
||||
const __m256 weight_1 = _mm256_fmadd_ps(quant_floats_1, broadcasted_d, broadcasted_m);
|
||||
const __m256 weight_2 = _mm256_fmadd_ps(quant_floats_2, broadcasted_d, broadcasted_m);
|
||||
const __m256 weight_3 = _mm256_fmadd_ps(quant_floats_3, broadcasted_d, broadcasted_m);
|
||||
|
||||
// TODO This outlier handling is VERY slow
|
||||
// Set outlier mask -- this should give 1 in the most significant bit
|
||||
outlier_mask[outlier_index] = -1.0F;
|
||||
// Load mask into vectors
|
||||
const __m256 outlier_mask_0 = _mm256_load_ps(outlier_mask);
|
||||
const __m256 outlier_mask_1 = _mm256_load_ps(outlier_mask + 8);
|
||||
const __m256 outlier_mask_2 = _mm256_load_ps(outlier_mask + 16);
|
||||
const __m256 outlier_mask_3 = _mm256_load_ps(outlier_mask + 24);
|
||||
// Reset mask array to all zeroes for the next iteration
|
||||
outlier_mask[outlier_index] = 0.0F;
|
||||
|
||||
// Replace the weight at the index of the outlier
|
||||
const __m256 weight_0_with_outlier = _mm256_blendv_ps(weight_0, broadcasted_outlier_value, outlier_mask_0);
|
||||
const __m256 weight_1_with_outlier = _mm256_blendv_ps(weight_1, broadcasted_outlier_value, outlier_mask_1);
|
||||
const __m256 weight_2_with_outlier = _mm256_blendv_ps(weight_2, broadcasted_outlier_value, outlier_mask_2);
|
||||
const __m256 weight_3_with_outlier = _mm256_blendv_ps(weight_3, broadcasted_outlier_value, outlier_mask_3);
|
||||
|
||||
// Load 32 floats of data of the second argument
|
||||
const float * src1_data = (float *) ((char *) src1->data + (block_index * QK * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13));
|
||||
|
||||
const __m256 src1_0 = _mm256_load_ps(src1_data);
|
||||
const __m256 src1_1 = _mm256_load_ps(src1_data + 8);
|
||||
const __m256 src1_2 = _mm256_load_ps(src1_data + 16);
|
||||
const __m256 src1_3 = _mm256_load_ps(src1_data + 24);
|
||||
|
||||
// Multiply weights and values of the second argument element-wise; add to accumulator
|
||||
accum = _mm256_fmadd_ps(src1_0, weight_0_with_outlier, accum);
|
||||
accum = _mm256_fmadd_ps(src1_1, weight_1_with_outlier, accum);
|
||||
accum = _mm256_fmadd_ps(src1_2, weight_2_with_outlier, accum);
|
||||
accum = _mm256_fmadd_ps(src1_3, weight_3_with_outlier, accum);
|
||||
}
|
||||
|
||||
// Add elements of accumulator
|
||||
__m128 res = _mm256_extractf128_ps(accum, 1);
|
||||
res = _mm_add_ps(res, _mm256_castps256_ps128(accum));
|
||||
res = _mm_add_ps(res, _mm_movehl_ps(res, res ));
|
||||
res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
||||
|
||||
*((float *) ((char *) dst->data + (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3))) = _mm_cvtss_f32(res);
|
||||
}
|
||||
#else
|
||||
float * const wdata = (float *) ((char *) params->wdata + (i01 * nb01 + i02 * nb02 + i03 * nb03));
|
||||
|
||||
dequantize_row_q4_1_o((char *) src0->data + (i01 * nb01 + i02 * nb02 + i03 * nb03), wdata, ne00);
|
||||
|
||||
for (int ic = 0; ic < ne11; ++ic) {
|
||||
// src1 indices
|
||||
const int i13 = i03;
|
||||
const int i12 = i02;
|
||||
const int i11 = ic;
|
||||
|
||||
// dst indices
|
||||
const int i0 = i01;
|
||||
const int i1 = i11;
|
||||
const int i2 = i02;
|
||||
const int i3 = i03;
|
||||
|
||||
ggml_vec_dot_f32(
|
||||
ne00,
|
||||
(float *) ((char *) dst->data + (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3)),
|
||||
wdata,
|
||||
(float *) ((char *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13))
|
||||
);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_mul_mat(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
|
@ -6870,6 +7406,10 @@ static void ggml_compute_forward_mul_mat(
|
|||
{
|
||||
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
{
|
||||
ggml_compute_forward_mul_mat_q4_1_o_f32(params, src0, src1, dst);
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst);
|
||||
|
@ -6965,6 +7505,7 @@ static void ggml_compute_forward_scale(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -7121,6 +7662,7 @@ static void ggml_compute_forward_get_rows(
|
|||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
{
|
||||
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
||||
} break;
|
||||
|
@ -7210,6 +7752,7 @@ static void ggml_compute_forward_diag_mask_inf(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -7304,6 +7847,7 @@ static void ggml_compute_forward_soft_max(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -7446,6 +7990,7 @@ static void ggml_compute_forward_rope(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -7714,6 +8259,7 @@ static void ggml_compute_forward_conv_1d_1s(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -7982,6 +8528,7 @@ static void ggml_compute_forward_conv_1d_2s(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -8467,6 +9014,7 @@ static void ggml_compute_forward_flash_attn(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -8678,6 +9226,7 @@ static void ggml_compute_forward_flash_ff(
|
|||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -9508,6 +10057,14 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||
#endif
|
||||
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
|
||||
cur = 0;
|
||||
} else if (node->src0->type == GGML_TYPE_Q4_1_O && node->src1->type == GGML_TYPE_F32) {
|
||||
#if defined(__AVX2__)
|
||||
cur = 0;
|
||||
#else
|
||||
// Assuming that src1 is a vector
|
||||
// TODO Not sure whether this is correct
|
||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * ggml_nelements(node->src1);
|
||||
#endif
|
||||
} else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
||||
|
@ -10729,6 +11286,29 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
|
|||
return (n/QK*sizeof(block_q4_1));
|
||||
}
|
||||
|
||||
size_t ggml_quantize_q4_1_o(const float * src, void * dst, int n, int k, int64_t * hist) {
|
||||
assert(k % QK == 0);
|
||||
const int nb = k / QK;
|
||||
|
||||
for (int j = 0; j < n; j += k) {
|
||||
block_q4_1_o * restrict y = (block_q4_1_o *) dst + j / QK;
|
||||
|
||||
quantize_row_q4_1_o_reference(src + j, y, k);
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
for (int l = 0; l < QK; l += 2) {
|
||||
const uint8_t vi0 = y[i].qs[l / 2] & 0xF;
|
||||
const uint8_t vi1 = y[i].qs[l / 2] >> 4;
|
||||
|
||||
hist[vi0]++;
|
||||
hist[vi1]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (n / QK * sizeof(block_q4_1_o));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int ggml_cpu_has_avx(void) {
|
||||
|
|
5
ggml.h
5
ggml.h
|
@ -226,7 +226,11 @@ struct ggml_context;
|
|||
|
||||
enum ggml_type {
|
||||
GGML_TYPE_Q4_0,
|
||||
// Stores min and delta per block, does quantized matmul.
|
||||
GGML_TYPE_Q4_1,
|
||||
// Same as Q4_1, but stores outliers separately, and matmul is done in FP32.
|
||||
// An outlier is the single absmax element in the quantized block.
|
||||
GGML_TYPE_Q4_1_O,
|
||||
GGML_TYPE_I8,
|
||||
GGML_TYPE_I16,
|
||||
GGML_TYPE_I32,
|
||||
|
@ -807,6 +811,7 @@ enum ggml_opt_result ggml_opt(
|
|||
|
||||
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
|
||||
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
|
||||
size_t ggml_quantize_q4_1_o(const float * src, void * dst, int n, int k, int64_t * hist);
|
||||
|
||||
//
|
||||
// system info
|
||||
|
|
29
rwkv.cpp
29
rwkv.cpp
|
@ -160,7 +160,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
|
|||
model->data_type == 0 ||
|
||||
model->data_type == 1 ||
|
||||
model->data_type == 2 ||
|
||||
model->data_type == 3,
|
||||
model->data_type == 3 ||
|
||||
model->data_type == 4,
|
||||
"Unsupported model data type %d",
|
||||
model->data_type
|
||||
);
|
||||
|
@ -216,7 +217,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
|
|||
data_type == 0 ||
|
||||
data_type == 1 ||
|
||||
data_type == 2 ||
|
||||
data_type == 3,
|
||||
data_type == 3 ||
|
||||
data_type == 4,
|
||||
"Unsupported parameter data type %d",
|
||||
data_type
|
||||
);
|
||||
|
@ -228,6 +230,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
|
|||
case 1: ggml_data_type = GGML_TYPE_F16; break;
|
||||
case 2: ggml_data_type = GGML_TYPE_Q4_0; break;
|
||||
case 3: ggml_data_type = GGML_TYPE_Q4_1; break;
|
||||
case 4: ggml_data_type = GGML_TYPE_Q4_1_O; break;
|
||||
default: return NULL;
|
||||
}
|
||||
|
||||
|
@ -553,18 +556,17 @@ void rwkv_free(struct rwkv_context * ctx) {
|
|||
}
|
||||
|
||||
bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) {
|
||||
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type);
|
||||
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4, "Unsupported quantization type %d", q_type);
|
||||
|
||||
ggml_type type;
|
||||
|
||||
switch (q_type) {
|
||||
case 2: type = GGML_TYPE_Q4_0; break;
|
||||
case 3: type = GGML_TYPE_Q4_1; break;
|
||||
case 4: type = GGML_TYPE_Q4_1_O; break;
|
||||
default: return false;
|
||||
};
|
||||
|
||||
RWKV_ASSERT_FALSE(type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1, "Unsupported data type %d", type);
|
||||
|
||||
printf("Loading model from '%s'\n", model_file_path_in);
|
||||
|
||||
auto finp = std::ifstream(model_file_path_in, std::ios::binary);
|
||||
|
@ -646,7 +648,8 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
|
|||
"f32",
|
||||
"f16",
|
||||
"q4_0",
|
||||
"q4_1"
|
||||
"q4_1",
|
||||
"q4_1_o"
|
||||
};
|
||||
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]);
|
||||
|
||||
|
@ -655,6 +658,7 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
|
|||
4.0F,
|
||||
2.0F,
|
||||
20.0F / 32.0F,
|
||||
24.0F / 32.0F,
|
||||
24.0F / 32.0F
|
||||
};
|
||||
total_size_orig += (size_t) (nelements * parameter_data_type_size[parameter_data_type]);
|
||||
|
@ -668,10 +672,11 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
|
|||
name != std::string("head.weight");
|
||||
|
||||
if (quantize) {
|
||||
if (parameter_data_type != 0 && parameter_data_type != 1) {
|
||||
fprintf(stderr, "unsupported data type %d for integer quantization\n", parameter_data_type);
|
||||
return false;
|
||||
}
|
||||
RWKV_ASSERT_FALSE(
|
||||
parameter_data_type == 0 || parameter_data_type == 1,
|
||||
"Unsupported parameter data type %d, only FP32 and FP16 can be quantized",
|
||||
parameter_data_type
|
||||
);
|
||||
|
||||
if (parameter_data_type == 1) {
|
||||
data_f16.resize(nelements);
|
||||
|
@ -719,6 +724,10 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
|
|||
{
|
||||
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1_O:
|
||||
{
|
||||
cur_size = ggml_quantize_q4_1_o(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "unsupported quantization type %d\n", type);
|
||||
|
|
|
@ -37,7 +37,8 @@ def main() -> None:
|
|||
assert data_type == 0 or\
|
||||
data_type == 1 or\
|
||||
data_type == 2 or\
|
||||
data_type == 3, f'Unsupported model data type {data_type}'
|
||||
data_type == 3 or\
|
||||
data_type == 4, f'Unsupported model data type {data_type}'
|
||||
|
||||
if data_type == 0:
|
||||
# FP32, high precision
|
||||
|
@ -46,12 +47,14 @@ def main() -> None:
|
|||
# FP16, lower precision, so higher threshold
|
||||
threshold = 0.0032
|
||||
elif data_type == 2:
|
||||
# INT4 quantized, even lower precision, so even higher threshold
|
||||
# This threshold will let some bugs pass
|
||||
threshold = 4.0
|
||||
# Q4_0 quantized, even lower precision, so even higher threshold
|
||||
threshold = 0.4
|
||||
elif data_type == 3:
|
||||
# This format stores more data, so error would be lower
|
||||
threshold = 1.2
|
||||
# Q4_1
|
||||
threshold = 1.21
|
||||
elif data_type == 4:
|
||||
# Q4_1_O
|
||||
threshold = 0.2
|
||||
|
||||
model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path)
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1.
|
||||
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3
|
||||
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0, Q4_1 or Q4_1_O (recommended).
|
||||
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1_o.bin 4
|
||||
|
||||
import argparse
|
||||
import rwkv_cpp_shared_library
|
||||
|
@ -8,12 +8,20 @@ def parse_args():
|
|||
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1')
|
||||
parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
|
||||
parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
|
||||
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3)
|
||||
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0), 3 (GGML_TYPE_Q4_1) or 4 (GGML_TYPE_Q4_1_O)', type=int, choices=[2, 3, 4], default=4)
|
||||
return parser.parse_args()
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if args.data_type == 2 or args.data_type == 3:
|
||||
print()
|
||||
print('WARNING!')
|
||||
print('You are using Q4_0 or Q4_1 quantization; it will heavily degrade RWKV quality.')
|
||||
print('For best quality preservation, it is recommended to use Q4_1_O.')
|
||||
print('More info at https://github.com/saharNooby/rwkv.cpp/issues/12')
|
||||
print()
|
||||
|
||||
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||
|
||||
library.rwkv_quantize_model_file(
|
||||
|
|
|
@ -118,5 +118,5 @@ class RWKVModel:
|
|||
|
||||
def __del__(self):
|
||||
# Free the context on GC in case user forgot to call free() explicitly.
|
||||
if self._valid:
|
||||
if hasattr(self, '_valid') and self._valid:
|
||||
self.free()
|
||||
|
|
Loading…
Reference in New Issue