From a96ec01b1aff73532396c8e4d8898658250482d8 Mon Sep 17 00:00:00 2001
From: saharNooby <saharnooby@protonmail.com>
Date: Mon, 17 Apr 2023 16:47:11 +0400
Subject: [PATCH] Revert "Replace ggml_1_minus_x with ggml_sub"

This reverts commit 189ad78a0ddc9cb13ed488b49cf52e5c4b3359cc.
---
 rwkv.cpp | 17 +++++------------
 1 file changed, 5 insertions(+), 12 deletions(-)

diff --git a/rwkv.cpp b/rwkv.cpp
index 44f3952..3e004b6 100644
--- a/rwkv.cpp
+++ b/rwkv.cpp
@@ -311,13 +311,6 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
     // Build graph
     struct ggml_tensor * state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_layer * 5 * n_embed);
 
-    // Constant vector for (1 - x) operation
-    struct ggml_tensor * ones = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed);
-
-    for (int i = 0; i < n_embed; i++) {
-        *((float *) ones->data + i) = 1.0F;
-    }
-
     // x = self.w.emb.weight[token]
     struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
     struct ggml_tensor * x = ggml_get_rows(ctx, model->emb, token_index);
@@ -343,17 +336,17 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
             struct ggml_tensor * xk = ggml_add(
                 ctx,
                 ggml_mul(ctx, x0, layer.att_time_mix_k),
-                ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_k))
+                ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_k))
             );
             struct ggml_tensor * xv = ggml_add(
                 ctx,
                 ggml_mul(ctx, x0, layer.att_time_mix_v),
-                ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_v))
+                ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_v))
             );
             struct ggml_tensor * xr = ggml_add(
                 ctx,
                 ggml_mul(ctx, x0, layer.att_time_mix_r),
-                ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_r))
+                ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_r))
             );
             // state[5 * i + 1] = x
             state_parts[5 * i + 1] = x0;
@@ -442,12 +435,12 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
             struct ggml_tensor * xk = ggml_add(
                 ctx,
                 ggml_mul(ctx, x0, layer.ffn_time_mix_k),
-                ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.ffn_time_mix_k))
+                ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.ffn_time_mix_k))
             );
             struct ggml_tensor * xr = ggml_add(
                 ctx,
                 ggml_mul(ctx, x0, layer.ffn_time_mix_r),
-                ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.ffn_time_mix_r))
+                ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.ffn_time_mix_r))
             );
             // state[5 * i + 0] = x
             state_parts[5 * i + 0] = x0;