Implement time mixing, fix matrix shape mismatch
This commit is contained in:
parent
873cb954d0
commit
56bf4fc856
|
@ -424,7 +424,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// --- Evaluate model ---
|
// --- Evaluate model ---
|
||||||
|
|
||||||
struct ggml_tensor * ones = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_embed);
|
struct ggml_tensor * ones = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed);
|
||||||
ggml_set_f32(ones, 1.0F);
|
ggml_set_f32(ones, 1.0F);
|
||||||
|
|
||||||
// x = self.w.emb.weight[token]
|
// x = self.w.emb.weight[token]
|
||||||
|
@ -437,7 +437,7 @@ int main(int argc, char ** argv) {
|
||||||
x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias);
|
x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias);
|
||||||
|
|
||||||
// For token 123 after ln0, should be [-0.4194, 1.1698, 0.7798 ... -1.1838, -0.8716, -0.2765]
|
// For token 123 after ln0, should be [-0.4194, 1.1698, 0.7798 ... -1.1838, -0.8716, -0.2765]
|
||||||
// Prints [[-0.419416 1.169782 0.779827 ... -1.183806 -0.871573 -0.276483]]
|
// Prints (768, 1), [[-0.419416 1.169782 0.779827 ... -1.183806 -0.871573 -0.276483]]
|
||||||
COMPUTE_AND_PRINT_TENSOR(ctx, x);
|
COMPUTE_AND_PRINT_TENSOR(ctx, x);
|
||||||
|
|
||||||
for (int i = 0; i < n_layer; i++) {
|
for (int i = 0; i < n_layer; i++) {
|
||||||
|
@ -445,9 +445,108 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// RWKV/time mixing
|
// RWKV/time mixing
|
||||||
{
|
{
|
||||||
|
// self.layer_norm(x, self.w.blocks[i].ln1)
|
||||||
struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);
|
struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);
|
||||||
// TODO Implement
|
// state[5 * i + 1]
|
||||||
x = ggml_add(ctx, x, x0);
|
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * 4);
|
||||||
|
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
|
||||||
|
// xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
|
||||||
|
// xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
|
||||||
|
struct ggml_tensor * xk = ggml_add(
|
||||||
|
ctx,
|
||||||
|
ggml_mul(ctx, x0, layer.att_time_mix_k),
|
||||||
|
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_k))
|
||||||
|
);
|
||||||
|
struct ggml_tensor * xv = ggml_add(
|
||||||
|
ctx,
|
||||||
|
ggml_mul(ctx, x0, layer.att_time_mix_v),
|
||||||
|
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_v))
|
||||||
|
);
|
||||||
|
struct ggml_tensor * xr = ggml_add(
|
||||||
|
ctx,
|
||||||
|
ggml_mul(ctx, x0, layer.att_time_mix_r),
|
||||||
|
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_r))
|
||||||
|
);
|
||||||
|
// state[5 * i + 1] = x
|
||||||
|
ggml_cpy(ctx, x0, x_prev);
|
||||||
|
|
||||||
|
// r = torch.sigmoid(rw @ xr)
|
||||||
|
struct ggml_tensor * r = ggml_sigmoid(
|
||||||
|
ctx,
|
||||||
|
ggml_mul_mat(ctx, layer.att_receptance, xr)
|
||||||
|
);
|
||||||
|
// k = kw @ xk
|
||||||
|
struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk);
|
||||||
|
// v = vw @ xv
|
||||||
|
struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv);
|
||||||
|
|
||||||
|
// aa = state[5 * i + 2]
|
||||||
|
// bb = state[5 * i + 3]
|
||||||
|
// pp = state[5 * i + 4]
|
||||||
|
struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * 4);
|
||||||
|
struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * 4);
|
||||||
|
struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * 4);
|
||||||
|
|
||||||
|
// ww = time_first + k
|
||||||
|
struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k);
|
||||||
|
// qq = torch.maximum(pp, ww)
|
||||||
|
// TODO Implement element-wise max in ggml
|
||||||
|
struct ggml_tensor * qq = pp;
|
||||||
|
// e1 = torch.exp(pp - qq)
|
||||||
|
// TODO Implement element-wise exp in ggml
|
||||||
|
struct ggml_tensor * e1 = ggml_sub(ctx, pp, qq);
|
||||||
|
// e2 = torch.exp(ww - qq)
|
||||||
|
// TODO Use exp
|
||||||
|
struct ggml_tensor * e2 = ggml_sub(ctx, ww, qq);
|
||||||
|
// a = e1 * aa + e2 * v
|
||||||
|
struct ggml_tensor * a = ggml_add(
|
||||||
|
ctx,
|
||||||
|
ggml_mul(ctx, e1, aa),
|
||||||
|
ggml_mul(ctx, e2, v)
|
||||||
|
);
|
||||||
|
// b = e1 * bb + e2
|
||||||
|
struct ggml_tensor * b = ggml_add(
|
||||||
|
ctx,
|
||||||
|
ggml_mul(ctx, e1, bb),
|
||||||
|
e2
|
||||||
|
);
|
||||||
|
// wkv = a / b
|
||||||
|
struct ggml_tensor * wkv = ggml_div(ctx, a, b);
|
||||||
|
// ww = pp + time_decay
|
||||||
|
ww = ggml_add(ctx, pp, layer.att_time_decay);
|
||||||
|
// qq = torch.maximum(ww, k)
|
||||||
|
// TODO Use max
|
||||||
|
qq = ww;
|
||||||
|
// e1 = torch.exp(ww - qq)
|
||||||
|
// TODO Use exp
|
||||||
|
e1 = ggml_sub(ctx, ww, qq);
|
||||||
|
// e2 = torch.exp(k - qq)
|
||||||
|
// TODO Use exp
|
||||||
|
e2 = ggml_sub(ctx, k, qq);
|
||||||
|
// state[5 * i + 2] = e1 * aa + e2 * v
|
||||||
|
ggml_cpy(ctx, ggml_add(
|
||||||
|
ctx,
|
||||||
|
ggml_mul(ctx, e1, aa),
|
||||||
|
ggml_mul(ctx, e2, v)
|
||||||
|
), aa);
|
||||||
|
// state[5 * i + 3] = e1 * bb + e2
|
||||||
|
ggml_cpy(ctx, ggml_add(
|
||||||
|
ctx,
|
||||||
|
ggml_mul(ctx, e1, bb),
|
||||||
|
e2
|
||||||
|
), bb);
|
||||||
|
// state[5 * i + 4] = qq
|
||||||
|
ggml_cpy(ctx, qq, pp);
|
||||||
|
// ow @ (r * wkv)
|
||||||
|
x = ggml_add(
|
||||||
|
ctx,
|
||||||
|
x,
|
||||||
|
ggml_mul_mat(
|
||||||
|
ctx,
|
||||||
|
layer.att_output,
|
||||||
|
ggml_mul(ctx, r, wkv)
|
||||||
|
)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// FFN/channel mixing
|
// FFN/channel mixing
|
||||||
|
@ -455,8 +554,7 @@ int main(int argc, char ** argv) {
|
||||||
// self.layer_norm(x, self.w.blocks[i].ln2)
|
// self.layer_norm(x, self.w.blocks[i].ln2)
|
||||||
struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias);
|
struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias);
|
||||||
// state[5 * i + 0]
|
// state[5 * i + 0]
|
||||||
int32_t offset_in_bytes = (5 * i + 0) * n_embed * 4;
|
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * 4);
|
||||||
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, offset_in_bytes);
|
|
||||||
// xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
|
// xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
|
||||||
// xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
|
// xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
|
||||||
struct ggml_tensor * xk = ggml_add(
|
struct ggml_tensor * xk = ggml_add(
|
||||||
|
@ -478,16 +576,20 @@ int main(int argc, char ** argv) {
|
||||||
ggml_mul_mat(ctx, layer.ffn_receptance, xr)
|
ggml_mul_mat(ctx, layer.ffn_receptance, xr)
|
||||||
);
|
);
|
||||||
// k = torch.square(torch.relu(kw @ xk))
|
// k = torch.square(torch.relu(kw @ xk))
|
||||||
// TODO Does not work; shape mismatch
|
struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(
|
||||||
//struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(
|
ctx,
|
||||||
// ctx,
|
ggml_mul_mat(ctx, layer.ffn_key, xk)
|
||||||
// ggml_mul_mat(ctx, layer.ffn_key, xk)
|
));
|
||||||
//));
|
|
||||||
// r * (vw @ k)
|
// r * (vw @ k)
|
||||||
// TODO Does not work; shape mismatch
|
x = ggml_add(
|
||||||
// x0 = ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k));
|
ctx,
|
||||||
// x = x + self.channel_mixing(...)
|
x,
|
||||||
x = ggml_add(ctx, x, x0);
|
ggml_mul(
|
||||||
|
ctx,
|
||||||
|
r,
|
||||||
|
ggml_mul_mat(ctx, layer.ffn_value, k)
|
||||||
|
)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -495,11 +597,11 @@ int main(int argc, char ** argv) {
|
||||||
x = ggml_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias);
|
x = ggml_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias);
|
||||||
|
|
||||||
// x = (self.w.head.weight @ x).float()
|
// x = (self.w.head.weight @ x).float()
|
||||||
// TODO Is transpose copying the tensor?
|
struct ggml_tensor * logits = ggml_mul_mat(ctx, model.head, x);
|
||||||
struct ggml_tensor * logits = ggml_mul_mat(ctx, x, ggml_transpose(ctx, model.head));
|
|
||||||
|
|
||||||
compute_graph(ctx, logits);
|
compute_graph(ctx, logits);
|
||||||
|
|
||||||
|
// TODO -nan(ind) -nan(ind) ... (maybe implement exp/max first?)
|
||||||
PRINT_TENSOR(logits);
|
PRINT_TENSOR(logits);
|
||||||
|
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
|
|
Loading…
Reference in New Issue