[FILE FORMAT CHANGED] Use ggml_get_rows to get embedding
This commit is contained in:
		
							parent
							
								
									16ec7a5c18
								
							
						
					
					
						commit
						fe98c94a63
					
				|  | @ -14,6 +14,8 @@ | |||
| 
 | ||||
| // --- Utilities ---
 | ||||
| 
 | ||||
| #define F32_SIZE 4 | ||||
| 
 | ||||
| // Checks that x is not false. If it is false, prints fancy message to stderr and aborts the execution.
 | ||||
| #define RWKV_ASSERT(x, ...) \ | ||||
|     do { \ | ||||
|  | @ -263,8 +265,9 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m | |||
|         std::string key(key_length, 0); | ||||
|         RWKV_ASSERT(fread(&key[0], 1, key_length, file) == key_length, "Failed to read parameter key"); | ||||
| 
 | ||||
|         // TODO Use ggml_type_size
 | ||||
|         size_t element_size = data_type == 0 ? 4 : 2; | ||||
|         size_t element_size = data_type == 0 ? | ||||
|             ggml_type_size(GGML_TYPE_F32) : | ||||
|             ggml_type_size(GGML_TYPE_F16); | ||||
|         size_t byte_count = element_count * element_size; | ||||
| 
 | ||||
|         RWKV_ASSERT(fread(tensor->data, 1, byte_count, file) == byte_count, "Failed to read parameter data"); | ||||
|  | @ -319,8 +322,8 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m | |||
|     // Verify order of dimensions
 | ||||
|     struct ggml_tensor * emb = model->emb; | ||||
|     RWKV_ASSERT(emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); | ||||
|     RWKV_ASSERT(emb->ne[0] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[0]); | ||||
|     RWKV_ASSERT(emb->ne[1] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[1]); | ||||
|     RWKV_ASSERT(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %d", emb->ne[1]); | ||||
|     RWKV_ASSERT(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %d", emb->ne[0]); | ||||
| } | ||||
| 
 | ||||
| // --- Operators ---
 | ||||
|  | @ -380,13 +383,13 @@ int main(int argc, char ** argv) { | |||
| 
 | ||||
|         for (int i = 0; i < n_layer; i++) { | ||||
|             // state[5 * i + 4] = -1e30
 | ||||
|             int32_t offset_in_bytes = (5 * i + 4) * n_embed * 4; | ||||
|             int32_t offset_in_bytes = (5 * i + 4) * n_embed * F32_SIZE; | ||||
|             struct ggml_tensor * state_part = ggml_view_1d(ctx, state, n_embed, offset_in_bytes); | ||||
|             ggml_set_f32(state_part, -1e30F); | ||||
|         } | ||||
|     } else { | ||||
|         RWKV_LOG("Loading state from %s", state_in_path); | ||||
|         int32_t state_file_size = state_element_count * 4; | ||||
|         int32_t state_file_size = state_element_count * F32_SIZE; | ||||
| 
 | ||||
|         FILE * state_in_file = fopen(state_in_path, "rb"); | ||||
|         RWKV_ASSERT(state_in_file != NULL, "Failed to open file %s", state_in_path); | ||||
|  | @ -400,10 +403,9 @@ int main(int argc, char ** argv) { | |||
|     // --- Evaluate model ---
 | ||||
| 
 | ||||
|     // x = self.w.emb.weight[token]
 | ||||
|     // TODO Replace with ggml_get_rows or similar
 | ||||
|     struct ggml_tensor * one_hot = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_vocab, 1); | ||||
|     ggml_set_f32_1d(one_hot, token, 1.0F); | ||||
|     struct ggml_tensor * x = ggml_mul_mat(ctx, model.emb, one_hot); | ||||
|     struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); | ||||
|     ggml_set_i32_1d(token_index, 0, token); | ||||
|     struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, token_index); | ||||
| 
 | ||||
|     // x = self.layer_norm(x, self.w.blocks[0].ln0)
 | ||||
|     x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias); | ||||
|  | @ -419,7 +421,7 @@ int main(int argc, char ** argv) { | |||
|             // self.layer_norm(x, self.w.blocks[i].ln1)
 | ||||
|             struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); | ||||
|             // state[5 * i + 1]
 | ||||
|             struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * 4); | ||||
|             struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * F32_SIZE); | ||||
|             // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
 | ||||
|             // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
 | ||||
|             // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
 | ||||
|  | @ -454,9 +456,9 @@ int main(int argc, char ** argv) { | |||
|             // aa = state[5 * i + 2]
 | ||||
|             // bb = state[5 * i + 3]
 | ||||
|             // pp = state[5 * i + 4]
 | ||||
|             struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * 4); | ||||
|             struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * 4); | ||||
|             struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * 4); | ||||
|             struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * F32_SIZE); | ||||
|             struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * F32_SIZE); | ||||
|             struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * F32_SIZE); | ||||
| 
 | ||||
