From 08948af227f562cff7dfd86081cc76d3fab4f21f Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 28 Oct 2024 17:36:30 -0700 Subject: [PATCH] Convert layer norm weights --- src/transformers/models/olmo/convert_olmo_weights_to_hf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/models/olmo/convert_olmo_weights_to_hf.py b/src/transformers/models/olmo/convert_olmo_weights_to_hf.py index 73ebdfc88e8433..8801a9fe191721 100644 --- a/src/transformers/models/olmo/convert_olmo_weights_to_hf.py +++ b/src/transformers/models/olmo/convert_olmo_weights_to_hf.py @@ -119,6 +119,8 @@ def write_model( f"model.layers.{layer_i}.mlp.gate_proj.weight": gate_proj_weight, f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.blocks.{layer_i}.ff_out.weight"], f"model.layers.{layer_i}.mlp.up_proj.weight": up_proj_weight, + f"model.layers.{layer_i}.input_layernorm.weight": loaded.get(f"transformer.blocks.{layer_i}.attn_norm.weight"), + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded.get(f"transformer.blocks.{layer_i}.ff_norm.weight"), } state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq @@ -134,6 +136,7 @@ def write_model( # TODO: Deal with weight-tying state_dict = { "model.embed_tokens.weight": loaded["transformer.wte.weight"], + "model.norm.weight": loaded.get("transformer.ln_f.weight"), "lm_head.weight": loaded["transformer.ff_out.weight"] if "transformer.ff_out.weight" in loaded else loaded["transformer.wte.weight"],