From ecbe466a364876927994e2f1ec14f4d82301d201 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Sat, 25 Mar 2023 19:47:21 +0200
Subject: [PATCH] Retire the ggml_mul_mat() branch for transposed src0 (#500)

* Retire the ggml_mul_mat() for transposed src0

- It can always be made contiguous with ggml_cpy()
- The code is now simplified
- The results are deterministic in respect to num threads

* SIMD-ify dequantize_row_q4_0() for ARM_NEON (#502)

* Attempt to SIMD-ify dequantize_row_q4_0() for ARM_NEON

* Fix dequantization - forgot to interleave the quants
---
 ggml.c | 1033 +++++++++++++++-----------------------------------------
 1 file changed, 277 insertions(+), 756 deletions(-)

diff --git a/ggml.c b/ggml.c
index d8e1fbd..291e12a 100644
--- a/ggml.c
+++ b/ggml.c
@@ -496,7 +496,7 @@ static void quantize_row_q4_0_reference(const float * restrict x, void * restric
 void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
     assert(k % QK == 0);
 
-#if __ARM_NEON || defined(__AVX2__) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__)
+#if defined(__ARM_NEON) || defined(__AVX2__) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__)
     const int nb = k / QK;
     const size_t bs = sizeof(float) + QK/2;
 
@@ -507,7 +507,6 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
 #endif
 
 #if defined(__POWER9_VECTOR__)
-#if QK == 32
     const vector float v85 = vec_splats(8.5f);
     for (int i = 0; i < nb; i++) {
         float amax = 0.0f; // absolute max
@@ -548,11 +547,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
         //memcpy(pb, pp, sizeof(pp));
         pb += bs;
     }
-#else
-#error "not implemented for QK"
-#endif
 #elif __ARM_NEON
-#if QK == 32
     for (int i = 0; i < nb; i++) {
         float amax = 0.0f; // absolute max
 
@@ -589,11 +584,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
         memcpy(pb, pp, sizeof(pp));
         pb += bs;
     }
-#else
-#error "not implemented for QK"
-#endif
 #elif defined(__AVX2__)
-#if QK == 32
     for (int i = 0; i < nb; i++) {
         // Load elements into 4 AVX vectors
         __m256 v0 = _mm256_loadu_ps( x );
@@ -660,11 +651,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
         _mm_storeu_si128( ( __m128i* )pb, res );
         pb += bs;
     }
-#else
-#error "not implemented for QK"
-#endif
 #elif defined(__wasm_simd128__)
-#if QK == 32
     for (int i = 0; i < nb; i++) {
         float amax = 0.0f; // absolute max
 
@@ -701,9 +688,6 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
         memcpy(pb, pp, sizeof(pp));
         pb += bs;
     }
-#else
-#error "not implemented for QK"
-#endif
 #else
     // scalar
     quantize_row_q4_0_reference(x, y, k);
@@ -771,7 +755,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
     const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
     const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + sizeof(float));
 
-#if defined(__AVX2__) && QK % 32 == 0
+#if defined(__AVX2__)
     for (int i = 0; i < nb; i++) {
         // scale factor
         const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs));
@@ -804,6 +788,59 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
             }
         }
     }
+#elif defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        const float d = *(const float *) (pd + i*bs);
+
+        const uint8_t * restrict pp = pb + i*bs;
+
+        const float32x4_t vd = vdupq_n_f32(d);
+
+        for (int l = 0; l < QK; l += 16) {
+            // Load 16x4-bit integers into 8x8-bit integers
+            const uint8x8_t v8 = vld1_u8(pp + l/2);
+
+            // Expand 4-bit nibbles to 8-bit bytes
+            const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
+            const uint8x8_t v1 = vshr_n_u8(v8, 4);
+
+            // Convert to signed 8-bit integers
+            const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
+            const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
+
+            // Subtract 8 from each byte
+            const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
+            const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
+
+            // Interleave and combine
+            const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
+            const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
+
+            const int8x16_t vq = vcombine_s8(vx_0, vx_1);
+
+            // convert to 2x int16x8_t
+            const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
+            const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
+
+            // convert to 4x float32x4_t
+            const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0)));
+            const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
+            const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
+            const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
+
+            // Multiply by d
+            const float32x4_t r0 = vmulq_f32(vf_0, vd);
+            const float32x4_t r1 = vmulq_f32(vf_1, vd);
+            const float32x4_t r2 = vmulq_f32(vf_2, vd);
+            const float32x4_t r3 = vmulq_f32(vf_3, vd);
+
+            // Store
+            vst1q_f32(y + i*QK + l +  0, r0);
+            vst1q_f32(y + i*QK + l +  4, r1);
+            vst1q_f32(y + i*QK + l +  8, r2);
+            vst1q_f32(y + i*QK + l + 12, r3);
+        }
+    }
 #else
     // scalar
     for (int i = 0; i < nb; i++) {
@@ -1500,8 +1537,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
 
     float sumf = 0.0;
 
-#ifdef __ARM_NEON
-#if QK == 32
+#if defined(__ARM_NEON)
     float sum0 = 0.0f;
     float sum1 = 0.0f;
 
@@ -1600,12 +1636,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
     }
 
     sumf = sum0 + sum1;
