Implement time mixing, fix matrix shape mismatch
This commit is contained in:
		
							parent
							
								
									873cb954d0
								
							
						
					
					
						commit
						56bf4fc856
					
				|  | @ -424,7 +424,7 @@ int main(int argc, char ** argv) { | ||||||
| 
 | 
 | ||||||
|     // --- Evaluate model ---
 |     // --- Evaluate model ---
 | ||||||
| 
 | 
 | ||||||
|     struct ggml_tensor * ones = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_embed); |     struct ggml_tensor * ones = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed); | ||||||
|     ggml_set_f32(ones, 1.0F); |     ggml_set_f32(ones, 1.0F); | ||||||
| 
 | 
 | ||||||
|     // x = self.w.emb.weight[token]
 |     // x = self.w.emb.weight[token]
 | ||||||
|  | @ -437,7 +437,7 @@ int main(int argc, char ** argv) { | ||||||
|     x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias); |     x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias); | ||||||
| 
 | 
 | ||||||
|     // For token 123 after ln0, should be [-0.4194, 1.1698, 0.7798 ... -1.1838, -0.8716, -0.2765]
 |     // For token 123 after ln0, should be [-0.4194, 1.1698, 0.7798 ... -1.1838, -0.8716, -0.2765]
 | ||||||
|     // Prints [[-0.419416 1.169782 0.779827 ... -1.183806 -0.871573 -0.276483]]
 |     // Prints (768, 1), [[-0.419416 1.169782 0.779827 ... -1.183806 -0.871573 -0.276483]]
 | ||||||
|     COMPUTE_AND_PRINT_TENSOR(ctx, x); |     COMPUTE_AND_PRINT_TENSOR(ctx, x); | ||||||
| 
 | 
 | ||||||
|     for (int i = 0; i < n_layer; i++) { |     for (int i = 0; i < n_layer; i++) { | ||||||
|  | @ -445,9 +445,108 @@ int main(int argc, char ** argv) { | ||||||
| 
 | 
 | ||||||
|         // RWKV/time mixing
 |         // RWKV/time mixing
 | ||||||
|         { |         { | ||||||
|  |             // self.layer_norm(x, self.w.blocks[i].ln1)
 | ||||||
|             struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); |             struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); | ||||||
|             // TODO Implement
 |             // state[5 * i + 1]
 | ||||||
|             x = ggml_add(ctx, x, x0); |             struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * 4); | ||||||
|  |             // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
 | ||||||
|  |             // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
 | ||||||
|  |             // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
 | ||||||
|  |             struct ggml_tensor * xk = ggml_add( | ||||||
|  |                 ctx, | ||||||
|  |                 ggml_mul(ctx, x0, layer.att_time_mix_k), | ||||||
|  |                 ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_k)) | ||||||
|  |             ); | ||||||
|  |             struct ggml_tensor * xv = ggml_add( | ||||||
|  |                 ctx, | ||||||
|  |                 ggml_mul(ctx, x0, layer.att_time_mix_v), | ||||||
|  |                 ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_v)) | ||||||
|  |             ); | ||||||
|  |             struct ggml_tensor * xr = ggml_add( | ||||||
|  |                 ctx, | ||||||
|  |                 ggml_mul(ctx, x0, layer.att_time_mix_r), | ||||||
|  |                 ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.att_time_mix_r)) | ||||||
|  |             ); | ||||||
|  |             // state[5 * i + 1] = x
 | ||||||
|  |             ggml_cpy(ctx, x0, x_prev); | ||||||
|  | 
 | ||||||
|  |             // r = torch.sigmoid(rw @ xr)
 | ||||||
|  |             struct ggml_tensor * r = ggml_sigmoid( | ||||||
|  |                 ctx, | ||||||
|  |                 ggml_mul_mat(ctx, layer.att_receptance, xr) | ||||||
|  |             ); | ||||||
|  |             // k = kw @ xk
 | ||||||
|  |             struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk); | ||||||
|  |             // v = vw @ xv
 | ||||||
|  |             struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv); | ||||||
|  | 
 | ||||||
|  |             // aa = state[5 * i + 2]
 | ||||||
|  |             // bb = state[5 * i + 3]
 | ||||||
|  |             // pp = state[5 * i + 4]
 | ||||||
|  |             struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * 4); | ||||||
|  |             struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * 4); | ||||||
|  |             struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * 4); | ||||||
|  | 
 | ||||||
|  |             // ww = time_first + k
 | ||||||
|  |             struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k); | ||||||
|  |             // qq = torch.maximum(pp, ww)
 | ||||||
|  |             // TODO Implement element-wise max in ggml
 | ||||||
|  |             struct ggml_tensor * qq = pp; | ||||||
|  |             // e1 = torch.exp(pp - qq)
 | ||||||
|  |             // TODO Implement element-wise exp in ggml
 | ||||||
|  |             struct ggml_tensor * e1 = ggml_sub(ctx, pp, qq); | ||||||
|  |             // e2 = torch.exp(ww - qq)
 | ||||||
|  |             // TODO Use exp
 | ||||||
|  |             struct ggml_tensor * e2 = ggml_sub(ctx, ww, qq); | ||||||
|  |             // a = e1 * aa + e2 * v
 | ||||||
|  |             struct ggml_tensor * a = ggml_add( | ||||||
|  |                 ctx, | ||||||
|  |                 ggml_mul(ctx, e1, aa), | ||||||
|  |                 ggml_mul(ctx, e2, v) | ||||||
|  |             ); | ||||||
|  |             // b = e1 * bb + e2
 | ||||||
