From 459e93cce07cab9052c06b5bf360819893442e1e Mon Sep 17 00:00:00 2001
From: slaren <2141330+slaren@users.noreply.github.com>
Date: Sat, 25 Mar 2023 19:31:48 +0100
Subject: [PATCH] Add AVX2 implementation of dequantize_row_q4_1 (#505)

---
 ggml.c | 34 +++++++++++++++++++++++++++++++++-
 1 file changed, 33 insertions(+), 1 deletion(-)

diff --git a/ggml.c b/ggml.c
index b566b56..c9a4e86 100644
--- a/ggml.c
+++ b/ggml.c
@@ -783,7 +783,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
 
             // Scale and store
             for (int j = 0; j < 4; j++) {
-                __m256 result = _mm256_mul_ps(vf[j], d_v);
+                const __m256 result = _mm256_mul_ps(vf[j], d_v);
                 _mm256_storeu_ps(y + i * QK + l + j*8, result);
             }
         }
@@ -879,6 +879,37 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
     const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
     const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
 
+#if defined(__AVX2__)
+    for (int i = 0; i < nb; i++) {
+        const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs));
+        const __m256 d_m = _mm256_broadcast_ss((const float *) (pm + i*bs));
+
+        const uint8_t * restrict pp = pb + i*bs;
+
+        for (int l = 0; l < QK; l += 32) {
+            // Load 32x4-bit integers into 32x8-bit integers
+            __m256i vx8 = bytesFromNibbles(pp+l/2);
+
+            // Convert to 16-bit int
+            const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
+            const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
+
+            // Convert to 32-bit int -> float 32
+            const __m256 vf[4] = {
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
+                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
+            };
+
+            // Scale, add m and store
+            for (int j = 0; j < 4; j++) {
+                const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
+                _mm256_storeu_ps(y + i * QK + l + j*8, result);
+            }
+        }
+    }
+#else
     for (int i = 0; i < nb; i++) {
         const float d = *(const float *) (pd + i*bs);
         const float m = *(const float *) (pm + i*bs);
@@ -901,6 +932,7 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
             assert(!isnan(y[i*QK + l + 1]));
         }
     }
+#endif
 }
 
 //