64 lines
1.6 KiB
C
64 lines
1.6 KiB
C
#include <rwkv.h>
|
|
|
|
#include <stdlib.h>
|
|
#include <stdio.h>
|
|
#include <string.h>
|
|
|
|
int main() {
|
|
struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2);
|
|
|
|
if (!ctx) {
|
|
enum rwkv_error_flags error = rwkv_get_last_error(NULL);
|
|
fprintf(stderr, "Unexpected error 0x%.8X\n", error);
|
|
return EXIT_FAILURE;
|
|
}
|
|
|
|
float * state = calloc(rwkv_get_state_buffer_element_count(ctx), sizeof(float));
|
|
float * logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float));
|
|
|
|
if (!state || !logits) {
|
|
fprintf(stderr, "Failed to allocate state/logits\n");
|
|
return EXIT_FAILURE;
|
|
}
|
|
|
|
// 0xd1 or 209 is space (0x20 or \u0120 in tokenizer)
|
|
const unsigned char * prompt = "hello\xd1world";
|
|
|
|
rwkv_eval(ctx, prompt[0], NULL, state, logits);
|
|
|
|
for (const unsigned char * token = prompt + 1; *token != 0; token++) {
|
|
rwkv_eval(ctx, *token, state, state, logits);
|
|
}
|
|
|
|
float * expected_logits = logits;
|
|
logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float));
|
|
|
|
if (!logits) {
|
|
fprintf(stderr, "Failed to allocate state/logits\n");
|
|
return EXIT_FAILURE;
|
|
}
|
|
|
|
struct rwkv_context * ctx2 = rwkv_clone_context(ctx, 2);
|
|
|
|
rwkv_eval(ctx, prompt[0], NULL, state, logits);
|
|
|
|
for (const unsigned char * token = prompt + 1; *token != 0; token++) {
|
|
rwkv_eval(ctx, *token, state, state, logits);
|
|
}
|
|
|
|
if (memcmp(expected_logits, logits, rwkv_get_logits_buffer_element_count(ctx) * sizeof(float))) {
|
|
fprintf(stderr, "results not identical :(\n");
|
|
return EXIT_FAILURE;
|
|
} else {
|
|
fprintf(stdout, "Results identical, success!\n");
|
|
}
|
|
|
|
rwkv_free(ctx);
|
|
rwkv_free(ctx2);
|
|
|
|
free(expected_logits);
|
|
free(logits);
|
|
free(state);
|
|
|
|
return EXIT_SUCCESS;
|
|
} |