Implement exp, max, 1_minus_x, sigmoid operators in ggml

This commit is contained in:
saharNooby 2023-03-31 19:04:35 +04:00
parent fe272dc3d3
commit 01d667f066
3 changed files with 509 additions and 43 deletions

View File

@ -323,31 +323,23 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
// --- Operators --- // --- Operators ---
// TODO Fuse and benchmark
struct ggml_tensor * ggml_layer_norm(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) { struct ggml_tensor * ggml_layer_norm(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) {
// LayerNorm in RWKV is: // LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias`
// x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias // Looks like ggml_norm does the first part, we only need to apply weight & bias.
// Looks like ggml_norm does the first part, we only need to apply weight & bias
x = ggml_norm(ctx, x); x = ggml_norm(ctx, x);
x = ggml_mul(ctx, x, weight); x = ggml_mul(ctx, x, weight);
x = ggml_add(ctx, x, bias); x = ggml_add(ctx, x, bias);
return x; return x;
} }
// TODO Fuse and benchmark
struct ggml_tensor * ggml_sigmoid(ggml_context * ctx, struct ggml_tensor * x) {
// ggml has no native sigmoid, but silu(x) / x can be an approximation
x = ggml_silu(ctx, x);
x = ggml_div(ctx, x, x);
return x;
}
// --- Script --- // --- Script ---
// Usage: main_rwkv.exe "C:\model.bin" <token index> "C:\state_in.bin" "C:\state_out.bin" "C:\logits_out.bin" // Usage: main_rwkv.exe "C:\model.bin" <token index> "C:\state_in.bin" "C:\state_out.bin" "C:\logits_out.bin"
// Token index is 0-based. // Token index is 0-based.
// To start from new state, pass empty string instead of input state file path. // To start from new state, pass empty string instead of input state file path.
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
ggml_run_test_suite();
RWKV_ASSERT(argc - 1 == 5, "Expected 5 arguments, got %d", argc - 1); RWKV_ASSERT(argc - 1 == 5, "Expected 5 arguments, got %d", argc - 1);
char * model_path = argv[1]; char * model_path = argv[1];
char * token_s = argv[2]; char * token_s = argv[2];
@ -408,9 +400,6 @@ int main(int argc, char ** argv) {
// --- Evaluate model --- // --- Evaluate model ---
struct ggml_tensor * ones = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed);
ggml_set_f32(ones, 1.0F);
// x = self.w.emb.weight[token] // x = self.w.emb.weight[token]
// TODO Replace with ggml_get_rows or similar // TODO Replace with ggml_get_rows or similar
struct ggml_tensor * one_hot = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_vocab, 1); struct ggml_tensor * one_hot = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_vocab, 1);
@ -433,27 +422,33 @@ int main(int argc, char ** argv) {
struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);
// state[5 * i + 1] // state[5 * i + 1]
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * 4); struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * 4);
COMPUTE_AND_PRINT_TENSOR(ctx, x_prev);
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
// xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
// xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
struct ggml_tensor * xk = ggml_add( struct ggml_tensor * xk = ggml_add(
ctx, ctx,
ggml_mul(ctx, x0, layer.att_time_mix_k), ggml_mul(ctx, x0, layer.att_time_mix_k),
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_k)) ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_k))
); );
struct ggml_tensor * xv = ggml_add( struct ggml_tensor * xv = ggml_add(
ctx, ctx,
ggml_mul(ctx, x0, layer.att_time_mix_v), ggml_mul(ctx, x0, layer.att_time_mix_v),
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_v)) ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_v))
); );
struct ggml_tensor * xr = ggml_add( struct ggml_tensor * xr = ggml_add(
ctx, ctx,
ggml_mul(ctx, x0, layer.att_time_mix_r), ggml_mul(ctx, x0, layer.att_time_mix_r),
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_r)) ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_r))
); );
// state[5 * i + 1] = x // state[5 * i + 1] = x
ggml_cpy(ctx, x0, x_prev); ggml_cpy(ctx, x0, x_prev);
COMPUTE_AND_PRINT_TENSOR(ctx, xk);
COMPUTE_AND_PRINT_TENSOR(ctx, xv);
COMPUTE_AND_PRINT_TENSOR(ctx, xr);
COMPUTE_AND_PRINT_TENSOR(ctx, x_prev);
// r = torch.sigmoid(rw @ xr) // r = torch.sigmoid(rw @ xr)
struct ggml_tensor * r = ggml_sigmoid( struct ggml_tensor * r = ggml_sigmoid(
ctx, ctx,
@ -474,14 +469,11 @@ int main(int argc, char ** argv) {
// ww = time_first + k // ww = time_first + k
struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k); struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k);
// qq = torch.maximum(pp, ww) // qq = torch.maximum(pp, ww)
// TODO Implement element-wise max in ggml struct ggml_tensor * qq = ggml_max(ctx, pp, ww);
struct ggml_tensor * qq = pp;
// e1 = torch.exp(pp - qq) // e1 = torch.exp(pp - qq)
// TODO Implement element-wise exp in ggml struct ggml_tensor * e1 = ggml_exp(ctx, ggml_sub(ctx, pp, qq));
struct ggml_tensor * e1 = ggml_sub(ctx, pp, qq);
// e2 = torch.exp(ww - qq) // e2 = torch.exp(ww - qq)
// TODO Use exp struct ggml_tensor * e2 = ggml_exp(ctx, ggml_sub(ctx, ww, qq));
struct ggml_tensor * e2 = ggml_sub(ctx, ww, qq);
// a = e1 * aa + e2 * v // a = e1 * aa + e2 * v
struct ggml_tensor * a = ggml_add( struct ggml_tensor * a = ggml_add(
ctx, ctx,
@ -499,27 +491,27 @@ int main(int argc, char ** argv) {
// ww = pp + time_decay // ww = pp + time_decay
ww = ggml_add(ctx, pp, layer.att_time_decay); ww = ggml_add(ctx, pp, layer.att_time_decay);
// qq = torch.maximum(ww, k) // qq = torch.maximum(ww, k)
// TODO Use max qq = ggml_max(ctx, ww, k);
qq = ww;
// e1 = torch.exp(ww - qq) // e1 = torch.exp(ww - qq)
// TODO Use exp e1 = ggml_exp(ctx, ggml_sub(ctx, ww, qq));
e1 = ggml_sub(ctx, ww, qq);
// e2 = torch.exp(k - qq) // e2 = torch.exp(k - qq)
// TODO Use exp e2 = ggml_exp(ctx, ggml_sub(ctx, k, qq));
e2 = ggml_sub(ctx, k, qq);
// state[5 * i + 2] = e1 * aa + e2 * v // state[5 * i + 2] = e1 * aa + e2 * v
// todo must save result
ggml_cpy(ctx, ggml_add( ggml_cpy(ctx, ggml_add(
ctx, ctx,
ggml_mul(ctx, e1, aa), ggml_mul(ctx, e1, aa),
ggml_mul(ctx, e2, v) ggml_mul(ctx, e2, v)
), aa); ), aa);
// state[5 * i + 3] = e1 * bb + e2 // state[5 * i + 3] = e1 * bb + e2
// todo must save result
ggml_cpy(ctx, ggml_add( ggml_cpy(ctx, ggml_add(
ctx, ctx,
ggml_mul(ctx, e1, bb), ggml_mul(ctx, e1, bb),
e2 e2
), bb); ), bb);
// state[5 * i + 4] = qq // state[5 * i + 4] = qq
// todo must save result
ggml_cpy(ctx, qq, pp); ggml_cpy(ctx, qq, pp);
// ow @ (r * wkv) // ow @ (r * wkv)
x = ggml_add( x = ggml_add(
@ -531,6 +523,8 @@ int main(int argc, char ** argv) {
ggml_mul(ctx, r, wkv) ggml_mul(ctx, r, wkv)
) )
); );
RWKV_LOG("RWKV %d completed", i);
COMPUTE_AND_PRINT_TENSOR(ctx, x);
} }
// FFN/channel mixing // FFN/channel mixing
@ -544,14 +538,15 @@ int main(int argc, char ** argv) {
struct ggml_tensor * xk = ggml_add( struct ggml_tensor * xk = ggml_add(
ctx, ctx,
ggml_mul(ctx, x0, layer.ffn_time_mix_k), ggml_mul(ctx, x0, layer.ffn_time_mix_k),
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.ffn_time_mix_k)) ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.ffn_time_mix_k))
); );
struct ggml_tensor * xr = ggml_add( struct ggml_tensor * xr = ggml_add(
ctx, ctx,
ggml_mul(ctx, x0, layer.ffn_time_mix_r), ggml_mul(ctx, x0, layer.ffn_time_mix_r),
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.ffn_time_mix_r)) ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.ffn_time_mix_r))
); );
// state[5 * i + 0] = x // state[5 * i + 0] = x
// todo must save result
ggml_cpy(ctx, x0, x_prev); ggml_cpy(ctx, x0, x_prev);
// r = torch.sigmoid(rw @ xr) // r = torch.sigmoid(rw @ xr)
@ -574,6 +569,8 @@ int main(int argc, char ** argv) {
ggml_mul_mat(ctx, layer.ffn_value, k) ggml_mul_mat(ctx, layer.ffn_value, k)
) )
); );
RWKV_LOG("FFN %d completed", i);
COMPUTE_AND_PRINT_TENSOR(ctx, x);
} }
} }
@ -588,6 +585,8 @@ int main(int argc, char ** argv) {
// TODO -nan(ind) -nan(ind) ... (maybe implement exp/max first?) // TODO -nan(ind) -nan(ind) ... (maybe implement exp/max first?)
PRINT_TENSOR(logits); PRINT_TENSOR(logits);
// TODO Save new state and logits
ggml_free(ctx); ggml_free(ctx);
RWKV_LOG("OK"); RWKV_LOG("OK");

428
ggml.c
View File

@ -1603,15 +1603,18 @@ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) {
inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } inline static void ggml_vec_1_minus_x_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1 - x[i]; }
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
inline static void ggml_vec_element_wise_max_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = max(x[i], y[i]); }
inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
#ifdef GGML_SIMD #ifdef GGML_SIMD
@ -2304,6 +2307,17 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
} }
#endif #endif
// Sigmoid function
inline static float ggml_sigmoid_f32(float x) {
return 1.0F / (1.0F + expf(-x));
}
inline static void ggml_vec_sigmoid_f32(const int n, float * y, const float * x) {
for (int i = 0; i < n; ++i) {
y[i] = ggml_sigmoid_f32(x[i]);
}
}
// Sigmoid Linear Unit (SiLU) function // Sigmoid Linear Unit (SiLU) function
inline static float ggml_silu_f32(float x) { inline static float ggml_silu_f32(float x) {
return x/(1.0f + expf(-x)); return x/(1.0f + expf(-x));
@ -2457,7 +2471,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"FLASH_FF", "FLASH_FF",
}; };
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35"); static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
@ -2501,7 +2515,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"flash_ff(x)", "flash_ff(x)",
}; };
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35"); static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
// //
// ggml object // ggml object
@ -3866,6 +3880,54 @@ struct ggml_tensor * ggml_neg_inplace(
return ggml_neg_impl(ctx, a, true); return ggml_neg_impl(ctx, a, true);
} }
// ggml_exp
struct ggml_tensor * ggml_exp(
struct ggml_context * ctx,
struct ggml_tensor * a) {
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
result->op = GGML_OP_EXP;
result->grad = a->grad ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = NULL;
return result;
}
// ggml_1_minus_x
struct ggml_tensor * ggml_1_minus_x(
struct ggml_context * ctx,
struct ggml_tensor * a) {
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
result->op = GGML_OP_1_MINUS_X;
result->grad = a->grad ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = NULL;
return result;
}
// ggml_max
struct ggml_tensor * ggml_max(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
GGML_ASSERT(ggml_are_same_shape(a, b));
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
result->op = GGML_OP_MAX;
result->grad = (a->grad || b->grad) ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = b;
return result;
}
// ggml_step // ggml_step
struct ggml_tensor * ggml_step_impl( struct ggml_tensor * ggml_step_impl(
@ -3968,6 +4030,21 @@ struct ggml_tensor * ggml_gelu_inplace(
return ggml_gelu_impl(ctx, a, true); return ggml_gelu_impl(ctx, a, true);
} }
// ggml_sigmoid
struct ggml_tensor * ggml_sigmoid(
struct ggml_context * ctx,
struct ggml_tensor * a) {
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
result->op = GGML_OP_SIGMOID;
result->grad = a->grad ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = NULL;
return result;
}
// ggml_silu // ggml_silu
struct ggml_tensor * ggml_silu_impl( struct ggml_tensor * ggml_silu_impl(
@ -5546,6 +5623,154 @@ static void ggml_compute_forward_neg(
} }
} }
// ggml_compute_forward_exp
static void ggml_compute_forward_exp_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
assert(params->ith == 0);
assert(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
const int n = ggml_nrows(src0);
const int nc = src0->ne[0];
assert(dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_vec_exp_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_compute_forward_exp(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_exp_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_F16:
case GGML_TYPE_COUNT:
{
GGML_ASSERT(false);
} break;
}
}
// ggml_compute_forward_1_minus_x
static void ggml_compute_forward_1_minus_x_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
assert(params->ith == 0);
assert(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
const int n = ggml_nrows(src0);
const int nc = src0->ne[0];
assert(dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_vec_1_minus_x_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_compute_forward_1_minus_x(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_1_minus_x_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_F16:
case GGML_TYPE_COUNT:
{
GGML_ASSERT(false);
} break;
}
}
// ggml_compute_forward_max
static void ggml_compute_forward_max_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
assert(params->ith == 0);
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
const int n = ggml_nrows(src0);
const int nc = src0->ne[0];
assert( dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
assert(src1->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_vec_element_wise_max_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])),
(float *) ((char *) src1->data + i*(src1->nb[1])));
}
}
static void ggml_compute_forward_max(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_max_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_F16:
case GGML_TYPE_COUNT:
{
GGML_ASSERT(false);
} break;
}
}
// ggml_compute_forward_step // ggml_compute_forward_step
static void ggml_compute_forward_step_f32( static void ggml_compute_forward_step_f32(
@ -5709,6 +5934,54 @@ static void ggml_compute_forward_gelu(
//printf("XXXXXXXX gelu\n"); //printf("XXXXXXXX gelu\n");
} }
// ggml_compute_forward_sigmoid
static void ggml_compute_forward_sigmoid_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
assert(params->ith == 0);
assert(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
const int n = ggml_nrows(src0);
const int nc = src0->ne[0];
assert(dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_vec_sigmoid_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_compute_forward_sigmoid(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_sigmoid_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_F16:
case GGML_TYPE_COUNT:
{
GGML_ASSERT(false);
} break;
}
}
// ggml_compute_forward_silu // ggml_compute_forward_silu
static void ggml_compute_forward_silu_f32( static void ggml_compute_forward_silu_f32(
@ -8423,6 +8696,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_neg(params, tensor->src0, tensor); ggml_compute_forward_neg(params, tensor->src0, tensor);
} break; } break;
case GGML_OP_EXP:
{
ggml_compute_forward_exp(params, tensor->src0, tensor);
} break;
case GGML_OP_1_MINUS_X:
{
ggml_compute_forward_1_minus_x(params, tensor->src0, tensor);
} break;
case GGML_OP_MAX:
{
ggml_compute_forward_max(params, tensor->src0, tensor->src1, tensor);
} break;
case GGML_OP_STEP: case GGML_OP_STEP:
{ {
ggml_compute_forward_step(params, tensor->src0, tensor); ggml_compute_forward_step(params, tensor->src0, tensor);
@ -8435,6 +8720,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_gelu(params, tensor->src0, tensor); ggml_compute_forward_gelu(params, tensor->src0, tensor);
} break; } break;
case GGML_OP_SIGMOID:
{
ggml_compute_forward_sigmoid(params, tensor->src0, tensor);
} break;
case GGML_OP_SILU: case GGML_OP_SILU:
{ {
ggml_compute_forward_silu(params, tensor->src0, tensor); ggml_compute_forward_silu(params, tensor->src0, tensor);
@ -8660,6 +8949,18 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace); src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
} }
} break; } break;
case GGML_OP_EXP:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_1_MINUS_X:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_MAX:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_STEP: case GGML_OP_STEP:
{ {
if (src0->grad) { if (src0->grad) {
@ -8681,6 +8982,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{ {
GGML_ASSERT(false); // TODO: not implemented GGML_ASSERT(false); // TODO: not implemented
} break; } break;
case GGML_OP_SIGMOID:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_SILU: case GGML_OP_SILU:
{ {
GGML_ASSERT(false); // TODO: not implemented GGML_ASSERT(false); // TODO: not implemented
@ -9101,8 +9406,12 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
case GGML_OP_ABS: case GGML_OP_ABS:
case GGML_OP_SGN: case GGML_OP_SGN:
case GGML_OP_NEG: case GGML_OP_NEG:
case GGML_OP_EXP:
case GGML_OP_1_MINUS_X:
case GGML_OP_MAX:
case GGML_OP_STEP: case GGML_OP_STEP:
case GGML_OP_RELU: case GGML_OP_RELU:
case GGML_OP_SIGMOID:
{ {
node->n_tasks = 1; node->n_tasks = 1;
} break; } break;
@ -10469,3 +10778,100 @@ int ggml_cpu_has_vsx(void) {
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
#define GGML_TEST_SET_ELEMENT_F32(tensor, i, value) *(float *) ((char *) tensor->data + 4 * i) = value
#define GGML_TEST_ASSERT_ELEMENT_F32(tensor, i, expected_value) do {\
float actual = *(float *) ((char *) tensor->data + 4 * i);\
if (fabs(actual - expected_value) >= 0.0001F) {\
fprintf(stderr, "*** Assertion failed ***\n");\
fprintf(stderr, "At %s[%d]: expected %f, actual %f\n", #tensor, i, expected_value, actual);\
fprintf(stderr, "%s:%d\n", __FILE__, __LINE__);\
abort();\
}\
} while (0)
void ggml_run_test_suite() {
fprintf(stderr, "Running ggml test suite...\n");
struct ggml_init_params params;
params.mem_size = 16 * 1024;
params.mem_buffer = NULL;
struct ggml_context * ctx = ggml_init(params);
struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 3, 2);
GGML_TEST_SET_ELEMENT_F32(a, 0, 1.0051F);
GGML_TEST_SET_ELEMENT_F32(a, 1, 1.0484F);
GGML_TEST_SET_ELEMENT_F32(a, 2, -0.4361F);
GGML_TEST_SET_ELEMENT_F32(a, 3, -0.6984F);
GGML_TEST_SET_ELEMENT_F32(a, 4, 1.7310F);
GGML_TEST_SET_ELEMENT_F32(a, 5, -0.0446F);
struct ggml_tensor * b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 3, 2);
GGML_TEST_SET_ELEMENT_F32(b, 0, -0.2566F);
GGML_TEST_SET_ELEMENT_F32(b, 1, -0.1412F);
GGML_TEST_SET_ELEMENT_F32(b, 2, 1.6200F);
GGML_TEST_SET_ELEMENT_F32(b, 3, 0.5156F);
GGML_TEST_SET_ELEMENT_F32(b, 4, -0.3934F);
GGML_TEST_SET_ELEMENT_F32(b, 5, -0.0694F);
// Test against torch.exp(a)
struct ggml_tensor * exp_a = ggml_exp(ctx, a);
struct ggml_cgraph graph = ggml_build_forward(exp_a);
graph.n_threads = 2;
ggml_graph_compute(ctx, &graph);
GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 0, 2.7322F);
GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 1, 2.8531F);
GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 2, 0.6466F);
GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 3, 0.4974F);
GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 4, 5.6463F);
GGML_TEST_ASSERT_ELEMENT_F32(exp_a, 5, 0.9564F);
// Test against (1 - a) in PyTorch
struct ggml_tensor * one_minus_a = ggml_1_minus_x(ctx, a);
graph = ggml_build_forward(one_minus_a);
graph.n_threads = 2;
ggml_graph_compute(ctx, &graph);
GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 0, -0.0051F);
GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 1, -0.0484F);
GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 2, 1.4361F);
GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 3, 1.6984F);
GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 4, -0.7310F);
GGML_TEST_ASSERT_ELEMENT_F32(one_minus_a, 5, 1.0446F);
// Test against torch.sigmoid(a)
struct ggml_tensor * sigmoid_a = ggml_sigmoid(ctx, a);
graph = ggml_build_forward(sigmoid_a);
graph.n_threads = 2;
ggml_graph_compute(ctx, &graph);
GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 0, 0.7321F);
GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 1, 0.7405F);
GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 2, 0.3927F);
GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 3, 0.3322F);
GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 4, 0.8495F);
GGML_TEST_ASSERT_ELEMENT_F32(sigmoid_a, 5, 0.4889F);
// Test against torch.maximum(a, b)
struct ggml_tensor * max_a_b = ggml_max(ctx, a, b);
graph = ggml_build_forward(max_a_b);
graph.n_threads = 2;
ggml_graph_compute(ctx, &graph);
GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 0, 1.0051F);
GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 1, 1.0484F);
GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 2, 1.6200F);
GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 3, 0.5156F);
GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 4, 1.7310F);
GGML_TEST_ASSERT_ELEMENT_F32(max_a_b, 5, -0.0446F);
ggml_free(ctx);
fprintf(stderr, "All ggml tests pass\n");
}

61
ggml.h
View File

@ -167,6 +167,32 @@
// //
// TODO // TODO
// //
// ## Adding new operators
//
// Suppose you want to add e^x unary operator. Following steps need to be done:
//
// In `ggml.h`:
//
// 1. Add member `GGML_OP_EXP` to `ggml_op` enum.
// 2. Declare the operator function: `struct ggml_tensor * ggml_exp(struct ggml_context * ctx, struct ggml_tensor * x);`.
//
// In `ggml.c`:
//
// 1. Implement `ggml_exp` function: it will create result tensor and set its' operator and arguments.
// 2. Create forward computation function for FP32: `ggml_compute_forward_exp_f32`: it will do the actual computation.
// 3. If needed, create forward computation functions for other types: FP16, INT32, etc.
// 4. Create forward dispatch function `ggml_compute_forward_exp`: it would dispatch the call based on tensor data type.
// 5. Add `case GGML_OP_EXP`:
// - to `ggml_compute_forward` and call the forward dispatch function here.
// - to `ggml_compute_backward` and add `GGML_ASSERT(false)` here.
// - to `ggml_graph_compute` and add `node->n_tasks = 1` here.
// 6. Fix all assertions that check value of `GGML_OP_COUNT`: you've added 1 operator, so increment asserted value by one.
//
// When in doubt, consult the code of existing operators similar to that you're implementing.
// Resulting operator would work for the forward pass, but will lack backward implementation and multi-threading support.
//
// TODO Implementing backward pass
// TODO Implementing multi-threading
// //
#ifdef __cplusplus #ifdef __cplusplus
@ -225,9 +251,22 @@ enum ggml_op {
GGML_OP_ABS, GGML_OP_ABS,
GGML_OP_SGN, GGML_OP_SGN,
GGML_OP_NEG, GGML_OP_NEG,
// Element-wise exponential function `e^x`.
// Same as `torch.exp(x)` from PyTorch.
GGML_OP_EXP,
// Element-wise `1 - x`.
GGML_OP_1_MINUS_X,
// Element-wise maximum of 2 values. Argument shapes must match.
// Same as `torch.maximum(x)` from PyTorch.
GGML_OP_MAX,
GGML_OP_STEP, GGML_OP_STEP,
GGML_OP_RELU, GGML_OP_RELU,
GGML_OP_GELU, GGML_OP_GELU,
// Element-wise sigmoid activation `1 / (1 + e^-x)`, also called logistic function.
// Same as `torch.sigmoid(x)` from PyTorch.
GGML_OP_SIGMOID,
GGML_OP_SILU, GGML_OP_SILU,
GGML_OP_NORM, // normalize GGML_OP_NORM, // normalize
GGML_OP_RMS_NORM, GGML_OP_RMS_NORM,
@ -463,6 +502,19 @@ struct ggml_tensor * ggml_neg(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
struct ggml_tensor * ggml_exp(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_1_minus_x(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_max(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * ggml_step( struct ggml_tensor * ggml_step(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
@ -476,6 +528,10 @@ struct ggml_tensor * ggml_gelu(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
struct ggml_tensor * ggml_sigmoid(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_silu( struct ggml_tensor * ggml_silu(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
@ -768,6 +824,11 @@ int ggml_cpu_has_blas(void);
int ggml_cpu_has_sse3(void); int ggml_cpu_has_sse3(void);
int ggml_cpu_has_vsx(void); int ggml_cpu_has_vsx(void);
// Run test suite for ggml.
// Exits normally, if all tests pass.
// Aborts the execution if any test did not pass.
void ggml_run_test_suite();
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif