Skip to content

Commit

Permalink
Run make style
Browse files Browse the repository at this point in the history
  • Loading branch information
2015aroras committed Oct 29, 2024
1 parent 02b7158 commit 30844ec
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions src/transformers/models/olmo/convert_olmo_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
import torch
import yaml
from tokenizers import Tokenizer

from transformers import OlmoConfig, OlmoForCausalLM
from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast


"""
Sample usage:
```
Expand Down Expand Up @@ -113,16 +115,22 @@ def write_model(
f"model.layers.{layer_i}.self_attn.q_proj.weight": q_proj_weight,
f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj_weight,
f"model.layers.{layer_i}.self_attn.v_proj.weight": v_proj_weight,
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[
f"transformer.blocks.{layer_i}.attn_out.weight"
],
f"model.layers.{layer_i}.self_attn.q_norm.weight": loaded.get(f"transformer.blocks.{layer_i}.q_norm.weight"),
f"model.layers.{layer_i}.self_attn.k_norm.weight": loaded.get(f"transformer.blocks.{layer_i}.k_norm.weight"),
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.blocks.{layer_i}.attn_out.weight"],
f"model.layers.{layer_i}.self_attn.q_norm.weight": loaded.get(
f"transformer.blocks.{layer_i}.q_norm.weight"
),
f"model.layers.{layer_i}.self_attn.k_norm.weight": loaded.get(
f"transformer.blocks.{layer_i}.k_norm.weight"
),
f"model.layers.{layer_i}.mlp.gate_proj.weight": gate_proj_weight,
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.blocks.{layer_i}.ff_out.weight"],
f"model.layers.{layer_i}.mlp.up_proj.weight": up_proj_weight,
f"model.layers.{layer_i}.input_layernorm.weight": loaded.get(f"transformer.blocks.{layer_i}.attn_norm.weight"),
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded.get(f"transformer.blocks.{layer_i}.ff_norm.weight"),
f"model.layers.{layer_i}.input_layernorm.weight": loaded.get(
f"transformer.blocks.{layer_i}.attn_norm.weight"
),
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded.get(
f"transformer.blocks.{layer_i}.ff_norm.weight"
),
}

state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
Expand Down

0 comments on commit 30844ec

Please sign in to comment.