Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OLMoE #1

Merged
merged 5 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,17 @@ def __init__(
if "phi" in self.cfg.tokenizer_name.lower():
use_fast = False
huggingface_token = os.environ.get("HF_TOKEN", None)
add_bos_token = False if self.cfg.original_architecture == "OlmoForCausalLM" else True
add_bos_token = self.cfg.original_architecture not in [
"OlmoForCausalLM",
"OlmoeForCausalLM",
]
self.set_tokenizer(
AutoTokenizer.from_pretrained(
self.cfg.tokenizer_name,
trust_remote_code=self.cfg.trust_remote_code,
use_fast=use_fast,
token=huggingface_token,
add_bos_token=add_bos_token
add_bos_token=add_bos_token,
),
default_padding_side=default_padding_side,
)
Expand Down Expand Up @@ -689,7 +692,10 @@ def set_tokenizer(
# tokenizers like LlamaTokenizer are different when bos token is automatically/manually
# prepended, and add_bos_token cannot be dynamically controlled after initialization
# (https://github.com/huggingface/transformers/issues/25886).
if self.cfg.original_architecture != "OlmoForCausalLM":
if self.cfg.original_architecture not in [
"OlmoForCausalLM",
"OlmoeForCausalLM",
]:
tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer)
else:
tokenizer_with_bos = tokenizer
Expand Down
4 changes: 2 additions & 2 deletions transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ class HookedTransformerConfig:
NTK_by_parts_factor (float): The overall factor used in the "NTK-by-parts" method that
affects the rate of change between low and high-frequency interpolation strategies.
Defaults to 8.0.


norm_topk_prob (bool): Whether to normalize the top-k probabilities in the MoE layer.
"""

n_layers: int
Expand Down Expand Up @@ -262,6 +261,7 @@ class HookedTransformerConfig:
NTK_by_parts_low_freq_factor: float = 1.0
NTK_by_parts_high_freq_factor: float = 4.0
NTK_by_parts_factor: float = 8.0
norm_topk_prob: bool = False

def __post_init__(self):
if self.n_heads == -1:
Expand Down
28 changes: 28 additions & 0 deletions transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry
from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear
from transformer_lens.utils import get_offset_position_ids
from transformer_lens.components.rms_norm import RMSNorm

if is_bitsandbytes_available():
import bitsandbytes as bnb
Expand Down Expand Up @@ -140,6 +141,10 @@ def __init__(
# will be overwritten by the child T5Attention class
self.has_relative_attention_bias = False

if self.cfg.original_architecture == "OlmoeForCausalLM":
self.q_norm = RMSNorm(cfg, cfg.d_model)
self.k_norm = RMSNorm(cfg, cfg.d_head * cfg.n_key_value_heads)

@property
def OV(self) -> FactoredMatrix:
"""
Expand Down Expand Up @@ -195,6 +200,29 @@ def forward(

q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)

# OLMoE uses QK-norm.
if self.cfg.original_architecture == "OlmoeForCausalLM":
q = einops.rearrange(
self.q_norm(
einops.rearrange(
q,
"batch pos head_index d_head -> batch pos (head_index d_head)",
)
),
"batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head",
head_index=q.shape[2],
)
k = einops.rearrange(
self.k_norm(
einops.rearrange(
k,
"batch pos head_index d_head -> batch pos (head_index d_head)",
)
),
"batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head",
head_index=k.shape[2],
)

if past_kv_cache_entry is not None:
# Appends the new keys and values to the cached values, and automatically updates the cache
kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)
Expand Down
3 changes: 2 additions & 1 deletion transformer_lens/components/mlps/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def forward(
# both are [batch, pos, experts_per_token]
weights = self.hook_expert_weights(F.softmax(gate_logits, dim=1, dtype=torch.float))
weights, expert_indices = torch.topk(weights, self.experts_per_token, dim=-1)
weights /= weights.sum(dim=-1, keepdim=True)
if self.cfg.norm_topk_prob:
weights /= weights.sum(dim=-1, keepdim=True)
expert_indices = self.hook_expert_indices(expert_indices)
weights = weights.to(x.dtype)

Expand Down
30 changes: 30 additions & 0 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
convert_neo_weights,
convert_neox_weights,
convert_olmo_weights,
convert_olmoe_weights,
convert_opt_weights,
convert_phi3_weights,
convert_phi_weights,
Expand Down Expand Up @@ -245,6 +246,9 @@
"allenai/OLMo-1B-0724-hf",
"allenai/OLMo-7B-Instruct-hf",
"allenai/OLMo-7B-SFT-hf",
"allenai/OLMoE-1B-7B-0924",
"allenai/OLMoE-1B-7B-0924-SFT",
"allenai/OLMoE-1B-7B-0924-Instruct",
]
"""Official model names for models on HuggingFace."""

Expand Down Expand Up @@ -1469,6 +1473,30 @@ def convert_hf_model_config(model_name: str, **kwargs):
"positional_embedding_type": "rotary",
"gated_mlp": True,
}
elif architecture == "OlmoeForCausalLM":
cfg_dict = {
"d_model": hf_config.hidden_size,
"d_head": hf_config.hidden_size // hf_config.num_attention_heads,
"n_heads": hf_config.num_attention_heads,
"d_mlp": hf_config.intermediate_size,
"n_layers": hf_config.num_hidden_layers,
"n_ctx": hf_config.max_position_embeddings,
"eps": hf_config.rms_norm_eps,
"d_vocab": hf_config.vocab_size,
"act_fn": hf_config.hidden_act,
"num_experts": hf_config.num_experts,
"experts_per_token": hf_config.num_experts_per_tok,
"norm_topk_prob": hf_config.norm_topk_prob,
"n_key_value_heads": hf_config.num_key_value_heads,
"rotary_base": hf_config.rope_theta,
"tie_word_embeddings": hf_config.tie_word_embeddings,
"initializer_range": hf_config.initializer_range,
"positional_embedding_type": "rotary",
"rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
"final_rms": True,
"gated_mlp": True,
"normalization_type": "LN",
}
elif architecture == "T5ForConditionalGeneration":
cfg_dict = {
"d_model": hf_config.d_model,
Expand Down Expand Up @@ -1889,6 +1917,8 @@ def get_pretrained_state_dict(
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "OlmoForCausalLM":
state_dict = convert_olmo_weights(hf_model, cfg)
elif cfg.original_architecture == "OlmoeForCausalLM":
state_dict = convert_olmoe_weights(hf_model, cfg)
else:
raise ValueError(
f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
Expand Down
1 change: 1 addition & 0 deletions transformer_lens/pretrained/weight_conversions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .t5 import convert_t5_weights
from .neel_solu_old import convert_neel_solu_old_weights
from .olmo import convert_olmo_weights
from .olmoe import convert_olmoe_weights
68 changes: 68 additions & 0 deletions transformer_lens/pretrained/weight_conversions/olmoe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import einops
import torch

from transformer_lens.HookedTransformerConfig import HookedTransformerConfig


def convert_olmoe_weights(olmoe, cfg: HookedTransformerConfig):
state_dict = {}

assert cfg.n_key_value_heads is not None
assert cfg.d_mlp is not None
assert cfg.num_experts is not None

state_dict["embed.W_E"] = olmoe.model.embed_tokens.weight

for l in range(cfg.n_layers):
olmoe_layer = olmoe.model.layers[l]
state_dict[f"blocks.{l}.ln1.w"] = olmoe_layer.input_layernorm.weight

W_Q = olmoe_layer.self_attn.q_proj.weight
W_K = olmoe_layer.self_attn.k_proj.weight
W_V = olmoe_layer.self_attn.v_proj.weight
W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads)
W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads)
state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
state_dict[f"blocks.{l}.attn._W_K"] = W_K
state_dict[f"blocks.{l}.attn._W_V"] = W_V
state_dict[f"blocks.{l}.attn.q_norm.w"] = olmoe_layer.self_attn.q_norm.weight
state_dict[f"blocks.{l}.attn.k_norm.w"] = olmoe_layer.self_attn.k_norm.weight

state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
cfg.n_heads, cfg.d_head, dtype=cfg.dtype
)
state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(
cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
)
state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(
cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
)

W_O = olmoe_layer.self_attn.o_proj.weight
W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
state_dict[f"blocks.{l}.attn.W_O"] = W_O

state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)

state_dict[f"blocks.{l}.ln2.w"] = olmoe_layer.post_attention_layernorm.weight

state_dict[f"blocks.{l}.mlp.W_gate.weight"] = olmoe_layer.mlp.gate.weight

for e in range(cfg.num_experts):
state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = (
olmoe_layer.mlp.experts[e].up_proj.weight
)
state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = (
olmoe_layer.mlp.experts[e].gate_proj.weight
)
state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.weight"] = (
olmoe_layer.mlp.experts[e].down_proj.weight
)

state_dict["ln_final.w"] = olmoe.model.norm.weight

state_dict["unembed.W_U"] = olmoe.lm_head.weight.T
state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)

return state_dict