Add Q4_1_O format
This commit is contained in:
		
							parent
							
								
									ec99bc1765
								
							
						
					
					
						commit
						c40941d9d0
					
				
							
								
								
									
										648
									
								
								ggml.c
								
								
								
								
							
							
						
						
									
										648
									
								
								ggml.c
								
								
								
								
							|  | @ -25,6 +25,35 @@ | ||||||
| #define static_assert(cond, msg) struct global_scope_noop_trick | #define static_assert(cond, msg) struct global_scope_noop_trick | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
|  | // https://gist.github.com/rygorous/2144712
 | ||||||
|  | // Public domain, by Fabian "ryg" Giesen
 | ||||||
|  | inline static float ggml_half_to_float_reference(uint16_t value) { | ||||||
|  |     union FP32 { | ||||||
|  |         uint32_t u; | ||||||
|  |         float f; | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     const union FP32 magic = { (254UL - 15UL) << 23 }; | ||||||
|  |     const union FP32 was_inf_nan = { (127UL + 16UL) << 23 }; | ||||||
|  | 
 | ||||||
|  |     union FP32 out; | ||||||
|  | 
 | ||||||
|  |     // Exponent/mantissa bits
 | ||||||
|  |     out.u = (value & 0x7FFFU) << 13; | ||||||
|  |     // Exponent adjust
 | ||||||
|  |     out.f *= magic.f; | ||||||
|  | 
 | ||||||
|  |     // Make sure Inf/NaN survive
 | ||||||
|  |     if (out.f >= was_inf_nan.f) { | ||||||
|  |         out.u |= 255UL << 23; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Sign bit
 | ||||||
|  |     out.u |= (value & 0x8000UL) << 16; | ||||||
|  | 
 | ||||||
|  |     return out.f; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| #if defined _MSC_VER || defined(__MINGW32__) | #if defined _MSC_VER || defined(__MINGW32__) | ||||||
| 
 | 
 | ||||||
| #if !defined(__MINGW32__) | #if !defined(__MINGW32__) | ||||||
|  | @ -326,42 +355,13 @@ static float table_f32_f16[1 << 16]; | ||||||
| // This is also true for POWER9.
 | // This is also true for POWER9.
 | ||||||
| #if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) | #if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) | ||||||
| 
 | 
 | ||||||
| // https://gist.github.com/rygorous/2144712
 |  | ||||||
| // Public domain, by Fabian "ryg" Giesen
 |  | ||||||
| inline static float ggml_half_to_float(uint16_t value) { |  | ||||||
|     union FP32 { |  | ||||||
|         uint32_t u; |  | ||||||
|         float f; |  | ||||||
|     }; |  | ||||||
| 
 | 
 | ||||||
|     const union FP32 magic = { (254UL - 15UL) << 23 }; |  | ||||||
|     const union FP32 was_inf_nan = { (127UL + 16UL) << 23 }; |  | ||||||
| 
 |  | ||||||
|     union FP32 out; |  | ||||||
| 
 |  | ||||||
|     // Exponent/mantissa bits
 |  | ||||||
|     out.u = (value & 0x7FFFU) << 13; |  | ||||||
|     // Exponent adjust
 |  | ||||||
|     out.f *= magic.f; |  | ||||||
| 
 |  | ||||||
|     // Make sure Inf/NaN survive
 |  | ||||||
|     if (out.f >= was_inf_nan.f) { |  | ||||||
|         out.u |= 255UL << 23; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // Sign bit
 |  | ||||||
|     out.u |= (value & 0x8000UL) << 16; |  | ||||||
| 
 |  | ||||||
|     return out.f; |  | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { | inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { | ||||||
|     // For some reason, lookup table does not work on my machine:
 |     // For some reason, lookup table does not work on my machine.
 | ||||||
|     // - Windows SDK version 10.0.19041.0
 |     // Replaced lookup with working reference code.
 | ||||||
|     // - CMAKE_SYSTEM_PROCESSOR: AMD64
 |  | ||||||
|     // Replaced lookup with some conversion code found online.
 |  | ||||||
|     // TODO This must be properly debugged and fixed
 |     // TODO This must be properly debugged and fixed
 | ||||||
|     return ggml_half_to_float(f); |     return ggml_half_to_float_reference(f); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) | #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) | ||||||
|  | @ -514,6 +514,19 @@ typedef struct { | ||||||
| } block_q4_1; | } block_q4_1; | ||||||
| static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding"); | static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding"); | ||||||
| 
 | 
 | ||||||
|  | // Method 4 with better outlier handling.
 | ||||||
|  | typedef struct { | ||||||
|  |     ggml_fp16_t d; | ||||||
|  |     ggml_fp16_t m; | ||||||
|  |     // We need only 5 bits for the in-block index, so 16 bits is overkill.
 | ||||||
|  |     // TODO Optimize if possible
 | ||||||
|  |     uint16_t outlier_index; | ||||||
|  |     ggml_fp16_t outlier_value; | ||||||
|  |     // Nibbles / quants.
 | ||||||
|  |     uint8_t qs[QK / 2]; | ||||||
|  | } block_q4_1_o; | ||||||
|  | static_assert(sizeof(block_q4_1_o) == 8 + QK / 2, "wrong q4_1_o block size/padding"); | ||||||
|  | 
 | ||||||
| // reference implementation for deterministic creation of model files
 | // reference implementation for deterministic creation of model files
 | ||||||
| static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { | static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { | ||||||
|     assert(k % QK == 0); |     assert(k % QK == 0); | ||||||
|  | @ -1118,6 +1131,208 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in | ||||||
| #endif | #endif | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Q4_1_O
 | ||||||
|  | 
 | ||||||
|  | static inline void quantize_row_q4_1_o_reference_single_block(const float * restrict x, block_q4_1_o * restrict block) { | ||||||
|  |     // An outlier is just the absmax element in the block.
 | ||||||
|  |     // We store it separately and do not quantize it.
 | ||||||
|  |     int outlier_index = -1; | ||||||
|  |     float outlier_value = 0.0F; | ||||||
|  | 
 | ||||||
|  |     for (int l = 0; l < QK; l++) { | ||||||
|  |         const float v = x[l]; | ||||||
|  | 
 | ||||||
|  |         if (fabsf(v) > fabsf(outlier_value)) { | ||||||
|  |             outlier_index = l; | ||||||
|  |             outlier_value = v; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     block->outlier_index = outlier_index; | ||||||
|  |     block->outlier_value = GGML_COMPUTE_FP32_TO_FP16(outlier_value); | ||||||
|  | 
 | ||||||
|  |     float min = FLT_MAX; | ||||||
|  |     float max = -FLT_MAX; | ||||||
|  | 
 | ||||||
|  |     for (int l = 0; l < QK; l++) { | ||||||
|  |         if (l == outlier_index) { | ||||||
|  |             // Ignore outlier when computing range.
 | ||||||
|  |             continue; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         const float v = x[l]; | ||||||
|  |         if (v < min) min = v; | ||||||
|  |         if (v > max) max = v; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     const float d = (max - min) / ((1 << 4) - 1); | ||||||
|  |     const float id = d ? 1.0F / d : 0.0F; | ||||||
|  | 
 | ||||||
|  |     block->d = GGML_COMPUTE_FP32_TO_FP16(d); | ||||||
|  |     block->m = GGML_COMPUTE_FP32_TO_FP16(min); | ||||||
|  | 
 | ||||||
|  |     uint8_t pp[QK / 2]; | ||||||
|  | 
 | ||||||
|  |     for (int l = 0; l < QK; l += 2) { | ||||||
|  |         float v0 = (x[l + 0] - min) * id; | ||||||
|  |         float v1 = (x[l + 1] - min) * id; | ||||||
|  | 
 | ||||||
|  |         // Write some garbage but valid index for the outlier.
 | ||||||
|  |         if (l + 0 == outlier_index) v0 = 0.0; | ||||||
|  |         if (l + 1 == outlier_index) v1 = 0.0; | ||||||
|  | 
 | ||||||
|  |         const uint8_t vi0 = roundf(v0); | ||||||
|  |         const uint8_t vi1 = roundf(v1); | ||||||
|  | 
 | ||||||
|  |         assert(vi0 >= 0 && vi0 < 16); | ||||||
|  |         assert(vi1 >= 0 && vi1 < 16); | ||||||
|  | 
 | ||||||
|  |         pp[l/2] = vi0 | (vi1 << 4); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     memcpy(block->qs, pp, sizeof(pp)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | static inline void dequantize_row_q4_1_o_reference_single_block(block_q4_1_o * restrict block, float * restrict y) { | ||||||
|  |     const float d = ggml_half_to_float_reference(block->d); | ||||||
|  |     const float m = ggml_half_to_float_reference(block->m); | ||||||
|  | 
 | ||||||
|  |     const uint8_t * restrict pp = block->qs; | ||||||
|  | 
 | ||||||
|  |     for (int l = 0; l < QK; l += 2) { | ||||||
|  |         const uint8_t vi = pp[l / 2]; | ||||||
|  | 
 | ||||||
|  |         const int8_t vi0 = vi & 0xF; | ||||||
|  |         const int8_t vi1 = vi >> 4; | ||||||
|  | 
 | ||||||
|  |         const float v0 = vi0 * d + m; | ||||||
|  |         const float v1 = vi1 * d + m; | ||||||
|  | 
 | ||||||
|  |         y[l + 0] = v0; | ||||||
|  |         y[l + 1] = v1; | ||||||
|  | 
 | ||||||
|  |         assert(!isnan(y[l + 0])); | ||||||
|  |         assert(!isnan(y[l + 1])); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Restore the outlier
 | ||||||
|  |     y[block->outlier_index] = ggml_half_to_float_reference(block->outlier_value); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | static void quantize_row_q4_1_o_reference(const float * restrict x, void * restrict vy, int k) { | ||||||
|  |     assert(k % QK == 0); | ||||||
|  |     const int nb = k / QK; | ||||||
|  | 
 | ||||||
|  |     block_q4_1_o * restrict y = vy; | ||||||
|  | 
 | ||||||
|  |     for (int i = 0; i < nb; i++) { | ||||||
|  |         quantize_row_q4_1_o_reference_single_block(x + i * QK, y + i); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | static void quantize_row_q4_1_o(const float * restrict x, void * restrict vy, int k) { | ||||||
|  |     quantize_row_q4_1_o_reference(x, vy, k); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | static void dequantize_row_q4_1_o(const void * restrict vx, float * restrict y, int k) { | ||||||
|  |     assert(k % QK == 0); | ||||||
|  |     const int nb = k / QK; | ||||||
|  | 
 | ||||||
|  |     const block_q4_1_o * restrict x = vx; | ||||||
|  | 
 | ||||||
|  | #if defined(__AVX2__) | ||||||
|  |     for (int i = 0; i < nb; i++) { | ||||||
|  |         const float x_d = ggml_half_to_float_reference(x[i].d); | ||||||
|  |         const float x_m = ggml_half_to_float_reference(x[i].m); | ||||||
|  | 
 | ||||||
|  |         const __m256 d_v = _mm256_broadcast_ss(&x_d); | ||||||
|  |         const __m256 d_m = _mm256_broadcast_ss(&x_m); | ||||||
|  | 
 | ||||||
|  |         const uint8_t * restrict pp = x[i].qs; | ||||||
|  | 
 | ||||||
|  |         for (int l = 0; l < QK; l += 32) { | ||||||
|  |             // Load 32x4-bit integers into 32x8-bit integers
 | ||||||
|  |             __m256i vx8 = bytesFromNibbles(pp+l/2); | ||||||
|  | 
 | ||||||
|  |             // Convert to 16-bit int
 | ||||||
|  |             const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0)); | ||||||
|  |             const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1)); | ||||||
|  | 
 | ||||||
|  |             // Convert to 32-bit int -> float 32
 | ||||||
|  |             const __m256 vf[4] = { | ||||||
|  |                 _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))), | ||||||
|  |                 _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))), | ||||||
|  |                 _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))), | ||||||
|  |                 _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1))) | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             // Scale, add m and store
 | ||||||
|  |             for (int j = 0; j < 4; j++) { | ||||||
|  |                 const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m); | ||||||
|  |                 _mm256_storeu_ps(y + i * QK + l + j*8, result); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Restore the outlier
 | ||||||
|  |         y[i * QK + x[i].outlier_index] = ggml_half_to_float_reference(x[i].outlier_value); | ||||||
|  |     } | ||||||
|  | #elif defined(__ARM_NEON) | ||||||
|  |     for (int i = 0; i < nb; i++) { | ||||||
|  |         const float x_d = ggml_half_to_float_reference(x[i].d); | ||||||
|  |         const float x_m = ggml_half_to_float_reference(x[i].m); | ||||||
|  | 
 | ||||||
|  |         const float32x4_t vd = vdupq_n_f32(x_d); | ||||||
|  |         const float32x4_t vm = vdupq_n_f32(x_m); | ||||||
|  | 
 | ||||||
|  |         const uint8_t * restrict pp = x[i].qs; | ||||||
|  | 
 | ||||||
|  |         for (int l = 0; l < QK; l += 16) { | ||||||
|  |             // Load 16x4-bit integers into 8x8-bit integers
 | ||||||
|  |             const uint8x8_t v8 = vld1_u8(pp + l/2); | ||||||
|  | 
 | ||||||
|  |             // Expand 4-bit qs to 8-bit bytes
 | ||||||
|  |             const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); | ||||||
|  |             const uint8x8_t v1 = vshr_n_u8(v8, 4); | ||||||
|  | 
 | ||||||
|  |             // Interleave and combine
 | ||||||
|  |             const uint8x8_t vx_0 = vzip1_u8(v0, v1); | ||||||
|  |             const uint8x8_t vx_1 = vzip2_u8(v0, v1); | ||||||
|  | 
 | ||||||
|  |             const uint8x16_t vq = vcombine_u8(vx_0, vx_1); | ||||||
|  | 
 | ||||||
|  |             // convert to 2x uint16x8_t
 | ||||||
|  |             const uint16x8_t vi_0 = vmovl_s8(vget_low_u8 (vq)); | ||||||
|  |             const uint16x8_t vi_1 = vmovl_s8(vget_high_u8(vq)); | ||||||
|  | 
 | ||||||
|  |             // convert to 4x float32x4_t
 | ||||||
|  |             const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0))); | ||||||
|  |             const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0))); | ||||||
|  |             const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1))); | ||||||
|  |             const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1))); | ||||||
|  | 
 | ||||||
|  |             // multiply by d and add m
 | ||||||
|  |             const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd); | ||||||
|  |             const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd); | ||||||
|  |             const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd); | ||||||
|  |             const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd); | ||||||
|  | 
 | ||||||
|  |             // Store
 | ||||||
|  |             vst1q_f32(y + i*QK + l +  0, r0); | ||||||
|  |             vst1q_f32(y + i*QK + l +  4, r1); | ||||||
|  |             vst1q_f32(y + i*QK + l +  8, r2); | ||||||
|  |             vst1q_f32(y + i*QK + l + 12, r3); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Restore the outlier
 | ||||||
|  |         y[i * QK + x[i].outlier_index] = ggml_half_to_float_reference(x[i].outlier_value); | ||||||
|  |     } | ||||||
|  | #else | ||||||
|  |     for (int i = 0; i < nb; i++) { | ||||||
|  |         dequantize_row_q4_1_o_reference_single_block(x + i, y + i * QK); | ||||||
|  |     } | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  | 
 | ||||||
| //
 | //
 | ||||||
| // simd mappings
 | // simd mappings
 | ||||||
| //
 | //
 | ||||||
|  | @ -2437,6 +2652,7 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x | ||||||
| //
 | //
 | ||||||
| 
 | 
 | ||||||
| static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { | static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { | ||||||
|  |     QK, | ||||||
|     QK, |     QK, | ||||||
|     QK, |     QK, | ||||||
|     1, |     1, | ||||||
|  | @ -2446,11 +2662,12 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { | ||||||
|     1, |     1, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5"); | static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_COUNT != 8"); | ||||||
| 
 | 
 | ||||||
| static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { | static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { | ||||||
|     sizeof(block_q4_0), |     sizeof(block_q4_0), | ||||||
|     sizeof(block_q4_1), |     sizeof(block_q4_1), | ||||||
|  |     sizeof(block_q4_1_o), | ||||||
|     sizeof(int8_t ), |     sizeof(int8_t ), | ||||||
|     sizeof(int16_t), |     sizeof(int16_t), | ||||||
|     sizeof(int32_t), |     sizeof(int32_t), | ||||||
|  | @ -2459,7 +2676,7 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| // don't forget to update the array above when adding new types
 | // don't forget to update the array above when adding new types
 | ||||||
| static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5"); | static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_COUNT != 8"); | ||||||
| 
 | 
 | ||||||
| static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { | static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { | ||||||
|     "NONE", |     "NONE", | ||||||
|  | @ -3196,6 +3413,10 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { | ||||||
|             { |             { | ||||||
|                 GGML_ASSERT(false); |                 GGML_ASSERT(false); | ||||||
|             } break; |             } break; | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|  |             { | ||||||
|  |                 GGML_ASSERT(false); | ||||||
|  |             } break; | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|             { |             { | ||||||
|                 assert(tensor->nb[0] == sizeof(int8_t)); |                 assert(tensor->nb[0] == sizeof(int8_t)); | ||||||
|  | @ -3256,6 +3477,10 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { | ||||||
|             { |             { | ||||||
|                 GGML_ASSERT(false); |                 GGML_ASSERT(false); | ||||||
|             } break; |             } break; | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|  |             { | ||||||
|  |                 GGML_ASSERT(false); | ||||||
|  |             } break; | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|             { |             { | ||||||
|                 assert(tensor->nb[0] == sizeof(int8_t)); |                 assert(tensor->nb[0] == sizeof(int8_t)); | ||||||
|  | @ -3310,6 +3535,10 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { | ||||||
|             { |             { | ||||||
|                 GGML_ASSERT(false); |                 GGML_ASSERT(false); | ||||||
|             } break; |             } break; | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|  |             { | ||||||
|  |                 GGML_ASSERT(false); | ||||||
|  |             } break; | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|             { |             { | ||||||
|                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); |                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); | ||||||
|  | @ -3354,6 +3583,10 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { | ||||||
|             { |             { | ||||||
|                 GGML_ASSERT(false); |                 GGML_ASSERT(false); | ||||||
|             } break; |             } break; | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|  |             { | ||||||
|  |                 GGML_ASSERT(false); | ||||||
|  |             } break; | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|             { |             { | ||||||
|                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); |                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); | ||||||
|  | @ -3396,6 +3629,10 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { | ||||||
|             { |             { | ||||||
|                 GGML_ASSERT(false); |                 GGML_ASSERT(false); | ||||||
|             } break; |             } break; | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|  |             { | ||||||
|  |                 GGML_ASSERT(false); | ||||||
|  |             } break; | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|             { |             { | ||||||
|                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); |                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); | ||||||
|  | @ -3440,6 +3677,10 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { | ||||||
|             { |             { | ||||||
|                 GGML_ASSERT(false); |                 GGML_ASSERT(false); | ||||||
|             } break; |             } break; | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|  |             { | ||||||
|  |                 GGML_ASSERT(false); | ||||||
|  |             } break; | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|             { |             { | ||||||
|                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); |                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); | ||||||
|  | @ -4990,6 +5231,7 @@ static void ggml_compute_forward_dup( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5067,6 +5309,7 @@ static void ggml_compute_forward_add( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5119,6 +5362,7 @@ static void ggml_compute_forward_sub( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5171,6 +5415,7 @@ static void ggml_compute_forward_mul( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5223,6 +5468,7 @@ static void ggml_compute_forward_div( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5271,6 +5517,7 @@ static void ggml_compute_forward_sqr( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5319,6 +5566,7 @@ static void ggml_compute_forward_sqrt( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5377,6 +5625,7 @@ static void ggml_compute_forward_sum( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5454,6 +5703,7 @@ static void ggml_compute_forward_mean( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5518,6 +5768,7 @@ static void ggml_compute_forward_repeat( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5566,6 +5817,7 @@ static void ggml_compute_forward_abs( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5614,6 +5866,7 @@ static void ggml_compute_forward_sgn( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5662,6 +5915,7 @@ static void ggml_compute_forward_neg( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5710,6 +5964,7 @@ static void ggml_compute_forward_exp( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5758,6 +6013,7 @@ static void ggml_compute_forward_1_minus_x( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5810,6 +6066,7 @@ static void ggml_compute_forward_max( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5858,6 +6115,7 @@ static void ggml_compute_forward_step( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5906,6 +6164,7 @@ static void ggml_compute_forward_relu( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -5971,6 +6230,7 @@ static void ggml_compute_forward_gelu( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -6021,6 +6281,7 @@ static void ggml_compute_forward_sigmoid( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -6086,6 +6347,7 @@ static void ggml_compute_forward_silu( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -6172,6 +6434,7 @@ static void ggml_compute_forward_norm( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -6252,6 +6515,7 @@ static void ggml_compute_forward_rms_norm( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -6669,6 +6933,11 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { | ||||||
|         .quantize_row_q   = quantize_row_q4_1, |         .quantize_row_q   = quantize_row_q4_1, | ||||||
|         .vec_dot_q        = ggml_vec_dot_q4_1, |         .vec_dot_q        = ggml_vec_dot_q4_1, | ||||||
|     }, |     }, | ||||||
|  |     [GGML_TYPE_Q4_1_O] = { | ||||||
|  |         .dequantize_row_q = dequantize_row_q4_1_o, | ||||||
|  |         .quantize_row_q   = quantize_row_q4_1_o, | ||||||
|  |         .vec_dot_q        = NULL, | ||||||
|  |     }, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| static void ggml_compute_forward_mul_mat_q_f32( | static void ggml_compute_forward_mul_mat_q_f32( | ||||||
|  | @ -6859,6 +7128,273 @@ static void ggml_compute_forward_mul_mat_q_f32( | ||||||
|     //}
 |     //}
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | static void ggml_compute_forward_mul_mat_q4_1_o_f32( | ||||||
|  |         const struct ggml_compute_params * params, | ||||||
|  |         const struct ggml_tensor * src0, | ||||||
|  |         const struct ggml_tensor * src1, | ||||||
|  |               struct ggml_tensor * dst) { | ||||||
|  |     int64_t t0 = ggml_perf_time_us(); | ||||||
|  |     UNUSED(t0); | ||||||
|  | 
 | ||||||
|  |     const int ne00 = src0->ne[0]; | ||||||
|  |     const int ne01 = src0->ne[1]; | ||||||
|  |     const int ne02 = src0->ne[2]; | ||||||
|  |     const int ne03 = src0->ne[3]; | ||||||
|  | 
 | ||||||
|  |     const int ne10 = src1->ne[0]; | ||||||
|  |     const int ne11 = src1->ne[1]; | ||||||
|  |     const int ne12 = src1->ne[2]; | ||||||
|  |     const int ne13 = src1->ne[3]; | ||||||
|  | 
 | ||||||
|  |     const int ne0  = dst->ne[0]; | ||||||
|  |     const int ne1  = dst->ne[1]; | ||||||
|  |     const int ne2  = dst->ne[2]; | ||||||
|  |     const int ne3  = dst->ne[3]; | ||||||
|  | 
 | ||||||
|  |     const int nb00 = src0->nb[0]; | ||||||
|  |     const int nb01 = src0->nb[1]; | ||||||
|  |     const int nb02 = src0->nb[2]; | ||||||
|  |     const int nb03 = src0->nb[3]; | ||||||
|  | 
 | ||||||
|  |     const int nb10 = src1->nb[0]; | ||||||
|  |     const int nb11 = src1->nb[1]; | ||||||
|  |     const int nb12 = src1->nb[2]; | ||||||
|  |     const int nb13 = src1->nb[3]; | ||||||
|  | 
 | ||||||
|  |     const int nb0  = dst->nb[0]; | ||||||
|  |     const int nb1  = dst->nb[1]; | ||||||
|  |     const int nb2  = dst->nb[2]; | ||||||
|  |     const int nb3  = dst->nb[3]; | ||||||
|  | 
 | ||||||
|  |     const int ith = params->ith; | ||||||
|  |     const int nth = params->nth; | ||||||
|  | 
 | ||||||
|  |     GGML_ASSERT(ne02 == ne12); | ||||||
|  |     GGML_ASSERT(ne03 == ne13); | ||||||
|  |     GGML_ASSERT(ne2  == ne12); | ||||||
|  |     GGML_ASSERT(ne3  == ne13); | ||||||
|  | 
 | ||||||
|  |     const enum ggml_type type = src0->type; | ||||||
|  | 
 | ||||||
|  |     // we don't support permuted src0 or src1
 | ||||||
|  |     GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]); | ||||||
|  |     GGML_ASSERT(nb10 == sizeof(float)); | ||||||
|  | 
 | ||||||
|  |     // dst cannot be transposed or permuted
 | ||||||
|  |     GGML_ASSERT(nb0 == sizeof(float)); | ||||||
|  |     GGML_ASSERT(nb0 <= nb1); | ||||||
|  |     GGML_ASSERT(nb1 <= nb2); | ||||||
|  |     GGML_ASSERT(nb2 <= nb3); | ||||||
|  | 
 | ||||||
|  |     GGML_ASSERT(ne0 == ne01); | ||||||
|  |     GGML_ASSERT(ne1 == ne11); | ||||||
|  |     GGML_ASSERT(ne2 == ne02); | ||||||
|  |     GGML_ASSERT(ne3 == ne03); | ||||||
|  | 
 | ||||||
|  |     // nb01 >= nb00 - src0 is not transposed
 | ||||||
|  |     //   compute by src0 rows
 | ||||||
|  | 
 | ||||||
|  | #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) | ||||||
|  |     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { | ||||||
|  |         if (params->ith != 0) { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if (params->type == GGML_TASK_INIT) { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if (params->type == GGML_TASK_FINALIZE) { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         float * const wdata = params->wdata; | ||||||
|  | 
 | ||||||
|  |         for (int i03 = 0; i03 < ne03; i03++) { | ||||||
|  |             for (int i02 = 0; i02 < ne02; i02++) { | ||||||
|  |                 { | ||||||
|  |                     size_t id = 0; | ||||||
|  |                     for (int i01 = 0; i01 < ne01; ++i01) { | ||||||
|  |                         dequantize_row_q4_1_o((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); | ||||||
|  |                         id += ne00; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 const float * x = wdata; | ||||||
|  |                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); | ||||||
|  | 
 | ||||||
|  |                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); | ||||||
|  | 
 | ||||||
|  |                 // zT = y * xT
 | ||||||
|  |                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, | ||||||
|  |                         ne11, ne01, ne10, | ||||||
|  |                         1.0f,    y, ne10, | ||||||
|  |                                  x, ne10, | ||||||
|  |                         0.0f,    d, ne01); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
 | ||||||
|  | 
 | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|  |     if (params->type == GGML_TASK_INIT) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (params->type == GGML_TASK_FINALIZE) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // parallelize by src0 rows using ggml_vec_dot_f32
 | ||||||
|  | 
 | ||||||
|  |     // total rows in src0
 | ||||||
|  |     const int nr = ne01*ne02*ne03; | ||||||
|  | 
 | ||||||
|  |     // rows per thread
 | ||||||
|  |     const int dr = (nr + nth - 1)/nth; | ||||||
|  | 
 | ||||||
|  |     // row range for this thread
 | ||||||
|  |     const int ir0 = dr*ith; | ||||||
|  |     const int ir1 = MIN(ir0 + dr, nr); | ||||||
|  | 
 | ||||||
|  | #if defined(__AVX2__) | ||||||
|  |     float outlier_mask[QK]; | ||||||
|  |     memset(outlier_mask, 0, QK * sizeof(float)); | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|  |     for (int ir = ir0; ir < ir1; ++ir) { | ||||||
|  |         // src0 indices
 | ||||||
|  |         const int i03 = ir/(ne02*ne01); | ||||||
|  |         const int i02 = (ir - i03*ne02*ne01)/ne01; | ||||||
|  |         const int i01 = (ir - i03*ne02*ne01 - i02*ne01); | ||||||
|  | 
 | ||||||
|  | #if defined(__AVX2__) | ||||||
|  |         for (int ic = 0; ic < ne11; ++ic) { | ||||||
|  |             // src1 indices
 | ||||||
|  |             const int i13 = i03; | ||||||
|  |             const int i12 = i02; | ||||||
|  |             const int i11 = ic; | ||||||
|  | 
 | ||||||
|  |             // dst indices
 | ||||||
|  |             const int i0 = i01; | ||||||
|  |             const int i1 = i11; | ||||||
|  |             const int i2 = i02; | ||||||
|  |             const int i3 = i03; | ||||||
|  | 
 | ||||||
|  |             const int block_count = ne00 / QK; | ||||||
|  | 
 | ||||||
|  |             const block_q4_1_o * row_blocks = (block_q4_1_o *) ((char *) src0->data + (i01 * nb01 + i02 * nb02 + i03 * nb03)); | ||||||
|  | 
 | ||||||
|  |             __m256 accum = _mm256_setzero_ps(); | ||||||
|  | 
 | ||||||
|  |             // Here we do fused dequantization and dot product.
 | ||||||
|  |             for (int block_index = 0; block_index < block_count; block_index++) { | ||||||
|  |                 const float block_d = ggml_half_to_float_reference(row_blocks[block_index].d); | ||||||
|  |                 const float block_m = ggml_half_to_float_reference(row_blocks[block_index].m); | ||||||
|  | 
 | ||||||
|  |                 // 0 .. 31
 | ||||||
|  |                 const uint16_t outlier_index = row_blocks[block_index].outlier_index; | ||||||
|  |                 const float outlier_value = ggml_half_to_float_reference(row_blocks[block_index].outlier_value); | ||||||
|  | 
 | ||||||
|  |                 const uint8_t * restrict quant_nibbles = row_blocks[block_index].qs; | ||||||
|  | 
 | ||||||
|  |                 // ---
 | ||||||
|  | 
 | ||||||
|  |                 // Broadcast values to 8x element float32 vectors
 | ||||||
|  |                 const __m256 broadcasted_d = _mm256_broadcast_ss(&block_d); | ||||||
|  |                 const __m256 broadcasted_m = _mm256_broadcast_ss(&block_m); | ||||||
|  |                 const __m256 broadcasted_outlier_value = _mm256_broadcast_ss(&outlier_value); | ||||||
|  | 
 | ||||||
|  |                 // Load 32x4-bit integers into 32x8-bit integers
 | ||||||
|  |                 const __m256i quant_bytes = bytesFromNibbles(quant_nibbles); | ||||||
|  | 
 | ||||||
|  |                 // Convert to 16-bit int
 | ||||||
|  |                 const __m256i quant_shorts_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(quant_bytes, 0)); | ||||||
|  |                 const __m256i quant_shorts_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(quant_bytes, 1)); | ||||||
|  | 
 | ||||||
|  |                 // Convert to 32-bit int and then to 32-bit float
 | ||||||
|  |                 const __m256 quant_floats_0 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_lo, 0))); | ||||||
|  |                 const __m256 quant_floats_1 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_lo, 1))); | ||||||
|  |                 const __m256 quant_floats_2 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_hi, 0))); | ||||||
|  |                 const __m256 quant_floats_3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_hi, 1))); | ||||||
|  | 
 | ||||||
|  |                 // Dequantize to ~original weights
 | ||||||
|  |                 const __m256 weight_0 = _mm256_fmadd_ps(quant_floats_0, broadcasted_d, broadcasted_m); | ||||||
|  |                 const __m256 weight_1 = _mm256_fmadd_ps(quant_floats_1, broadcasted_d, broadcasted_m); | ||||||
|  |                 const __m256 weight_2 = _mm256_fmadd_ps(quant_floats_2, broadcasted_d, broadcasted_m); | ||||||
|  |                 const __m256 weight_3 = _mm256_fmadd_ps(quant_floats_3, broadcasted_d, broadcasted_m); | ||||||
|  | 
 | ||||||
|  |                 // TODO This outlier handling is VERY slow
 | ||||||
|  |                 // Set outlier mask -- this should give 1 in the most significant bit
 | ||||||
|  |                 outlier_mask[outlier_index] = -1.0F; | ||||||
|  |                 // Load mask into vectors
 | ||||||
|  |                 const __m256 outlier_mask_0 = _mm256_load_ps(outlier_mask); | ||||||
|  |                 const __m256 outlier_mask_1 = _mm256_load_ps(outlier_mask + 8); | ||||||
|  |                 const __m256 outlier_mask_2 = _mm256_load_ps(outlier_mask + 16); | ||||||
|  |                 const __m256 outlier_mask_3 = _mm256_load_ps(outlier_mask + 24); | ||||||
|  |                 // Reset mask array to all zeroes for the next iteration
 | ||||||
|  |                 outlier_mask[outlier_index] = 0.0F; | ||||||
|  | 
 | ||||||
|  |                 // Replace the weight at the index of the outlier
 | ||||||
|  |                 const __m256 weight_0_with_outlier = _mm256_blendv_ps(weight_0, broadcasted_outlier_value, outlier_mask_0); | ||||||
|  |                 const __m256 weight_1_with_outlier = _mm256_blendv_ps(weight_1, broadcasted_outlier_value, outlier_mask_1); | ||||||
|  |                 const __m256 weight_2_with_outlier = _mm256_blendv_ps(weight_2, broadcasted_outlier_value, outlier_mask_2); | ||||||
|  |                 const __m256 weight_3_with_outlier = _mm256_blendv_ps(weight_3, broadcasted_outlier_value, outlier_mask_3); | ||||||
|  | 
 | ||||||
|  |                 // Load 32 floats of data of the second argument
 | ||||||
|  |                 const float * src1_data = (float *) ((char *) src1->data + (block_index * QK * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13)); | ||||||
|  | 
 | ||||||
|  |                 const __m256 src1_0 = _mm256_load_ps(src1_data); | ||||||
|  |                 const __m256 src1_1 = _mm256_load_ps(src1_data + 8); | ||||||
|  |                 const __m256 src1_2 = _mm256_load_ps(src1_data + 16); | ||||||
|  |                 const __m256 src1_3 = _mm256_load_ps(src1_data + 24); | ||||||
|  | 
 | ||||||
|  |                 // Multiply weights and values of the second argument element-wise; add to accumulator
 | ||||||
|  |                 accum = _mm256_fmadd_ps(src1_0, weight_0_with_outlier, accum); | ||||||
|  |                 accum = _mm256_fmadd_ps(src1_1, weight_1_with_outlier, accum); | ||||||
|  |                 accum = _mm256_fmadd_ps(src1_2, weight_2_with_outlier, accum); | ||||||
|  |                 accum = _mm256_fmadd_ps(src1_3, weight_3_with_outlier, accum); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Add elements of accumulator
 | ||||||
|  |             __m128 res = _mm256_extractf128_ps(accum, 1); | ||||||
|  |             res = _mm_add_ps(res, _mm256_castps256_ps128(accum)); | ||||||
|  |             res = _mm_add_ps(res, _mm_movehl_ps(res, res )); | ||||||
|  |             res = _mm_add_ss(res, _mm_movehdup_ps(res)); | ||||||
|  | 
 | ||||||
|  |             *((float *) ((char *) dst->data + (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3))) = _mm_cvtss_f32(res); | ||||||
|  |         } | ||||||
|  | #else | ||||||
|  |         float * const wdata = (float *) ((char *) params->wdata + (i01 * nb01 + i02 * nb02 + i03 * nb03)); | ||||||
|  | 
 | ||||||
|  |         dequantize_row_q4_1_o((char *) src0->data + (i01 * nb01 + i02 * nb02 + i03 * nb03), wdata, ne00); | ||||||
|  | 
 | ||||||
|  |         for (int ic = 0; ic < ne11; ++ic) { | ||||||
|  |             // src1 indices
 | ||||||
|  |             const int i13 = i03; | ||||||
|  |             const int i12 = i02; | ||||||
|  |             const int i11 = ic; | ||||||
|  | 
 | ||||||
|  |             // dst indices
 | ||||||
|  |             const int i0 = i01; | ||||||
|  |             const int i1 = i11; | ||||||
|  |             const int i2 = i02; | ||||||
|  |             const int i3 = i03; | ||||||
|  | 
 | ||||||
|  |             ggml_vec_dot_f32( | ||||||
|  |                     ne00, | ||||||
|  |                     (float *) ((char *) dst->data + (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3)), | ||||||
|  |                     wdata, | ||||||
|  |                     (float *) ((char *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13)) | ||||||
|  |             ); | ||||||
|  |         } | ||||||
|  | #endif | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
| static void ggml_compute_forward_mul_mat( | static void ggml_compute_forward_mul_mat( | ||||||
|         const struct ggml_compute_params * params, |         const struct ggml_compute_params * params, | ||||||
|         const struct ggml_tensor * src0, |         const struct ggml_tensor * src0, | ||||||
|  | @ -6870,6 +7406,10 @@ static void ggml_compute_forward_mul_mat( | ||||||
|             { |             { | ||||||
|                 ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); |                 ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); | ||||||
|             } break; |             } break; | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|  |             { | ||||||
|  |                 ggml_compute_forward_mul_mat_q4_1_o_f32(params, src0, src1, dst); | ||||||
|  |             } break; | ||||||
|         case GGML_TYPE_F16: |         case GGML_TYPE_F16: | ||||||
|             { |             { | ||||||
|                 ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst); |                 ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst); | ||||||
|  | @ -6965,6 +7505,7 @@ static void ggml_compute_forward_scale( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -7121,6 +7662,7 @@ static void ggml_compute_forward_get_rows( | ||||||
|     switch (src0->type) { |     switch (src0->type) { | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|             { |             { | ||||||
|                 ggml_compute_forward_get_rows_q(params, src0, src1, dst); |                 ggml_compute_forward_get_rows_q(params, src0, src1, dst); | ||||||
|             } break; |             } break; | ||||||
|  | @ -7210,6 +7752,7 @@ static void ggml_compute_forward_diag_mask_inf( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -7304,6 +7847,7 @@ static void ggml_compute_forward_soft_max( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -7446,6 +7990,7 @@ static void ggml_compute_forward_rope( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -7714,6 +8259,7 @@ static void ggml_compute_forward_conv_1d_1s( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -7982,6 +8528,7 @@ static void ggml_compute_forward_conv_1d_2s( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -8467,6 +9014,7 @@ static void ggml_compute_forward_flash_attn( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -8678,6 +9226,7 @@ static void ggml_compute_forward_flash_ff( | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|         case GGML_TYPE_Q4_1: |         case GGML_TYPE_Q4_1: | ||||||
|  |         case GGML_TYPE_Q4_1_O: | ||||||
|         case GGML_TYPE_I8: |         case GGML_TYPE_I8: | ||||||
|         case GGML_TYPE_I16: |         case GGML_TYPE_I16: | ||||||
|         case GGML_TYPE_I32: |         case GGML_TYPE_I32: | ||||||
|  | @ -9508,6 +10057,14 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) | ||||||
| #endif | #endif | ||||||
|                         } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { |                         } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { | ||||||
|                             cur = 0; |                             cur = 0; | ||||||
|  |                         } else if (node->src0->type == GGML_TYPE_Q4_1_O && node->src1->type == GGML_TYPE_F32) { | ||||||
|  | #if defined(__AVX2__) | ||||||
|  |                             cur = 0; | ||||||
|  | #else | ||||||
|  |                             // Assuming that src1 is a vector
 | ||||||
|  |                             // TODO Not sure whether this is correct
 | ||||||
|  |                             cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * ggml_nelements(node->src1); | ||||||
|  | #endif | ||||||
|                         } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) { |                         } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) { | ||||||
| #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) | #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) | ||||||
|                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { |                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { | ||||||
|  | @ -10729,6 +11286,29 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * | ||||||
|     return (n/QK*sizeof(block_q4_1)); |     return (n/QK*sizeof(block_q4_1)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | size_t ggml_quantize_q4_1_o(const float * src, void * dst, int n, int k, int64_t * hist) { | ||||||
|  |     assert(k % QK == 0); | ||||||
|  |     const int nb = k / QK; | ||||||
|  | 
 | ||||||
|  |     for (int j = 0; j < n; j += k) { | ||||||
|  |         block_q4_1_o * restrict y = (block_q4_1_o *) dst + j / QK; | ||||||
|  | 
 | ||||||
|  |         quantize_row_q4_1_o_reference(src + j, y, k); | ||||||
|  | 
 | ||||||
|  |         for (int i = 0; i < nb; i++) { | ||||||
|  |             for (int l = 0; l < QK; l += 2) { | ||||||
|  |                 const uint8_t vi0 = y[i].qs[l / 2] & 0xF; | ||||||
|  |                 const uint8_t vi1 = y[i].qs[l / 2] >> 4; | ||||||
|  | 
 | ||||||
|  |                 hist[vi0]++; | ||||||
|  |                 hist[vi1]++; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     return (n / QK * sizeof(block_q4_1_o)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////////////
 | ////////////////////////////////////////////////////////////////////////////////
 | ||||||
| 
 | 
 | ||||||
| int ggml_cpu_has_avx(void) { | int ggml_cpu_has_avx(void) { | ||||||
|  |  | ||||||
							
								
								
									
										5
									
								
								ggml.h
								
								
								
								
							
							
						
						
									
										5
									
								
								ggml.h
								
								
								
								
							|  | @ -226,7 +226,11 @@ struct ggml_context; | ||||||
| 
 | 
 | ||||||
| enum ggml_type { | enum ggml_type { | ||||||
|     GGML_TYPE_Q4_0, |     GGML_TYPE_Q4_0, | ||||||
|  |     // Stores min and delta per block, does quantized matmul.
 | ||||||
|     GGML_TYPE_Q4_1, |     GGML_TYPE_Q4_1, | ||||||
|  |     // Same as Q4_1, but stores outliers separately, and matmul is done in FP32.
 | ||||||
|  |     // An outlier is the single absmax element in the quantized block.
 | ||||||
|  |     GGML_TYPE_Q4_1_O, | ||||||
|     GGML_TYPE_I8, |     GGML_TYPE_I8, | ||||||
|     GGML_TYPE_I16, |     GGML_TYPE_I16, | ||||||
|     GGML_TYPE_I32, |     GGML_TYPE_I32, | ||||||
|  | @ -807,6 +811,7 @@ enum ggml_opt_result ggml_opt( | ||||||
| 
 | 
 | ||||||
| size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); | size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); | ||||||
| size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); | size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); | ||||||
|  | size_t ggml_quantize_q4_1_o(const float * src, void * dst, int n, int k, int64_t * hist); | ||||||
| 
 | 
 | ||||||
| //
 | //
 | ||||||
| // system info
 | // system info
 | ||||||
|  |  | ||||||
							
								
								
									
										29
									
								
								rwkv.cpp
								
								
								
								
							
							
						
						
									
										29
									
								
								rwkv.cpp
								
								
								
								
							|  | @ -160,7 +160,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr | ||||||
|         model->data_type == 0 || |         model->data_type == 0 || | ||||||
|             model->data_type == 1 || |             model->data_type == 1 || | ||||||
|             model->data_type == 2 || |             model->data_type == 2 || | ||||||
|             model->data_type == 3, |             model->data_type == 3 || | ||||||
|  |             model->data_type == 4, | ||||||
|         "Unsupported model data type %d", |         "Unsupported model data type %d", | ||||||
|         model->data_type |         model->data_type | ||||||
|     ); |     ); | ||||||
|  | @ -216,7 +217,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr | ||||||
|             data_type == 0 || |             data_type == 0 || | ||||||
|                 data_type == 1 || |                 data_type == 1 || | ||||||
|                 data_type == 2 || |                 data_type == 2 || | ||||||
|                 data_type == 3, |                 data_type == 3 || | ||||||
|  |                 data_type == 4, | ||||||
|             "Unsupported parameter data type %d", |             "Unsupported parameter data type %d", | ||||||
|             data_type |             data_type | ||||||
|         ); |         ); | ||||||
|  | @ -228,6 +230,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr | ||||||
|             case 1: ggml_data_type = GGML_TYPE_F16; break; |             case 1: ggml_data_type = GGML_TYPE_F16; break; | ||||||
|             case 2: ggml_data_type = GGML_TYPE_Q4_0; break; |             case 2: ggml_data_type = GGML_TYPE_Q4_0; break; | ||||||
|             case 3: ggml_data_type = GGML_TYPE_Q4_1; break; |             case 3: ggml_data_type = GGML_TYPE_Q4_1; break; | ||||||
|  |             case 4: ggml_data_type = GGML_TYPE_Q4_1_O; break; | ||||||
|             default: return NULL; |             default: return NULL; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  | @ -553,18 +556,17 @@ void rwkv_free(struct rwkv_context * ctx) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) { | bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) { | ||||||
|     RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type); |     RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4, "Unsupported quantization type %d", q_type); | ||||||
| 
 | 
 | ||||||
|     ggml_type type; |     ggml_type type; | ||||||
| 
 | 
 | ||||||
|     switch (q_type) { |     switch (q_type) { | ||||||
|         case 2: type = GGML_TYPE_Q4_0; break; |         case 2: type = GGML_TYPE_Q4_0; break; | ||||||
|         case 3: type = GGML_TYPE_Q4_1; break; |         case 3: type = GGML_TYPE_Q4_1; break; | ||||||
|  |         case 4: type = GGML_TYPE_Q4_1_O; break; | ||||||
|         default: return false; |         default: return false; | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     RWKV_ASSERT_FALSE(type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1, "Unsupported data type %d", type); |  | ||||||
| 
 |  | ||||||
|     printf("Loading model from '%s'\n", model_file_path_in); |     printf("Loading model from '%s'\n", model_file_path_in); | ||||||
| 
 | 
 | ||||||
|     auto finp = std::ifstream(model_file_path_in, std::ios::binary); |     auto finp = std::ifstream(model_file_path_in, std::ios::binary); | ||||||
|  | @ -646,7 +648,8 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode | ||||||
|                     "f32", |                     "f32", | ||||||
|                     "f16", |                     "f16", | ||||||
|                     "q4_0", |                     "q4_0", | ||||||
|                     "q4_1" |                     "q4_1", | ||||||
|  |                     "q4_1_o" | ||||||
|                 }; |                 }; | ||||||
|                 printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]); |                 printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]); | ||||||
| 
 | 
 | ||||||
|  | @ -655,6 +658,7 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode | ||||||
|                     4.0F, |                     4.0F, | ||||||
|                     2.0F, |                     2.0F, | ||||||
|                     20.0F / 32.0F, |                     20.0F / 32.0F, | ||||||
|  |                     24.0F / 32.0F, | ||||||
|                     24.0F / 32.0F |                     24.0F / 32.0F | ||||||
|                 }; |                 }; | ||||||
|                 total_size_orig += (size_t) (nelements * parameter_data_type_size[parameter_data_type]); |                 total_size_orig += (size_t) (nelements * parameter_data_type_size[parameter_data_type]); | ||||||
|  | @ -668,10 +672,11 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode | ||||||
|                     name != std::string("head.weight"); |                     name != std::string("head.weight"); | ||||||
| 
 | 
 | ||||||
|             if (quantize) { |             if (quantize) { | ||||||
|                 if (parameter_data_type != 0 && parameter_data_type != 1) { |                 RWKV_ASSERT_FALSE( | ||||||
|                     fprintf(stderr, "unsupported data type %d for integer quantization\n", parameter_data_type); |                     parameter_data_type == 0 || parameter_data_type == 1, | ||||||
|                     return false; |                     "Unsupported parameter data type %d, only FP32 and FP16 can be quantized", | ||||||
|                 } |                     parameter_data_type | ||||||
|  |                 ); | ||||||
| 
 | 
 | ||||||
|                 if (parameter_data_type == 1) { |                 if (parameter_data_type == 1) { | ||||||
|                     data_f16.resize(nelements); |                     data_f16.resize(nelements); | ||||||
|  | @ -719,6 +724,10 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode | ||||||
|                         { |                         { | ||||||
|                             cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); |                             cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); | ||||||
|                         } break; |                         } break; | ||||||
|  |                     case GGML_TYPE_Q4_1_O: | ||||||
|  |                         { | ||||||
|  |                             cur_size = ggml_quantize_q4_1_o(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); | ||||||
|  |                         } break; | ||||||
|                     default: |                     default: | ||||||
|                         { |                         { | ||||||
|                             fprintf(stderr, "unsupported quantization type %d\n", type); |                             fprintf(stderr, "unsupported quantization type %d\n", type); | ||||||
|  |  | ||||||
|  | @ -37,7 +37,8 @@ def main() -> None: | ||||||
|         assert data_type == 0 or\ |         assert data_type == 0 or\ | ||||||
|                data_type == 1 or\ |                data_type == 1 or\ | ||||||
|                data_type == 2 or\ |                data_type == 2 or\ | ||||||
|                data_type == 3, f'Unsupported model data type {data_type}' |                data_type == 3 or\ | ||||||
|  |                data_type == 4, f'Unsupported model data type {data_type}' | ||||||
| 
 | 
 | ||||||
|         if data_type == 0: |         if data_type == 0: | ||||||
|             # FP32, high precision |             # FP32, high precision | ||||||
|  | @ -46,12 +47,14 @@ def main() -> None: | ||||||
|             # FP16, lower precision, so higher threshold |             # FP16, lower precision, so higher threshold | ||||||
|             threshold = 0.0032 |             threshold = 0.0032 | ||||||
|         elif data_type == 2: |         elif data_type == 2: | ||||||
|             # INT4 quantized, even lower precision, so even higher threshold |             # Q4_0 quantized, even lower precision, so even higher threshold | ||||||
|             # This threshold will let some bugs pass |             threshold = 0.4 | ||||||
|             threshold = 4.0 |  | ||||||
|         elif data_type == 3: |         elif data_type == 3: | ||||||
|             # This format stores more data, so error would be lower |             # Q4_1 | ||||||
|             threshold = 1.2 |             threshold = 1.21 | ||||||
|  |         elif data_type == 4: | ||||||
|  |             # Q4_1_O | ||||||
|  |             threshold = 0.2 | ||||||
| 
 | 
 | ||||||
|     model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path) |     model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,5 +1,5 @@ | ||||||
| # Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1. | # Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0, Q4_1 or Q4_1_O (recommended). | ||||||
| # Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3 | # Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1_o.bin 4 | ||||||
| 
 | 
 | ||||||
| import argparse | import argparse | ||||||
| import rwkv_cpp_shared_library | import rwkv_cpp_shared_library | ||||||
|  | @ -8,12 +8,20 @@ def parse_args(): | ||||||
|     parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1') |     parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1') | ||||||
|     parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file') |     parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file') | ||||||
|     parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten') |     parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten') | ||||||
|     parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3) |     parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0), 3 (GGML_TYPE_Q4_1) or 4 (GGML_TYPE_Q4_1_O)', type=int, choices=[2, 3, 4], default=4) | ||||||
|     return parser.parse_args() |     return parser.parse_args() | ||||||
| 
 | 
 | ||||||
| def main() -> None: | def main() -> None: | ||||||
|     args = parse_args() |     args = parse_args() | ||||||
| 
 | 
 | ||||||
|  |     if args.data_type == 2 or args.data_type == 3: | ||||||
|  |         print() | ||||||
|  |         print('WARNING!') | ||||||
|  |         print('You are using Q4_0 or Q4_1 quantization; it will heavily degrade RWKV quality.') | ||||||
|  |         print('For best quality preservation, it is recommended to use Q4_1_O.') | ||||||
|  |         print('More info at https://github.com/saharNooby/rwkv.cpp/issues/12') | ||||||
|  |         print() | ||||||
|  | 
 | ||||||
|     library = rwkv_cpp_shared_library.load_rwkv_shared_library() |     library = rwkv_cpp_shared_library.load_rwkv_shared_library() | ||||||
| 
 | 
 | ||||||
|     library.rwkv_quantize_model_file( |     library.rwkv_quantize_model_file( | ||||||
|  |  | ||||||
|  | @ -118,5 +118,5 @@ class RWKVModel: | ||||||
| 
 | 
 | ||||||
|     def __del__(self): |     def __del__(self): | ||||||
|         # Free the context on GC in case user forgot to call free() explicitly. |         # Free the context on GC in case user forgot to call free() explicitly. | ||||||
|         if self._valid: |         if hasattr(self, '_valid') and self._valid: | ||||||
|             self.free() |             self.free() | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue