Skip to content

Commit

Permalink
llama: rwkv6: Fix tensor loading for 7B/14B models
Browse files Browse the repository at this point in the history
Signed-off-by: Molly Sophia <[email protected]>
  • Loading branch information
MollySophia committed Aug 13, 2024
1 parent 48605aa commit e7d35a3
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7726,10 +7726,9 @@ static bool llm_load_tensors(
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});

// TODO: Parameterize this
const int time_mix_extra_dim = 32;
const int time_decay_extra_dim = 64;
const int head_size = 64;
const int time_mix_extra_dim = (n_embd == 4096) ? 64 : 32;
const int time_decay_extra_dim = (n_embd == 4096) ? 128 : 64;
const int head_size = hparams.wkv_head_size;
const int attn_hidden_size = n_embd;
const int ffn_size = (int)(n_embd * 3.5 / 32) * 32;

Expand Down

0 comments on commit e7d35a3

Please sign in to comment.