From 6fe9486cee59a0cee20511c022a85fb091db507f Mon Sep 17 00:00:00 2001 From: saharNooby Date: Sat, 1 Apr 2023 10:06:39 +0400 Subject: [PATCH] Finally, FP32 inference --- examples/main_rwkv/main_rwkv.cpp | 51 ++++++++++++------- ...mpare_cpp_with_reference_implementation.py | 12 +++-- 2 files changed, 42 insertions(+), 21 deletions(-) diff --git a/examples/main_rwkv/main_rwkv.cpp b/examples/main_rwkv/main_rwkv.cpp index be88c65..6dcc3fd 100644 --- a/examples/main_rwkv/main_rwkv.cpp +++ b/examples/main_rwkv/main_rwkv.cpp @@ -107,8 +107,7 @@ void print_tensor(struct ggml_tensor * tensor, char * name) { void compute_graph(struct ggml_context * ctx, struct ggml_tensor * tensor) { struct ggml_cgraph graph = ggml_build_forward(tensor); - // TODO Move to script arguments - graph.n_threads = std::max(1, (int32_t) std::thread::hardware_concurrency() / 2); + graph.n_threads = 1; ggml_graph_compute(ctx, &graph); } @@ -252,7 +251,10 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m read_int32(file, &x); read_int32(file, &y); element_count = x * y; - // Not a typo, dimensions should be reversed here + // Dimension order is reversed here: + // * PyTorch shape is (x rows, y columns) + // * ggml shape is (y elements in a row, x elements in a column) + // Both shapes represent the same tensor. tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, y, x); } else { abort(); @@ -376,7 +378,6 @@ int main(int argc, char ** argv) { RWKV_LOG("Creating new state"); ggml_set_f32(state, 0.0F); - // TODO Verify correctness for (int i = 0; i < n_layer; i++) { // state[5 * i + 4] = -1e30 int32_t offset_in_bytes = (5 * i + 4) * n_embed * 4; @@ -407,6 +408,9 @@ int main(int argc, char ** argv) { // x = self.layer_norm(x, self.w.blocks[0].ln0) x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias); + // We collect parts of new state here. Each part is (n_embed) vector. + struct ggml_tensor ** state_parts = new ggml_tensor * [5 * n_layer]; + for (int i = 0; i < n_layer; i++) { auto layer = model.layers[i]; @@ -435,7 +439,7 @@ int main(int argc, char ** argv) { ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_r)) ); // state[5 * i + 1] = x - ggml_cpy(ctx, x0, x_prev); + state_parts[5 * i + 1] = x0; // r = torch.sigmoid(rw @ xr) struct ggml_tensor * r = ggml_sigmoid( @@ -485,22 +489,19 @@ int main(int argc, char ** argv) { // e2 = torch.exp(k - qq) e2 = ggml_exp(ctx, ggml_sub(ctx, k, qq)); // state[5 * i + 2] = e1 * aa + e2 * v - // TODO Must save result - ggml_cpy(ctx, ggml_add( + state_parts[5 * i + 2] = ggml_add( ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v) - ), aa); + ); // state[5 * i + 3] = e1 * bb + e2 - // TODO Must save result - ggml_cpy(ctx, ggml_add( + state_parts[5 * i + 3] = ggml_add( ctx, ggml_mul(ctx, e1, bb), e2 - ), bb); + ); // state[5 * i + 4] = qq - // TODO Must save result - ggml_cpy(ctx, qq, pp); + state_parts[5 * i + 4] = qq; // ow @ (r * wkv) x = ggml_add( ctx, @@ -532,8 +533,7 @@ int main(int argc, char ** argv) { ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.ffn_time_mix_r)) ); // state[5 * i + 0] = x - // TODO Must save result - ggml_cpy(ctx, x0, x_prev); + state_parts[5 * i + 0] = x0; // r = torch.sigmoid(rw @ xr) struct ggml_tensor * r = ggml_sigmoid( @@ -564,9 +564,26 @@ int main(int argc, char ** argv) { // x = (self.w.head.weight @ x).float() struct ggml_tensor * logits = ggml_mul_mat(ctx, model.head, x); - compute_graph(ctx, logits); + struct ggml_cgraph graph = ggml_build_forward(logits); - PRINT_TENSOR(logits); + for (int i = 0; i < n_layer * 5; i++) { + ggml_build_forward_expand(&graph, state_parts[i]); + } + + // TODO Move to script arguments + graph.n_threads = std::max(1, (int32_t) std::thread::hardware_concurrency() / 2); + + ggml_graph_compute(ctx, &graph); + + // Update state + for (int i = 0; i < n_layer * 5; i++) { + struct ggml_tensor * state_part_src = state_parts[i]; + struct ggml_tensor * state_part_dest = ggml_view_1d(ctx, state, n_embed, i * n_embed * 4); + + for (int j = 0; j < n_embed; j++) { + ggml_set_f32_1d(state_part_dest, j, ggml_get_f32_1d(state_part_src, j)); + } + } { RWKV_LOG("Saving state to %s", state_out_path); diff --git a/rwkv/compare_cpp_with_reference_implementation.py b/rwkv/compare_cpp_with_reference_implementation.py index 7edc07a..8c7f080 100644 --- a/rwkv/compare_cpp_with_reference_implementation.py +++ b/rwkv/compare_cpp_with_reference_implementation.py @@ -18,8 +18,9 @@ def parse_args(): def main() -> None: args = parse_args() + token_count: int = 64 # It's not important what exactly these tokens are; just that output of both model matches. - tokens: List[int] = [(i + 1) for i in range(32)] + tokens: List[int] = [(i + 1) for i in range(token_count)] state_path: str = './state.bin' logits_path: str = './logits.bin' @@ -27,9 +28,11 @@ def main() -> None: ref_logits, ref_state = None, None - for token in tokens: + for i in range(token_count): + token: int = tokens[i] + print() - print(f'--- Token {token} ---') + print(f'--- {i + 1}/{token_count} ---') subprocess.run( [ @@ -55,8 +58,9 @@ def main() -> None: print(f'Actual logits: {actual_logits}') print('Difference per token: %.8f' % (difference,)) - assert abs(difference) <= 0.000001, 'Difference is too big' + assert abs(difference) <= 0.00005, 'Difference is too big' + print() print('Test passes') if __name__ == "__main__":