Skip to content

Commit

Permalink
Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV
Browse files Browse the repository at this point in the history
Signed-off-by: Molly Sophia <[email protected]>
  • Loading branch information
MollySophia committed Aug 12, 2024
1 parent e6f08bc commit d1931ea
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 19 deletions.
2 changes: 1 addition & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2705,7 +2705,7 @@ class StarCoder2Model(Model):

@Model.register("Rwkv6ForCausalLM")
class RwkvModel(Model):
model_arch = gguf.MODEL_ARCH.RWKV
model_arch = gguf.MODEL_ARCH.RWKV6

def set_vocab(self):
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
Expand Down
6 changes: 3 additions & 3 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class MODEL_ARCH(IntEnum):
GEMMA = auto()
GEMMA2 = auto()
STARCODER2 = auto()
RWKV = auto()
RWKV6 = auto()
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
Expand Down Expand Up @@ -362,7 +362,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV: "rwkv",
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
Expand Down Expand Up @@ -903,7 +903,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.RWKV: [
MODEL_ARCH.RWKV6: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
30 changes: 15 additions & 15 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ enum llm_arch {
LLM_ARCH_T5,
LLM_ARCH_T5ENCODER,
LLM_ARCH_JAIS,
LLM_ARCH_RWKV,
LLM_ARCH_RWKV6,
LLM_ARCH_UNKNOWN,
};

Expand Down Expand Up @@ -256,7 +256,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" },
{ LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_RWKV, "rwkv" },
{ LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -1328,7 +1328,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
},
},
{
LLM_ARCH_RWKV,
LLM_ARCH_RWKV6,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
Expand Down Expand Up @@ -3052,7 +3052,7 @@ static bool llama_kv_cache_init(
cache.has_shift = false;

// TODO: find a nicer way to add other recurrent model architectures
cache.recurrent = model.arch == LLM_ARCH_MAMBA || model.arch == LLM_ARCH_RWKV;
cache.recurrent = model.arch == LLM_ARCH_MAMBA || model.arch == LLM_ARCH_RWKV6;
cache.v_trans = !cache.recurrent && !cparams.flash_attn;

cache.head = 0;
Expand Down Expand Up @@ -5353,7 +5353,7 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_RWKV:
case LLM_ARCH_RWKV6:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
Expand Down Expand Up @@ -7705,7 +7705,7 @@ static bool llm_load_tensors(
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
}
} break;
case LLM_ARCH_RWKV:
case LLM_ARCH_RWKV6:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});

Expand Down Expand Up @@ -8560,7 +8560,7 @@ static struct ggml_tensor * llm_build_kv(
}


static struct ggml_tensor * llm_build_time_mix(
static struct ggml_tensor * llm_build_time_mix_rwkv6(
struct ggml_context * ctx,
const struct llama_layer * layer,
struct ggml_tensor * cur,
Expand Down Expand Up @@ -8721,7 +8721,7 @@ static struct ggml_tensor * llm_build_time_mix(
return ggml_mul_mat(ctx, layer->time_mix_output, cur);
}

static struct ggml_tensor * llm_build_channel_mix(
static struct ggml_tensor * llm_build_channel_mix_rwkv6(
struct ggml_context * ctx,
const struct llama_layer * layer,
struct ggml_tensor * cur,
Expand Down Expand Up @@ -14139,7 +14139,7 @@ struct llm_build_context {
return gf;
}

ggml_cgraph * build_rwkv() {
ggml_cgraph * build_rwkv6() {
ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);

// Token shift state dimensions should be 2 * n_emb
Expand Down Expand Up @@ -14187,7 +14187,7 @@ struct llm_build_context {
n_embd, n_tokens
);

cur = ggml_add(ctx0, cur, llm_build_time_mix(ctx0, layer, x_norm, x_prev, &wkv_states, state_seq));
cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(ctx0, layer, x_norm, x_prev, &wkv_states, state_seq));
ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand(
gf,
Expand Down Expand Up @@ -14223,7 +14223,7 @@ struct llm_build_context {
ggml_view_1d(ctx0, tmp, n_embd * n_tokens, 0),
n_embd, n_tokens
);
cur = ggml_add(ctx0, cur, llm_build_channel_mix(ctx0, layer, x_norm, x_prev));
cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(ctx0, layer, x_norm, x_prev));
ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand(
gf,
Expand Down Expand Up @@ -14528,9 +14528,9 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_jais();
} break;
case LLM_ARCH_RWKV:
case LLM_ARCH_RWKV6:
{
result = llm.build_rwkv();
result = llm.build_rwkv6();
} break;
default:
GGML_ABORT("fatal error");
Expand Down Expand Up @@ -17255,7 +17255,7 @@ struct llama_context * llama_new_context_with_model(
ggml_type type_v = params.type_v;

// Mamba and RWKV only need a constant number of KV cache cells per sequence
if (model->arch == LLM_ARCH_MAMBA || model->arch == LLM_ARCH_RWKV) {
if (model->arch == LLM_ARCH_MAMBA || model->arch == LLM_ARCH_RWKV6) {
// Mamba and RWKV need at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_seq_max);
// it's probably best to keep as much precision as possible for the states
Expand Down Expand Up @@ -17565,7 +17565,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_T5:
case LLM_ARCH_T5ENCODER:
case LLM_ARCH_JAIS:
case LLM_ARCH_RWKV:
case LLM_ARCH_RWKV6:
return LLAMA_ROPE_TYPE_NONE;

// use what we call a normal RoPE, operating on pairs of consecutive head values
Expand Down

0 comments on commit d1931ea

Please sign in to comment.