Implement time mixing, fix matrix shape mismatch

This commit is contained in:
saharNooby 2023-03-30 20:29:41 +04:00
parent 873cb954d0
commit 56bf4fc856
1 changed files with 119 additions and 17 deletions

View File

@ -424,7 +424,7 @@ int main(int argc, char ** argv) {
// --- Evaluate model --- // --- Evaluate model ---
struct ggml_tensor * ones = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_embed); struct ggml_tensor * ones = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed);
ggml_set_f32(ones, 1.0F); ggml_set_f32(ones, 1.0F);
// x = self.w.emb.weight[token] // x = self.w.emb.weight[token]
@ -437,7 +437,7 @@ int main(int argc, char ** argv) {
x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias); x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias);
// For token 123 after ln0, should be [-0.4194, 1.1698, 0.7798 ... -1.1838, -0.8716, -0.2765] // For token 123 after ln0, should be [-0.4194, 1.1698, 0.7798 ... -1.1838, -0.8716, -0.2765]
// Prints [[-0.419416 1.169782 0.779827 ... -1.183806 -0.871573 -0.276483]] // Prints (768, 1), [[-0.419416 1.169782 0.779827 ... -1.183806 -0.871573 -0.276483]]
COMPUTE_AND_PRINT_TENSOR(ctx, x); COMPUTE_AND_PRINT_TENSOR(ctx, x);
for (int i = 0; i < n_layer; i++) { for (int i = 0; i < n_layer; i++) {
@ -445,9 +445,108 @@ int main(int argc, char ** argv) {
// RWKV/time mixing // RWKV/time mixing
{ {
// self.layer_norm(x, self.w.blocks[i].ln1)
struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);
// TODO Implement // state[5 * i + 1]
x = ggml_add(ctx, x, x0); struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * 4);
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
// xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
// xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
struct ggml_tensor * xk = ggml_add(
ctx,
ggml_mul(ctx, x0, layer.att_time_mix_k),
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_k))
);
struct ggml_tensor * xv = ggml_add(
ctx,
ggml_mul(ctx, x0, layer.att_time_mix_v),
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_v))
);
struct ggml_tensor * xr = ggml_add(
ctx,
ggml_mul(ctx, x0, layer.att_time_mix_r),
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_r))
);
// state[5 * i + 1] = x
ggml_cpy(ctx, x0, x_prev);
// r = torch.sigmoid(rw @ xr)
struct ggml_tensor * r = ggml_sigmoid(
ctx,
ggml_mul_mat(ctx, layer.att_receptance, xr)
);
// k = kw @ xk
struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk);
// v = vw @ xv
struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv);
// aa = state[5 * i + 2]
// bb = state[5 * i + 3]
// pp = state[5 * i + 4]
struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * 4);
struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * 4);
struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * 4);
// ww = time_first + k
struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k);
// qq = torch.maximum(pp, ww)
// TODO Implement element-wise max in ggml
struct ggml_tensor * qq = pp;
// e1 = torch.exp(pp - qq)
// TODO Implement element-wise exp in ggml
struct ggml_tensor * e1 = ggml_sub(ctx, pp, qq);
// e2 = torch.exp(ww - qq)
// TODO Use exp
struct ggml_tensor * e2 = ggml_sub(ctx, ww, qq);
// a = e1 * aa + e2 * v
struct ggml_tensor * a = ggml_add(
ctx,
ggml_mul(ctx, e1, aa),
ggml_mul(ctx, e2, v)
);
// b = e1 * bb + e2
struct ggml_tensor * b = ggml_add(
ctx,
ggml_mul(ctx, e1, bb),
e2
);
// wkv = a / b
struct ggml_tensor * wkv = ggml_div(ctx, a, b);
// ww = pp + time_decay
ww = ggml_add(ctx, pp, layer.att_time_decay);
// qq = torch.maximum(ww, k)
// TODO Use max
qq = ww;
// e1 = torch.exp(ww - qq)
// TODO Use exp
e1 = ggml_sub(ctx, ww, qq);
// e2 = torch.exp(k - qq)
// TODO Use exp
e2 = ggml_sub(ctx, k, qq);
// state[5 * i + 2] = e1 * aa + e2 * v
ggml_cpy(ctx, ggml_add(
ctx,
ggml_mul(ctx, e1, aa),
ggml_mul(ctx, e2, v)
), aa);
// state[5 * i + 3] = e1 * bb + e2
ggml_cpy(ctx, ggml_add(
ctx,
ggml_mul(ctx, e1, bb),
e2
), bb);
// state[5 * i + 4] = qq
ggml_cpy(ctx, qq, pp);
// ow @ (r * wkv)
x = ggml_add(
ctx,
x,
ggml_mul_mat(
ctx,
layer.att_output,
ggml_mul(ctx, r, wkv)
)
);
} }
// FFN/channel mixing // FFN/channel mixing
@ -455,8 +554,7 @@ int main(int argc, char ** argv) {
// self.layer_norm(x, self.w.blocks[i].ln2) // self.layer_norm(x, self.w.blocks[i].ln2)
struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias);
// state[5 * i + 0] // state[5 * i + 0]
int32_t offset_in_bytes = (5 * i + 0) * n_embed * 4; struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * 4);
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, offset_in_bytes);
// xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
// xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
struct ggml_tensor * xk = ggml_add( struct ggml_tensor * xk = ggml_add(
@ -478,16 +576,20 @@ int main(int argc, char ** argv) {
ggml_mul_mat(ctx, layer.ffn_receptance, xr) ggml_mul_mat(ctx, layer.ffn_receptance, xr)
); );
// k = torch.square(torch.relu(kw @ xk)) // k = torch.square(torch.relu(kw @ xk))
// TODO Does not work; shape mismatch struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(
//struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu( ctx,
// ctx, ggml_mul_mat(ctx, layer.ffn_key, xk)
// ggml_mul_mat(ctx, layer.ffn_key, xk) ));
//));
// r * (vw @ k) // r * (vw @ k)
// TODO Does not work; shape mismatch x = ggml_add(
// x0 = ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); ctx,
// x = x + self.channel_mixing(...) x,
x = ggml_add(ctx, x, x0); ggml_mul(
ctx,
r,
ggml_mul_mat(ctx, layer.ffn_value, k)
)
);
} }
} }
@ -495,11 +597,11 @@ int main(int argc, char ** argv) {
x = ggml_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias); x = ggml_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias);
// x = (self.w.head.weight @ x).float() // x = (self.w.head.weight @ x).float()
// TODO Is transpose copying the tensor? struct ggml_tensor * logits = ggml_mul_mat(ctx, model.head, x);
struct ggml_tensor * logits = ggml_mul_mat(ctx, x, ggml_transpose(ctx, model.head));
compute_graph(ctx, logits); compute_graph(ctx, logits);
// TODO -nan(ind) -nan(ind) ... (maybe implement exp/max first?)
PRINT_TENSOR(logits); PRINT_TENSOR(logits);
ggml_free(ctx); ggml_free(ctx);