Finally, FP32 inference
This commit is contained in:
parent
61c6b1a4e0
commit
6fe9486cee
|
@ -107,8 +107,7 @@ void print_tensor(struct ggml_tensor * tensor, char * name) {
|
||||||
void compute_graph(struct ggml_context * ctx, struct ggml_tensor * tensor) {
|
void compute_graph(struct ggml_context * ctx, struct ggml_tensor * tensor) {
|
||||||
struct ggml_cgraph graph = ggml_build_forward(tensor);
|
struct ggml_cgraph graph = ggml_build_forward(tensor);
|
||||||
|
|
||||||
// TODO Move to script arguments
|
graph.n_threads = 1;
|
||||||
graph.n_threads = std::max(1, (int32_t) std::thread::hardware_concurrency() / 2);
|
|
||||||
|
|
||||||
ggml_graph_compute(ctx, &graph);
|
ggml_graph_compute(ctx, &graph);
|
||||||
}
|
}
|
||||||
|
@ -252,7 +251,10 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
||||||
read_int32(file, &x);
|
read_int32(file, &x);
|
||||||
read_int32(file, &y);
|
read_int32(file, &y);
|
||||||
element_count = x * y;
|
element_count = x * y;
|
||||||
// Not a typo, dimensions should be reversed here
|
// Dimension order is reversed here:
|
||||||
|
// * PyTorch shape is (x rows, y columns)
|
||||||
|
// * ggml shape is (y elements in a row, x elements in a column)
|
||||||
|
// Both shapes represent the same tensor.
|
||||||
tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, y, x);
|
tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, y, x);
|
||||||
} else {
|
} else {
|
||||||
abort();
|
abort();
|
||||||
|
@ -376,7 +378,6 @@ int main(int argc, char ** argv) {
|
||||||
RWKV_LOG("Creating new state");
|
RWKV_LOG("Creating new state");
|
||||||
ggml_set_f32(state, 0.0F);
|
ggml_set_f32(state, 0.0F);
|
||||||
|
|
||||||
// TODO Verify correctness
|
|
||||||
for (int i = 0; i < n_layer; i++) {
|
for (int i = 0; i < n_layer; i++) {
|
||||||
// state[5 * i + 4] = -1e30
|
// state[5 * i + 4] = -1e30
|
||||||
int32_t offset_in_bytes = (5 * i + 4) * n_embed * 4;
|
int32_t offset_in_bytes = (5 * i + 4) * n_embed * 4;
|
||||||
|
@ -407,6 +408,9 @@ int main(int argc, char ** argv) {
|
||||||
// x = self.layer_norm(x, self.w.blocks[0].ln0)
|
// x = self.layer_norm(x, self.w.blocks[0].ln0)
|
||||||
x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias);
|
x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias);
|
||||||
|
|
||||||
|
// We collect parts of new state here. Each part is (n_embed) vector.
|
||||||
|
struct ggml_tensor ** state_parts = new ggml_tensor * [5 * n_layer];
|
||||||
|
|
||||||
for (int i = 0; i < n_layer; i++) {
|
for (int i = 0; i < n_layer; i++) {
|
||||||
auto layer = model.layers[i];
|
auto layer = model.layers[i];
|
||||||
|
|
||||||
|
@ -435,7 +439,7 @@ int main(int argc, char ** argv) {
|
||||||
ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_r))
|
ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.att_time_mix_r))
|
||||||
);
|
);
|
||||||
// state[5 * i + 1] = x
|
// state[5 * i + 1] = x
|
||||||
ggml_cpy(ctx, x0, x_prev);
|
state_parts[5 * i + 1] = x0;
|
||||||
|
|
||||||
// r = torch.sigmoid(rw @ xr)
|
// r = torch.sigmoid(rw @ xr)
|
||||||
struct ggml_tensor * r = ggml_sigmoid(
|
struct ggml_tensor * r = ggml_sigmoid(
|
||||||
|
@ -485,22 +489,19 @@ int main(int argc, char ** argv) {
|
||||||
// e2 = torch.exp(k - qq)
|
// e2 = torch.exp(k - qq)
|
||||||
e2 = ggml_exp(ctx, ggml_sub(ctx, k, qq));
|
e2 = ggml_exp(ctx, ggml_sub(ctx, k, qq));
|
||||||
// state[5 * i + 2] = e1 * aa + e2 * v
|
// state[5 * i + 2] = e1 * aa + e2 * v
|
||||||
// TODO Must save result
|
state_parts[5 * i + 2] = ggml_add(
|
||||||
ggml_cpy(ctx, ggml_add(
|
|
||||||
ctx,
|
ctx,
|
||||||
ggml_mul(ctx, e1, aa),
|
ggml_mul(ctx, e1, aa),
|
||||||
ggml_mul(ctx, e2, v)
|
ggml_mul(ctx, e2, v)
|
||||||
), aa);
|
);
|
||||||
// state[5 * i + 3] = e1 * bb + e2
|
// state[5 * i + 3] = e1 * bb + e2
|
||||||
// TODO Must save result
|
state_parts[5 * i + 3] = ggml_add(
|
||||||
ggml_cpy(ctx, ggml_add(
|
|
||||||
ctx,
|
ctx,
|
||||||
ggml_mul(ctx, e1, bb),
|
ggml_mul(ctx, e1, bb),
|
||||||
e2
|
e2
|
||||||
), bb);
|
);
|
||||||
// state[5 * i + 4] = qq
|
// state[5 * i + 4] = qq
|
||||||
// TODO Must save result
|
state_parts[5 * i + 4] = qq;
|
||||||
ggml_cpy(ctx, qq, pp);
|
|
||||||
// ow @ (r * wkv)
|
// ow @ (r * wkv)
|
||||||
x = ggml_add(
|
x = ggml_add(
|
||||||
ctx,
|
ctx,
|
||||||
|
@ -532,8 +533,7 @@ int main(int argc, char ** argv) {
|
||||||
ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.ffn_time_mix_r))
|
ggml_mul(ctx, x_prev, ggml_1_minus_x(ctx, layer.ffn_time_mix_r))
|
||||||
);
|
);
|
||||||
// state[5 * i + 0] = x
|
// state[5 * i + 0] = x
|
||||||
// TODO Must save result
|
state_parts[5 * i + 0] = x0;
|
||||||
ggml_cpy(ctx, x0, x_prev);
|
|
||||||
|
|
||||||
// r = torch.sigmoid(rw @ xr)
|
// r = torch.sigmoid(rw @ xr)
|
||||||
struct ggml_tensor * r = ggml_sigmoid(
|
struct ggml_tensor * r = ggml_sigmoid(
|
||||||
|
@ -564,9 +564,26 @@ int main(int argc, char ** argv) {
|
||||||
// x = (self.w.head.weight @ x).float()
|
// x = (self.w.head.weight @ x).float()
|
||||||
struct ggml_tensor * logits = ggml_mul_mat(ctx, model.head, x);
|
struct ggml_tensor * logits = ggml_mul_mat(ctx, model.head, x);
|
||||||
|
|
||||||
compute_graph(ctx, logits);
|
struct ggml_cgraph graph = ggml_build_forward(logits);
|
||||||
|
|
||||||
PRINT_TENSOR(logits);
|
for (int i = 0; i < n_layer * 5; i++) {
|
||||||
|
ggml_build_forward_expand(&graph, state_parts[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO Move to script arguments
|
||||||
|
graph.n_threads = std::max(1, (int32_t) std::thread::hardware_concurrency() / 2);
|
||||||
|
|
||||||
|
ggml_graph_compute(ctx, &graph);
|
||||||
|
|
||||||
|
// Update state
|
||||||
|
for (int i = 0; i < n_layer * 5; i++) {
|
||||||
|
struct ggml_tensor * state_part_src = state_parts[i];
|
||||||
|
struct ggml_tensor * state_part_dest = ggml_view_1d(ctx, state, n_embed, i * n_embed * 4);
|
||||||
|
|
||||||
|
for (int j = 0; j < n_embed; j++) {
|
||||||
|
ggml_set_f32_1d(state_part_dest, j, ggml_get_f32_1d(state_part_src, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
RWKV_LOG("Saving state to %s", state_out_path);
|
RWKV_LOG("Saving state to %s", state_out_path);
|
||||||
|
|
|
@ -18,8 +18,9 @@ def parse_args():
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
|
token_count: int = 64
|
||||||
# It's not important what exactly these tokens are; just that output of both model matches.
|
# It's not important what exactly these tokens are; just that output of both model matches.
|
||||||
tokens: List[int] = [(i + 1) for i in range(32)]
|
tokens: List[int] = [(i + 1) for i in range(token_count)]
|
||||||
state_path: str = './state.bin'
|
state_path: str = './state.bin'
|
||||||
logits_path: str = './logits.bin'
|
logits_path: str = './logits.bin'
|
||||||
|
|
||||||
|
@ -27,9 +28,11 @@ def main() -> None:
|
||||||
|
|
||||||
ref_logits, ref_state = None, None
|
ref_logits, ref_state = None, None
|
||||||
|
|
||||||
for token in tokens:
|
for i in range(token_count):
|
||||||
|
token: int = tokens[i]
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(f'--- Token {token} ---')
|
print(f'--- {i + 1}/{token_count} ---')
|
||||||
|
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
[
|
[
|
||||||
|
@ -55,8 +58,9 @@ def main() -> None:
|
||||||
print(f'Actual logits: {actual_logits}')
|
print(f'Actual logits: {actual_logits}')
|
||||||
print('Difference per token: %.8f' % (difference,))
|
print('Difference per token: %.8f' % (difference,))
|
||||||
|
|
||||||
assert abs(difference) <= 0.000001, 'Difference is too big'
|
assert abs(difference) <= 0.00005, 'Difference is too big'
|
||||||
|
|
||||||
|
print()
|
||||||
print('Test passes')
|
print('Test passes')
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue