Skip to content

Commit

Permalink
Reimplement OLMoE changes.
Browse files Browse the repository at this point in the history
Originally from TransformerLensOrg#718.
  • Loading branch information
joelburget committed Dec 14, 2024
1 parent 1b34ccd commit f0a0a68
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 1 deletion.
3 changes: 2 additions & 1 deletion transformer_lens/components/mlps/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def forward(
# both are [batch, pos, experts_per_token]
weights = self.hook_expert_weights(F.softmax(gate_logits, dim=1, dtype=torch.float))
weights, expert_indices = torch.topk(weights, self.experts_per_token, dim=-1)
weights /= weights.sum(dim=-1, keepdim=True)
if self.cfg.original_architecture != "OlmoeForCausalLM":
weights /= weights.sum(dim=-1, keepdim=True)
expert_indices = self.hook_expert_indices(expert_indices)
weights = weights.to(x.dtype)

Expand Down
34 changes: 34 additions & 0 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
convert_neo_weights,
convert_neox_weights,
convert_olmo_weights,
convert_olmoe_weights,
convert_opt_weights,
convert_phi3_weights,
convert_phi_weights,
Expand Down Expand Up @@ -245,6 +246,9 @@
"allenai/OLMo-1B-0724-hf",
"allenai/OLMo-7B-Instruct-hf",
"allenai/OLMo-7B-SFT-hf",
"allenai/OLMoE-1B-7B-0924",
"allenai/OLMoE-1B-7B-0924-SFT",
"allenai/OLMoE-1B-7B-0924-Instruct",
]
"""Official model names for models on HuggingFace."""

Expand Down Expand Up @@ -1469,6 +1473,34 @@ def convert_hf_model_config(model_name: str, **kwargs):
"positional_embedding_type": "rotary",
"gated_mlp": True,
}
elif architecture == "OlmoeForCausalLM":
cfg_dict = {
"d_model": hf_config.hidden_size,
"d_head": hf_config.hidden_size // hf_config.num_attention_heads,
"n_heads": hf_config.num_attention_heads,
"d_mlp": hf_config.intermediate_size,
"n_layers": hf_config.num_hidden_layers,
"n_ctx": hf_config.max_position_embeddings,
"eps": hf_config.rms_norm_eps,
"d_vocab": hf_config.vocab_size,
"act_fn": hf_config.hidden_act,
"num_experts": hf_config.num_experts,
"experts_per_token": hf_config.num_experts_per_tok,
# TODO: implement!
# "router_aux_loss_coef": hf_config.router_aux_loss_coef,
# "router_z_loss_coef": hf_config.router_z_loss_coef,
# "norm_topk_prob": hf_config.norm_topk_prob,
# end
"n_key_value_heads": hf_config.num_key_value_heads,
"rotary_base": hf_config.rope_theta,
"tie_word_embeddings": hf_config.tie_word_embeddings,
"initializer_range": hf_config.initializer_range,
"positional_embedding_type": "rotary",
"rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
"final_rms": True,
"gated_mlp": True,
"normalization_type": "RMS",
}
elif architecture == "T5ForConditionalGeneration":
cfg_dict = {
"d_model": hf_config.d_model,
Expand Down Expand Up @@ -1889,6 +1921,8 @@ def get_pretrained_state_dict(
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "OlmoForCausalLM":
state_dict = convert_olmo_weights(hf_model, cfg)
elif cfg.original_architecture == "OlmoeForCausalLM":
state_dict = convert_olmoe_weights(hf_model, cfg)
else:
raise ValueError(
f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
Expand Down
1 change: 1 addition & 0 deletions transformer_lens/pretrained/weight_conversions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .t5 import convert_t5_weights
from .neel_solu_old import convert_neel_solu_old_weights
from .olmo import convert_olmo_weights
from .olmoe import convert_olmoe_weights
64 changes: 64 additions & 0 deletions transformer_lens/pretrained/weight_conversions/olmoe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import einops
import torch

from transformer_lens.HookedTransformerConfig import HookedTransformerConfig


def convert_olmoe_weights(olmoe, cfg: HookedTransformerConfig):
state_dict = {}

assert cfg.n_key_value_heads is not None
assert cfg.d_mlp is not None
assert cfg.num_experts is not None

state_dict["embed.W_E"] = olmoe.model.embed_tokens.weight

for l in range(cfg.n_layers):
olmoe_layer = olmoe.model.layers[l]
state_dict[f"blocks.{l}.ln1.w"] = olmoe_layer.input_layernorm.weight

W_Q = olmoe.model.layers[l].self_attn.q_proj.weight
W_K = olmoe.model.layers[l].self_attn.k_proj.weight
W_V = olmoe.model.layers[l].self_attn.v_proj.weight
W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads)
W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads)
state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
state_dict[f"blocks.{l}.attn._W_K"] = W_K
state_dict[f"blocks.{l}.attn._W_V"] = W_V

state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(
cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
)
state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(
cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
)

W_O = olmoe_layer.self_attn.o_proj.weight
W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
state_dict[f"blocks.{l}.attn.W_O"] = W_O

state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)

state_dict[f"blocks.{l}.ln2.w"] = olmoe_layer.post_attention_layernorm.weight

state_dict[f"blocks.{l}.mlp.W_gate.weight"] = olmoe_layer.mlp.gate.weight

for e in range(cfg.num_experts):
state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = olmoe_layer.mlp.experts[
e
].up_proj.weight
state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = olmoe_layer.mlp.experts[
e
].gate_proj.weight
state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.weight"] = olmoe_layer.mlp.experts[
e
].down_proj.weight

state_dict["ln_final.w"] = olmoe.model.norm.weight

state_dict["unembed.W_U"] = olmoe.lm_head.weight.T
state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)

return state_dict

0 comments on commit f0a0a68

Please sign in to comment.