Skip to content

Commit

Permalink
Convert layer norm weights
Browse files Browse the repository at this point in the history
  • Loading branch information
2015aroras committed Oct 29, 2024
1 parent f8349e7 commit 08948af
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/transformers/models/olmo/convert_olmo_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def write_model(
f"model.layers.{layer_i}.mlp.gate_proj.weight": gate_proj_weight,
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.blocks.{layer_i}.ff_out.weight"],
f"model.layers.{layer_i}.mlp.up_proj.weight": up_proj_weight,
f"model.layers.{layer_i}.input_layernorm.weight": loaded.get(f"transformer.blocks.{layer_i}.attn_norm.weight"),
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded.get(f"transformer.blocks.{layer_i}.ff_norm.weight"),
}

state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
Expand All @@ -134,6 +136,7 @@ def write_model(
# TODO: Deal with weight-tying
state_dict = {
"model.embed_tokens.weight": loaded["transformer.wte.weight"],
"model.norm.weight": loaded.get("transformer.ln_f.weight"),
"lm_head.weight": loaded["transformer.ff_out.weight"]
if "transformer.ff_out.weight" in loaded
else loaded["transformer.wte.weight"],
Expand Down

0 comments on commit 08948af

Please sign in to comment.