Add initial AVX512 support for dot product on Linux (#320)
* Update Makefile to detect AVX512 support and add compiler flags if it's available * Based on existing AVX2 implementation, dot product on one 32-value block of 4-bit quantized ints at a time * Perform 8 bit -> 16 bit sign extension and multiply+add on 32 values at time instead of 16 * Use built-in AVX512 horizontal reduce add to get sum at the end * Manual unrolling on inner dot product loop to reduce loop counter overhead
This commit is contained in:
		
							parent
							
								
									8cf9f34edd
								
							
						
					
					
						commit
						2e664f1ff4
					
				
							
								
								
									
										32
									
								
								Makefile
								
								
								
								
							
							
						
						
									
										32
									
								
								Makefile
								
								
								
								
							|  | @ -95,6 +95,38 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686)) | |||
| 		ifneq (,$(findstring sse3,$(SSE3_M))) | ||||
| 			CFLAGS += -msse3 | ||||
| 		endif | ||||
| 		AVX512F_M := $(shell grep "avx512f " /proc/cpuinfo) | ||||
| 		ifneq (,$(findstring avx512f,$(AVX512F_M))) | ||||
| 			CFLAGS += -mavx512f | ||||
| 		endif | ||||
| 		AVX512BW_M := $(shell grep "avx512bw " /proc/cpuinfo) | ||||
| 		ifneq (,$(findstring avx512bw,$(AVX512BW_M))) | ||||
| 			CFLAGS += -mavx512bw | ||||
| 		endif | ||||
| 		AVX512DQ_M := $(shell grep "avx512dq " /proc/cpuinfo) | ||||
| 		ifneq (,$(findstring avx512dq,$(AVX512DQ_M))) | ||||
| 			CFLAGS += -mavx512dq | ||||
| 		endif | ||||
| 		AVX512VL_M := $(shell grep "avx512vl " /proc/cpuinfo) | ||||
| 		ifneq (,$(findstring avx512vl,$(AVX512VL_M))) | ||||
| 			CFLAGS += -mavx512vl | ||||
| 		endif | ||||
| 		AVX512CD_M := $(shell grep "avx512cd " /proc/cpuinfo) | ||||
| 		ifneq (,$(findstring avx512cd,$(AVX512CD_M))) | ||||
| 			CFLAGS += -mavx512cd | ||||
| 		endif | ||||
| 		AVX512ER_M := $(shell grep "avx512er " /proc/cpuinfo) | ||||
| 		ifneq (,$(findstring avx512er,$(AVX512ER_M))) | ||||
| 			CFLAGS += -mavx512er | ||||
| 		endif | ||||
| 		AVX512IFMA_M := $(shell grep "avx512ifma " /proc/cpuinfo) | ||||
| 		ifneq (,$(findstring avx512ifma,$(AVX512IFMA_M))) | ||||
| 			CFLAGS += -mavx512ifma | ||||
| 		endif | ||||
| 		AVX512PF_M := $(shell grep "avx512pf " /proc/cpuinfo) | ||||
| 		ifneq (,$(findstring avx512pf,$(AVX512PF_M))) | ||||
| 			CFLAGS += -mavx512pf | ||||
| 		endif | ||||
| 	else ifeq ($(UNAME_S),Haiku) | ||||
| 		AVX1_M := $(shell sysinfo -cpu | grep "AVX ") | ||||
| 		ifneq (,$(findstring avx,$(AVX1_M))) | ||||
|  |  | |||
							
								
								
									
										80
									
								
								ggml.c
								
								
								
								
							
							
						
						
									
										80
									
								
								ggml.c
								
								
								
								
							|  | @ -361,7 +361,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); | |||
| 
 | ||||
| // AVX routines provided by GH user Const-me
 | ||||
| // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
 | ||||
| #if __AVX2__ | ||||
| #if __AVX2__ || __AVX512F__ | ||||
| // Unpack 32 4-bit fields into 32 bytes
 | ||||
| // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
 | ||||
| static inline __m256i bytesFromNibbles( const uint8_t* rsi ) | ||||
|  | @ -397,7 +397,6 @@ static inline __m128i packNibbles( __m256i bytes ) | |||
| } | ||||
| #endif | ||||
| 
 | ||||
| 
 | ||||
| // method 5
 | ||||
| // blocks of QK elements
 | ||||
| // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
 | ||||
|  | @ -1262,6 +1261,47 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float | |||
|     *s = sumf; | ||||
| } | ||||
| 
 | ||||
