From 01d667f0660f4f0c8b85e99a48c078a21856c053 Mon Sep 17 00:00:00 2001
From: saharNooby <saharnooby@protonmail.com>
Date: Fri, 31 Mar 2023 19:04:35 +0400
Subject: [PATCH] Implement exp, max, 1_minus_x, sigmoid operators in ggml

---
 examples/main_rwkv/main_rwkv.cpp |  63 +++--
 ggml.c                           | 428 ++++++++++++++++++++++++++++++-
 ggml.h                           |  61 +++++
 3 files changed, 509 insertions(+), 43 deletions(-)

diff --git a/examples/main_rwkv/main_rwkv.cpp b/examples/main_rwkv/main_rwkv.cpp
index 9f81e1a..d392292 100644
--- a/examples/main_rwkv/main_rwkv.cpp
+++ b/examples/main_rwkv/main_rwkv.cpp
@@ -323,31 +323,23 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
 
 // --- Operators ---
 
-// TODO Fuse and benchmark
 struct ggml_tensor * ggml_layer_norm(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) {
-    // LayerNorm in RWKV is:
-    // x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias
-    // Looks like ggml_norm does the first part, we only need to apply weight & bias
+    // LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias`
+    // Looks like ggml_norm does the first part, we only need to apply weight & bias.
     x = ggml_norm(ctx, x);
     x = ggml_mul(ctx, x, weight);
     x = ggml_add(ctx, x, bias);
     return x;
 }
 
-// TODO Fuse and benchmark
-struct ggml_tensor * ggml_sigmoid(ggml_context * ctx, struct ggml_tensor * x) {
-    // ggml has no native sigmoid, but silu(x) / x can be an approximation
-    x = ggml_silu(ctx, x);
-    x = ggml_div(ctx, x, x);
-    return x;
-}
-
 // --- Script ---
 
 // Usage: main_rwkv.exe "C:\model.bin" <token index> "C:\state_in.bin" "C:\state_out.bin" "C:\logits_out.bin"
 // Token index is 0-based.
 // To start from new state, pass empty string instead of input state file path.
 int main(int argc, char ** argv) {
+    ggml_run_test_suite();
+
     RWKV_ASSERT(argc - 1 == 5, "Expected 5 arguments, got %d", argc - 1);
     char * model_path = argv[1];
     char * token_s = argv[2];
@@ -408,9 +400,6 @@ int main(int argc, char ** argv) {
 
     // --- Evaluate model ---
 
-    struct ggml_tensor * ones = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed);
-    ggml_set_f32(ones, 1.0F);
-
     // x = self.w.emb.weight[token]
     // TODO Replace with ggml_get_rows or similar
     struct ggml_tensor * one_hot = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_vocab, 1);
@@ -433,27 +422,33 @@ int main(int argc, char ** argv) {
             struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);
             // state[5 * i + 1]
             struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * 4);
+            COMPUTE_AND_PRINT_TENSOR(ctx, x_prev);
             // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
             // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
             // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
             struct ggml_tensor * xk = ggml_add(
                 ctx,
                 ggml_mul(ctx, x0, layer.att_time_mix_k),
-                ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_k))
+                ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_k))
             );
             struct ggml_tensor * xv = ggml_add(
                 ctx,
                 ggml_mul(ctx, x0, layer.att_time_mix_v),
-                ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_v))
+                ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_v))
             );
             struct ggml_tensor * xr = ggml_add(
                 ctx,
                 ggml_mul(ctx, x0, layer.att_time_mix_r),
-                ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_r))
+                ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_r))
             );
             // state[5 * i + 1] = x
             ggml_cpy(ctx, x0, x_prev);
 
+            COMPUTE_AND_PRINT_TENSOR(ctx, xk);
+            COMPUTE_AND_PRINT_TENSOR(ctx, xv);
+            COMPUTE_AND_PRINT_TENSOR(ctx, xr);
+            COMPUTE_AND_PRINT_TENSOR(ctx, x_prev);
+
             // r = torch.sigmoid(rw @ xr)
             struct ggml_tensor * r = ggml_sigmoid(
                 ctx,
@@ -474,14 +469,11 @@ int main(int argc, char ** argv) {
             // ww = time_first + k
             struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k);
             // qq = torch.maximum(pp, ww)
-            // TODO Implement element-wise max in ggml
-            struct ggml_tensor * qq = pp;
+            struct ggml_tensor * qq = ggml_max(ctx, pp, ww);
             // e1 = torch.exp(pp - qq)
-            // TODO Implement element-wise exp in ggml
-            struct ggml_tensor * e1 = ggml_sub(ctx, pp, qq);
+            struct ggml_tensor * e1 = ggml_exp(ctx, ggml_sub(ctx, pp, qq));
             // e2 = torch.exp(ww - qq)
-            // TODO Use exp
-            struct ggml_tensor * e2 = ggml_sub(ctx, ww, qq);
+            struct ggml_tensor * e2 = ggml_exp(ctx, ggml_sub(ctx, ww, qq));
             // a = e1 * aa + e2 * v
             struct ggml_tensor * a = ggml_add(
                 ctx,
@@ -499,27 +491,27 @@ int main(int argc, char ** argv) {
             // ww = pp + time_decay
             ww = ggml_add(ctx, pp, layer.att_time_decay);
             // qq = torch.maximum(ww, k)
-            // TODO Use max
-            qq = ww;
+            qq = ggml_max(ctx, ww, k);
             // e1 = torch.exp(ww - qq)
-            // TODO Use exp
-            e1 = ggml_sub(ctx, ww, qq);
+            e1 = ggml_exp(ctx, ggml_sub(ctx, ww, qq));
             // e2 = torch.exp(k - qq)
-            // TODO Use exp
-            e2 = ggml_sub(ctx, k, qq);
+            e2 = ggml_exp(ctx, ggml_sub(ctx, k, qq));
             // state[5 * i + 2] = e1 * aa + e2 * v
+            // todo must save result
             ggml_cpy(ctx, ggml_add(
                 ctx,
                 ggml_mul(ctx, e1, aa),
                 ggml_mul(ctx, e2, v)
             ), aa);
             // state[5 * i + 3] = e1 * bb + e2
+            // todo must save result
             ggml_cpy(ctx, ggml_add(
                 ctx,
                 ggml_mul(ctx, e1, bb),
                 e2
             ), bb);
             // state[5 * i + 4] = qq
+            // todo must save result
             ggml_cpy(ctx, qq, pp);
             // ow @ (r * wkv)
             x = ggml_add(
@@ -531,6 +523,8 @@ int main(int argc, char ** argv) {
                     ggml_mul(ctx, r, wkv)
                 )
             );
+            RWKV_LOG("RWKV %d completed", i);
+            COMPUTE_AND_PRINT_TENSOR(ctx, x);
         }
 
         // FFN/channel mixing
@@ -544,14 +538,15 @@ int main(int argc, char ** argv) {
             struct ggml_tensor * xk = ggml_add(
                 ctx,
                 ggml_mul(ctx, x0, layer.ffn_time_mix_k),
-                ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.ffn_time_mix_k))
+                ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.ffn_time_mix_k))
             );
             struct ggml_tensor * xr = ggml_add(
                 ctx,
                 ggml_mul(ctx, x0, layer.ffn_time_mix_r),
-                ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.ffn_time_mix_r))
+                ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.ffn_time_mix_r))
             );
             // state[5 * i + 0] = x
+            // todo must save result
             ggml_cpy(ctx, x0, x_prev);
 
             // r = torch.sigmoid(rw @ xr)
@@ -574,6 +569,8 @@ int main(int argc, char ** argv) {
                     ggml_mul_mat(ctx, layer.ffn_value, k)
                 )
             );
+            RWKV_LOG("FFN %d completed", i);
+            COMPUTE_AND_PRINT_TENSOR(ctx, x);
         }
     }
 
@@ -588,6 +585,8 @@ int main(int argc, char ** argv) {
     // TODO -nan(ind) -nan(ind) ... (maybe implement exp/max first?)
     PRINT_TENSOR(logits);
 
+    // TODO Save new state and logits
+
     ggml_free(ctx);
 
     RWKV_LOG("OK");
diff --git a/ggml.c b/ggml.c
index b7d79ab..a7d932f 100644
--- a/ggml.c
+++ b/ggml.c
@@ -1603,15 +1603,18 @@ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) {
 
 inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
 
-inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i]; }
-inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i] += x[i];        }
-inline static void ggml_vec_acc1_f32(const int n, float * y, const float   v)                  { for (int i = 0; i < n; ++i) y[i] += v;           }
-inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] - y[i]; }
-inline static void ggml_vec_set_f32 (const int n, float * x, const float   v)                  { for (int i = 0; i < n; ++i) x[i]  = v;           }
-inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = x[i];        }
-inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = -x[i];       }
-inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
-inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
+inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y)              { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i];     }
+inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x)                               { for (int i = 0; i < n; ++i) y[i] += x[i];            }
+inline static void ggml_vec_acc1_f32(const int n, float * y, const float   v)                               { for (int i = 0; i < n; ++i) y[i] += v;               }
+inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y)              { for (int i = 0; i < n; ++i) z[i]  = x[i] - y[i];     }
+inline static void ggml_vec_set_f32 (const int n, float * x, const float   v)                               { for (int i = 0; i < n; ++i) x[i]  = v;               }
+inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)                               { for (int i = 0; i < n; ++i) y[i]  = x[i];            }
+inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)                               { for (int i = 0; i < n; ++i) y[i]  = -x[i];           }
+inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x)                               { for (int i = 0; i < n; ++i) y[i]  = expf(x[i]);      }
+inline static void ggml_vec_1_minus_x_f32 (const int n, float * y, const float * x)                         { for (int i = 0; i < n; ++i) y[i]  = 1 - x[i];        }
+inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y)              { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];       }
+inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y)              { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];       }
+inline static void ggml_vec_element_wise_max_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = max(x[i], y[i]); }
 
 inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
 #ifdef GGML_SIMD
@@ -2304,6 +2307,17 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
 }
 #endif
 
+// Sigmoid function
+inline static float ggml_sigmoid_f32(float x) {
+    return 1.0F / (1.0F + expf(-x));
+}
+
+inline static void ggml_vec_sigmoid_f32(const int n, float * y, const float * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_sigmoid_f32(x[i]);
+    }
+}
+
 // Sigmoid Linear Unit (SiLU) function
 inline static float ggml_silu_f32(float x) {
     return x/(1.0f + expf(-x));
@@ -2457,7 +2471,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "FLASH_FF",
 };
 
-static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
+static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -2501,7 +2515,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "flash_ff(x)",
 };
 
-static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
+static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
 
 //
 // ggml object
@@ -3866,6 +3880,54 @@ struct ggml_tensor * ggml_neg_inplace(
     return ggml_neg_impl(ctx, a, true);
 }
 
+// ggml_exp
+
+struct ggml_tensor * ggml_exp(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_EXP;
+    result->grad = a->grad ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+// ggml_1_minus_x
+
+struct ggml_tensor * ggml_1_minus_x(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_1_MINUS_X;
+    result->grad = a->grad ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+// ggml_max
+
+struct ggml_tensor * ggml_max(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(ggml_are_same_shape(a, b));
+
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_MAX;
+    result->grad = (a->grad || b->grad) ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
 // ggml_step
 
 struct ggml_tensor * ggml_step_impl(
@@ -3968,6 +4030,21 @@ struct ggml_tensor * ggml_gelu_inplace(
     return ggml_gelu_impl(ctx, a, true);
 }
 
+// ggml_sigmoid
+
+struct ggml_tensor * ggml_sigmoid(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a) {
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_SIGMOID;
+    result->grad = a->grad ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
 // ggml_silu
 
 struct ggml_tensor * ggml_silu_impl(
@@ -5546,6 +5623,154 @@ static void ggml_compute_forward_neg(
     }
 }
 
+// ggml_compute_forward_exp
+
+static void ggml_compute_forward_exp_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_exp_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_exp(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_exp_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_1_minus_x
+
+static void ggml_compute_forward_1_minus_x_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_1_minus_x_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_1_minus_x(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_1_minus_x_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_max
+
+static void ggml_compute_forward_max_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+    assert(src1->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_element_wise_max_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])),
+                (float *) ((char *) src1->data + i*(src1->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_max(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_max_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_step
 
 static void ggml_compute_forward_step_f32(
@@ -5709,6 +5934,54 @@ static void ggml_compute_forward_gelu(
     //printf("XXXXXXXX gelu\n");
 }
 
+// ggml_compute_forward_sigmoid
+
+static void ggml_compute_forward_sigmoid_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sigmoid_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sigmoid(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sigmoid_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_silu
 
 static void ggml_compute_forward_silu_f32(
@@ -8423,6 +8696,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_neg(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_EXP:
+            {
+                ggml_compute_forward_exp(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_1_MINUS_X:
+            {
+                ggml_compute_forward_1_minus_x(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_MAX:
+            {
+                ggml_compute_forward_max(params, tensor->src0, tensor->src1, tensor);
+            } break;
         case GGML_OP_STEP:
             {
                 ggml_compute_forward_step(params, tensor->src0, tensor);
@@ -8435,6 +8720,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_gelu(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_SIGMOID:
+            {
+                ggml_compute_forward_sigmoid(params, tensor->src0, tensor);
+            } break;
         case GGML_OP_SILU:
             {
                 ggml_compute_forward_silu(params, tensor->src0, tensor);
@@ -8660,6 +8949,18 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
                 }
             } break;
+        case GGML_OP_EXP:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_1_MINUS_X:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_MAX:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_STEP:
             {
                 if (src0->grad) {
@@ -8681,6 +8982,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_SIGMOID:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_SILU:
             {
                 GGML_ASSERT(false); // TODO: not implemented
@@ -9101,8 +9406,12 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                 case GGML_OP_ABS:
                 case GGML_OP_SGN:
                 case GGML_OP_NEG:
+                case GGML_OP_EXP:
+                case GGML_OP_1_MINUS_X:
+                case GGML_OP_MAX:
                 case GGML_OP_STEP:
                 case GGML_OP_RELU:
+                case GGML_OP_SIGMOID:
                     {
                         node->n_tasks = 1;
                     } break;
@@ -10469,3 +10778,100 @@ int ggml_cpu_has_vsx(void) {
 }
 
 ////////////////////////////////////////////////////////////////////////////////
+
+#define GGML_TEST_SET_ELEMENT_F32(tensor, i, value) *(float *) ((char *) tensor->data + 4 * i) = value
+
+#define GGML_TEST_ASSERT_ELEMENT_F32(tensor, i, expected_value) do {\
+        float actual = *(float *) ((char *) tensor->data + 4 * i);\
+        if (fabs(actual - expected_value) >= 0.0001F) {\
+            fprintf(stderr, "*** Assertion failed ***\n");\
+            fprintf(stderr, "At %s[%d]: expected %f, actual %f\n", #tensor, i, expected_value, actual);\
+            fprintf(stderr, "%s:%d\n", __FILE__, __LINE__);\
+            abort();\
+        }\
+    } while (0)
+
+void ggml_run_test_suite() {
+    fprintf(stderr, "Running ggml test suite...\n");
+
+    struct ggml_init_params params;
+    params.mem_size = 16 * 1024;
+    params.mem_buffer = NULL;
+    struct ggml_context * ctx = ggml_init(params);
+
+    struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 3, 2);
+    GGML_TEST_SET_ELEMENT_F32(a, 0, 1.0051F);
+    GGML_TEST_SET_ELEMENT_F32(a, 1, 1.0484F);
+    GGML_TEST_SET_ELEMENT_F32(a, 2, -0.4361F);
+    GGML_TEST_SET_ELEMENT_F32(a, 3, -0.6984F);
+    GGML_TEST_SET_ELEMENT_F32(a, 4, 1.7310F);
+    GGML_TEST_SET_ELEMENT_F32(a, 5, -0.0446F);
+
+    struct ggml_tensor * b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 3, 2);
+    GGML_TEST_SET_ELEMENT_F32(b, 0, -0.2566F);
+    GGML_TEST_SET_ELEMENT_F32(b, 1, -0.1412F);
+    GGML_TEST_SET_ELEMENT_F32(b, 2, 1.6200F);
+    GGML_TEST_SET_ELEMENT_F32(b, 3, 0.5156F);
+    GGML_TEST_SET_ELEMENT_F32(b, 4, -0.3934F);
+    GGML_TEST_SET_ELEMENT_F32(b, 5, -0.0694F);
+
+    // Test against torch.exp(a)
+    struct ggml_tensor * exp_a = ggml_exp(ctx, a);
+
+    struct ggml_cgraph graph = ggml_build_forward(exp_a);
+    graph.n_threads = 2;
+    ggml_graph_compute(ctx, &graph);
+
+    GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 0, 2.7322F);
+    GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 1, 2.8531F);
+    GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 2, 0.6466F);
+    GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 3, 0.4974F);
+    GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 4, 5.6463F);
+    GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 5, 0.9564F);
+
+    // Test against (1 - a) in PyTorch
+    struct ggml_tensor * one_minus_a = ggml_1_minus_x(ctx, a);
+
+    graph = ggml_build_forward(one_minus_a);
+    graph.n_threads = 2;
+    ggml_graph_compute(ctx, &graph);
+
+    GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 0, -0.0051F);
+    GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 1, -0.0484F);
+    GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 2, 1.4361F);
+    GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 3, 1.6984F);
+    GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 4, -0.7310F);
+    GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 5, 1.0446F);
+
+    // Test against torch.sigmoid(a)
+    struct ggml_tensor * sigmoid_a = ggml_sigmoid(ctx, a);
+
+    graph = ggml_build_forward(sigmoid_a);
+    graph.n_threads = 2;
+    ggml_graph_compute(ctx, &graph);
+
+    GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 0, 0.7321F);
+    GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 1, 0.7405F);
+    GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 2, 0.3927F);
+    GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 3, 0.3322F);
+    GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 4, 0.8495F);
+    GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 5, 0.4889F);
+
+    // Test against torch.maximum(a, b)
+    struct ggml_tensor * max_a_b = ggml_max(ctx, a, b);
+
+    graph = ggml_build_forward(max_a_b);
+    graph.n_threads = 2;
+    ggml_graph_compute(ctx, &graph);
+
+    GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 0, 1.0051F);
+    GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 1, 1.0484F);
+    GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 2, 1.6200F);
+    GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 3, 0.5156F);
+    GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 4, 1.7310F);
+    GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 5, -0.0446F);
+
+    ggml_free(ctx);
+
+    fprintf(stderr, "All ggml tests pass\n");
+}
diff --git a/ggml.h b/ggml.h
index 335230f..0b7f9a3 100644
--- a/ggml.h
+++ b/ggml.h
@@ -167,6 +167,32 @@
 //
 // TODO
 //
+// ## Adding new operators
+//
+// Suppose you want to add e^x unary operator. Following steps need to be done:
+//
+// In `ggml.h`:
+//
+// 1. Add member `GGML_OP_EXP` to `ggml_op` enum.
+// 2. Declare the operator function: `struct ggml_tensor * ggml_exp(struct ggml_context * ctx, struct ggml_tensor * x);`.
+//
+// In `ggml.c`:
+//
+// 1. Implement `ggml_exp` function: it will create result tensor and set its' operator and arguments.
+// 2. Create forward computation function for FP32: `ggml_compute_forward_exp_f32`: it will do the actual computation.
+// 3. If needed, create forward computation functions for other types: FP16, INT32, etc.
+// 4. Create forward dispatch function `ggml_compute_forward_exp`: it would dispatch the call based on tensor data type.
+// 5. Add `case GGML_OP_EXP`:
+//   - to `ggml_compute_forward` and call the forward dispatch function here.
+//   - to `ggml_compute_backward` and add `GGML_ASSERT(false)` here.
+//   - to `ggml_graph_compute` and add `node->n_tasks = 1` here.
+// 6. Fix all assertions that check value of `GGML_OP_COUNT`: you've added 1 operator, so increment asserted value by one.
+//
+// When in doubt, consult the code of existing operators similar to that you're implementing.
+// Resulting operator would work for the forward pass, but will lack backward implementation and multi-threading support.
+//
+// TODO Implementing backward pass
+// TODO Implementing multi-threading
 //
 
 #ifdef  __cplusplus
@@ -225,9 +251,22 @@ enum ggml_op {
     GGML_OP_ABS,
     GGML_OP_SGN,
     GGML_OP_NEG,
+    // Element-wise exponential function `e^x`.
+    // Same as `torch.exp(x)` from PyTorch.
+    GGML_OP_EXP,
+    // Element-wise `1 - x`.
+    GGML_OP_1_MINUS_X,
+
+    // Element-wise maximum of 2 values. Argument shapes must match.
+    // Same as `torch.maximum(x)` from PyTorch.
+    GGML_OP_MAX,
+
     GGML_OP_STEP,
     GGML_OP_RELU,
     GGML_OP_GELU,
+    // Element-wise sigmoid activation `1 / (1 + e^-x)`, also called logistic function.
+    // Same as `torch.sigmoid(x)` from PyTorch.
+    GGML_OP_SIGMOID,
     GGML_OP_SILU,
     GGML_OP_NORM, // normalize
     GGML_OP_RMS_NORM,
@@ -463,6 +502,19 @@ struct ggml_tensor * ggml_neg(
         struct ggml_context * ctx,
         struct ggml_tensor  * a);
 
+struct ggml_tensor * ggml_exp(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+struct ggml_tensor * ggml_1_minus_x(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
+struct ggml_tensor * ggml_max(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
 struct ggml_tensor * ggml_step(
         struct ggml_context * ctx,
         struct ggml_tensor  * a);
@@ -476,6 +528,10 @@ struct ggml_tensor * ggml_gelu(
         struct ggml_context * ctx,
         struct ggml_tensor  * a);
 
+struct ggml_tensor * ggml_sigmoid(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
 struct ggml_tensor * ggml_silu(
         struct ggml_context * ctx,
         struct ggml_tensor  * a);
@@ -768,6 +824,11 @@ int ggml_cpu_has_blas(void);
 int ggml_cpu_has_sse3(void);
 int ggml_cpu_has_vsx(void);
 
+// Run test suite for ggml.
+// Exits normally, if all tests pass.
+// Aborts the execution if any test did not pass.
+void ggml_run_test_suite();
+
 #ifdef  __cplusplus
 }
 #endif