[FILE FORMAT CHANGED] Use ggml_get_rows to get embedding

This commit is contained in:
saharNooby 2023-04-01 11:28:32 +04:00
parent 16ec7a5c18
commit fe98c94a63
3 changed files with 27 additions and 28 deletions

View File

@ -14,6 +14,8 @@
// --- Utilities --- // --- Utilities ---
#define F32_SIZE 4
// Checks that x is not false. If it is false, prints fancy message to stderr and aborts the execution. // Checks that x is not false. If it is false, prints fancy message to stderr and aborts the execution.
#define RWKV_ASSERT(x, ...) \ #define RWKV_ASSERT(x, ...) \
do { \ do { \
@ -263,8 +265,9 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
std::string key(key_length, 0); std::string key(key_length, 0);
RWKV_ASSERT(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key"); RWKV_ASSERT(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key");
// TODO Use ggml_type_size size_t element_size = data_type == 0 ?
size_t element_size = data_type == 0 ? 4 : 2; ggml_type_size(GGML_TYPE_F32) :
ggml_type_size(GGML_TYPE_F16);
size_t byte_count = element_count * element_size; size_t byte_count = element_count * element_size;
RWKV_ASSERT(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data"); RWKV_ASSERT(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data");
@ -319,8 +322,8 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
// Verify order of dimensions // Verify order of dimensions
struct ggml_tensor * emb = model->emb; struct ggml_tensor * emb = model->emb;
RWKV_ASSERT(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); RWKV_ASSERT(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims);
RWKV_ASSERT(emb->ne[0] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[0]); RWKV_ASSERT(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[1]);
RWKV_ASSERT(emb->ne[1] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[1]); RWKV_ASSERT(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[0]);
} }
// --- Operators --- // --- Operators ---
@ -380,13 +383,13 @@ int main(int argc, char ** argv) {
for (int i = 0; i < n_layer; i++) { for (int i = 0; i < n_layer; i++) {
// state[5 * i + 4] = -1e30 // state[5 * i + 4] = -1e30
int32_t offset_in_bytes = (5 * i + 4) * n_embed * 4; int32_t offset_in_bytes = (5 * i + 4) * n_embed * F32_SIZE;
struct ggml_tensor * state_part = ggml_view_1d(ctx, state, n_embed, offset_in_bytes); struct ggml_tensor * state_part = ggml_view_1d(ctx, state, n_embed, offset_in_bytes);
ggml_set_f32(state_part, -1e30F); ggml_set_f32(state_part, -1e30F);
} }
} else { } else {
RWKV_LOG("Loading state from %s", state_in_path); RWKV_LOG("Loading state from %s", state_in_path);
int32_t state_file_size = state_element_count * 4; int32_t state_file_size = state_element_count * F32_SIZE;
FILE * state_in_file = fopen(state_in_path, "rb"); FILE * state_in_file = fopen(state_in_path, "rb");
RWKV_ASSERT(state_in_file != NULL, "Failed to open file %s", state_in_path); RWKV_ASSERT(state_in_file != NULL, "Failed to open file %s", state_in_path);
@ -400,10 +403,9 @@ int main(int argc, char ** argv) {
// --- Evaluate model --- // --- Evaluate model ---
// x = self.w.emb.weight[token] // x = self.w.emb.weight[token]
// TODO Replace with ggml_get_rows or similar struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
struct ggml_tensor * one_hot = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_vocab, 1); ggml_set_i32_1d(token_index, 0, token);
ggml_set_f32_1d(one_hot, token, 1.0F); struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, token_index);
struct ggml_tensor * x = ggml_mul_mat(ctx, model.emb, one_hot);
// x = self.layer_norm(x, self.w.blocks[0].ln0) // x = self.layer_norm(x, self.w.blocks[0].ln0)
x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias); x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias);
@ -419,7 +421,7 @@ int main(int argc, char ** argv) {
// self.layer_norm(x, self.w.blocks[i].ln1) // self.layer_norm(x, self.w.blocks[i].ln1)
struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);
// state[5 * i + 1] // state[5 * i + 1]
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * 4); struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * F32_SIZE);
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
// xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
// xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
@ -454,9 +456,9 @@ int main(int argc, char ** argv) {
// aa = state[5 * i + 2] // aa = state[5 * i + 2]
// bb = state[5 * i + 3] // bb = state[5 * i + 3]
// pp = state[5 * i + 4] // pp = state[5 * i + 4]
struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * 4); struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * F32_SIZE);
struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * 4); struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * F32_SIZE);
struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * 4); struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * F32_SIZE);
// ww = time_first + k // ww = time_first + k
struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k); struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k);
@ -519,7 +521,7 @@ int main(int argc, char ** argv) {
// self.layer_norm(x, self.w.blocks[i].ln2) // self.layer_norm(x, self.w.blocks[i].ln2)
struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias);
// state[5 * i + 0] // state[5 * i + 0]
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * 4); struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * F32_SIZE);
// xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
// xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
struct ggml_tensor * xk = ggml_add( struct ggml_tensor * xk = ggml_add(
@ -577,17 +579,17 @@ int main(int argc, char ** argv) {
// Update state // Update state
for (int i = 0; i < n_layer * 5; i++) { for (int i = 0; i < n_layer * 5; i++) {
struct ggml_tensor * state_part_src = state_parts[i]; struct ggml_tensor * src = state_parts[i];
struct ggml_tensor * state_part_dest = ggml_view_1d(ctx, state, n_embed, i * n_embed * 4); struct ggml_tensor * dest = ggml_view_1d(ctx, state, n_embed, i * n_embed * F32_SIZE);
for (int j = 0; j < n_embed; j++) { for (int j = 0; j < n_embed; j++) {
ggml_set_f32_1d(state_part_dest, j, ggml_get_f32_1d(state_part_src, j)); ggml_set_f32_1d(dest, j, ggml_get_f32_1d(src, j));
} }
} }
{ {
RWKV_LOG("Saving state to %s", state_out_path); RWKV_LOG("Saving state to %s", state_out_path);
int32_t state_file_size = state_element_count * 4; int32_t state_file_size = state_element_count * F32_SIZE;
FILE * state_out_file = fopen(state_out_path, "wb"); FILE * state_out_file = fopen(state_out_path, "wb");
RWKV_ASSERT(state_out_file != NULL, "Failed to open file %s", state_out_path); RWKV_ASSERT(state_out_file != NULL, "Failed to open file %s", state_out_path);
@ -599,7 +601,7 @@ int main(int argc, char ** argv) {
{ {
RWKV_LOG("Saving logits to %s", logits_out_path); RWKV_LOG("Saving logits to %s", logits_out_path);
int32_t logits_file_size = n_vocab * 4; int32_t logits_file_size = n_vocab * F32_SIZE;
FILE * logits_out_file = fopen(logits_out_path, "wb"); FILE * logits_out_file = fopen(logits_out_path, "wb");
RWKV_ASSERT(logits_out_file != NULL, "Failed to open file %s", logits_out_path); RWKV_ASSERT(logits_out_file != NULL, "Failed to open file %s", logits_out_path);

View File

@ -72,7 +72,7 @@ def main() -> None:
print(f'Actual logits: {actual_logits}') print(f'Actual logits: {actual_logits}')
print('Difference per token: %.8f' % (difference,)) print('Difference per token: %.8f' % (difference,))
assert abs(difference) <= 0.00005, 'Difference is too big' assert abs(difference) <= 0.000005, 'Difference is too big'
# Check small token amount first to avoid waiting too long before seeing that model is broken # Check small token amount first to avoid waiting too long before seeing that model is broken
compare_logits(tokens[:4]) compare_logits(tokens[:4])

View File

@ -23,6 +23,7 @@
# int32 key_length; # int32 key_length;
# // 0 if float32, 1 if float16. # // 0 if float32, 1 if float16.
# int32 data_type; # int32 data_type;
# // Same values and order as in PyTorch's tensor.shape
# int32[dim_count] shape; # int32[dim_count] shape;
# // Keys are like "emb.weight", "block.0.ln1.weight". # // Keys are like "emb.weight", "block.0.ln1.weight".
# uint8[key_length] key_utf8; # uint8[key_length] key_utf8;
@ -89,10 +90,6 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
if data_type == 'float16' and len(tensor.shape) > 1: if data_type == 'float16' and len(tensor.shape) > 1:
tensor = tensor.half() tensor = tensor.half()
if k == 'emb.weight':
# Allows embedding matrix to be multiplied by one-hot vector
tensor = torch.permute(tensor, dims=[i for i in reversed(range(len(tensor.shape)))]).contiguous()
shape = tensor.shape shape = tensor.shape
print(f'Writing {k}, shape {shape}, type {tensor.dtype}') print(f'Writing {k}, shape {shape}, type {tensor.dtype}')
@ -153,10 +150,10 @@ def test() -> None:
2, 2,
10, 10,
0, 0,
2, 3, 3, 2,
'emb.weight'.encode('utf-8'), 'emb.weight'.encode('utf-8'),
1.0, 3.0, 5.0, 1.0, 2.0, 3.0,
2.0, 4.0, 6.0, 4.0, 5.0, 6.0,
# blocks.0.ln1.weight # blocks.0.ln1.weight
1, 1,
19, 19,