Implement exp, max, 1_minus_x, sigmoid operators in ggml
This commit is contained in:
parent
fe272dc3d3
commit
01d667f066
|
@ -323,31 +323,23 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
|||
|
||||
// --- Operators ---
|
||||
|
||||
// TODO Fuse and benchmark
|
||||
struct ggml_tensor * ggml_layer_norm(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) {
|
||||
// LayerNorm in RWKV is:
|
||||
// x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias
|
||||
// Looks like ggml_norm does the first part, we only need to apply weight & bias
|
||||
// LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias`
|
||||
// Looks like ggml_norm does the first part, we only need to apply weight & bias.
|
||||
x = ggml_norm(ctx, x);
|
||||
x = ggml_mul(ctx, x, weight);
|
||||
x = ggml_add(ctx, x, bias);
|
||||
return x;
|
||||
}
|
||||
|
||||
// TODO Fuse and benchmark
|
||||
struct ggml_tensor * ggml_sigmoid(ggml_context * ctx, struct ggml_tensor * x) {
|
||||
// ggml has no native sigmoid, but silu(x) / x can be an approximation
|
||||
x = ggml_silu(ctx, x);
|
||||
x = ggml_div(ctx, x, x);
|
||||
return x;
|
||||
}
|
||||
|
||||
// --- Script ---
|
||||
|
||||
// Usage: main_rwkv.exe "C:\model.bin" <token index> "C:\state_in.bin" "C:\state_out.bin" "C:\logits_out.bin"
|
||||
// Token index is 0-based.
|
||||
// To start from new state, pass empty string instead of input state file path.
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_run_test_suite();
|
||||
|
||||
RWKV_ASSERT(argc - 1 == 5, "Expected 5 arguments, got %d", argc - 1);
|
||||
char * model_path = argv[1];
|
||||
char * token_s = argv[2];
|
||||
|
@ -408,9 +400,6 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// --- Evaluate model ---
|
||||
|
||||
struct ggml_tensor * ones = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed);
|
||||
ggml_set_f32(ones, 1.0F);
|
||||
|
||||
// x = self.w.emb.weight[token]
|
||||
// TODO Replace with ggml_get_rows or similar
|
||||
struct ggml_tensor * one_hot = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_vocab, 1);
|
||||
|
@ -433,27 +422,33 @@ int main(int argc, char ** argv) {
|
|||
struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);
|
||||
// state[5 * i + 1]
|
||||
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * 4);
|
||||
COMPUTE_AND_PRINT_TENSOR(ctx, x_prev);
|
||||
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
|
||||
// xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
|
||||
// xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
|
||||
struct ggml_tensor * xk = ggml_add(
|
||||
ctx,
|
||||
ggml_mul(ctx, x0, layer.att_time_mix_k),
|
||||
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_k))
|
||||
ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_k))
|
||||
);
|
||||
struct ggml_tensor * xv = ggml_add(
|
||||
ctx,
|
||||
ggml_mul(ctx, x0, layer.att_time_mix_v),
|
||||
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_v))
|
||||
ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_v))
|
||||
);
|
||||
struct ggml_tensor * xr = ggml_add(
|
||||
ctx,
|
||||
ggml_mul(ctx, x0, layer.att_time_mix_r),
|
||||
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_r))
|
||||
ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_r))
|
||||
);
|
||||
// state[5 * i + 1] = x
|
||||
ggml_cpy(ctx, x0, x_prev);
|
||||
|
||||
COMPUTE_AND_PRINT_TENSOR(ctx, xk);
|
||||
COMPUTE_AND_PRINT_TENSOR(ctx, xv);
|
||||
COMPUTE_AND_PRINT_TENSOR(ctx, xr);
|
||||
COMPUTE_AND_PRINT_TENSOR(ctx, x_prev);
|
||||
|
||||
// r = torch.sigmoid(rw @ xr)
|
||||
struct ggml_tensor * r = ggml_sigmoid(
|
||||
ctx,
|
||||
|
@ -474,14 +469,11 @@ int main(int argc, char ** argv) {
|
|||
// ww = time_first + k
|
||||
struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k);
|
||||
// qq = torch.maximum(pp, ww)
|
||||
// TODO Implement element-wise max in ggml
|
||||
struct ggml_tensor * qq = pp;
|
||||
struct ggml_tensor * qq = ggml_max(ctx, pp, ww);
|
||||
// e1 = torch.exp(pp - qq)
|
||||
// TODO Implement element-wise exp in ggml
|
||||
struct ggml_tensor * e1 = ggml_sub(ctx, pp, qq);
|
||||
struct ggml_tensor * e1 = ggml_exp(ctx, ggml_sub(ctx, pp, qq));
|
||||
// e2 = torch.exp(ww - qq)
|
||||
// TODO Use exp
|
||||
struct ggml_tensor * e2 = ggml_sub(ctx, ww, qq);
|
||||
struct ggml_tensor * e2 = ggml_exp(ctx, ggml_sub(ctx, ww, qq));
|
||||
// a = e1 * aa + e2 * v
|
||||
struct ggml_tensor * a = ggml_add(
|
||||
ctx,
|
||||
|
@ -499,27 +491,27 @@ int main(int argc, char ** argv) {
|
|||
// ww = pp + time_decay
|
||||
ww = ggml_add(ctx, pp, layer.att_time_decay);
|
||||
// qq = torch.maximum(ww, k)
|
||||
// TODO Use max
|
||||
qq = ww;
|
||||
qq = ggml_max(ctx, ww, k);
|
||||
// e1 = torch.exp(ww - qq)
|
||||
// TODO Use exp
|
||||
e1 = ggml_sub(ctx, ww, qq);
|
||||
e1 = ggml_exp(ctx, ggml_sub(ctx, ww, qq));
|
||||
// e2 = torch.exp(k - qq)
|
||||
// TODO Use exp
|
||||
e2 = ggml_sub(ctx, k, qq);
|
||||
e2 = ggml_exp(ctx, ggml_sub(ctx, k, qq));
|
||||
// state[5 * i + 2] = e1 * aa + e2 * v
|
||||
// todo must save result
|
||||
ggml_cpy(ctx, ggml_add(
|
||||
ctx,
|
||||
ggml_mul(ctx, e1, aa),
|
||||
ggml_mul(ctx, e2, v)
|
||||
), aa);
|
||||
// state[5 * i + 3] = e1 * bb + e2
|
||||
// todo must save result
|
||||
ggml_cpy(ctx, ggml_add(
|
||||
ctx,
|
||||
ggml_mul(ctx, e1, bb),
|
||||
e2
|
||||
), bb);
|
||||
// state[5 * i + 4] = qq
|
||||
// todo must save result
|
||||
ggml_cpy(ctx, qq, pp);
|
||||
// ow @ (r * wkv)
|
||||
x = ggml_add(
|
||||
|
@ -531,6 +523,8 @@ int main(int argc, char ** argv) {
|
|||
ggml_mul(ctx, r, wkv)
|
||||
)
|
||||
);
|
||||
RWKV_LOG("RWKV %d completed", i);
|
||||
COMPUTE_AND_PRINT_TENSOR(ctx, x);
|
||||
}
|
||||
|
||||
// FFN/channel mixing
|
||||
|
@ -544,14 +538,15 @@ int main(int argc, char ** argv) {
|
|||
struct ggml_tensor * xk = ggml_add(
|
||||
ctx,
|
||||
ggml_mul(ctx, x0, layer.ffn_time_mix_k),
|
||||
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.ffn_time_mix_k))
|
||||
ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.ffn_time_mix_k))
|
||||
);
|
||||
struct ggml_tensor * xr = ggml_add(
|
||||
ctx,
|
||||
ggml_mul(ctx, x0, layer.ffn_time_mix_r),
|
||||
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.ffn_time_mix_r))
|
||||
ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.ffn_time_mix_r))
|
||||
);
|
||||
// state[5 * i + 0] = x
|
||||
// todo must save result
|
||||
ggml_cpy(ctx, x0, x_prev);
|
||||
|
||||
// r = torch.sigmoid(rw @ xr)
|
||||
|
@ -574,6 +569,8 @@ int main(int argc, char ** argv) {
|
|||
ggml_mul_mat(ctx, layer.ffn_value, k)
|
||||
)
|
||||
);
|
||||
RWKV_LOG("FFN %d completed", i);
|
||||
COMPUTE_AND_PRINT_TENSOR(ctx, x);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -588,6 +585,8 @@ int main(int argc, char ** argv) {
|
|||
// TODO -nan(ind) -nan(ind) ... (maybe implement exp/max first?)
|
||||
PRINT_TENSOR(logits);
|
||||
|
||||
// TODO Save new state and logits
|
||||
|
||||
ggml_free(ctx);
|
||||
|
||||
RWKV_LOG("OK");
|
||||
|
|
428
ggml.c
428
ggml.c
|
@ -1603,15 +1603,18 @@ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) {
|
|||
|
||||
inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
||||
|
||||
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
|
||||
inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
|
||||
inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
|
||||
inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
|
||||
inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
||||
inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
|
||||
inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
|
||||
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
|
||||
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
|
||||
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
|
||||
inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
|
||||
inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
|
||||
inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
|
||||
inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
||||
inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
|
||||
inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
|
||||
inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
|
||||
inline static void ggml_vec_1_minus_x_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1 - x[i]; }
|
||||
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
|
||||
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
|
||||
inline static void ggml_vec_element_wise_max_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = max(x[i], y[i]); }
|
||||
|
||||
inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
|
||||
#ifdef GGML_SIMD
|
||||
|
@ -2304,6 +2307,17 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
|
|||
}
|
||||
#endif
|
||||
|
||||
// Sigmoid function
|
||||
inline static float ggml_sigmoid_f32(float x) {
|
||||
return 1.0F / (1.0F + expf(-x));
|
||||
}
|
||||
|
||||
inline static void ggml_vec_sigmoid_f32(const int n, float * y, const float * x) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = ggml_sigmoid_f32(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Sigmoid Linear Unit (SiLU) function
|
||||
inline static float ggml_silu_f32(float x) {
|
||||
return x/(1.0f + expf(-x));
|
||||
|
@ -2457,7 +2471,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|||
"FLASH_FF",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
|
||||
static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
@ -2501,7 +2515,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"flash_ff(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
|
||||
static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
|
||||
|
||||
//
|
||||
// ggml object
|
||||
|
@ -3866,6 +3880,54 @@ struct ggml_tensor * ggml_neg_inplace(
|
|||
return ggml_neg_impl(ctx, a, true);
|
||||
}
|
||||
|
||||
// ggml_exp
|
||||
|
||||
struct ggml_tensor * ggml_exp(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_EXP;
|
||||
result->grad = a->grad ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src0 = a;
|
||||
result->src1 = NULL;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_1_minus_x
|
||||
|
||||
struct ggml_tensor * ggml_1_minus_x(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_1_MINUS_X;
|
||||
result->grad = a->grad ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src0 = a;
|
||||
result->src1 = NULL;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_max
|
||||
|
||||
struct ggml_tensor * ggml_max(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b) {
|
||||
GGML_ASSERT(ggml_are_same_shape(a, b));
|
||||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_MAX;
|
||||
result->grad = (a->grad || b->grad) ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src0 = a;
|
||||
result->src1 = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_step
|
||||
|
||||
struct ggml_tensor * ggml_step_impl(
|
||||
|
@ -3968,6 +4030,21 @@ struct ggml_tensor * ggml_gelu_inplace(
|
|||
return ggml_gelu_impl(ctx, a, true);
|
||||
}
|
||||
|
||||
// ggml_sigmoid
|
||||
|
||||
struct ggml_tensor * ggml_sigmoid(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_SIGMOID;
|
||||
result->grad = a->grad ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src0 = a;
|
||||
result->src1 = NULL;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_silu
|
||||
|
||||
struct ggml_tensor * ggml_silu_impl(
|
||||
|
@ -5546,6 +5623,154 @@ static void ggml_compute_forward_neg(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_exp
|
||||
|
||||
static void ggml_compute_forward_exp_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
assert(params->ith == 0);
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int n = ggml_nrows(src0);
|
||||
const int nc = src0->ne[0];
|
||||
|
||||
assert(dst->nb[0] == sizeof(float));
|
||||
assert(src0->nb[0] == sizeof(float));
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
ggml_vec_exp_f32(nc,
|
||||
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_exp(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_exp_f32(params, src0, dst);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_COUNT:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_1_minus_x
|
||||
|
||||
static void ggml_compute_forward_1_minus_x_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
assert(params->ith == 0);
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int n = ggml_nrows(src0);
|
||||
const int nc = src0->ne[0];
|
||||
|
||||
assert(dst->nb[0] == sizeof(float));
|
||||
assert(src0->nb[0] == sizeof(float));
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
ggml_vec_1_minus_x_f32(nc,
|
||||
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_1_minus_x(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_1_minus_x_f32(params, src0, dst);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_COUNT:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_max
|
||||
|
||||
static void ggml_compute_forward_max_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
assert(params->ith == 0);
|
||||
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int n = ggml_nrows(src0);
|
||||
const int nc = src0->ne[0];
|
||||
|
||||
assert( dst->nb[0] == sizeof(float));
|
||||
assert(src0->nb[0] == sizeof(float));
|
||||
assert(src1->nb[0] == sizeof(float));
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
ggml_vec_element_wise_max_f32(nc,
|
||||
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i*(src0->nb[1])),
|
||||
(float *) ((char *) src1->data + i*(src1->nb[1])));
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_max(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_max_f32(params, src0, src1, dst);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_COUNT:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_step
|
||||
|
||||
static void ggml_compute_forward_step_f32(
|
||||
|
@ -5709,6 +5934,54 @@ static void ggml_compute_forward_gelu(
|
|||
//printf("XXXXXXXX gelu\n");
|
||||
}
|
||||
|
||||
// ggml_compute_forward_sigmoid
|
||||
|
||||
static void ggml_compute_forward_sigmoid_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
assert(params->ith == 0);
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int n = ggml_nrows(src0);
|
||||
const int nc = src0->ne[0];
|
||||
|
||||
assert(dst->nb[0] == sizeof(float));
|
||||
assert(src0->nb[0] == sizeof(float));
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
ggml_vec_sigmoid_f32(nc,
|
||||
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_sigmoid(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_sigmoid_f32(params, src0, dst);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_COUNT:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_silu
|
||||
|
||||
static void ggml_compute_forward_silu_f32(
|
||||
|
@ -8423,6 +8696,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_neg(params, tensor->src0, tensor);
|
||||
} break;
|
||||
case GGML_OP_EXP:
|
||||
{
|
||||
ggml_compute_forward_exp(params, tensor->src0, tensor);
|
||||
} break;
|
||||
case GGML_OP_1_MINUS_X:
|
||||
{
|
||||
ggml_compute_forward_1_minus_x(params, tensor->src0, tensor);
|
||||
} break;
|
||||
case GGML_OP_MAX:
|
||||
{
|
||||
ggml_compute_forward_max(params, tensor->src0, tensor->src1, tensor);
|
||||
} break;
|
||||
case GGML_OP_STEP:
|
||||
{
|
||||
ggml_compute_forward_step(params, tensor->src0, tensor);
|
||||
|
@ -8435,6 +8720,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_gelu(params, tensor->src0, tensor);
|
||||
} break;
|
||||
case GGML_OP_SIGMOID:
|
||||
{
|
||||
ggml_compute_forward_sigmoid(params, tensor->src0, tensor);
|
||||
} break;
|
||||
case GGML_OP_SILU:
|
||||
{
|
||||
ggml_compute_forward_silu(params, tensor->src0, tensor);
|
||||
|
@ -8660,6 +8949,18 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_EXP:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_1_MINUS_X:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_MAX:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_STEP:
|
||||
{
|
||||
if (src0->grad) {
|
||||
|
@ -8681,6 +8982,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_SIGMOID:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_SILU:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
|
@ -9101,8 +9406,12 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||
case GGML_OP_ABS:
|
||||
case GGML_OP_SGN:
|
||||
case GGML_OP_NEG:
|
||||
case GGML_OP_EXP:
|
||||
case GGML_OP_1_MINUS_X:
|
||||
case GGML_OP_MAX:
|
||||
case GGML_OP_STEP:
|
||||
case GGML_OP_RELU:
|
||||
case GGML_OP_SIGMOID:
|
||||
{
|
||||
node->n_tasks = 1;
|
||||
} break;
|
||||
|
@ -10469,3 +10778,100 @@ int ggml_cpu_has_vsx(void) {
|
|||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define GGML_TEST_SET_ELEMENT_F32(tensor, i, value) *(float *) ((char *) tensor->data + 4 * i) = value
|
||||
|
||||
#define GGML_TEST_ASSERT_ELEMENT_F32(tensor, i, expected_value) do {\
|
||||
float actual = *(float *) ((char *) tensor->data + 4 * i);\
|
||||
if (fabs(actual - expected_value) >= 0.0001F) {\
|
||||
fprintf(stderr, "*** Assertion failed ***\n");\
|
||||
fprintf(stderr, "At %s[%d]: expected %f, actual %f\n", #tensor, i, expected_value, actual);\
|
||||
fprintf(stderr, "%s:%d\n", __FILE__, __LINE__);\
|
||||
abort();\
|
||||
}\
|
||||
} while (0)
|
||||
|
||||
void ggml_run_test_suite() {
|
||||
fprintf(stderr, "Running ggml test suite...\n");
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = 16 * 1024;
|
||||
params.mem_buffer = NULL;
|
||||
struct ggml_context * ctx = ggml_init(params);
|
||||
|
||||
struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 3, 2);
|
||||
GGML_TEST_SET_ELEMENT_F32(a, 0, 1.0051F);
|
||||
GGML_TEST_SET_ELEMENT_F32(a, 1, 1.0484F);
|
||||
GGML_TEST_SET_ELEMENT_F32(a, 2, -0.4361F);
|
||||
GGML_TEST_SET_ELEMENT_F32(a, 3, -0.6984F);
|
||||
GGML_TEST_SET_ELEMENT_F32(a, 4, 1.7310F);
|
||||
GGML_TEST_SET_ELEMENT_F32(a, 5, -0.0446F);
|
||||
|
||||
struct ggml_tensor * b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 3, 2);
|
||||
GGML_TEST_SET_ELEMENT_F32(b, 0, -0.2566F);
|
||||
GGML_TEST_SET_ELEMENT_F32(b, 1, -0.1412F);
|
||||
GGML_TEST_SET_ELEMENT_F32(b, 2, 1.6200F);
|
||||
GGML_TEST_SET_ELEMENT_F32(b, 3, 0.5156F);
|
||||
GGML_TEST_SET_ELEMENT_F32(b, 4, -0.3934F);
|
||||
GGML_TEST_SET_ELEMENT_F32(b, 5, -0.0694F);
|
||||
|
||||
// Test against torch.exp(a)
|
||||
struct ggml_tensor * exp_a = ggml_exp(ctx, a);
|
||||
|
||||
struct ggml_cgraph graph = ggml_build_forward(exp_a);
|
||||
graph.n_threads = 2;
|
||||
ggml_graph_compute(ctx, &graph);
|
||||
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 0, 2.7322F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 1, 2.8531F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 2, 0.6466F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 3, 0.4974F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 4, 5.6463F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 5, 0.9564F);
|
||||
|
||||
// Test against (1 - a) in PyTorch
|
||||
struct ggml_tensor * one_minus_a = ggml_1_minus_x(ctx, a);
|
||||
|
||||
graph = ggml_build_forward(one_minus_a);
|
||||
graph.n_threads = 2;
|
||||
ggml_graph_compute(ctx, &graph);
|
||||
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 0, -0.0051F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 1, -0.0484F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 2, 1.4361F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 3, 1.6984F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 4, -0.7310F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 5, 1.0446F);
|
||||
|
||||
// Test against torch.sigmoid(a)
|
||||
struct ggml_tensor * sigmoid_a = ggml_sigmoid(ctx, a);
|
||||
|
||||
graph = ggml_build_forward(sigmoid_a);
|
||||
graph.n_threads = 2;
|
||||
ggml_graph_compute(ctx, &graph);
|
||||
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 0, 0.7321F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 1, 0.7405F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 2, 0.3927F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 3, 0.3322F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 4, 0.8495F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 5, 0.4889F);
|
||||
|
||||
// Test against torch.maximum(a, b)
|
||||
struct ggml_tensor * max_a_b = ggml_max(ctx, a, b);
|
||||
|
||||
graph = ggml_build_forward(max_a_b);
|
||||
graph.n_threads = 2;
|
||||
ggml_graph_compute(ctx, &graph);
|
||||
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 0, 1.0051F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 1, 1.0484F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 2, 1.6200F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 3, 0.5156F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 4, 1.7310F);
|
||||
GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 5, -0.0446F);
|
||||
|
||||
ggml_free(ctx);
|
||||
|
||||
fprintf(stderr, "All ggml tests pass\n");
|
||||
}
|
||||
|
|
61
ggml.h
61
ggml.h
|
@ -167,6 +167,32 @@
|
|||
//
|
||||
// TODO
|
||||
//
|
||||
// ## Adding new operators
|
||||
//
|
||||
// Suppose you want to add e^x unary operator. Following steps need to be done:
|
||||
//
|
||||
// In `ggml.h`:
|
||||
//
|
||||
// 1. Add member `GGML_OP_EXP` to `ggml_op` enum.
|
||||
// 2. Declare the operator function: `struct ggml_tensor * ggml_exp(struct ggml_context * ctx, struct ggml_tensor * x);`.
|
||||
//
|
||||
// In `ggml.c`:
|
||||
//
|
||||
// 1. Implement `ggml_exp` function: it will create result tensor and set its' operator and arguments.
|
||||
// 2. Create forward computation function for FP32: `ggml_compute_forward_exp_f32`: it will do the actual computation.
|
||||
// 3. If needed, create forward computation functions for other types: FP16, INT32, etc.
|
||||
// 4. Create forward dispatch function `ggml_compute_forward_exp`: it would dispatch the call based on tensor data type.
|
||||
// 5. Add `case GGML_OP_EXP`:
|
||||
// - to `ggml_compute_forward` and call the forward dispatch function here.
|
||||
// - to `ggml_compute_backward` and add `GGML_ASSERT(false)` here.
|
||||
// - to `ggml_graph_compute` and add `node->n_tasks = 1` here.
|
||||
// 6. Fix all assertions that check value of `GGML_OP_COUNT`: you've added 1 operator, so increment asserted value by one.
|
||||
//
|
||||
// When in doubt, consult the code of existing operators similar to that you're implementing.
|
||||
// Resulting operator would work for the forward pass, but will lack backward implementation and multi-threading support.
|
||||
//
|
||||
// TODO Implementing backward pass
|
||||
// TODO Implementing multi-threading
|
||||
//
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -225,9 +251,22 @@ enum ggml_op {
|
|||
GGML_OP_ABS,
|
||||
GGML_OP_SGN,
|
||||
GGML_OP_NEG,
|
||||
// Element-wise exponential function `e^x`.
|
||||
// Same as `torch.exp(x)` from PyTorch.
|
||||
GGML_OP_EXP,
|
||||
// Element-wise `1 - x`.
|
||||
GGML_OP_1_MINUS_X,
|
||||
|
||||
// Element-wise maximum of 2 values. Argument shapes must match.
|
||||
// Same as `torch.maximum(x)` from PyTorch.
|
||||
GGML_OP_MAX,
|
||||
|
||||
GGML_OP_STEP,
|
||||
GGML_OP_RELU,
|
||||
GGML_OP_GELU,
|
||||
// Element-wise sigmoid activation `1 / (1 + e^-x)`, also called logistic function.
|
||||
// Same as `torch.sigmoid(x)` from PyTorch.
|
||||
GGML_OP_SIGMOID,
|
||||
GGML_OP_SILU,
|
||||
GGML_OP_NORM, // normalize
|
||||
GGML_OP_RMS_NORM,
|
||||
|
@ -463,6 +502,19 @@ struct ggml_tensor * ggml_neg(
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
struct ggml_tensor * ggml_exp(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
struct ggml_tensor * ggml_1_minus_x(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
struct ggml_tensor * ggml_max(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
struct ggml_tensor * ggml_step(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
@ -476,6 +528,10 @@ struct ggml_tensor * ggml_gelu(
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
struct ggml_tensor * ggml_sigmoid(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
struct ggml_tensor * ggml_silu(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
@ -768,6 +824,11 @@ int ggml_cpu_has_blas(void);
|
|||
int ggml_cpu_has_sse3(void);
|
||||
int ggml_cpu_has_vsx(void);
|
||||
|
||||
// Run test suite for ggml.
|
||||
// Exits normally, if all tests pass.
|
||||
// Aborts the execution if any test did not pass.
|
||||
void ggml_run_test_suite();
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue