Make ln0 work correctly

This commit is contained in:
saharNooby 2023-03-30 20:01:26 +04:00
parent 2f51451561
commit 873cb954d0
2 changed files with 14 additions and 23 deletions

View File

@ -116,6 +116,9 @@ void print_tensor(struct ggml_tensor * tensor, char * name) {
// Prints tensor name, dimensionality, shape and part of its contents.
#define PRINT_TENSOR(x) print_tensor(x, #x)
// Same as above, but additionally computes tensor graph before printing the tensor.
#define COMPUTE_AND_PRINT_TENSOR(ctx, x) do { compute_graph(ctx, x); print_tensor(x, #x); } while (0)
// Computes value of the tensor and all tensors it depends on.
void compute_graph(struct ggml_context * ctx, struct ggml_tensor * tensor) {
struct ggml_cgraph graph = ggml_build_forward(tensor);
@ -238,7 +241,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
break;
}
RWKV_ASSERT(dim_count == 1 || dim_count == 2 || dim_count == 3, "Unsupported dimension count %d", dim_count);
RWKV_ASSERT(dim_count == 1 || dim_count == 2, "Unsupported dimension count %d", dim_count);
int32_t key_length;
read_int32(file, &key_length);
@ -267,13 +270,6 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
element_count = x * y;
// Not a typo, dimensions should be reversed here
tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, y, x);
} else if (dim_count == 3) {
read_int32(file, &x);
read_int32(file, &y);
read_int32(file, &z);
element_count = x * y * z;
// Not a typo, dimensions should be reversed here
tensor = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, z, y, x);
} else {
abort();
}
@ -432,23 +428,17 @@ int main(int argc, char ** argv) {
ggml_set_f32(ones, 1.0F);
// x = self.w.emb.weight[token]
struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
ggml_set_i32_1d(token_index, 0, token);
// TODO Is transpose copying the tensor?
struct ggml_tensor * x = ggml_get_rows(ctx, ggml_transpose(ctx, model.emb), token_index);
compute_graph(ctx, x);
// For token 123, should be [-0.1836, 0.4434, 0.3848 ... -0.4102, -0.3164, -0.1826]
// TODO NOT CORRECT
PRINT_TENSOR(x);
// TODO Replace with ggml_get_rows or similar
struct ggml_tensor * one_hot = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_vocab, 1);
ggml_set_f32_1d(one_hot, token, 1.0F);
struct ggml_tensor * x = ggml_mul_mat(ctx, model.emb, one_hot);
// x = self.layer_norm(x, self.w.blocks[0].ln0)
x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias);
compute_graph(ctx, x);
// For token 123, should be [-0.4194, 1.1698, 0.7798 ... -1.1838, -0.8716, -0.2765]
// TODO NOT CORRECT
PRINT_TENSOR(x);
// For token 123 after ln0, should be [-0.4194, 1.1698, 0.7798 ... -1.1838, -0.8716, -0.2765]
// Prints [[-0.419416 1.169782 0.779827 ... -1.183806 -0.871573 -0.276483]]
COMPUTE_AND_PRINT_TENSOR(ctx, x);
for (int i = 0; i < n_layer; i++) {
auto layer = model.layers[i];

View File

@ -89,7 +89,8 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
if data_type == 'float16' and len(tensor.shape) > 1:
tensor = tensor.half()
# ggml stores tensor values in other way than PyTorch, need to flip dimension order here
if k == 'emb.weight':
# Allows embedding matrix to be multiplied by one-hot vector
tensor = torch.permute(tensor, dims=[i for i in reversed(range(len(tensor.shape)))]).contiguous()
shape = tensor.shape