Move graph building into its own function (#69)
step towards #50 and loading models from memory among other things
This commit is contained in:
parent
b61d94aef0
commit
3ca9c7f785
385
rwkv.cpp
385
rwkv.cpp
|
@ -244,22 +244,23 @@ struct ggml_tensor * rwkv_max(ggml_context * ctx, struct ggml_tensor * x, struct
|
|||
struct ggml_tensor * rwkv_layer_norm(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) {
|
||||
// LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias`
|
||||
// Looks like ggml_norm does the first part, we only need to apply weight & bias.
|
||||
x = ggml_norm(ctx, x);
|
||||
x = ggml_mul(ctx, x, weight);
|
||||
x = ggml_add_inplace(ctx, x, bias);
|
||||
return x;
|
||||
return ggml_add_inplace(ctx, ggml_mul(ctx, ggml_norm(ctx, x), weight), bias);
|
||||
}
|
||||
|
||||
// --- Implementation ---
|
||||
|
||||
struct rwkv_graph {
|
||||
struct ggml_tensor * state;
|
||||
std::unique_ptr<struct ggml_tensor * []> state_parts;
|
||||
struct ggml_tensor * token_index;
|
||||
struct ggml_tensor * logits;
|
||||
std::unique_ptr<struct ggml_cgraph> cgraph;
|
||||
};
|
||||
|
||||
struct rwkv_context {
|
||||
std::unique_ptr<struct rwkv_model> model;
|
||||
struct ggml_tensor * token_index;
|
||||
struct ggml_tensor * state;
|
||||
struct ggml_tensor ** state_parts;
|
||||
struct ggml_tensor * logits;
|
||||
struct ggml_context * ctx;
|
||||
std::unique_ptr<struct ggml_cgraph> graph;
|
||||
struct rwkv_graph graph;
|
||||
enum rwkv_error_flags last_error;
|
||||
bool print_errors;
|
||||
};
|
||||
|
@ -280,6 +281,164 @@ enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) {
|
|||
return value;
|
||||
}
|
||||
|
||||
bool rwkv_build_graph(struct ggml_context * ctx, struct rwkv_model * model, const uint32_t n_threads, struct rwkv_graph * out) {
|
||||
std::unique_ptr<struct ggml_cgraph> cgraph(new(std::nothrow) struct ggml_cgraph());
|
||||
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, cgraph.get(), "Failed to allocate graph");
|
||||
cgraph->n_threads = n_threads;
|
||||
|
||||
size_t n_embed = model->n_embed, n_layer = model->n_layer;
|
||||
struct ggml_tensor * state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_layer * 5 * n_embed);
|
||||
|
||||
// We collect parts of new state here. Each part is (n_embed) vector.
|
||||
std::unique_ptr<struct ggml_tensor * []> state_parts(new(std::nothrow) ggml_tensor * [n_layer * 5]);
|
||||
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, state_parts.get(), "Failed to allocate state parts");
|
||||
|
||||
// x = self.w.emb.weight[token]
|
||||
struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
|
||||
struct ggml_tensor * x = ggml_get_rows(ctx, model->emb, token_index);
|
||||
|
||||
// x = self.layer_norm(x, self.w.blocks[0].ln0)
|
||||
x = rwkv_layer_norm(ctx, x, model->ln0_weight, model->ln0_bias);
|
||||
|
||||
for (size_t i = 0; i < n_layer; i++) {
|
||||
struct rwkv_layer layer = model->layers[i];
|
||||
size_t part_index = i * 5;
|
||||
size_t state_part_size = n_embed * sizeof(float);
|
||||
|
||||
// RWKV/time mixing
|
||||
{
|
||||
// self.layer_norm(x, self.w.blocks[i].ln1)
|
||||
struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);
|
||||
|
||||
// x0 = state[5 * i + 1]
|
||||
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (part_index + 1) * state_part_size);
|
||||
// aa = state[5 * i + 2]
|
||||
struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (part_index + 2) * state_part_size);
|
||||
// bb = state[5 * i + 3]
|
||||
struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (part_index + 3) * state_part_size);
|
||||
// pp = state[5 * i + 4]
|
||||
struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (part_index + 4) * state_part_size);
|
||||
|
||||
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
|
||||
struct ggml_tensor * xk = ggml_add_inplace(ctx,
|
||||
ggml_mul(ctx, x0, layer.att_time_mix_k),
|
||||
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k))
|
||||
);
|
||||
|
||||
// xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
|
||||
struct ggml_tensor * xv = ggml_add_inplace(ctx,
|
||||
ggml_mul(ctx, x0, layer.att_time_mix_v),
|
||||
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v))
|
||||
);
|
||||
|
||||
// xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
|
||||
struct ggml_tensor * xr = ggml_add_inplace(ctx,
|
||||
ggml_mul(ctx, x0, layer.att_time_mix_r),
|
||||
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r))
|
||||
);
|
||||
|
||||
// r = torch.sigmoid(rw @ xr)
|
||||
struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr));
|
||||
// k = kw @ xk
|
||||
struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk);
|
||||
// v = vw @ xv
|
||||
struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv);
|
||||
|
||||
// ww = time_first + k
|
||||
struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k);
|
||||
// qq = torch.maximum(pp, ww)
|
||||
struct ggml_tensor * qq = rwkv_max(ctx, pp, ww);
|
||||
// e1 = torch.exp(pp - qq)
|
||||
struct ggml_tensor * e1 = rwkv_exp(ctx, ggml_sub(ctx, pp, qq));
|
||||
// e2 = torch.exp(ww - qq)
|
||||
struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq));
|
||||
|
||||
// a = e1 * aa + e2 * v
|
||||
struct ggml_tensor * a = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v));
|
||||
// b = e1 * bb + e2
|
||||
struct ggml_tensor * b = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2);
|
||||
|
||||
// ww = pp + time_decay
|
||||
ww = ggml_add_inplace(ctx, pp, layer.att_time_decay);
|
||||
// qq = torch.maximum(ww, k)
|
||||
qq = rwkv_max(ctx, ww, k);
|
||||
// e1 = torch.exp(ww - qq)
|
||||
e1 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq));
|
||||
// e2 = torch.exp(k - qq)
|
||||
e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq));
|
||||
|
||||
// state[5 * i + 1] = x0
|
||||
// state[5 * i + 2] = e1 * aa + e2 * v
|
||||
// state[5 * i + 3] = e1 * bb + e2
|
||||
// state[5 * i + 4] = qq
|
||||
|
||||
state_parts[part_index + 1] = x0;
|
||||
state_parts[part_index + 2] = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v));
|
||||
state_parts[part_index + 3] = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2);
|
||||
state_parts[part_index + 4] = qq;
|
||||
|
||||
// wkv = a / b
|
||||
struct ggml_tensor * wkv = ggml_div(ctx, a, b);
|
||||
|
||||
// ow @ (r * wkv)
|
||||
x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, wkv)));
|
||||
}
|
||||
|
||||
// FFN/channel mixing
|
||||
{
|
||||
// self.layer_norm(x, self.w.blocks[i].ln2)
|
||||
struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias);
|
||||
|
||||
// x_prev = state[5 * i + 0]
|
||||
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, part_index * state_part_size);
|
||||
|
||||
// xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
|
||||
struct ggml_tensor * xk = ggml_add_inplace(
|
||||
ctx,
|
||||
ggml_mul(ctx, x0, layer.ffn_time_mix_k),
|
||||
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k))
|
||||
);
|
||||
|
||||
// xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
|
||||
struct ggml_tensor * xr = ggml_add_inplace(
|
||||
ctx,
|
||||
ggml_mul(ctx, x0, layer.ffn_time_mix_r),
|
||||
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r))
|
||||
);
|
||||
|
||||
// state[5 * i + 0] = x
|
||||
state_parts[part_index] = x0;
|
||||
|
||||
// r = torch.sigmoid(rw @ xr)
|
||||
struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr));
|
||||
|
||||
// k = torch.square(torch.relu(kw @ xk))
|
||||
struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk)));
|
||||
|
||||
// r * (vw @ k)
|
||||
x = ggml_add_inplace(ctx, x, ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)));
|
||||
}
|
||||
}
|
||||
|
||||
// x = self.layer_norm(x, self.w.ln_out)
|
||||
x = rwkv_layer_norm(ctx, x, model->ln_out_weight, model->ln_out_bias);
|
||||
|
||||
// x = (self.w.head.weight @ x).float()
|
||||
struct ggml_tensor * logits = ggml_mul_mat(ctx, model->head, x);
|
||||
|
||||
ggml_build_forward_expand(cgraph.get(), logits);
|
||||
|
||||
for (uint32_t i = 0; i < n_layer * 5; i++)
|
||||
ggml_build_forward_expand(cgraph.get(), state_parts[i]);
|
||||
|
||||
out->state = state;
|
||||
out->state_parts = std::move(state_parts);
|
||||
out->token_index = token_index;
|
||||
out->logits = logits;
|
||||
out->cgraph = std::move(cgraph);
|
||||
return true;
|
||||
}
|
||||
|
||||
struct rwkv_file_guard {
|
||||
FILE * file;
|
||||
~rwkv_file_guard() { if (file) fclose(file); }
|
||||
|
@ -418,192 +577,16 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
|
|||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]);
|
||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]);
|
||||
|
||||
uint32_t n_embed = model->n_embed;
|
||||
uint32_t n_layer = model->n_layer;
|
||||
size_t n_embed = model->n_embed;
|
||||
size_t n_layer = model->n_layer;
|
||||
|
||||
// Build graph
|
||||
struct ggml_tensor * state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_layer * 5 * n_embed);
|
||||
|
||||
// x = self.w.emb.weight[token]
|
||||
struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
|
||||
struct ggml_tensor * x = ggml_get_rows(ctx, model->emb, token_index);
|
||||
|
||||
// x = self.layer_norm(x, self.w.blocks[0].ln0)
|
||||
x = rwkv_layer_norm(ctx, x, model->ln0_weight, model->ln0_bias);
|
||||
|
||||
// We collect parts of new state here. Each part is (n_embed) vector.
|
||||
struct ggml_tensor ** state_parts = new ggml_tensor * [n_layer * 5];
|
||||
|
||||
for (uint32_t i = 0; i < n_layer; i++) {
|
||||
auto layer = model->layers[i];
|
||||
|
||||
// RWKV/time mixing
|
||||
{
|
||||
// self.layer_norm(x, self.w.blocks[i].ln1)
|
||||
struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);
|
||||
// state[5 * i + 1]
|
||||
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * sizeof(float));
|
||||
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
|
||||
// xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
|
||||
// xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
|
||||
struct ggml_tensor * xk = ggml_add_inplace(
|
||||
ctx,
|
||||
ggml_mul(ctx, x0, layer.att_time_mix_k),
|
||||
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k))
|
||||
);
|
||||
struct ggml_tensor * xv = ggml_add_inplace(
|
||||
ctx,
|
||||
ggml_mul(ctx, x0, layer.att_time_mix_v),
|
||||
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v))
|
||||
);
|
||||
struct ggml_tensor * xr = ggml_add_inplace(
|
||||
ctx,
|
||||
ggml_mul(ctx, x0, layer.att_time_mix_r),
|
||||
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r))
|
||||
);
|
||||
// state[5 * i + 1] = x
|
||||
state_parts[5 * i + 1] = x0;
|
||||
|
||||
// r = torch.sigmoid(rw @ xr)
|
||||
struct ggml_tensor * r = rwkv_sigmoid(
|
||||
ctx,
|
||||
ggml_mul_mat(ctx, layer.att_receptance, xr)
|
||||
);
|
||||
// k = kw @ xk
|
||||
struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk);
|
||||
// v = vw @ xv
|
||||
struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv);
|
||||
|
||||
// aa = state[5 * i + 2]
|
||||
// bb = state[5 * i + 3]
|
||||
// pp = state[5 * i + 4]
|
||||
struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * sizeof(float));
|
||||
struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * sizeof(float));
|
||||
struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * sizeof(float));
|
||||
|
||||
// ww = time_first + k
|
||||
struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k);
|
||||
// qq = torch.maximum(pp, ww)
|
||||
struct ggml_tensor * qq = rwkv_max(ctx, pp, ww);
|
||||
// e1 = torch.exp(pp - qq)
|
||||
struct ggml_tensor * e1 = rwkv_exp(ctx, ggml_sub(ctx, pp, qq));
|
||||
// e2 = torch.exp(ww - qq)
|
||||
struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq));
|
||||
// a = e1 * aa + e2 * v
|
||||
struct ggml_tensor * a = ggml_add_inplace(
|
||||
ctx,
|
||||
ggml_mul(ctx, e1, aa),
|
||||
ggml_mul(ctx, e2, v)
|
||||
);
|
||||
// b = e1 * bb + e2
|
||||
struct ggml_tensor * b = ggml_add_inplace(
|
||||
ctx,
|
||||
ggml_mul(ctx, e1, bb),
|
||||
e2
|
||||
);
|
||||
// wkv = a / b
|
||||
struct ggml_tensor * wkv = ggml_div(ctx, a, b);
|
||||
// ww = pp + time_decay
|
||||
ww = ggml_add(ctx, pp, layer.att_time_decay);
|
||||
// qq = torch.maximum(ww, k)
|
||||
qq = rwkv_max(ctx, ww, k);
|
||||
// e1 = torch.exp(ww - qq)
|
||||
e1 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq));
|
||||
// e2 = torch.exp(k - qq)
|
||||
e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq));
|
||||
// state[5 * i + 2] = e1 * aa + e2 * v
|
||||
state_parts[5 * i + 2] = ggml_add_inplace(
|
||||
ctx,
|
||||
ggml_mul(ctx, e1, aa),
|
||||
ggml_mul(ctx, e2, v)
|
||||
);
|
||||
// state[5 * i + 3] = e1 * bb + e2
|
||||
state_parts[5 * i + 3] = ggml_add_inplace(
|
||||
ctx,
|
||||
ggml_mul(ctx, e1, bb),
|
||||
e2
|
||||
);
|
||||
// state[5 * i + 4] = qq
|
||||
state_parts[5 * i + 4] = qq;
|
||||
// ow @ (r * wkv)
|
||||
x = ggml_add_inplace(
|
||||
ctx,
|
||||
x,
|
||||
ggml_mul_mat(
|
||||
ctx,
|
||||
layer.att_output,
|
||||
ggml_mul(ctx, r, wkv)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// FFN/channel mixing
|
||||
{
|
||||
// self.layer_norm(x, self.w.blocks[i].ln2)
|
||||
struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias);
|
||||
// state[5 * i + 0]
|
||||
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * sizeof(float));
|
||||
// xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
|
||||
// xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
|
||||
struct ggml_tensor * xk = ggml_add_inplace(
|
||||
ctx,
|
||||
ggml_mul(ctx, x0, layer.ffn_time_mix_k),
|
||||
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k))
|
||||
);
|
||||
struct ggml_tensor * xr = ggml_add_inplace(
|
||||
ctx,
|
||||
ggml_mul(ctx, x0, layer.ffn_time_mix_r),
|
||||
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r))
|
||||
);
|
||||
// state[5 * i + 0] = x
|
||||
state_parts[5 * i + 0] = x0;
|
||||
|
||||
// r = torch.sigmoid(rw @ xr)
|
||||
struct ggml_tensor * r = rwkv_sigmoid(
|
||||
ctx,
|
||||
ggml_mul_mat(ctx, layer.ffn_receptance, xr)
|
||||
);
|
||||
// k = torch.square(torch.relu(kw @ xk))
|
||||
struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(
|
||||
ctx,
|
||||
ggml_mul_mat(ctx, layer.ffn_key, xk)
|
||||
));
|
||||
// r * (vw @ k)
|
||||
x = ggml_add_inplace(
|
||||
ctx,
|
||||
x,
|
||||
ggml_mul(
|
||||
ctx,
|
||||
r,
|
||||
ggml_mul_mat(ctx, layer.ffn_value, k)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// x = self.layer_norm(x, self.w.ln_out)
|
||||
x = rwkv_layer_norm(ctx, x, model->ln_out_weight, model->ln_out_bias);
|
||||
|
||||
// x = (self.w.head.weight @ x).float()
|
||||
struct ggml_tensor * logits = ggml_mul_mat(ctx, model->head, x);
|
||||
|
||||
std::unique_ptr<struct ggml_cgraph> graph(new(std::nothrow) struct ggml_cgraph());
|
||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_GRAPH | RWKV_ERROR_ALLOC, graph.get(), "Failed to allocate graph");
|
||||
|
||||
ggml_build_forward_expand(graph.get(), logits);
|
||||
|
||||
for (uint32_t i = 0; i < n_layer * 5; i++)
|
||||
ggml_build_forward_expand(graph.get(), state_parts[i]);
|
||||
|
||||
graph->n_threads = n_threads;
|
||||
struct rwkv_graph graph;
|
||||
RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_graph(ctx, model.get(), n_threads, &graph));
|
||||
|
||||
std::unique_ptr<struct rwkv_context> rwkv_ctx(new(std::nothrow) struct rwkv_context());
|
||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx.get(), "Failed to allocate context");
|
||||
rwkv_ctx->model = std::move(model);
|
||||
rwkv_ctx->token_index = token_index;
|
||||
rwkv_ctx->state = state;
|
||||
rwkv_ctx->state_parts = state_parts;
|
||||
rwkv_ctx->logits = logits;
|
||||
rwkv_ctx->ctx = ctx;
|
||||
rwkv_ctx->graph = std::move(graph);
|
||||
rwkv_ctx->last_error = RWKV_ERROR_NONE;
|
||||
|
@ -627,40 +610,40 @@ bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const floa
|
|||
RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, logits_out != NULL, "logits_out is NULL");
|
||||
RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < ctx->model->n_vocab, "Token is out of range 0..%d", ctx->model->n_vocab - 1);
|
||||
|
||||
uint32_t n_layer = ctx->model->n_layer;
|
||||
uint32_t n_embed = ctx->model->n_embed;
|
||||
const struct rwkv_graph * graph = &ctx->graph;
|
||||
size_t n_layer = ctx->model->n_layer;
|
||||
size_t n_embed = ctx->model->n_embed;
|
||||
|
||||
ggml_set_i32_1d(ctx->token_index, 0, token);
|
||||
ggml_set_i32_1d(graph->token_index, 0, token);
|
||||
|
||||
if (state_in == NULL) {
|
||||
ggml_set_f32(ctx->state, 0.0F);
|
||||
ggml_set_f32(graph->state, 0.0F);
|
||||
|
||||
for (uint64_t i = 0; i < n_layer; i++) {
|
||||
for (size_t i = 0; i < n_layer; i++) {
|
||||
// state[5 * i + 4] = -1e30
|
||||
ggml_set_f32(
|
||||
ggml_view_1d(ctx->ctx, ctx->state, n_embed, (5 * i + 4) * n_embed * sizeof(float)),
|
||||
ggml_view_1d(ctx->ctx, graph->state, n_embed, (5 * i + 4) * n_embed * sizeof(float)),
|
||||
-1e30F
|
||||
);
|
||||
}
|
||||
} else {
|
||||
memcpy(ctx->state->data, state_in, ctx->state->ne[0] * sizeof(float));
|
||||
memcpy(graph->state->data, state_in, graph->state->ne[0] * sizeof(float));
|
||||
}
|
||||
|
||||
ggml_graph_compute(ctx->ctx, ctx->graph.get());
|
||||
ggml_graph_compute(ctx->ctx, graph->cgraph.get());
|
||||
|
||||
for (uint32_t i = 0; i < n_layer * 5; i++) {
|
||||
struct ggml_tensor * part = ctx->state_parts[i];
|
||||
for (size_t i = 0; i < n_layer * 5; i++) {
|
||||
struct ggml_tensor * part = graph->state_parts[i];
|
||||
memcpy(state_out + i * n_embed, part->data, part->ne[0] * sizeof(float));
|
||||
}
|
||||
|
||||
memcpy(logits_out, ctx->logits->data, ctx->logits->ne[0] * sizeof(float));
|
||||
memcpy(logits_out, graph->logits->data, graph->logits->ne[0] * sizeof(float));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void rwkv_free(struct rwkv_context * ctx) {
|
||||
std::unique_ptr<struct rwkv_context> rwkv_ctx(ctx);
|
||||
delete[] ctx->state_parts;
|
||||
ggml_free(ctx->ctx);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue