Move graph building into its own function (#69)

step towards #50 and loading models from memory among other things
This commit is contained in:
LoganDark 2023-05-26 05:30:07 -07:00 committed by GitHub
parent b61d94aef0
commit 3ca9c7f785
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 187 additions and 204 deletions

385
rwkv.cpp
View File

@ -244,22 +244,23 @@ struct ggml_tensor * rwkv_max(ggml_context * ctx, struct ggml_tensor * x, struct
struct ggml_tensor * rwkv_layer_norm(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) {
// LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias`
// Looks like ggml_norm does the first part, we only need to apply weight & bias.
x = ggml_norm(ctx, x);
x = ggml_mul(ctx, x, weight);
x = ggml_add_inplace(ctx, x, bias);
return x;
return ggml_add_inplace(ctx, ggml_mul(ctx, ggml_norm(ctx, x), weight), bias);
}
// --- Implementation ---
struct rwkv_graph {
struct ggml_tensor * state;
std::unique_ptr<struct ggml_tensor * []> state_parts;
struct ggml_tensor * token_index;
struct ggml_tensor * logits;
std::unique_ptr<struct ggml_cgraph> cgraph;
};
struct rwkv_context {
std::unique_ptr<struct rwkv_model> model;
struct ggml_tensor * token_index;
struct ggml_tensor * state;
struct ggml_tensor ** state_parts;
struct ggml_tensor * logits;
struct ggml_context * ctx;
std::unique_ptr<struct ggml_cgraph> graph;
struct rwkv_graph graph;
enum rwkv_error_flags last_error;
bool print_errors;
};
@ -280,6 +281,164 @@ enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) {
return value;
}
bool rwkv_build_graph(struct ggml_context * ctx, struct rwkv_model * model, const uint32_t n_threads, struct rwkv_graph * out) {
std::unique_ptr<struct ggml_cgraph> cgraph(new(std::nothrow) struct ggml_cgraph());
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, cgraph.get(), "Failed to allocate graph");
cgraph->n_threads = n_threads;
size_t n_embed = model->n_embed, n_layer = model->n_layer;
struct ggml_tensor * state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_layer * 5 * n_embed);
// We collect parts of new state here. Each part is (n_embed) vector.
std::unique_ptr<struct ggml_tensor * []> state_parts(new(std::nothrow) ggml_tensor * [n_layer * 5]);
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, state_parts.get(), "Failed to allocate state parts");
// x = self.w.emb.weight[token]
struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
struct ggml_tensor * x = ggml_get_rows(ctx, model->emb, token_index);
// x = self.layer_norm(x, self.w.blocks[0].ln0)
x = rwkv_layer_norm(ctx, x, model->ln0_weight, model->ln0_bias);
for (size_t i = 0; i < n_layer; i++) {
struct rwkv_layer layer = model->layers[i];
size_t part_index = i * 5;
size_t state_part_size = n_embed * sizeof(float);
// RWKV/time mixing
{
// self.layer_norm(x, self.w.blocks[i].ln1)
struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);
// x0 = state[5 * i + 1]
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (part_index + 1) * state_part_size);
// aa = state[5 * i + 2]
struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (part_index + 2) * state_part_size);
// bb = state[5 * i + 3]
struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (part_index + 3) * state_part_size);
// pp = state[5 * i + 4]
struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (part_index + 4) * state_part_size);
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
struct ggml_tensor * xk = ggml_add_inplace(ctx,
ggml_mul(ctx, x0, layer.att_time_mix_k),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k))
);
// xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
struct ggml_tensor * xv = ggml_add_inplace(ctx,
ggml_mul(ctx, x0, layer.att_time_mix_v),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v))
);
// xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
struct ggml_tensor * xr = ggml_add_inplace(ctx,
ggml_mul(ctx, x0, layer.att_time_mix_r),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r))
);
// r = torch.sigmoid(rw @ xr)
struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr));
// k = kw @ xk
struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk);
// v = vw @ xv
struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv);
// ww = time_first + k
struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k);
// qq = torch.maximum(pp, ww)
struct ggml_tensor * qq = rwkv_max(ctx, pp, ww);
// e1 = torch.exp(pp - qq)
struct ggml_tensor * e1 = rwkv_exp(ctx, ggml_sub(ctx, pp, qq));
// e2 = torch.exp(ww - qq)
struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq));
// a = e1 * aa + e2 * v
struct ggml_tensor * a = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v));
// b = e1 * bb + e2
struct ggml_tensor * b = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2);
// ww = pp + time_decay
ww = ggml_add_inplace(ctx, pp, layer.att_time_decay);
// qq = torch.maximum(ww, k)
qq = rwkv_max(ctx, ww, k);
// e1 = torch.exp(ww - qq)
e1 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq));
// e2 = torch.exp(k - qq)
e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq));
// state[5 * i + 1] = x0
// state[5 * i + 2] = e1 * aa + e2 * v
// state[5 * i + 3] = e1 * bb + e2
// state[5 * i + 4] = qq
state_parts[part_index + 1] = x0;
state_parts[part_index + 2] = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v));
state_parts[part_index + 3] = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2);
state_parts[part_index + 4] = qq;
// wkv = a / b
struct ggml_tensor * wkv = ggml_div(ctx, a, b);
// ow @ (r * wkv)
x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, wkv)));
}
// FFN/channel mixing
{
// self.layer_norm(x, self.w.blocks[i].ln2)
struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias);
// x_prev = state[5 * i + 0]
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, part_index * state_part_size);
// xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
struct ggml_tensor * xk = ggml_add_inplace(
ctx,
ggml_mul(ctx, x0, layer.ffn_time_mix_k),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k))
);
// xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
struct ggml_tensor * xr = ggml_add_inplace(
ctx,
ggml_mul(ctx, x0, layer.ffn_time_mix_r),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r))
);
// state[5 * i + 0] = x
state_parts[part_index] = x0;
// r = torch.sigmoid(rw @ xr)
struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr));
// k = torch.square(torch.relu(kw @ xk))
struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk)));
// r * (vw @ k)
x = ggml_add_inplace(ctx, x, ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)));
}
}
// x = self.layer_norm(x, self.w.ln_out)
x = rwkv_layer_norm(ctx, x, model->ln_out_weight, model->ln_out_bias);
// x = (self.w.head.weight @ x).float()
struct ggml_tensor * logits = ggml_mul_mat(ctx, model->head, x);
ggml_build_forward_expand(cgraph.get(), logits);
for (uint32_t i = 0; i < n_layer * 5; i++)
ggml_build_forward_expand(cgraph.get(), state_parts[i]);
out->state = state;
out->state_parts = std::move(state_parts);
out->token_index = token_index;
out->logits = logits;
out->cgraph = std::move(cgraph);
return true;
}
struct rwkv_file_guard {
FILE * file;
~rwkv_file_guard() { if (file) fclose(file); }
@ -418,192 +577,16 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]);
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]);
uint32_t n_embed = model->n_embed;
uint32_t n_layer = model->n_layer;
size_t n_embed = model->n_embed;
size_t n_layer = model->n_layer;
// Build graph
struct ggml_tensor * state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_layer * 5 * n_embed);
// x = self.w.emb.weight[token]
struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
struct ggml_tensor * x = ggml_get_rows(ctx, model->emb, token_index);
// x = self.layer_norm(x, self.w.blocks[0].ln0)
x = rwkv_layer_norm(ctx, x, model->ln0_weight, model->ln0_bias);
// We collect parts of new state here. Each part is (n_embed) vector.
struct ggml_tensor ** state_parts = new ggml_tensor * [n_layer * 5];
for (uint32_t i = 0; i < n_layer; i++) {
auto layer = model->layers[i];
// RWKV/time mixing
{
// self.layer_norm(x, self.w.blocks[i].ln1)
struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);
// state[5 * i + 1]
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * sizeof(float));
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
// xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
// xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
struct ggml_tensor * xk = ggml_add_inplace(
ctx,
ggml_mul(ctx, x0, layer.att_time_mix_k),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k))
);
struct ggml_tensor * xv = ggml_add_inplace(
ctx,
ggml_mul(ctx, x0, layer.att_time_mix_v),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v))
);
struct ggml_tensor * xr = ggml_add_inplace(
ctx,
ggml_mul(ctx, x0, layer.att_time_mix_r),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r))
);
// state[5 * i + 1] = x
state_parts[5 * i + 1] = x0;
// r = torch.sigmoid(rw @ xr)
struct ggml_tensor * r = rwkv_sigmoid(
ctx,
ggml_mul_mat(ctx, layer.att_receptance, xr)
);
// k = kw @ xk
struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk);
// v = vw @ xv
struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv);
// aa = state[5 * i + 2]
// bb = state[5 * i + 3]
// pp = state[5 * i + 4]
struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * sizeof(float));
struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * sizeof(float));
struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * sizeof(float));
// ww = time_first + k
struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k);
// qq = torch.maximum(pp, ww)
struct ggml_tensor * qq = rwkv_max(ctx, pp, ww);
// e1 = torch.exp(pp - qq)
struct ggml_tensor * e1 = rwkv_exp(ctx, ggml_sub(ctx, pp, qq));
// e2 = torch.exp(ww - qq)
struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq));
// a = e1 * aa + e2 * v
struct ggml_tensor * a = ggml_add_inplace(
ctx,
ggml_mul(ctx, e1, aa),
ggml_mul(ctx, e2, v)
);
// b = e1 * bb + e2
struct ggml_tensor * b = ggml_add_inplace(
ctx,
ggml_mul(ctx, e1, bb),
e2
);
// wkv = a / b
struct ggml_tensor * wkv = ggml_div(ctx, a, b);
// ww = pp + time_decay
ww = ggml_add(ctx, pp, layer.att_time_decay);
// qq = torch.maximum(ww, k)
qq = rwkv_max(ctx, ww, k);
// e1 = torch.exp(ww - qq)
e1 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq));
// e2 = torch.exp(k - qq)
e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq));
// state[5 * i + 2] = e1 * aa + e2 * v
state_parts[5 * i + 2] = ggml_add_inplace(
ctx,
ggml_mul(ctx, e1, aa),
ggml_mul(ctx, e2, v)
);
// state[5 * i + 3] = e1 * bb + e2
state_parts[5 * i + 3] = ggml_add_inplace(
ctx,
ggml_mul(ctx, e1, bb),
e2
);
// state[5 * i + 4] = qq
state_parts[5 * i + 4] = qq;
// ow @ (r * wkv)
x = ggml_add_inplace(
ctx,
x,
ggml_mul_mat(
ctx,
layer.att_output,
ggml_mul(ctx, r, wkv)
)
);
}
// FFN/channel mixing
{
// self.layer_norm(x, self.w.blocks[i].ln2)
struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias);
// state[5 * i + 0]
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * sizeof(float));
// xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
// xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
struct ggml_tensor * xk = ggml_add_inplace(
ctx,
ggml_mul(ctx, x0, layer.ffn_time_mix_k),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k))
);
struct ggml_tensor * xr = ggml_add_inplace(
ctx,
ggml_mul(ctx, x0, layer.ffn_time_mix_r),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r))
);
// state[5 * i + 0] = x
state_parts[5 * i + 0] = x0;
// r = torch.sigmoid(rw @ xr)
struct ggml_tensor * r = rwkv_sigmoid(
ctx,
ggml_mul_mat(ctx, layer.ffn_receptance, xr)
);
// k = torch.square(torch.relu(kw @ xk))
struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(
ctx,
ggml_mul_mat(ctx, layer.ffn_key, xk)
));
// r * (vw @ k)
x = ggml_add_inplace(
ctx,
x,
ggml_mul(
ctx,
r,
ggml_mul_mat(ctx, layer.ffn_value, k)
)
);
}
}
// x = self.layer_norm(x, self.w.ln_out)
x = rwkv_layer_norm(ctx, x, model->ln_out_weight, model->ln_out_bias);
// x = (self.w.head.weight @ x).float()
struct ggml_tensor * logits = ggml_mul_mat(ctx, model->head, x);
std::unique_ptr<struct ggml_cgraph> graph(new(std::nothrow) struct ggml_cgraph());
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_GRAPH | RWKV_ERROR_ALLOC, graph.get(), "Failed to allocate graph");
ggml_build_forward_expand(graph.get(), logits);
for (uint32_t i = 0; i < n_layer * 5; i++)
ggml_build_forward_expand(graph.get(), state_parts[i]);
graph->n_threads = n_threads;
struct rwkv_graph graph;
RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_graph(ctx, model.get(), n_threads, &graph));
std::unique_ptr<struct rwkv_context> rwkv_ctx(new(std::nothrow) struct rwkv_context());
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx.get(), "Failed to allocate context");
rwkv_ctx->model = std::move(model);
rwkv_ctx->token_index = token_index;
rwkv_ctx->state = state;
rwkv_ctx->state_parts = state_parts;
rwkv_ctx->logits = logits;
rwkv_ctx->ctx = ctx;
rwkv_ctx->graph = std::move(graph);
rwkv_ctx->last_error = RWKV_ERROR_NONE;
@ -627,40 +610,40 @@ bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const floa
RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, logits_out != NULL, "logits_out is NULL");
RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < ctx->model->n_vocab, "Token is out of range 0..%d", ctx->model->n_vocab - 1);
uint32_t n_layer = ctx->model->n_layer;
uint32_t n_embed = ctx->model->n_embed;
const struct rwkv_graph * graph = &ctx->graph;
size_t n_layer = ctx->model->n_layer;
size_t n_embed = ctx->model->n_embed;
ggml_set_i32_1d(ctx->token_index, 0, token);
ggml_set_i32_1d(graph->token_index, 0, token);
if (state_in == NULL) {
ggml_set_f32(ctx->state, 0.0F);
ggml_set_f32(graph->state, 0.0F);
for (uint64_t i = 0; i < n_layer; i++) {
for (size_t i = 0; i < n_layer; i++) {
// state[5 * i + 4] = -1e30
ggml_set_f32(
ggml_view_1d(ctx->ctx, ctx->state, n_embed, (5 * i + 4) * n_embed * sizeof(float)),
ggml_view_1d(ctx->ctx, graph->state, n_embed, (5 * i + 4) * n_embed * sizeof(float)),
-1e30F
);
}
} else {
memcpy(ctx->state->data, state_in, ctx->state->ne[0] * sizeof(float));
memcpy(graph->state->data, state_in, graph->state->ne[0] * sizeof(float));
}
ggml_graph_compute(ctx->ctx, ctx->graph.get());
ggml_graph_compute(ctx->ctx, graph->cgraph.get());
for (uint32_t i = 0; i < n_layer * 5; i++) {
struct ggml_tensor * part = ctx->state_parts[i];
for (size_t i = 0; i < n_layer * 5; i++) {
struct ggml_tensor * part = graph->state_parts[i];
memcpy(state_out + i * n_embed, part->data, part->ne[0] * sizeof(float));
}
memcpy(logits_out, ctx->logits->data, ctx->logits->ne[0] * sizeof(float));
memcpy(logits_out, graph->logits->data, graph->logits->ne[0] * sizeof(float));
return true;
}
void rwkv_free(struct rwkv_context * ctx) {
std::unique_ptr<struct rwkv_context> rwkv_ctx(ctx);
delete[] ctx->state_parts;
ggml_free(ctx->ctx);
}