Revert "Replace ggml_1_minus_x with ggml_sub"

This reverts commit 189ad78a0d.
This commit is contained in:
saharNooby 2023-04-17 16:47:11 +04:00
parent 189ad78a0d
commit a96ec01b1a
1 changed files with 5 additions and 12 deletions

View File

@ -311,13 +311,6 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
// Build graph // Build graph
struct ggml_tensor * state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_layer * 5 * n_embed); struct ggml_tensor * state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_layer * 5 * n_embed);
// Constant vector for (1 - x) operation
struct ggml_tensor * ones = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed);
for (int i = 0; i < n_embed; i++) {
*((float *) ones->data + i) = 1.0F;
}
// x = self.w.emb.weight[token] // x = self.w.emb.weight[token]
struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
struct ggml_tensor * x = ggml_get_rows(ctx, model->emb, token_index); struct ggml_tensor * x = ggml_get_rows(ctx, model->emb, token_index);
@ -343,17 +336,17 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
struct ggml_tensor * xk = ggml_add( struct ggml_tensor * xk = ggml_add(
ctx, ctx,
ggml_mul(ctx, x0, layer.att_time_mix_k), ggml_mul(ctx, x0, layer.att_time_mix_k),
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_k)) ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_k))
); );
struct ggml_tensor * xv = ggml_add( struct ggml_tensor * xv = ggml_add(
ctx, ctx,
ggml_mul(ctx, x0, layer.att_time_mix_v), ggml_mul(ctx, x0, layer.att_time_mix_v),
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_v)) ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_v))
); );
struct ggml_tensor * xr = ggml_add( struct ggml_tensor * xr = ggml_add(
ctx, ctx,
ggml_mul(ctx, x0, layer.att_time_mix_r), ggml_mul(ctx, x0, layer.att_time_mix_r),
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_r)) ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_r))
); );
// state[5 * i + 1] = x // state[5 * i + 1] = x
state_parts[5 * i + 1] = x0; state_parts[5 * i + 1] = x0;
@ -442,12 +435,12 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
struct ggml_tensor * xk = ggml_add( struct ggml_tensor * xk = ggml_add(
ctx, ctx,
ggml_mul(ctx, x0, layer.ffn_time_mix_k), ggml_mul(ctx, x0, layer.ffn_time_mix_k),
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.ffn_time_mix_k)) ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.ffn_time_mix_k))
); );
struct ggml_tensor * xr = ggml_add( struct ggml_tensor * xr = ggml_add(
ctx, ctx,
ggml_mul(ctx, x0, layer.ffn_time_mix_r), ggml_mul(ctx, x0, layer.ffn_time_mix_r),
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.ffn_time_mix_r)) ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.ffn_time_mix_r))
); );
// state[5 * i + 0] = x // state[5 * i + 0] = x
state_parts[5 * i + 0] = x0; state_parts[5 * i + 0] = x0;