| #if __AVX512F__ && QK == 32 | ||||
| static inline __m512 dot_q4_0_oneblock_avx512( | ||||
|     __m512 acc, | ||||
|     const uint8_t * pd0, | ||||
|     const uint8_t * pd1, | ||||
|     const uint8_t * pb0, | ||||
|     const uint8_t * pb1, | ||||
|     size_t bs, | ||||
|     int i | ||||
| ) { | ||||
|     const float * d0_0 = (const float *) (pd0 + i*bs); | ||||
|     const float * d1_0 = (const float *) (pd1 + i*bs); | ||||
| 
 | ||||
|     const uint8_t * restrict p0 = pb0 + (i+0)*bs; | ||||
|     const uint8_t * restrict p1 = pb1 + (i+0)*bs; | ||||
| 
 | ||||
|     // Compute combined scale for the block
 | ||||
|     float scaleScalar = d0_0[0] * d1_0[0]; | ||||
|     __m512 scale = _mm512_set1_ps( scaleScalar ); | ||||
| 
 | ||||
|     __m256i bx = bytesFromNibbles( p0 ); | ||||
|     __m256i by = bytesFromNibbles( p1 ); | ||||
| 
 | ||||
|     // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
 | ||||
|     const __m256i off = _mm256_set1_epi8( 8 ); | ||||
|     bx = _mm256_sub_epi8( bx, off ); | ||||
|     by = _mm256_sub_epi8( by, off ); | ||||
| 
 | ||||
|     // Sign-extend 16 signed bytes into int16_t
 | ||||
|     __m512i x32 = _mm512_cvtepi8_epi16( bx ); | ||||
|     __m512i y32 = _mm512_cvtepi8_epi16( by ); | ||||
|     // Compute products of int16_t integers, add pairwise
 | ||||
|     __m512i i64 = _mm512_madd_epi16( x32, y32 ); | ||||
| 
 | ||||
|     // Convert int32_t to float
 | ||||
|     __m512 p = _mm512_cvtepi32_ps( i64 ); | ||||
|     // Apply the scale, and accumulate
 | ||||
|     return _mm512_fmadd_ps( scale, p, acc ); | ||||
| } | ||||
| #endif | ||||
| 
 | ||||
| inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { | ||||
|     ggml_float sumf = 0.0; | ||||
| 
 | ||||
|  | @ -1417,6 +1457,40 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void | |||
| #else | ||||
| #error "not implemented for QK" | ||||
| #endif | ||||
| #elif defined(__AVX512F__) | ||||
| 
 | ||||
| #if QK == 32 | ||||
|     // Initialize accumulator with zeros
 | ||||
|     __m512 acc0 = _mm512_setzero_ps(); | ||||
|     __m512 acc1 = _mm512_setzero_ps(); | ||||
| 
 | ||||
|     const int superblock_size = 8; | ||||
|     const int superblock_count = nb / superblock_size; | ||||
|     const int remainder = nb % superblock_size; | ||||
| 
 | ||||
|     for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) { | ||||
|         int i = superblock_ix * superblock_size; | ||||
| 
 | ||||
|         acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+0 ); | ||||
|         acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+1 ); | ||||
|         acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+2 ); | ||||
|         acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+3 ); | ||||
|         acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+4 ); | ||||
|         acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+5 ); | ||||
|         acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+6 ); | ||||
|         acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+7 ); | ||||
|     } | ||||
| 
 | ||||
|     // Remainders
 | ||||
|     for (int i = superblock_count * superblock_size; i < nb; ++i) { | ||||
|         acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i ); | ||||
|     } | ||||
| 
 | ||||
|     // Horizontal sum of all lanes of the accumulator
 | ||||
|     sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 ); | ||||
| #else | ||||
| #error "not implemented for QK" | ||||
| #endif | ||||
| #elif defined(__AVX2__) | ||||
| #if QK == 32 | ||||
|     const size_t countBlocks = nb; | ||||
|  | @ -1928,7 +2002,7 @@ inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * res | |||
|     const size_t bs = 2*sizeof(float) + QK/2; | ||||
| 
 | ||||
|     const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs); | ||||
|     const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs +   sizeof(float));  | ||||
|     const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs +   sizeof(float)); | ||||
|     const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float)); | ||||
| 
 | ||||
|     for (int i = 0; i < nb; i++) { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue