From fe98c94a635416a7f97aa6b9eabb56215ca1b281 Mon Sep 17 00:00:00 2001 From: saharNooby Date: Sat, 1 Apr 2023 11:28:32 +0400 Subject: [PATCH] [FILE FORMAT CHANGED] Use ggml_get_rows to get embedding --- examples/main_rwkv/main_rwkv.cpp | 42 ++++++++++--------- ...mpare_cpp_with_reference_implementation.py | 2 +- rwkv/convert_pytorch_rwkv_to_ggml.py | 11 ++--- 3 files changed, 27 insertions(+), 28 deletions(-) diff --git a/examples/main_rwkv/main_rwkv.cpp b/examples/main_rwkv/main_rwkv.cpp index 6dcc3fd..d72cd92 100644 --- a/examples/main_rwkv/main_rwkv.cpp +++ b/examples/main_rwkv/main_rwkv.cpp @@ -14,6 +14,8 @@ // --- Utilities --- +#define F32_SIZE 4 + // Checks that x is not false. If it is false, prints fancy message to stderr and aborts the execution. #define RWKV_ASSERT(x, ...) \ do { \ @@ -263,8 +265,9 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m std::string key(key_length, 0); RWKV_ASSERT(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key"); - // TODO Use ggml_type_size - size_t element_size = data_type == 0 ? 4 : 2; + size_t element_size = data_type == 0 ? + ggml_type_size(GGML_TYPE_F32) : + ggml_type_size(GGML_TYPE_F16); size_t byte_count = element_count * element_size; RWKV_ASSERT(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data"); @@ -319,8 +322,8 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m // Verify order of dimensions struct ggml_tensor * emb = model->emb; RWKV_ASSERT(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); - RWKV_ASSERT(emb->ne[0] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[0]); - RWKV_ASSERT(emb->ne[1] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[1]); + RWKV_ASSERT(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[1]); + RWKV_ASSERT(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[0]); } // --- Operators --- @@ -380,13 +383,13 @@ int main(int argc, char ** argv) { for (int i = 0; i < n_layer; i++) { // state[5 * i + 4] = -1e30 - int32_t offset_in_bytes = (5 * i + 4) * n_embed * 4; + int32_t offset_in_bytes = (5 * i + 4) * n_embed * F32_SIZE; struct ggml_tensor * state_part = ggml_view_1d(ctx, state, n_embed, offset_in_bytes); ggml_set_f32(state_part, -1e30F); } } else { RWKV_LOG("Loading state from %s", state_in_path); - int32_t state_file_size = state_element_count * 4; + int32_t state_file_size = state_element_count * F32_SIZE; FILE * state_in_file = fopen(state_in_path, "rb"); RWKV_ASSERT(state_in_file != NULL, "Failed to open file %s", state_in_path); @@ -400,10 +403,9 @@ int main(int argc, char ** argv) { // --- Evaluate model --- // x = self.w.emb.weight[token] - // TODO Replace with ggml_get_rows or similar - struct ggml_tensor * one_hot = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_vocab, 1); - ggml_set_f32_1d(one_hot, token, 1.0F); - struct ggml_tensor * x = ggml_mul_mat(ctx, model.emb, one_hot); + struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + ggml_set_i32_1d(token_index, 0, token); + struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, token_index); // x = self.layer_norm(x, self.w.blocks[0].ln0) x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias); @@ -419,7 +421,7 @@ int main(int argc, char ** argv) { // self.layer_norm(x, self.w.blocks[i].ln1) struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); // state[5 * i + 1] - struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * 4); + struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * F32_SIZE); // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) @@ -454,9 +456,9 @@ int main(int argc, char ** argv) { // aa = state[5 * i + 2] // bb = state[5 * i + 3] // pp = state[5 * i + 4] - struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * 4); - struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * 4); - struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * 4); + struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * F32_SIZE); + struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * F32_SIZE); + struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * F32_SIZE); // ww = time_first + k struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k); @@ -519,7 +521,7 @@ int main(int argc, char ** argv) { // self.layer_norm(x, self.w.blocks[i].ln2) struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); // state[5 * i + 0] - struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * 4); + struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * F32_SIZE); // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) struct ggml_tensor * xk = ggml_add( @@ -577,17 +579,17 @@ int main(int argc, char ** argv) { // Update state for (int i = 0; i < n_layer * 5; i++) { - struct ggml_tensor * state_part_src = state_parts[i]; - struct ggml_tensor * state_part_dest = ggml_view_1d(ctx, state, n_embed, i * n_embed * 4); + struct ggml_tensor * src = state_parts[i]; + struct ggml_tensor * dest = ggml_view_1d(ctx, state, n_embed, i * n_embed * F32_SIZE); for (int j = 0; j < n_embed; j++) { - ggml_set_f32_1d(state_part_dest, j, ggml_get_f32_1d(state_part_src, j)); + ggml_set_f32_1d(dest, j, ggml_get_f32_1d(src, j)); } } { RWKV_LOG("Saving state to %s", state_out_path); - int32_t state_file_size = state_element_count * 4; + int32_t state_file_size = state_element_count * F32_SIZE; FILE * state_out_file = fopen(state_out_path, "wb"); RWKV_ASSERT(state_out_file != NULL, "Failed to open file %s", state_out_path); @@ -599,7 +601,7 @@ int main(int argc, char ** argv) { { RWKV_LOG("Saving logits to %s", logits_out_path); - int32_t logits_file_size = n_vocab * 4; + int32_t logits_file_size = n_vocab * F32_SIZE; FILE * logits_out_file = fopen(logits_out_path, "wb"); RWKV_ASSERT(logits_out_file != NULL, "Failed to open file %s", logits_out_path); diff --git a/rwkv/compare_cpp_with_reference_implementation.py b/rwkv/compare_cpp_with_reference_implementation.py index 0e08e67..d8252e2 100644 --- a/rwkv/compare_cpp_with_reference_implementation.py +++ b/rwkv/compare_cpp_with_reference_implementation.py @@ -72,7 +72,7 @@ def main() -> None: print(f'Actual logits: {actual_logits}') print('Difference per token: %.8f' % (difference,)) - assert abs(difference) <= 0.00005, 'Difference is too big' + assert abs(difference) <= 0.000005, 'Difference is too big' # Check small token amount first to avoid waiting too long before seeing that model is broken compare_logits(tokens[:4]) diff --git a/rwkv/convert_pytorch_rwkv_to_ggml.py b/rwkv/convert_pytorch_rwkv_to_ggml.py index 7c1a339..2ff8ea1 100644 --- a/rwkv/convert_pytorch_rwkv_to_ggml.py +++ b/rwkv/convert_pytorch_rwkv_to_ggml.py @@ -23,6 +23,7 @@ # int32 key_length; # // 0 if float32, 1 if float16. # int32 data_type; +# // Same values and order as in PyTorch's tensor.shape # int32[dim_count] shape; # // Keys are like "emb.weight", "block.0.ln1.weight". # uint8[key_length] key_utf8; @@ -89,10 +90,6 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t if data_type == 'float16' and len(tensor.shape) > 1: tensor = tensor.half() - if k == 'emb.weight': - # Allows embedding matrix to be multiplied by one-hot vector - tensor = torch.permute(tensor, dims=[i for i in reversed(range(len(tensor.shape)))]).contiguous() - shape = tensor.shape print(f'Writing {k}, shape {shape}, type {tensor.dtype}') @@ -153,10 +150,10 @@ def test() -> None: 2, 10, 0, - 2, 3, + 3, 2, 'emb.weight'.encode('utf-8'), - 1.0, 3.0, 5.0, - 2.0, 4.0, 6.0, + 1.0, 2.0, 3.0, + 4.0, 5.0, 6.0, # blocks.0.ln1.weight 1, 19,