diff --git a/examples/main_rwkv/main_rwkv.cpp b/examples/main_rwkv/main_rwkv.cpp index b428cc0..9288e6f 100644 --- a/examples/main_rwkv/main_rwkv.cpp +++ b/examples/main_rwkv/main_rwkv.cpp @@ -424,7 +424,7 @@ int main(int argc, char ** argv) { // --- Evaluate model --- - struct ggml_tensor * ones = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_embed); + struct ggml_tensor * ones = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed); ggml_set_f32(ones, 1.0F); // x = self.w.emb.weight[token] @@ -437,7 +437,7 @@ int main(int argc, char ** argv) { x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias); // For token 123 after ln0, should be [-0.4194, 1.1698, 0.7798 ... -1.1838, -0.8716, -0.2765] - // Prints [[-0.419416 1.169782 0.779827 ... -1.183806 -0.871573 -0.276483]] + // Prints (768, 1), [[-0.419416 1.169782 0.779827 ... -1.183806 -0.871573 -0.276483]] COMPUTE_AND_PRINT_TENSOR(ctx, x); for (int i = 0; i < n_layer; i++) { @@ -445,9 +445,108 @@ int main(int argc, char ** argv) { // RWKV/time mixing { + // self.layer_norm(x, self.w.blocks[i].ln1) struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); - // TODO Implement - x = ggml_add(ctx, x, x0); + // state[5 * i + 1] + struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * 4); + // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) + // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) + // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) + struct ggml_tensor * xk = ggml_add( + ctx, + ggml_mul(ctx, x0, layer.att_time_mix_k), + ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_k)) + ); + struct ggml_tensor * xv = ggml_add( + ctx, + ggml_mul(ctx, x0, layer.att_time_mix_v), + ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_v)) + ); + struct ggml_tensor * xr = ggml_add( + ctx, + ggml_mul(ctx, x0, layer.att_time_mix_r), + ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_r)) + ); + // state[5 * i + 1] = x + ggml_cpy(ctx, x0, x_prev); + + // r = torch.sigmoid(rw @ xr) + struct ggml_tensor * r = ggml_sigmoid( + ctx, + ggml_mul_mat(ctx, layer.att_receptance, xr) + ); + // k = kw @ xk + struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk); + // v = vw @ xv + struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv); + + // aa = state[5 * i + 2] + // bb = state[5 * i + 3] + // pp = state[5 * i + 4] + struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * 4); + struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * 4); + struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * 4); + + // ww = time_first + k + struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k); + // qq = torch.maximum(pp, ww) + // TODO Implement element-wise max in ggml + struct ggml_tensor * qq = pp; + // e1 = torch.exp(pp - qq) + // TODO Implement element-wise exp in ggml + struct ggml_tensor * e1 = ggml_sub(ctx, pp, qq); + // e2 = torch.exp(ww - qq) + // TODO Use exp + struct ggml_tensor * e2 = ggml_sub(ctx, ww, qq); + // a = e1 * aa + e2 * v + struct ggml_tensor * a = ggml_add( + ctx, + ggml_mul(ctx, e1, aa), + ggml_mul(ctx, e2, v) + ); + // b = e1 * bb + e2 + struct ggml_tensor * b = ggml_add( + ctx, + ggml_mul(ctx, e1, bb), + e2 + ); + // wkv = a / b + struct ggml_tensor * wkv = ggml_div(ctx, a, b); + // ww = pp + time_decay + ww = ggml_add(ctx, pp, layer.att_time_decay); + // qq = torch.maximum(ww, k) + // TODO Use max + qq = ww; + // e1 = torch.exp(ww - qq) + // TODO Use exp + e1 = ggml_sub(ctx, ww, qq); + // e2 = torch.exp(k - qq) + // TODO Use exp + e2 = ggml_sub(ctx, k, qq); + // state[5 * i + 2] = e1 * aa + e2 * v + ggml_cpy(ctx, ggml_add( + ctx, + ggml_mul(ctx, e1, aa), + ggml_mul(ctx, e2, v) + ), aa); + // state[5 * i + 3] = e1 * bb + e2 + ggml_cpy(ctx, ggml_add( + ctx, + ggml_mul(ctx, e1, bb), + e2 + ), bb); + // state[5 * i + 4] = qq + ggml_cpy(ctx, qq, pp); + // ow @ (r * wkv) + x = ggml_add( + ctx, + x, + ggml_mul_mat( + ctx, + layer.att_output, + ggml_mul(ctx, r, wkv) + ) + ); } // FFN/channel mixing @@ -455,8 +554,7 @@ int main(int argc, char ** argv) { // self.layer_norm(x, self.w.blocks[i].ln2) struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); // state[5 * i + 0] - int32_t offset_in_bytes = (5 * i + 0) * n_embed * 4; - struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, offset_in_bytes); + struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * 4); // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) struct ggml_tensor * xk = ggml_add( @@ -478,16 +576,20 @@ int main(int argc, char ** argv) { ggml_mul_mat(ctx, layer.ffn_receptance, xr) ); // k = torch.square(torch.relu(kw @ xk)) - // TODO Does not work; shape mismatch - //struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu( - // ctx, - // ggml_mul_mat(ctx, layer.ffn_key, xk) - //)); + struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu( + ctx, + ggml_mul_mat(ctx, layer.ffn_key, xk) + )); // r * (vw @ k) - // TODO Does not work; shape mismatch - // x0 = ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); - // x = x + self.channel_mixing(...) - x = ggml_add(ctx, x, x0); + x = ggml_add( + ctx, + x, + ggml_mul( + ctx, + r, + ggml_mul_mat(ctx, layer.ffn_value, k) + ) + ); } } @@ -495,11 +597,11 @@ int main(int argc, char ** argv) { x = ggml_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias); // x = (self.w.head.weight @ x).float() - // TODO Is transpose copying the tensor? - struct ggml_tensor * logits = ggml_mul_mat(ctx, x, ggml_transpose(ctx, model.head)); + struct ggml_tensor * logits = ggml_mul_mat(ctx, model.head, x); compute_graph(ctx, logits); + // TODO -nan(ind) -nan(ind) ... (maybe implement exp/max first?) PRINT_TENSOR(logits); ggml_free(ctx);