From 873cb954d0a0010c184d4ba8dc3e48ff6ecd688b Mon Sep 17 00:00:00 2001 From: saharNooby Date: Thu, 30 Mar 2023 20:01:26 +0400 Subject: [PATCH] Make ln0 work correctly --- examples/main_rwkv/main_rwkv.cpp | 32 ++++++++++------------------ rwkv/convert_pytorch_rwkv_to_ggml.py | 5 +++-- 2 files changed, 14 insertions(+), 23 deletions(-) diff --git a/examples/main_rwkv/main_rwkv.cpp b/examples/main_rwkv/main_rwkv.cpp index 797420a..b428cc0 100644 --- a/examples/main_rwkv/main_rwkv.cpp +++ b/examples/main_rwkv/main_rwkv.cpp @@ -116,6 +116,9 @@ void print_tensor(struct ggml_tensor * tensor, char * name) { // Prints tensor name, dimensionality, shape and part of its contents. #define PRINT_TENSOR(x) print_tensor(x, #x) +// Same as above, but additionally computes tensor graph before printing the tensor. +#define COMPUTE_AND_PRINT_TENSOR(ctx, x) do { compute_graph(ctx, x); print_tensor(x, #x); } while (0) + // Computes value of the tensor and all tensors it depends on. void compute_graph(struct ggml_context * ctx, struct ggml_tensor * tensor) { struct ggml_cgraph graph = ggml_build_forward(tensor); @@ -238,7 +241,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m break; } - RWKV_ASSERT(dim_count == 1 || dim_count == 2 || dim_count == 3, "Unsupported dimension count %d", dim_count); + RWKV_ASSERT(dim_count == 1 || dim_count == 2, "Unsupported dimension count %d", dim_count); int32_t key_length; read_int32(file, &key_length); @@ -267,13 +270,6 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m element_count = x * y; // Not a typo, dimensions should be reversed here tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, y, x); - } else if (dim_count == 3) { - read_int32(file, &x); - read_int32(file, &y); - read_int32(file, &z); - element_count = x * y * z; - // Not a typo, dimensions should be reversed here - tensor = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, z, y, x); } else { abort(); } @@ -432,23 +428,17 @@ int main(int argc, char ** argv) { ggml_set_f32(ones, 1.0F); // x = self.w.emb.weight[token] - struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); - ggml_set_i32_1d(token_index, 0, token); - // TODO Is transpose copying the tensor? - struct ggml_tensor * x = ggml_get_rows(ctx, ggml_transpose(ctx, model.emb), token_index); - - compute_graph(ctx, x); - // For token 123, should be [-0.1836, 0.4434, 0.3848 ... -0.4102, -0.3164, -0.1826] - // TODO NOT CORRECT - PRINT_TENSOR(x); + // TODO Replace with ggml_get_rows or similar + struct ggml_tensor * one_hot = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_vocab, 1); + ggml_set_f32_1d(one_hot, token, 1.0F); + struct ggml_tensor * x = ggml_mul_mat(ctx, model.emb, one_hot); // x = self.layer_norm(x, self.w.blocks[0].ln0) x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias); - compute_graph(ctx, x); - // For token 123, should be [-0.4194, 1.1698, 0.7798 ... -1.1838, -0.8716, -0.2765] - // TODO NOT CORRECT - PRINT_TENSOR(x); + // For token 123 after ln0, should be [-0.4194, 1.1698, 0.7798 ... -1.1838, -0.8716, -0.2765] + // Prints [[-0.419416 1.169782 0.779827 ... -1.183806 -0.871573 -0.276483]] + COMPUTE_AND_PRINT_TENSOR(ctx, x); for (int i = 0; i < n_layer; i++) { auto layer = model.layers[i]; diff --git a/rwkv/convert_pytorch_rwkv_to_ggml.py b/rwkv/convert_pytorch_rwkv_to_ggml.py index 68c740f..55f70b9 100644 --- a/rwkv/convert_pytorch_rwkv_to_ggml.py +++ b/rwkv/convert_pytorch_rwkv_to_ggml.py @@ -89,8 +89,9 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t if data_type == 'float16' and len(tensor.shape) > 1: tensor = tensor.half() - # ggml stores tensor values in other way than PyTorch, need to flip dimension order here - tensor = torch.permute(tensor, dims=[i for i in reversed(range(len(tensor.shape)))]).contiguous() + if k == 'emb.weight': + # Allows embedding matrix to be multiplied by one-hot vector + tensor = torch.permute(tensor, dims=[i for i in reversed(range(len(tensor.shape)))]).contiguous() shape = tensor.shape