diff --git a/rwkv.cpp b/rwkv.cpp index 44f3952..3e004b6 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -311,13 +311,6 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr // Build graph struct ggml_tensor * state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_layer * 5 * n_embed); - // Constant vector for (1 - x) operation - struct ggml_tensor * ones = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed); - - for (int i = 0; i < n_embed; i++) { - *((float *) ones->data + i) = 1.0F; - } - // x = self.w.emb.weight[token] struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); struct ggml_tensor * x = ggml_get_rows(ctx, model->emb, token_index); @@ -343,17 +336,17 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr struct ggml_tensor * xk = ggml_add( ctx, ggml_mul(ctx, x0, layer.att_time_mix_k), - ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_k)) + ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_k)) ); struct ggml_tensor * xv = ggml_add( ctx, ggml_mul(ctx, x0, layer.att_time_mix_v), - ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_v)) + ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_v)) ); struct ggml_tensor * xr = ggml_add( ctx, ggml_mul(ctx, x0, layer.att_time_mix_r), - ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_r)) + ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_r)) ); // state[5 * i + 1] = x state_parts[5 * i + 1] = x0; @@ -442,12 +435,12 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr struct ggml_tensor * xk = ggml_add( ctx, ggml_mul(ctx, x0, layer.ffn_time_mix_k), - ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.ffn_time_mix_k)) + ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.ffn_time_mix_k)) ); struct ggml_tensor * xr = ggml_add( ctx, ggml_mul(ctx, x0, layer.ffn_time_mix_r), - ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.ffn_time_mix_r)) + ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.ffn_time_mix_r)) ); // state[5 * i + 0] = x state_parts[5 * i + 0] = x0;