Make ln0 work correctly
This commit is contained in:
parent
2f51451561
commit
873cb954d0
|
@ -116,6 +116,9 @@ void print_tensor(struct ggml_tensor * tensor, char * name) {
|
|||
// Prints tensor name, dimensionality, shape and part of its contents.
|
||||
#define PRINT_TENSOR(x) print_tensor(x, #x)
|
||||
|
||||
// Same as above, but additionally computes tensor graph before printing the tensor.
|
||||
#define COMPUTE_AND_PRINT_TENSOR(ctx, x) do { compute_graph(ctx, x); print_tensor(x, #x); } while (0)
|
||||
|
||||
// Computes value of the tensor and all tensors it depends on.
|
||||
void compute_graph(struct ggml_context * ctx, struct ggml_tensor * tensor) {
|
||||
struct ggml_cgraph graph = ggml_build_forward(tensor);
|
||||
|
@ -238,7 +241,7 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
|||
break;
|
||||
}
|
||||
|
||||
RWKV_ASSERT(dim_count == 1 || dim_count == 2 || dim_count == 3, "Unsupported dimension count %d", dim_count);
|
||||
RWKV_ASSERT(dim_count == 1 || dim_count == 2, "Unsupported dimension count %d", dim_count);
|
||||
|
||||
int32_t key_length;
|
||||
read_int32(file, &key_length);
|
||||
|
@ -267,13 +270,6 @@ void load_rwkv_model(ggml_context * ctx, char * file_path, struct rwkv_model * m
|
|||
element_count = x * y;
|
||||
// Not a typo, dimensions should be reversed here
|
||||
tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, y, x);
|
||||
} else if (dim_count == 3) {
|
||||
read_int32(file, &x);
|
||||
read_int32(file, &y);
|
||||
read_int32(file, &z);
|
||||
element_count = x * y * z;
|
||||
// Not a typo, dimensions should be reversed here
|
||||
tensor = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, z, y, x);
|
||||
} else {
|
||||
abort();
|
||||
}
|
||||
|
@ -432,23 +428,17 @@ int main(int argc, char ** argv) {
|
|||
ggml_set_f32(ones, 1.0F);
|
||||
|
||||
// x = self.w.emb.weight[token]
|
||||
struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
|
||||
ggml_set_i32_1d(token_index, 0, token);
|
||||
// TODO Is transpose copying the tensor?
|
||||
struct ggml_tensor * x = ggml_get_rows(ctx, ggml_transpose(ctx, model.emb), token_index);
|
||||
|
||||
compute_graph(ctx, x);
|
||||
// For token 123, should be [-0.1836, 0.4434, 0.3848 ... -0.4102, -0.3164, -0.1826]
|
||||
// TODO NOT CORRECT
|
||||
PRINT_TENSOR(x);
|
||||
// TODO Replace with ggml_get_rows or similar
|
||||
struct ggml_tensor * one_hot = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_vocab, 1);
|
||||
ggml_set_f32_1d(one_hot, token, 1.0F);
|
||||
struct ggml_tensor * x = ggml_mul_mat(ctx, model.emb, one_hot);
|
||||
|
||||
// x = self.layer_norm(x, self.w.blocks[0].ln0)
|
||||
x = ggml_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias);
|
||||
|
||||
compute_graph(ctx, x);
|
||||
// For token 123, should be [-0.4194, 1.1698, 0.7798 ... -1.1838, -0.8716, -0.2765]
|
||||
// TODO NOT CORRECT
|
||||
PRINT_TENSOR(x);
|
||||
// For token 123 after ln0, should be [-0.4194, 1.1698, 0.7798 ... -1.1838, -0.8716, -0.2765]
|
||||
// Prints [[-0.419416 1.169782 0.779827 ... -1.183806 -0.871573 -0.276483]]
|
||||
COMPUTE_AND_PRINT_TENSOR(ctx, x);
|
||||
|
||||
for (int i = 0; i < n_layer; i++) {
|
||||
auto layer = model.layers[i];
|
||||
|
|
|
@ -89,8 +89,9 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
|
|||
if data_type == 'float16' and len(tensor.shape) > 1:
|
||||
tensor = tensor.half()
|
||||
|
||||
# ggml stores tensor values in other way than PyTorch, need to flip dimension order here
|
||||
tensor = torch.permute(tensor, dims=[i for i in reversed(range(len(tensor.shape)))]).contiguous()
|
||||
if k == 'emb.weight':
|
||||
# Allows embedding matrix to be multiplied by one-hot vector
|
||||
tensor = torch.permute(tensor, dims=[i for i in reversed(range(len(tensor.shape)))]).contiguous()
|
||||
|
||||
shape = tensor.shape
|
||||
|
||||
|
|
Loading…
Reference in New Issue