[FILE FORMAT CHANGED] Use ggml_get_rows to get embedding
This commit is contained in:
parent
16ec7a5c18
commit
fe98c94a63
|
@ -14,6 +14,8 @@
|
||||||
|
|
||||||
// --- Utilities ---
|
// --- Utilities ---
|
||||||
|
|
||||||
|
#define F32_SIZE 4
|
||||||
|
|
||||||
// Checks that x is not false. If it is false, prints fancy message to stderr and aborts the execution.
|
// Checks that x is not false. If it is false, prints fancy message to stderr and aborts the execution.
|
||||||
#define RWKV_ASSERT(x, ...) \
|
#define RWKV_ASSERT(x, ...) \
|
||||||
do { \
|
do { \
|
||||||
|
@ -263,8 +265,9 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
||||||
std::string key(key_length, 0);
|
std::string key(key_length, 0);
|
||||||
RWKV_ASSERT(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key");
|
RWKV_ASSERT(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key");
|
||||||
|
|
||||||
// TODO Use ggml_type_size
|
size_t element_size = data_type == 0 ?
|
||||||
size_t element_size = data_type == 0 ? 4 : 2;
|
ggml_type_size(GGML_TYPE_F32) :
|
||||||
|
ggml_type_size(GGML_TYPE_F16);
|
||||||
size_t byte_count = element_count * element_size;
|
size_t byte_count = element_count * element_size;
|
||||||
|
|
||||||
RWKV_ASSERT(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data");
|
RWKV_ASSERT(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data");
|
||||||
|
@ -319,8 +322,8 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
||||||
// Verify order of dimensions
|
// Verify order of dimensions
|
||||||
struct ggml_tensor * emb = model->emb;
|
struct ggml_tensor * emb = model->emb;
|
||||||
RWKV_ASSERT(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims);
|
RWKV_ASSERT(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims);
|
||||||
RWKV_ASSERT(emb->ne[0] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[0]);
|
RWKV_ASSERT(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[1]);
|
||||||
RWKV_ASSERT(emb->ne[1] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[1]);
|
RWKV_ASSERT(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Operators ---
|
// --- Operators ---
|
||||||
|
@ -380,13 +383,13 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
for (int i = 0; i < n_layer; i++) {
|
for (int i = 0; i < n_layer; i++) {
|
||||||
// state[5 * i + 4] = -1e30
|
// state[5 * i + 4] = -1e30
|
||||||
int32_t offset_in_bytes = (5 * i + 4) * n_embed * 4;
|
int32_t offset_in_bytes = (5 * i + 4) * n_embed * F32_SIZE;
|
||||||
struct ggml_tensor * state_part = ggml_view_1d(ctx, state, n_embed, offset_in_bytes);
|
struct ggml_tensor * state_part = ggml_view_1d(ctx, state, n_embed, offset_in_bytes);
|
||||||
ggml_set_f32(state_part, -1e30F);
|
ggml_set_f32(state_part, -1e30F);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
RWKV_LOG("Loading state from %s", state_in_path);
|
RWKV_LOG("Loading state from %s", state_in_path);
|
||||||
int32_t state_file_size = state_element_count * 4;
|
int32_t state_file_size = state_element_count * F32_SIZE;
|
||||||
|
|
||||||
FILE * state_in_file = fopen(state_in_path, "rb");
|
FILE * state_in_file = fopen(state_in_path, "rb");
|
||||||
RWKV_ASSERT(state_in_file != NULL, "Failed to open file %s", state_in_path);
|
RWKV_ASSERT(state_in_file != NULL, "Failed to open file %s", state_in_path);
|
||||||
|
@ -400,10 +403,9 @@ int main(int argc, char ** argv) {
|
||||||
// --- Evaluate model ---
|
// --- Evaluate model ---
|
||||||
|
|
||||||
// x = self.w.emb.weight[token]
|
// x = self.w.emb.weight[token]
|
||||||
// TODO Replace with ggml_get_rows or similar
|
struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
|
||||||
struct ggml_tensor * one_hot = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_vocab, 1);
|
ggml_set_i32_1d(token_index, 0, token);
|
||||||
ggml_set_f32_1d(one_hot, token, 1.0F);
|
struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, token_index);
|
||||||
struct ggml_tensor * x = ggml_mul_mat(ctx, model.emb, one_hot);
|
|
||||||
|
|
||||||
// x = self.layer_norm(x, self.w.blocks[0].ln0)
|
// x = self.layer_norm(x, self.w.blocks[0].ln0)
|
||||||
x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias);
|
x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias);
|
||||||
|
@ -419,7 +421,7 @@ int main(int argc, char ** argv) {
|
||||||
// self.layer_norm(x, self.w.blocks[i].ln1)
|
// self.layer_norm(x, self.w.blocks[i].ln1)
|
||||||
struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);
|
struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);
|
||||||
// state[5 * i + 1]
|
// state[5 * i + 1]
|
||||||
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * 4);
|
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * F32_SIZE);
|
||||||
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
|
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
|
||||||
// xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
|
// xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
|
||||||
// xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
|
// xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
|
||||||
|
@ -454,9 +456,9 @@ int main(int argc, char ** argv) {
|
||||||
// aa = state[5 * i + 2]
|
// aa = state[5 * i + 2]
|
||||||
// bb = state[5 * i + 3]
|
// bb = state[5 * i + 3]
|
||||||
// pp = state[5 * i + 4]
|
// pp = state[5 * i + 4]
|
||||||
struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * 4);
|
struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * F32_SIZE);
|
||||||
struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * 4);
|
struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * F32_SIZE);
|
||||||
struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * 4);
|
struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * F32_SIZE);
|
||||||
|
|
||||||
// ww = time_first + k
|
// ww = time_first + k
|
||||||
struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k);
|
struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k);
|
||||||
|
@ -519,7 +521,7 @@ int main(int argc, char ** argv) {
|
||||||
// self.layer_norm(x, self.w.blocks[i].ln2)
|
// self.layer_norm(x, self.w.blocks[i].ln2)
|
||||||
struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias);
|
struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias);
|
||||||
// state[5 * i + 0]
|
// state[5 * i + 0]
|
||||||
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * 4);
|
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * F32_SIZE);
|
||||||
// xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
|
// xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
|
||||||
// xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
|
// xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
|
||||||
struct ggml_tensor * xk = ggml_add(
|
struct ggml_tensor * xk = ggml_add(
|
||||||
|
@ -577,17 +579,17 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// Update state
|
// Update state
|
||||||
for (int i = 0; i < n_layer * 5; i++) {
|
for (int i = 0; i < n_layer * 5; i++) {
|
||||||
struct ggml_tensor * state_part_src = state_parts[i];
|
struct ggml_tensor * src = state_parts[i];
|
||||||
struct ggml_tensor * state_part_dest = ggml_view_1d(ctx, state, n_embed, i * n_embed * 4);
|
struct ggml_tensor * dest = ggml_view_1d(ctx, state, n_embed, i * n_embed * F32_SIZE);
|
||||||
|
|
||||||
for (int j = 0; j < n_embed; j++) {
|
for (int j = 0; j < n_embed; j++) {
|
||||||
ggml_set_f32_1d(state_part_dest, j, ggml_get_f32_1d(state_part_src, j));
|
ggml_set_f32_1d(dest, j, ggml_get_f32_1d(src, j));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
RWKV_LOG("Saving state to %s", state_out_path);
|
RWKV_LOG("Saving state to %s", state_out_path);
|
||||||
int32_t state_file_size = state_element_count * 4;
|
int32_t state_file_size = state_element_count * F32_SIZE;
|
||||||
|
|
||||||
FILE * state_out_file = fopen(state_out_path, "wb");
|
FILE * state_out_file = fopen(state_out_path, "wb");
|
||||||
RWKV_ASSERT(state_out_file != NULL, "Failed to open file %s", state_out_path);
|
RWKV_ASSERT(state_out_file != NULL, "Failed to open file %s", state_out_path);
|
||||||
|
@ -599,7 +601,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
{
|
{
|
||||||
RWKV_LOG("Saving logits to %s", logits_out_path);
|
RWKV_LOG("Saving logits to %s", logits_out_path);
|
||||||
int32_t logits_file_size = n_vocab * 4;
|
int32_t logits_file_size = n_vocab * F32_SIZE;
|
||||||
|
|
||||||
FILE * logits_out_file = fopen(logits_out_path, "wb");
|
FILE * logits_out_file = fopen(logits_out_path, "wb");
|
||||||
RWKV_ASSERT(logits_out_file != NULL, "Failed to open file %s", logits_out_path);
|
RWKV_ASSERT(logits_out_file != NULL, "Failed to open file %s", logits_out_path);
|
||||||
|
|
|
@ -72,7 +72,7 @@ def main() -> None:
|
||||||
print(f'Actual logits: {actual_logits}')
|
print(f'Actual logits: {actual_logits}')
|
||||||
print('Difference per token: %.8f' % (difference,))
|
print('Difference per token: %.8f' % (difference,))
|
||||||
|
|
||||||
assert abs(difference) <= 0.00005, 'Difference is too big'
|
assert abs(difference) <= 0.000005, 'Difference is too big'
|
||||||
|
|
||||||
# Check small token amount first to avoid waiting too long before seeing that model is broken
|
# Check small token amount first to avoid waiting too long before seeing that model is broken
|
||||||
compare_logits(tokens[:4])
|
compare_logits(tokens[:4])
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
# int32 key_length;
|
# int32 key_length;
|
||||||
# // 0 if float32, 1 if float16.
|
# // 0 if float32, 1 if float16.
|
||||||
# int32 data_type;
|
# int32 data_type;
|
||||||
|
# // Same values and order as in PyTorch's tensor.shape
|
||||||
# int32[dim_count] shape;
|
# int32[dim_count] shape;
|
||||||
# // Keys are like "emb.weight", "block.0.ln1.weight".
|
# // Keys are like "emb.weight", "block.0.ln1.weight".
|
||||||
# uint8[key_length] key_utf8;
|
# uint8[key_length] key_utf8;
|
||||||
|
@ -89,10 +90,6 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
|
||||||
if data_type == 'float16' and len(tensor.shape) > 1:
|
if data_type == 'float16' and len(tensor.shape) > 1:
|
||||||
tensor = tensor.half()
|
tensor = tensor.half()
|
||||||
|
|
||||||
if k == 'emb.weight':
|
|
||||||
# Allows embedding matrix to be multiplied by one-hot vector
|
|
||||||
tensor = torch.permute(tensor, dims=[i for i in reversed(range(len(tensor.shape)))]).contiguous()
|
|
||||||
|
|
||||||
shape = tensor.shape
|
shape = tensor.shape
|
||||||
|
|
||||||
print(f'Writing {k}, shape {shape}, type {tensor.dtype}')
|
print(f'Writing {k}, shape {shape}, type {tensor.dtype}')
|
||||||
|
@ -153,10 +150,10 @@ def test() -> None:
|
||||||
2,
|
2,
|
||||||
10,
|
10,
|
||||||
0,
|
0,
|
||||||
2, 3,
|
3, 2,
|
||||||
'emb.weight'.encode('utf-8'),
|
'emb.weight'.encode('utf-8'),
|
||||||
1.0, 3.0, 5.0,
|
1.0, 2.0, 3.0,
|
||||||
2.0, 4.0, 6.0,
|
4.0, 5.0, 6.0,
|
||||||
# blocks.0.ln1.weight
|
# blocks.0.ln1.weight
|
||||||
1,
|
1,
|
||||||
19,
|
19,
|
||||||
|
|
Loading…
Reference in New Issue