|  |             struct ggml_tensor * b = ggml_add( | ||||||
|  |                 ctx, | ||||||
|  |                 ggml_mul(ctx, e1, bb), | ||||||
|  |                 e2 | ||||||
|  |             ); | ||||||
|  |             // wkv = a / b
 | ||||||
|  |             struct ggml_tensor * wkv = ggml_div(ctx, a, b); | ||||||
|  |             // ww = pp + time_decay
 | ||||||
|  |             ww = ggml_add(ctx, pp, layer.att_time_decay); | ||||||
|  |             // qq = torch.maximum(ww, k)
 | ||||||
|  |             // TODO Use max
 | ||||||
|  |             qq = ww; | ||||||
|  |             // e1 = torch.exp(ww - qq)
 | ||||||
|  |             // TODO Use exp
 | ||||||
|  |             e1 = ggml_sub(ctx, ww, qq); | ||||||
|  |             // e2 = torch.exp(k - qq)
 | ||||||
|  |             // TODO Use exp
 | ||||||
|  |             e2 = ggml_sub(ctx, k, qq); | ||||||
|  |             // state[5 * i + 2] = e1 * aa + e2 * v
 | ||||||
|  |             ggml_cpy(ctx, ggml_add( | ||||||
|  |                 ctx, | ||||||
|  |                 ggml_mul(ctx, e1, aa), | ||||||
|  |                 ggml_mul(ctx, e2, v) | ||||||
|  |             ), aa); | ||||||
|  |             // state[5 * i + 3] = e1 * bb + e2
 | ||||||
|  |             ggml_cpy(ctx, ggml_add( | ||||||
|  |                 ctx, | ||||||
|  |                 ggml_mul(ctx, e1, bb), | ||||||
|  |                 e2 | ||||||
|  |             ), bb); | ||||||
|  |             // state[5 * i + 4] = qq
 | ||||||
|  |             ggml_cpy(ctx, qq, pp); | ||||||
|  |             // ow @ (r * wkv)
 | ||||||
|  |             x = ggml_add( | ||||||
|  |                 ctx, | ||||||
|  |                 x, | ||||||
|  |                 ggml_mul_mat( | ||||||
|  |                     ctx, | ||||||
|  |                     layer.att_output, | ||||||
|  |                     ggml_mul(ctx, r, wkv) | ||||||
|  |                 ) | ||||||
|  |             ); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         // FFN/channel mixing
 |         // FFN/channel mixing
 | ||||||
|  | @ -455,8 +554,7 @@ int main(int argc, char ** argv) { | ||||||
|             // self.layer_norm(x, self.w.blocks[i].ln2)
 |             // self.layer_norm(x, self.w.blocks[i].ln2)
 | ||||||
|             struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); |             struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); | ||||||
|             // state[5 * i + 0]
 |             // state[5 * i + 0]
 | ||||||
|             int32_t offset_in_bytes = (5 * i + 0) * n_embed * 4; |             struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * 4); | ||||||
|             struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, offset_in_bytes); |  | ||||||
|             // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
 |             // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
 | ||||||
|             // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
 |             // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
 | ||||||
|             struct ggml_tensor * xk = ggml_add( |             struct ggml_tensor * xk = ggml_add( | ||||||
|  | @ -478,16 +576,20 @@ int main(int argc, char ** argv) { | ||||||
|                 ggml_mul_mat(ctx, layer.ffn_receptance, xr) |                 ggml_mul_mat(ctx, layer.ffn_receptance, xr) | ||||||
|             ); |             ); | ||||||
|             // k = torch.square(torch.relu(kw @ xk))
 |             // k = torch.square(torch.relu(kw @ xk))
 | ||||||
|             // TODO Does not work; shape mismatch
 |             struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu( | ||||||
|             //struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(
 |                 ctx, | ||||||
|             //    ctx,
 |                 ggml_mul_mat(ctx, layer.ffn_key, xk) | ||||||
|             //    ggml_mul_mat(ctx, layer.ffn_key, xk)
 |             )); | ||||||
|             //));
 |  | ||||||
|             // r * (vw @ k)
 |             // r * (vw @ k)
 | ||||||
|             // TODO Does not work; shape mismatch
 |             x = ggml_add( | ||||||
|             // x0 = ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k));
 |                 ctx, | ||||||
|             // x = x + self.channel_mixing(...)
 |                 x, | ||||||
|             x = ggml_add(ctx, x, x0); |                 ggml_mul( | ||||||
|  |                     ctx, | ||||||
|  |                     r, | ||||||
|  |                     ggml_mul_mat(ctx, layer.ffn_value, k) | ||||||
|  |                 ) | ||||||
|  |             ); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | @ -495,11 +597,11 @@ int main(int argc, char ** argv) { | ||||||
|     x = ggml_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias); |     x = ggml_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias); | ||||||
| 
 | 
 | ||||||
|     // x = (self.w.head.weight @ x).float()
 |     // x = (self.w.head.weight @ x).float()
 | ||||||
|     // TODO Is transpose copying the tensor?
 |     struct ggml_tensor * logits = ggml_mul_mat(ctx, model.head, x); | ||||||
|     struct ggml_tensor * logits = ggml_mul_mat(ctx, x, ggml_transpose(ctx, model.head)); |  | ||||||
| 
 | 
 | ||||||
|     compute_graph(ctx, logits); |     compute_graph(ctx, logits); | ||||||
| 
 | 
 | ||||||
|  |     // TODO -nan(ind) -nan(ind) ... (maybe implement exp/max first?)
 | ||||||
|     PRINT_TENSOR(logits); |     PRINT_TENSOR(logits); | ||||||
| 
 | 
 | ||||||
|     ggml_free(ctx); |     ggml_free(ctx); | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue