#include #include #include #include int main() { struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); if (!ctx) { enum rwkv_error_flags error = rwkv_get_last_error(NULL); fprintf(stderr, "Unexpected error 0x%.8X\n", error); return EXIT_FAILURE; } float * state = calloc(rwkv_get_state_buffer_element_count(ctx), sizeof(float)); float * logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float)); if (!state || !logits) { fprintf(stderr, "Failed to allocate state/logits\n"); return EXIT_FAILURE; } // 0xd1 or 209 is space (0x20 or \u0120 in tokenizer) const unsigned char * prompt = "hello\xd1world"; rwkv_eval(ctx, prompt[0], NULL, state, logits); for (const unsigned char * token = prompt + 1; *token != 0; token++) { rwkv_eval(ctx, *token, state, state, logits); } float * expected_logits = logits; logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float)); if (!logits) { fprintf(stderr, "Failed to allocate state/logits\n"); return EXIT_FAILURE; } struct rwkv_context * ctx2 = rwkv_clone_context(ctx, 2); rwkv_eval(ctx, prompt[0], NULL, state, logits); for (const unsigned char * token = prompt + 1; *token != 0; token++) { rwkv_eval(ctx, *token, state, state, logits); } if (memcmp(expected_logits, logits, rwkv_get_logits_buffer_element_count(ctx) * sizeof(float))) { fprintf(stderr, "results not identical :(\n"); return EXIT_FAILURE; } else { fprintf(stdout, "Results identical, success!\n"); } rwkv_free(ctx); rwkv_free(ctx2); free(expected_logits); free(logits); free(state); return EXIT_SUCCESS; }