|             // ww = time_first + k
 | ||||
|             struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k); | ||||
|  | @ -519,7 +521,7 @@ int main(int argc, char ** argv) { | |||
|             // self.layer_norm(x, self.w.blocks[i].ln2)
 | ||||
|             struct ggml_tensor * x0 = ggml_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); | ||||
|             // state[5 * i + 0]
 | ||||
|             struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * 4); | ||||
|             struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * F32_SIZE); | ||||
|             // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
 | ||||
|             // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
 | ||||
|             struct ggml_tensor * xk = ggml_add( | ||||
|  | @ -577,17 +579,17 @@ int main(int argc, char ** argv) { | |||
| 
 | ||||
|     // Update state
 | ||||
|     for (int i = 0; i < n_layer * 5; i++) { | ||||
|         struct ggml_tensor * state_part_src = state_parts[i]; | ||||
|         struct ggml_tensor * state_part_dest = ggml_view_1d(ctx, state, n_embed, i * n_embed * 4); | ||||
|         struct ggml_tensor * src = state_parts[i]; | ||||
|         struct ggml_tensor * dest = ggml_view_1d(ctx, state, n_embed, i * n_embed * F32_SIZE); | ||||
| 
 | ||||
|         for (int j = 0; j < n_embed; j++) { | ||||
|             ggml_set_f32_1d(state_part_dest, j, ggml_get_f32_1d(state_part_src, j)); | ||||
|             ggml_set_f32_1d(dest, j, ggml_get_f32_1d(src, j)); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     { | ||||
|         RWKV_LOG("Saving state to %s", state_out_path); | ||||
|         int32_t state_file_size = state_element_count * 4; | ||||
|         int32_t state_file_size = state_element_count * F32_SIZE; | ||||
| 
 | ||||
|         FILE * state_out_file = fopen(state_out_path, "wb"); | ||||
|         RWKV_ASSERT(state_out_file != NULL, "Failed to open file %s", state_out_path); | ||||
|  | @ -599,7 +601,7 @@ int main(int argc, char ** argv) { | |||
| 
 | ||||
|     { | ||||
|         RWKV_LOG("Saving logits to %s", logits_out_path); | ||||
|         int32_t logits_file_size = n_vocab * 4; | ||||
|         int32_t logits_file_size = n_vocab * F32_SIZE; | ||||
| 
 | ||||
|         FILE * logits_out_file = fopen(logits_out_path, "wb"); | ||||
|         RWKV_ASSERT(logits_out_file != NULL, "Failed to open file %s", logits_out_path); | ||||
|  |  | |||
|  | @ -72,7 +72,7 @@ def main() -> None: | |||
|         print(f'Actual logits: {actual_logits}') | ||||
|         print('Difference per token: %.8f' % (difference,)) | ||||
| 
 | ||||
|         assert abs(difference) <= 0.00005, 'Difference is too big' | ||||
|         assert abs(difference) <= 0.000005, 'Difference is too big' | ||||
| 
 | ||||
|     # Check small token amount first to avoid waiting too long before seeing that model is broken | ||||
|     compare_logits(tokens[:4]) | ||||
|  |  | |||
|  | @ -23,6 +23,7 @@ | |||
| #   int32 key_length; | ||||
| #   // 0 if float32, 1 if float16. | ||||
| #   int32 data_type; | ||||
| #   // Same values and order as in PyTorch's tensor.shape | ||||
| #   int32[dim_count] shape; | ||||
| #   // Keys are like "emb.weight", "block.0.ln1.weight". | ||||
| #   uint8[key_length] key_utf8; | ||||
|  | @ -89,10 +90,6 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t | |||
|             if data_type == 'float16' and len(tensor.shape) > 1: | ||||
|                 tensor = tensor.half() | ||||
| 
 | ||||
|             if k == 'emb.weight': | ||||
|                 # Allows embedding matrix to be multiplied by one-hot vector | ||||
|                 tensor = torch.permute(tensor, dims=[i for i in reversed(range(len(tensor.shape)))]).contiguous() | ||||
| 
 | ||||
|             shape = tensor.shape | ||||
| 
 | ||||
|             print(f'Writing {k}, shape {shape}, type {tensor.dtype}') | ||||
|  | @ -153,10 +150,10 @@ def test() -> None: | |||
|             2, | ||||
|             10, | ||||
|             0, | ||||
|             2, 3, | ||||
|             3, 2, | ||||
|             'emb.weight'.encode('utf-8'), | ||||
|             1.0, 3.0, 5.0, | ||||
|             2.0, 4.0, 6.0, | ||||
|             1.0, 2.0, 3.0, | ||||
|             4.0, 5.0, 6.0, | ||||
|             # blocks.0.ln1.weight | ||||
|             1, | ||||
|             19, | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue