From f0a0a6872840b33102be61dd94469352f51715df Mon Sep 17 00:00:00 2001 From: Joel Burget Date: Fri, 13 Dec 2024 21:20:10 -0800 Subject: [PATCH] Reimplement OLMoE changes. Originally from https://github.com/TransformerLensOrg/TransformerLens/pull/718. --- transformer_lens/components/mlps/moe.py | 3 +- transformer_lens/loading_from_pretrained.py | 34 ++++++++++ .../pretrained/weight_conversions/__init__.py | 1 + .../pretrained/weight_conversions/olmoe.py | 64 +++++++++++++++++++ 4 files changed, 101 insertions(+), 1 deletion(-) create mode 100644 transformer_lens/pretrained/weight_conversions/olmoe.py diff --git a/transformer_lens/components/mlps/moe.py b/transformer_lens/components/mlps/moe.py index e01f25ee9..6354108dc 100644 --- a/transformer_lens/components/mlps/moe.py +++ b/transformer_lens/components/mlps/moe.py @@ -88,7 +88,8 @@ def forward( # both are [batch, pos, experts_per_token] weights = self.hook_expert_weights(F.softmax(gate_logits, dim=1, dtype=torch.float)) weights, expert_indices = torch.topk(weights, self.experts_per_token, dim=-1) - weights /= weights.sum(dim=-1, keepdim=True) + if self.cfg.original_architecture != "OlmoeForCausalLM": + weights /= weights.sum(dim=-1, keepdim=True) expert_indices = self.hook_expert_indices(expert_indices) weights = weights.to(x.dtype) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 43282ad6a..21e0f98b9 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -36,6 +36,7 @@ convert_neo_weights, convert_neox_weights, convert_olmo_weights, + convert_olmoe_weights, convert_opt_weights, convert_phi3_weights, convert_phi_weights, @@ -245,6 +246,9 @@ "allenai/OLMo-1B-0724-hf", "allenai/OLMo-7B-Instruct-hf", "allenai/OLMo-7B-SFT-hf", + "allenai/OLMoE-1B-7B-0924", + "allenai/OLMoE-1B-7B-0924-SFT", + "allenai/OLMoE-1B-7B-0924-Instruct", ] """Official model names for models on HuggingFace.""" @@ -1469,6 +1473,34 @@ def convert_hf_model_config(model_name: str, **kwargs): "positional_embedding_type": "rotary", "gated_mlp": True, } + elif architecture == "OlmoeForCausalLM": + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": hf_config.max_position_embeddings, + "eps": hf_config.rms_norm_eps, + "d_vocab": hf_config.vocab_size, + "act_fn": hf_config.hidden_act, + "num_experts": hf_config.num_experts, + "experts_per_token": hf_config.num_experts_per_tok, + # TODO: implement! + # "router_aux_loss_coef": hf_config.router_aux_loss_coef, + # "router_z_loss_coef": hf_config.router_z_loss_coef, + # "norm_topk_prob": hf_config.norm_topk_prob, + # end + "n_key_value_heads": hf_config.num_key_value_heads, + "rotary_base": hf_config.rope_theta, + "tie_word_embeddings": hf_config.tie_word_embeddings, + "initializer_range": hf_config.initializer_range, + "positional_embedding_type": "rotary", + "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, + "final_rms": True, + "gated_mlp": True, + "normalization_type": "RMS", + } elif architecture == "T5ForConditionalGeneration": cfg_dict = { "d_model": hf_config.d_model, @@ -1889,6 +1921,8 @@ def get_pretrained_state_dict( state_dict = convert_gemma_weights(hf_model, cfg) elif cfg.original_architecture == "OlmoForCausalLM": state_dict = convert_olmo_weights(hf_model, cfg) + elif cfg.original_architecture == "OlmoeForCausalLM": + state_dict = convert_olmoe_weights(hf_model, cfg) else: raise ValueError( f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." diff --git a/transformer_lens/pretrained/weight_conversions/__init__.py b/transformer_lens/pretrained/weight_conversions/__init__.py index 8f942e46d..bb2146832 100644 --- a/transformer_lens/pretrained/weight_conversions/__init__.py +++ b/transformer_lens/pretrained/weight_conversions/__init__.py @@ -19,3 +19,4 @@ from .t5 import convert_t5_weights from .neel_solu_old import convert_neel_solu_old_weights from .olmo import convert_olmo_weights +from .olmoe import convert_olmoe_weights diff --git a/transformer_lens/pretrained/weight_conversions/olmoe.py b/transformer_lens/pretrained/weight_conversions/olmoe.py new file mode 100644 index 000000000..02dc1972e --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/olmoe.py @@ -0,0 +1,64 @@ +import einops +import torch + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def convert_olmoe_weights(olmoe, cfg: HookedTransformerConfig): + state_dict = {} + + assert cfg.n_key_value_heads is not None + assert cfg.d_mlp is not None + assert cfg.num_experts is not None + + state_dict["embed.W_E"] = olmoe.model.embed_tokens.weight + + for l in range(cfg.n_layers): + olmoe_layer = olmoe.model.layers[l] + state_dict[f"blocks.{l}.ln1.w"] = olmoe_layer.input_layernorm.weight + + W_Q = olmoe.model.layers[l].self_attn.q_proj.weight + W_K = olmoe.model.layers[l].self_attn.k_proj.weight + W_V = olmoe.model.layers[l].self_attn.v_proj.weight + W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads) + W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn._W_K"] = W_K + state_dict[f"blocks.{l}.attn._W_V"] = W_V + + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) + + W_O = olmoe_layer.self_attn.o_proj.weight + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_O"] = W_O + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.ln2.w"] = olmoe_layer.post_attention_layernorm.weight + + state_dict[f"blocks.{l}.mlp.W_gate.weight"] = olmoe_layer.mlp.gate.weight + + for e in range(cfg.num_experts): + state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = olmoe_layer.mlp.experts[ + e + ].up_proj.weight + state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = olmoe_layer.mlp.experts[ + e + ].gate_proj.weight + state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.weight"] = olmoe_layer.mlp.experts[ + e + ].down_proj.weight + + state_dict["ln_final.w"] = olmoe.model.norm.weight + + state_dict["unembed.W_U"] = olmoe.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) + + return state_dict