Skip to content

Commit

Permalink
llama: rwkv6: Add kv time_mix_extra_dim and time_decay_extra_dim
Browse files Browse the repository at this point in the history
Signed-off-by: Molly Sophia <[email protected]>
  • Loading branch information
MollySophia committed Aug 25, 2024
1 parent 13c6145 commit 8e2e9aa
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 2 deletions.
4 changes: 4 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2754,6 +2754,8 @@ def set_gguf_parameters(self):
layer_norm_eps = self.hparams["layer_norm_epsilon"]
rescale_every_n_layers = self.hparams["rescale_every"]
intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else int((hidden_size * 3.5) // 32 * 32)
time_mix_extra_dim = 64 if hidden_size == 4096 else 32
time_decay_extra_dim = 128 if hidden_size == 4096 else 64

# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
Expand All @@ -2762,6 +2764,8 @@ def set_gguf_parameters(self):
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_file_type(self.ftype)

Expand Down
2 changes: 2 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class LLM:
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim"
TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim"

class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,12 @@ def add_expert_weights_scale(self, value: float) -> None:
def add_rescale_every_n_layers(self, count: int) -> None:
self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count)

def add_time_mix_extra_dim(self, dim: int) -> None:
self.add_uint32(Keys.LLM.TIME_MIX_EXTRA_DIM.format(arch=self.arch), dim)

def add_time_decay_extra_dim(self, dim: int) -> None:
self.add_uint32(Keys.LLM.TIME_DECAY_EXTRA_DIM.format(arch=self.arch), dim)

def add_wkv_head_size(self, size: int) -> None:
self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)

Expand Down
14 changes: 12 additions & 2 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ enum llm_kv {
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
LLM_KV_RESCALE_EVERY_N_LAYERS,
LLM_KV_TIME_MIX_EXTRA_DIM,
LLM_KV_TIME_DECAY_EXTRA_DIM,

LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
Expand Down Expand Up @@ -400,6 +402,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
{ LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" },
{ LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" },
{ LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" },

{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
Expand Down Expand Up @@ -2296,6 +2300,8 @@ struct llama_hparams {

// for RWKV
uint32_t rescale_every_n_layers = 0;
uint32_t time_mix_extra_dim = 0;
uint32_t time_decay_extra_dim = 0;
uint32_t wkv_head_size = 0;

float rope_attn_factor = 1.0f;
Expand Down Expand Up @@ -2362,6 +2368,8 @@ struct llama_hparams {
if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;

if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true;
if (this->time_mix_extra_dim != other.time_mix_extra_dim) return true;
if (this->time_decay_extra_dim != other.time_decay_extra_dim) return true;
if (this->wkv_head_size != other.wkv_head_size) return true;

if (this->dec_start_token_id != other.dec_start_token_id) return true;
Expand Down Expand Up @@ -5909,6 +5917,8 @@ static void llm_load_hparams(
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim);
ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim);
ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false);

switch (hparams.n_layer) {
Expand Down Expand Up @@ -8364,8 +8374,8 @@ static bool llm_load_tensors(
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});

const int time_mix_extra_dim = (n_embd == 4096) ? 64 : 32;
const int time_decay_extra_dim = (n_embd == 4096) ? 128 : 64;
const int time_mix_extra_dim = hparams.time_mix_extra_dim;
const int time_decay_extra_dim = hparams.time_decay_extra_dim;
const int head_size = hparams.wkv_head_size;
const int attn_hidden_size = n_embd;
const int ffn_size = hparams.n_ff_arr[0];
Expand Down

0 comments on commit 8e2e9aa

Please sign in to comment.