Implement time mixing, fix matrix shape mismatch
This commit is contained in:
parent
873cb954d0
commit
56bf4fc856
|
@ -424,7 +424,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// --- Evaluate model ---
|
||||
|
||||
struct ggml_tensor * ones = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_embed);
|
||||
struct ggml_tensor * ones = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed);
|
||||
ggml_set_f32(ones, 1.0F);
|
||||
|
||||
// x = self.w.emb.weight[token]
|
||||
|
@ -437,7 +437,7 @@ int main(int argc, char ** argv) {
|
|||
x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias);
|
||||
|
||||
// For token 123 after ln0, should be [-0.4194, 1.1698, 0.7798 ... -1.1838, -0.8716, -0.2765]
|
||||
// Prints [[-0.419416 1.169782 0.779827 ... -1.183806 -0.871573 -0.276483]]
|
||||
// Prints (768, 1), [[-0.419416 1.169782 0.779827 ... -1.183806 -0.871573 -0.276483]]
|
||||
COMPUTE_AND_PRINT_TENSOR(ctx, x);
|
||||
|
||||
for (int i = 0; i < n_layer; i++) {
|
||||
|
@ -445,9 +445,108 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// RWKV/time mixing
|
||||
{
|
||||
// self.layer_norm(x, self.w.blocks[i].ln1)
|
||||
struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);
|
||||
// TODO Implement
|
||||
x = ggml_add(ctx, x, x0);
|
||||
// state[5 * i + 1]
|
||||
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * 4);
|
||||
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
|
||||
// xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
|
||||
// xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
|
||||
struct ggml_tensor * xk = ggml_add(
|
||||
ctx,
|
||||
ggml_mul(ctx, x0, layer.att_time_mix_k),
|
||||
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_k))
|
||||
);
|
||||
struct ggml_tensor * xv = ggml_add(
|
||||
ctx,
|
||||
ggml_mul(ctx, x0, layer.att_time_mix_v),
|
||||
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_v))
|
||||
);
|
||||
struct ggml_tensor * xr = ggml_add(
|
||||
ctx,
|
||||
ggml_mul(ctx, x0, layer.att_time_mix_r),
|
||||
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_r))
|
||||
);
|
||||
// state[5 * i + 1] = x
|
||||
ggml_cpy(ctx, x0, x_prev);
|
||||
|
||||
// r = torch.sigmoid(rw @ xr)
|
||||
struct ggml_tensor * r = ggml_sigmoid(
|
||||
ctx,
|
||||
ggml_mul_mat(ctx, layer.att_receptance, xr)
|
||||
);
|
||||
// k = kw @ xk
|
||||
struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk);
|
||||
// v = vw @ xv
|
||||
struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv);
|
||||
|
||||
// aa = state[5 * i + 2]
|
||||
// bb = state[5 * i + 3]
|
||||
// pp = state[5 * i + 4]
|
||||
struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * 4);
|
||||
struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * 4);
|
||||
struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * 4);
|
||||
|
||||
// ww = time_first + k
|
||||
struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k);
|
||||
// qq = torch.maximum(pp, ww)
|
||||
// TODO Implement element-wise max in ggml
|
||||
struct ggml_tensor * qq = pp;
|
||||
// e1 = torch.exp(pp - qq)
|
||||
// TODO Implement element-wise exp in ggml
|
||||
struct ggml_tensor * e1 = ggml_sub(ctx, pp, qq);
|
||||
// e2 = torch.exp(ww - qq)
|
||||
// TODO Use exp
|
||||
struct ggml_tensor * e2 = ggml_sub(ctx, ww, qq);
|
||||
// a = e1 * aa + e2 * v
|
||||
struct ggml_tensor * a = ggml_add(
|
||||
ctx,
|
||||
ggml_mul(ctx, e1, aa),
|
||||
ggml_mul(ctx, e2, v)
|
||||
);
|
||||
// b = e1 * bb + e2
|
||||
struct ggml_tensor * b = ggml_add(
|
||||
ctx,
|
||||
ggml_mul(ctx, e1, bb),
|
||||
e2
|
||||
);
|
||||
// wkv = a / b
|
||||
struct ggml_tensor * wkv = ggml_div(ctx, a, b);
|
||||
// ww = pp + time_decay
|
||||
ww = ggml_add(ctx, pp, layer.att_time_decay);
|
||||
// qq = torch.maximum(ww, k)
|
||||
// TODO Use max
|
||||
qq = ww;
|
||||
// e1 = torch.exp(ww - qq)
|
||||
// TODO Use exp
|
||||
e1 = ggml_sub(ctx, ww, qq);
|
||||
// e2 = torch.exp(k - qq)
|
||||
// TODO Use exp
|
||||
e2 = ggml_sub(ctx, k, qq);
|
||||
// state[5 * i + 2] = e1 * aa + e2 * v
|
||||
ggml_cpy(ctx, ggml_add(
|
||||
ctx,
|
||||
ggml_mul(ctx, e1, aa),
|
||||
ggml_mul(ctx, e2, v)
|
||||
), aa);
|
||||
// state[5 * i + 3] = e1 * bb + e2
|
||||
ggml_cpy(ctx, ggml_add(
|
||||
ctx,
|
||||
ggml_mul(ctx, e1, bb),
|
||||
e2
|
||||
), bb);
|
||||
// state[5 * i + 4] = qq
|
||||
ggml_cpy(ctx, qq, pp);
|
||||
// ow @ (r * wkv)
|
||||
x = ggml_add(
|
||||
ctx,
|
||||
x,
|
||||
ggml_mul_mat(
|
||||
ctx,
|
||||
layer.att_output,
|
||||
ggml_mul(ctx, r, wkv)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// FFN/channel mixing
|
||||
|
@ -455,8 +554,7 @@ int main(int argc, char ** argv) {
|
|||
// self.layer_norm(x, self.w.blocks[i].ln2)
|
||||
struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias);
|
||||
// state[5 * i + 0]
|
||||
int32_t offset_in_bytes = (5 * i + 0) * n_embed * 4;
|
||||
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, offset_in_bytes);
|
||||
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * 4);
|
||||
// xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
|
||||
// xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
|
||||
struct ggml_tensor * xk = ggml_add(
|
||||
|
@ -478,16 +576,20 @@ int main(int argc, char ** argv) {
|
|||
ggml_mul_mat(ctx, layer.ffn_receptance, xr)
|
||||
);
|
||||
// k = torch.square(torch.relu(kw @ xk))
|
||||
// TODO Does not work; shape mismatch
|
||||
//struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(
|
||||
// ctx,
|
||||
// ggml_mul_mat(ctx, layer.ffn_key, xk)
|
||||
//));
|
||||
struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(
|
||||
ctx,
|
||||
ggml_mul_mat(ctx, layer.ffn_key, xk)
|
||||
));
|
||||
// r * (vw @ k)
|
||||
// TODO Does not work; shape mismatch
|
||||
// x0 = ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k));
|
||||
// x = x + self.channel_mixing(...)
|
||||
x = ggml_add(ctx, x, x0);
|
||||
x = ggml_add(
|
||||
ctx,
|
||||
x,
|
||||
ggml_mul(
|
||||
ctx,
|
||||
r,
|
||||
ggml_mul_mat(ctx, layer.ffn_value, k)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -495,11 +597,11 @@ int main(int argc, char ** argv) {
|
|||
x = ggml_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias);
|
||||
|
||||
// x = (self.w.head.weight @ x).float()
|
||||
// TODO Is transpose copying the tensor?
|
||||
struct ggml_tensor * logits = ggml_mul_mat(ctx, x, ggml_transpose(ctx, model.head));
|
||||
struct ggml_tensor * logits = ggml_mul_mat(ctx, model.head, x);
|
||||
|
||||
compute_graph(ctx, logits);
|
||||
|
||||
// TODO -nan(ind) -nan(ind) ... (maybe implement exp/max first?)
|
||||
PRINT_TENSOR(logits);
|
||||
|
||||
ggml_free(ctx);
|
||||
|
|
Loading…
Reference in New Issue