-#else
-#error "not implemented for QK"
-#endif
 #elif defined(__AVX512F__)
-
-#if QK == 32
     // Initialize accumulator with zeros
     __m512 acc0 = _mm512_setzero_ps();
     __m512 acc1 = _mm512_setzero_ps();
@@ -1634,11 +1665,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
 
     // Horizontal sum of all lanes of the accumulator
     sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
-#else
-#error "not implemented for QK"
-#endif
 #elif defined(__AVX2__)
-#if QK == 32
     const size_t countBlocks = nb;
 
     // Initialize accumulator with zeros
@@ -1689,11 +1716,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
     res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
 
     sumf = _mm_cvtss_f32( res );
-#else
-#error "not implemented for QK"
-#endif
 #elif defined(__wasm_simd128__)
-#if QK == 32
     // wasm simd
     float sum0 = 0.0f;
     float sum1 = 0.0f;
@@ -1776,9 +1799,6 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
     }
 
     sumf = sum0 + sum1;
-#else
-#error "not implemented for QK"
-#endif
 #else
     // scalar
     for (int i = 0; i < nb; i++) {
@@ -1823,7 +1843,6 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
     float sumf = 0.0;
 
 #if defined(__AVX2__)
-#if QK == 32
     // Initialize accumulator with zeros
     __m256 acc = _mm256_setzero_ps();
     // Accumulator for constant offsets
@@ -1898,9 +1917,6 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
     res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
 
     sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
-#else
-#error "not implemented for QK"
-#endif
 #else
     // scalar
     for (int i = 0; i < nb; i++) {
@@ -2017,167 +2033,6 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
 #endif
 }
 
-inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) {
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F16_STEP - 1));
-
-    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
-
-    GGML_F16_VEC ax[GGML_F16_ARR];
-    GGML_F16_VEC ay[GGML_F16_ARR];
-
-    for (int i = 0; i < np; i += GGML_F16_STEP) {
-        for (int j = 0; j < GGML_F16_ARR; j++) {
-            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
-            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
-            ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
-
-            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
-        }
-    }
-
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        GGML_ASSERT(false);
-        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
-    }
-#else
-    for (int i = 0; i < n; ++i) {
-        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
-    }
-#endif
-}
-
-inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * restrict x, const float v) {
-    assert(n % QK == 0);
-
-    const int nb = n / QK;
-    const size_t bs = sizeof(float) + QK/2;
-
-    const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
-    const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + sizeof(float));
-
-#if __ARM_NEON
-#if QK == 32
-    for (int i = 0; i < nb; ++i) {
-        const float d0 = v*(*(const float *) (pd + i*bs));
-
-        const uint8_t * restrict pp = pb + i*bs;
-
-        const uint8x8_t m4b = vdup_n_u8(0xf);
-        const int8x8_t  s8b = vdup_n_s8(0x8);
-
-        const float32x4_t vd = vdupq_n_f32(d0);
-
-        for (int j = 0; j < 2; j++) {
-            const uint8x8_t vx = vld1_u8(pp + j*8);
-
-            const int8x8_t vxl = vreinterpret_s8_u8(vand_u8(vx, m4b));
-            const int8x8_t vxh = vreinterpret_s8_u8(vshr_n_u8(vx, 4));
-
-            // sub 8
-            const int8x8_t vxls = vsub_s8(vxl, s8b);
-            const int8x8_t vxhs = vsub_s8(vxh, s8b);
-
-            //const int8x8_t vxlt = vzip_s8(vxls, vxhs)[0];
-            //const int8x8_t vxht = vzip_s8(vxls, vxhs)[1];
-            const int8x8_t vxlt = vzip1_s8(vxls, vxhs);
-            const int8x8_t vxht = vzip2_s8(vxls, vxhs);
-
-            const int8x16_t vxq = vcombine_s8(vxlt, vxht);
-
-            // convert to 2x int16x8_t
-            const int16x8_t vxq0 = vmovl_s8(vget_low_s8 (vxq));
-            const int16x8_t vxq1 = vmovl_s8(vget_high_s8(vxq));
-
-            // convert to 4x float32x4_t
-            const float32x4_t vx0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq0)));
-            const float32x4_t vx1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq0)));
-            const float32x4_t vx2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq1)));
-            const float32x4_t vx3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq1)));
-
-            const float32x4_t vy0 = vld1q_f32(y + i*32 + j*16 + 0);
-            const float32x4_t vy1 = vld1q_f32(y + i*32 + j*16 + 4);
-            const float32x4_t vy2 = vld1q_f32(y + i*32 + j*16 + 8);
-            const float32x4_t vy3 = vld1q_f32(y + i*32 + j*16 + 12);
-
-            const float32x4_t vr0 = vfmaq_f32(vy0, vx0, vd);
-            const float32x4_t vr1 = vfmaq_f32(vy1, vx1, vd);
-            const float32x4_t vr2 = vfmaq_f32(vy2, vx2, vd);
-            const float32x4_t vr3 = vfmaq_f32(vy3, vx3, vd);
-
-            vst1q_f32(y + i*32 + j*16 + 0,  vr0);
-            vst1q_f32(y + i*32 + j*16 + 4,  vr1);
-            vst1q_f32(y + i*32 + j*16 + 8,  vr2);
-            vst1q_f32(y + i*32 + j*16 + 12, vr3);
-        }
-    }
-#endif
-#else
-    // scalar
-    for (int i = 0; i < nb; i++) {
-        const float d = *(const float *) (pd + i*bs);
-
-        const uint8_t * restrict pp = pb + i*bs;
-
-        for (int l = 0; l < QK; l += 2) {
-            const uint8_t vi = pp[l/2];
-
-            const int8_t vi0 = vi & 0xf;
-            const int8_t vi1 = vi >> 4;
-
-            const float v0 = (vi0 - 8)*d;
-            const float v1 = (vi1 - 8)*d;
-
-            y[i*QK + l + 0] += v0*v;
-            y[i*QK + l + 1] += v1*v;
-
-            assert(!isnan(y[i*QK + l + 0]));
-            assert(!isnan(y[i*QK + l + 1]));
-            assert(!isinf(y[i*QK + l + 0]));
-            assert(!isinf(y[i*QK + l + 1]));
-        }
-    }
-#endif
-}
-
-inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * restrict x, const float v) {
-    assert(n % QK == 0);
-
-    const int nb = n / QK;
-    const size_t bs = 2*sizeof(float) + QK/2;
-
-    const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
-    const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs +   sizeof(float));
-    const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
-
-    for (int i = 0; i < nb; i++) {
-        const float d = *(const float *) (pd + i*bs);
-        const float m = *(const float *) (pm + i*bs);
-
-        const uint8_t * restrict pp = pb + i*bs;
-
-        for (int l = 0; l < QK; l += 2) {
-            const uint8_t vi = pp[l/2];
-
-            const uint8_t vi0 = vi & 0xf;
-            const uint8_t vi1 = vi >> 4;
-
-            const float v0 = d*vi0 + m;
-            const float v1 = d*vi1 + m;
-
-            y[i*QK + l + 0] += v0*v;
-            y[i*QK + l + 1] += v1*v;
-
-            assert(!isnan(y[i*QK + l + 0]));
-            assert(!isnan(y[i*QK + l + 1]));
-            assert(!isinf(y[i*QK + l + 0]));
-            assert(!isinf(y[i*QK + l + 1]));
-            //printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1);
-        }
-    }
-}
-
 //inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
 inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
 #if defined(GGML_SIMD)
@@ -2612,9 +2467,13 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return
-        (t0->ne[0]  == t1->ne[0])  &&
-        (t0->ne[2]  == t1->ne[2])  &&
-        (t0->ne[3]  == t1->ne[3]);
+        (t0->ne[0] == t1->ne[0])  &&
+        (t0->ne[2] == t1->ne[2])  &&
+        (t0->ne[3] == t1->ne[3]);
+}
+
+static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
+    return tensor->nb[0] > tensor->nb[1];
 }
 
 static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
@@ -4010,6 +3869,7 @@ struct ggml_tensor * ggml_mul_mat(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b) {
     GGML_ASSERT(ggml_can_mul_mat(a, b));
+    GGML_ASSERT(!ggml_is_transposed(a));
 
     bool is_node = false;
 
@@ -5949,7 +5809,7 @@ static void ggml_compute_forward_mul_mat_f32(
     assert(ne3  == ne13);
 
     // TODO: we don't support permuted src0
-    assert(nb00 == sizeof(float) || nb01 == sizeof(float));
+    assert(nb00 == sizeof(float));
 
     // dst cannot be transposed or permuted
     assert(nb0 == sizeof(float));
@@ -5964,9 +5824,6 @@ static void ggml_compute_forward_mul_mat_f32(
 
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
-    //
-    // nb00 <  nb01 - src0 is transposed
-    //   compute by src0 columns
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
@@ -6007,126 +5864,50 @@ static void ggml_compute_forward_mul_mat_f32(
 #endif
 
     if (params->type == GGML_TASK_INIT) {
-        if (nb01 >= nb00) {
-            return;
-        }
-
-        // TODO: fix this memset (wsize is overestimated)
-        memset(params->wdata, 0, params->wsize);
         return;
     }
 
     if (params->type == GGML_TASK_FINALIZE) {
-        if (nb01 >= nb00) {
-            return;
-        }
-
-        // TODO: fix this memset (wsize is overestimated)
-        //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
-
-        float * const wdata = params->wdata;
-
-        // cols per thread
-        const int dc = (ne + nth - 1)/nth;
-
-        // col range for this thread
-        const int ic0 = dc*ith;
-        const int ic1 = MIN(ic0 + dc, ne);
-
-        ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0);
-
-        for (int k = 1; k < nth; k++) {
-            ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0);
-        }
-
         return;
     }
 
-    if (nb01 >= nb00) {
-        // TODO: do not support transposed src1
-        assert(nb10 == sizeof(float));
+    // TODO: do not support transposed src1
+    assert(nb10 == sizeof(float));
 
-        // parallelize by src0 rows using ggml_vec_dot_f32
+    // parallelize by src0 rows using ggml_vec_dot_f32
 
-        // total rows in src0
-        const int nr = ne01*ne02*ne03;
+    // total rows in src0
+    const int nr = ne01*ne02*ne03;
 
-        // rows per thread
-        const int dr = (nr + nth - 1)/nth;
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
 
-        // row range for this thread
-        const int ir0 = dr*ith;
-        const int ir1 = MIN(ir0 + dr, nr);
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
 
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src0 indices
-            const int i03 = ir/(ne02*ne01);
-            const int i02 = (ir - i03*ne02*ne01)/ne01;
-            const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i03 = ir/(ne02*ne01);
+        const int i02 = (ir - i03*ne02*ne01)/ne01;
+        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
 
-            for (int ic = 0; ic < ne11; ++ic) {
-                // src1 indices
-                const int i13 = i03;
-                const int i12 = i02;
-                const int i11 = ic;
+        for (int ic = 0; ic < ne11; ++ic) {
+            // src1 indices
+            const int i13 = i03;
+            const int i12 = i02;
+            const int i11 = ic;
 
-                // dst indices
-                const int i0 = i01;
-                const int i1 = i11;
-                const int i2 = i02;
-                const int i3 = i03;
+            // dst indices
+            const int i0 = i01;
+            const int i1 = i11;
+            const int i2 = i02;
+            const int i3 = i03;
 
-                ggml_vec_dot_f32(ne00,
-                        (float *) ((char *)  dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
-                        (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
-                        (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
-            }
-        }
-    } else {
-        // parallelize by src1 columns using ggml_vec_mad_f32
-        // each thread has its own work data
-        // during FINALIZE we accumulate all work data into dst
-
-        // total columns in src1
-        const int nc = ne10;
-
-        // columns per thread
-        const int dc = (nc + nth - 1)/nth;
-
-        // column range for this thread
-        const int ic0 = dc*ith;
-        const int ic1 = MIN(ic0 + dc, nc);
-
-        // work data for thread
-        const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
-        float * const wdata = params->wdata;
-
-        for (int i13 = 0; i13 < ne13; ++i13) {
-            for (int i12 = 0; i12 < ne12; ++i12) {
-                for (int i11 = 0; i11 < ne11; ++i11) {
-                    for (int ic = ic0; ic < ic1; ++ic) {
-                        // src1 indices
-                        const int i10 = ic;
-
-                        // src0 indices
-                        const int i03 = i13;
-                        const int i02 = i12;
-                        const int i00 = ic;
-
-                        // dst indices
-                        const int i1 = i11;
-                        const int i2 = i12;
-                        const int i3 = i13;
-
-                        assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
-
-                        ggml_vec_mad_f32(ne01,
-                                (float *) (wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0),
-                                (float *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)),
-                               *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)));
-                    }
-                }
-            }
+            ggml_vec_dot_f32(ne00,
+                    (float *) ((char *)  dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+                    (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
+                    (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
         }
     }
 
@@ -6192,7 +5973,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     GGML_ASSERT(ne3  == ne13);
 
     // TODO: we don't support permuted src0
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
 
     // dst cannot be transposed or permuted
     GGML_ASSERT(nb0 == sizeof(float));
@@ -6207,9 +5988,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
-    //
-    // nb00 <  nb01 - src0 is transposed
-    //   compute by src0 columns
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
@@ -6261,148 +6039,66 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 #endif
 
     if (params->type == GGML_TASK_INIT) {
-        if (nb01 >= nb00) {
-            ggml_fp16_t * const wdata = params->wdata;
+        ggml_fp16_t * const wdata = params->wdata;
 
-            size_t id = 0;
-            for (int i13 = 0; i13 < ne13; ++i13) {
-                for (int i12 = 0; i12 < ne12; ++i12) {
-                    for (int i11 = 0; i11 < ne11; ++i11) {
-                        for (int i10 = 0; i10 < ne10; ++i10) {
-                            wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
-                        }
+        size_t id = 0;
+        for (int i13 = 0; i13 < ne13; ++i13) {
+            for (int i12 = 0; i12 < ne12; ++i12) {
+                for (int i11 = 0; i11 < ne11; ++i11) {
+                    for (int i10 = 0; i10 < ne10; ++i10) {
+                        wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
                     }
                 }
             }
-
-            GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
-
-            return;
         }
 
-        // TODO: fix this memset (wsize is overestimated)
-        memset(params->wdata, 0, params->wsize);
+        GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
+
         return;
     }
 
     if (params->type == GGML_TASK_FINALIZE) {
-        if (nb01 >= nb00) {
-            return;
-        }
-
-        // TODO: fix this memset (wsize is overestimated)
-        //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
-
-        ggml_fp16_t * const wdata = params->wdata;
-
-        // cols per thread
-        const int dc = (ne + nth - 1)/nth;
-
-        // col range for this thread
-        const int ic0 = dc*ith;
-        const int ic1 = MIN(ic0 + dc, ne);
-
-        for (int i = ic0; i < ic1; ++i) {
-            ((float *) dst->data)[i] = GGML_FP16_TO_FP32(wdata[i]);
-        }
-
-        for (int k = 1; k < nth; k++) {
-            for (int i = ic0; i < ic1; ++i) {
-                ((float *) dst->data)[i] += GGML_FP16_TO_FP32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]);
-            }
-        }
-
         return;
     }
 
-    if (nb01 >= nb00) {
-        // fp16 -> half the size, so divide by 2
-        // TODO: do not support transposed src1
-        assert(nb10/2 == sizeof(ggml_fp16_t));
+    // fp16 -> half the size, so divide by 2
+    // TODO: do not support transposed src1
+    assert(nb10/2 == sizeof(ggml_fp16_t));
 
-        // parallelize by src0 rows using ggml_vec_dot_f16
+    // parallelize by src0 rows using ggml_vec_dot_f16
 
-        // total rows in src0
-        const int nr = ne01*ne02*ne03;
+    // total rows in src0
+    const int nr = ne01*ne02*ne03;
 
-        // rows per thread
-        const int dr = (nr + nth - 1)/nth;
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
 
-        // row range for this thread
-        const int ir0 = dr*ith;
-        const int ir1 = MIN(ir0 + dr, nr);
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
 
-        ggml_fp16_t * wdata = params->wdata;
+    ggml_fp16_t * wdata = params->wdata;
 
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src0 indices
-            const int i03 = ir/(ne02*ne01);
-            const int i02 = (ir - i03*ne02*ne01)/ne01;
-            const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i03 = ir/(ne02*ne01);
+        const int i02 = (ir - i03*ne02*ne01)/ne01;
+        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
 
-            const int i13 = i03;
-            const int i12 = i02;
+        const int i13 = i03;
+        const int i12 = i02;
 
-            const int i0 = i01;
-            const int i2 = i02;
-            const int i3 = i03;
+        const int i0 = i01;
+        const int i2 = i02;
+        const int i3 = i03;
 
-            ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
-            ggml_fp16_t * src1_col =                                wdata + (       0 + i12*ne11 + i13*ne12*ne11)*ne00;
+        ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+        ggml_fp16_t * src1_col =                                wdata + (       0 + i12*ne11 + i13*ne12*ne11)*ne00;
 
-            float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
+        float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
 
-            for (int ic = 0; ic < ne11; ++ic) {
-                ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
-            }
-        }
-    } else {
-        // parallelize by src1 columns using ggml_vec_mad_f16
-        // each thread has its own work data
-        // during FINALIZE we accumulate all work data into dst
-
-        // total columns in src1
-        const int nc = ne10;
-
-        // columns per thread
-        const int dc = (nc + nth - 1)/nth;
-
-        // column range for this thread
-        const int ic0 = dc*ith;
-        const int ic1 = MIN(ic0 + dc, nc);
-
-        // work data for thread
-        const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
-        ggml_fp16_t * const wdata = params->wdata;
-
-        for (int i13 = 0; i13 < ne13; ++i13) {
-            for (int i12 = 0; i12 < ne12; ++i12) {
-                for (int i11 = 0; i11 < ne11; ++i11) {
-                    // dst indices
-                    const int i1 = i11;
-                    const int i2 = i12;
-                    const int i3 = i13;
-
-                    ggml_fp16_t * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
-
-                    for (int ic = ic0; ic < ic1; ++ic) {
-                        // src1 indices
-                        const int i10 = ic;
-
-                        // src0 indices
-                        const int i03 = i13;
-                        const int i02 = i12;
-                        const int i00 = ic;
-
-                        assert(sizeof(ggml_fp16_t)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
-
-                        ggml_fp16_t * src0_col =  (ggml_fp16_t *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
-                        float         src1_val = *      (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
-
-                        ggml_vec_mad_f16(ne01, dst_row, src0_col, src1_val);
-                    }
-                }
-            }
+        for (int ic = 0; ic < ne11; ++ic) {
+            ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
         }
     }
 
@@ -6467,7 +6163,7 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
     GGML_ASSERT(ne3  == ne13);
 
     // TODO: we don't support permuted src0
-    GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0] || nb01 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0]);
+    GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0]);
 
     // dst cannot be transposed or permuted
     GGML_ASSERT(nb0 == sizeof(float));
@@ -6482,9 +6178,6 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
 
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
-    //
-    // nb00 <  nb01 - src0 is transposed
-    //   compute by src0 columns
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
@@ -6509,9 +6202,6 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
                 {
                     size_t id = 0;
                     for (int i01 = 0; i01 < ne01; ++i01) {
-                        //for (int i00 = 0; i00 < ne00; ++i00) {
-                        //    wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
-                        //}
                         dequantize_row_q4_0((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
                         id += ne00;
                     }
@@ -6538,143 +6228,63 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
 #endif
 
     if (params->type == GGML_TASK_INIT) {
-        //printf("HHHHHHHHH ith = %d, nth = %d\n", ith, nth);
-        if (nb01 >= nb00) {
-            char * wdata = params->wdata;
-
-            for (int i13 = 0; i13 < ne13; ++i13) {
-                for (int i12 = 0; i12 < ne12; ++i12) {
-                    for (int i11 = 0; i11 < ne11; ++i11) {
-                        //for (int i10 = 0; i10 < ne10; ++i10) {
-                        //    wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
-                        //}
-                        quantize_row_q4_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
-                        wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
-                    }
-                }
-            }
-
-            return;
-        }
-
-        // TODO: fix this memset (wsize is overestimated)
-        memset(params->wdata, 0, params->wsize);
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        if (nb01 >= nb00) {
-            return;
-        }
-
-        float * const wdata = params->wdata;
-
-        // cols per thread
-        const int dc = (ne + nth - 1)/nth;
-
-        // col range for this thread
-        const int ic0 = dc*ith;
-        const int ic1 = MIN(ic0 + dc, ne);
-
-        ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0);
-
-        for (int k = 1; k < nth; k++) {
-            ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0);
-        }
-
-        return;
-    }
-
-    if (nb01 >= nb00) {
-        // TODO: do not support transposed src1
-
-        // parallelize by src0 rows using ggml_vec_dot_q4_0
-
-        // total rows in src0
-        const int nr = ne01*ne02*ne03;
-
-        // rows per thread
-        const int dr = (nr + nth - 1)/nth;
-
-        // row range for this thread
-        const int ir0 = dr*ith;
-        const int ir1 = MIN(ir0 + dr, nr);
-
-        void * wdata = params->wdata;
-
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src0 indices
-            const int i03 = ir/(ne02*ne01);
-            const int i02 = (ir - i03*ne02*ne01)/ne01;
-            const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int i13 = i03;
-            const int i12 = i02;
-
-            const int i0 = i01;
-            const int i2 = i02;
-            const int i3 = i03;
-
-            void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
-            char * src1_col =          ((char *)      wdata + (      (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]);
-
-            float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
-
-            assert(ne00 % 32 == 0);
-
-            for (int ic = 0; ic < ne11; ++ic) {
-                ggml_vec_dot_q4_0(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0])));
-            }
-        }
-    } else {
-        //printf("AAAAA ith = %d, nth = %d\n", ith, nth);
-        // parallelize by src1 columns using ggml_vec_mad_q4_0
-        // each thread has its own work data
-        // during FINALIZE we accumulate all work data into dst
-
-        // total columns in src1
-        const int nc = ne10;
-
-        // columns per thread
-        const int dc = (nc + nth - 1)/nth;
-
-        // column range for this thread
-        const int ic0 = dc*ith;
-        const int ic1 = MIN(ic0 + dc, nc);
-
-        // work data for thread
-        const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
-        float * const wdata = params->wdata;
+        char * wdata = params->wdata;
 
         for (int i13 = 0; i13 < ne13; ++i13) {
             for (int i12 = 0; i12 < ne12; ++i12) {
                 for (int i11 = 0; i11 < ne11; ++i11) {
-                    // dst indices
-                    const int i1 = i11;
-                    const int i2 = i12;
-                    const int i3 = i13;
-
-                    float * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
-
-                    for (int ic = ic0; ic < ic1; ++ic) {
-                        // src1 indices
-                        const int i10 = ic;
-
-                        // src0 indices
-                        const int i03 = i13;
-                        const int i02 = i12;
-                        const int i00 = ic;
-
-                        assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
-
-                        void * src0_col =   (void *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
-                        float  src1_val = *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
-
-                        ggml_vec_mad_q4_0(ne01, dst_row, src0_col, src1_val);
-                    }
+                    quantize_row_q4_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+                    wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
                 }
             }
         }
+
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // TODO: do not support transposed src1
+
+    // parallelize by src0 rows using ggml_vec_dot_q4_0
+
+    // total rows in src0
+    const int nr = ne01*ne02*ne03;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    void * wdata = params->wdata;
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i03 = ir/(ne02*ne01);
+        const int i02 = (ir - i03*ne02*ne01)/ne01;
+        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        const int i13 = i03;
+        const int i12 = i02;
+
+        const int i0 = i01;
+        const int i2 = i02;
+        const int i3 = i03;
+
+        void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+        char * src1_col =          ((char *)      wdata + (      (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]);
+
+        float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
+
+        assert(ne00 % 32 == 0);
+
+        for (int ic = 0; ic < ne11; ++ic) {
+            ggml_vec_dot_q4_0(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0])));
+        }
     }
 
     //int64_t t1 = ggml_time_us();
@@ -6738,7 +6348,7 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
     GGML_ASSERT(ne3  == ne13);
 
     // TODO: we don't support permuted src0
-    GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1] || nb01 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1]);
+    GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1]);
 
     // dst cannot be transposed or permuted
     GGML_ASSERT(nb0 == sizeof(float));
@@ -6753,9 +6363,6 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
 
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
-    //
-    // nb00 <  nb01 - src0 is transposed
-    //   compute by src0 columns
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
@@ -6780,9 +6387,6 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
                 {
                     size_t id = 0;
                     for (int i01 = 0; i01 < ne01; ++i01) {
-                        //for (int i00 = 0; i00 < ne00; ++i00) {
-                        //    wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
-                        //}
                         dequantize_row_q4_1((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
                         id += ne00;
                     }
@@ -6809,143 +6413,66 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
 #endif
 
     if (params->type == GGML_TASK_INIT) {
-        //printf("HHHHHHHHH ith = %d, nth = %d\n", ith, nth);
-        if (nb01 >= nb00) {
-            char * wdata = params->wdata;
-
-            for (int i13 = 0; i13 < ne13; ++i13) {
-                for (int i12 = 0; i12 < ne12; ++i12) {
-                    for (int i11 = 0; i11 < ne11; ++i11) {
-                        //for (int i10 = 0; i10 < ne10; ++i10) {
-                        //    wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
-                        //}
-                        quantize_row_q4_1((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
-                        wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
-                    }
-                }
-            }
-
-            return;
-        }
-
-        // TODO: fix this memset (wsize is overestimated)
-        memset(params->wdata, 0, params->wsize);
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        if (nb01 >= nb00) {
-            return;
-        }
-
-        float * const wdata = params->wdata;
-
-        // cols per thread
-        const int dc = (ne + nth - 1)/nth;
-
-        // col range for this thread
-        const int ic0 = dc*ith;
-        const int ic1 = MIN(ic0 + dc, ne);
-
-        ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0);
-
-        for (int k = 1; k < nth; k++) {
-            ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0);
-        }
-
-        return;
-    }
-
-    if (nb01 >= nb00) {
-        // TODO: do not support transposed src1
-
-        // parallelize by src0 rows using ggml_vec_dot_q4_1
-
-        // total rows in src0
-        const int nr = ne01*ne02*ne03;
-
-        // rows per thread
-        const int dr = (nr + nth - 1)/nth;
-
-        // row range for this thread
-        const int ir0 = dr*ith;
-        const int ir1 = MIN(ir0 + dr, nr);
-
-        void * wdata = params->wdata;
-
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src0 indices
-            const int i03 = ir/(ne02*ne01);
-            const int i02 = (ir - i03*ne02*ne01)/ne01;
-            const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int i13 = i03;
-            const int i12 = i02;
-
-            const int i0 = i01;
-            const int i2 = i02;
-            const int i3 = i03;
-
-            void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
-            char * src1_col =          ((char *)      wdata + (      (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]);
-
-            float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
-
-            assert(ne00 % 32 == 0);
-
-            for (int ic = 0; ic < ne11; ++ic) {
-                ggml_vec_dot_q4_1(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1])));
-            }
-        }
-    } else {
-        //printf("AAAAA ith = %d, nth = %d\n", ith, nth);
-        // parallelize by src1 columns using ggml_vec_mad_q4_1
-        // each thread has its own work data
-        // during FINALIZE we accumulate all work data into dst
-
-        // total columns in src1
-        const int nc = ne10;
-
-        // columns per thread
-        const int dc = (nc + nth - 1)/nth;
-
-        // column range for this thread
-        const int ic0 = dc*ith;
-        const int ic1 = MIN(ic0 + dc, nc);
-
-        // work data for thread
-        const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
-        float * const wdata = params->wdata;
+        char * wdata = params->wdata;
 
         for (int i13 = 0; i13 < ne13; ++i13) {
             for (int i12 = 0; i12 < ne12; ++i12) {
                 for (int i11 = 0; i11 < ne11; ++i11) {
-                    // dst indices
-                    const int i1 = i11;
-                    const int i2 = i12;
-                    const int i3 = i13;
-
-                    float * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
-
-                    for (int ic = ic0; ic < ic1; ++ic) {
-                        // src1 indices
-                        const int i10 = ic;
-
-                        // src0 indices
-                        const int i03 = i13;
-                        const int i02 = i12;
-                        const int i00 = ic;
-
-                        assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
-
-                        void * src0_col =   (void *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
-                        float  src1_val = *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
-
-                        ggml_vec_mad_q4_1(ne01, dst_row, src0_col, src1_val);
-                    }
+                    //for (int i10 = 0; i10 < ne10; ++i10) {
+                    //    wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
+                    //}
+                    quantize_row_q4_1((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+                    wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
                 }
             }
         }
+
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // TODO: do not support transposed src1
+
+    // parallelize by src0 rows using ggml_vec_dot_q4_1
+
+    // total rows in src0
+    const int nr = ne01*ne02*ne03;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    void * wdata = params->wdata;
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i03 = ir/(ne02*ne01);
+        const int i02 = (ir - i03*ne02*ne01)/ne01;
+        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        const int i13 = i03;
+        const int i12 = i02;
+
+        const int i0 = i01;
+        const int i2 = i02;
+        const int i3 = i03;
+
+        void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+        char * src1_col =          ((char *)      wdata + (      (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]);
+
+        float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
+
+        assert(ne00 % 32 == 0);
+
+        for (int ic = 0; ic < ne11; ++ic) {
+            ggml_vec_dot_q4_1(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1])));
+        }
     }
 
     //int64_t t1 = ggml_time_us();
@@ -9588,57 +9115,51 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 
                         size_t cur = 0;
 
-                        // TODO: better way to determine if the matrix is transposed
-                        if (node->src0->nb[1] < node->src0->nb[0]) {
-                            cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1)
-                                                                   // TODO: overestimated by factor of x2 for FP16
-                        } else {
-                            if (node->src0->type == GGML_TYPE_F16 &&
+                        if (node->src0->type == GGML_TYPE_F16 &&
                                 node->src1->type == GGML_TYPE_F32) {
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-                                if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
-                                    node->n_tasks = 1; // TODO: this actually is doing nothing
-                                                       //       the threads are still spinning
-                                    cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
-                                    //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
-                                    //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
-                                    //printf("cur = %zu\n", cur);
-                                } else {
-                                    cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
-                                }
-#else
-                                cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
-#endif
-                            } else if (node->src0->type == GGML_TYPE_F32 &&
-                                       node->src1->type == GGML_TYPE_F32) {
-                                cur = 0;
-                            } else if (node->src0->type == GGML_TYPE_Q4_0 &&
-                                       node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-                                if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
-                                    node->n_tasks = 1;
-                                    cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
-                                } else {
-                                    cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
-                                }
-#else
-                                cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
-#endif
-                            } else if (node->src0->type == GGML_TYPE_Q4_1 &&
-                                       node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-                                if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
-                                    node->n_tasks = 1;
-                                    cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
-                                } else {
-                                    cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
-                                }
-#else
-                                cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
-#endif
+                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                                node->n_tasks = 1; // TODO: this actually is doing nothing
+                                                   //       the threads are still spinning
+                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
+                                //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
+                                //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
+                                //printf("cur = %zu\n", cur);
                             } else {
-                                GGML_ASSERT(false);
+                                cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
                             }
+#else
+                            cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
+#endif
+                        } else if (node->src0->type == GGML_TYPE_F32 &&
+                                node->src1->type == GGML_TYPE_F32) {
+                            cur = 0;
+                        } else if (node->src0->type == GGML_TYPE_Q4_0 &&
+                                node->src1->type == GGML_TYPE_F32) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                                node->n_tasks = 1;
+                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
+                            } else {
+                                cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
+                            }
+#else
+                            cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
+#endif
+                        } else if (node->src0->type == GGML_TYPE_Q4_1 &&
+                                node->src1->type == GGML_TYPE_F32) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                                node->n_tasks = 1;
+                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
+                            } else {
+                                cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
+                            }
+#else
+                            cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
+#endif
+                        } else {
+                            GGML_ASSERT(false);
                         }
 
                         work_size = MAX(work_size, cur);