Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Gemma to reflect upstream HF changes #596

Merged
merged 5 commits into from
May 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"n_ctx": 8192,
"eps": 1e-06,
"d_vocab": 256000,
"act_fn": "gelu",
"act_fn": "gelu_new",
"initializer_range": 0.02,
"normalization_type": "RMS",
"rotary_base": 10000.0,
Expand All @@ -1130,7 +1130,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"n_ctx": 8192,
"eps": 1e-06,
"d_vocab": 256000,
"act_fn": "gelu",
"act_fn": "gelu_new",
"initializer_range": 0.02,
"normalization_type": "RMS",
"rotary_base": 10000.0,
Expand Down Expand Up @@ -2592,19 +2592,22 @@ def convert_phi_weights(phi, cfg: HookedTransformerConfig):
def convert_gemma_weights(gemma, cfg: HookedTransformerConfig):
state_dict = {}

assert cfg.n_key_value_heads is not None # mypy
assert cfg.d_mlp is not None # mypy
assert cfg.n_key_value_heads is not None # keep mypy happy
assert cfg.d_mlp is not None # keep mypy happy

# Gemma Models scale embeddings by multiplying by sqrt(d_model)
state_dict["embed.W_E"] = gemma.model.embed_tokens.weight * (cfg.d_model**0.5)
# Gemma Models scale embeddings by multiplying by sqrt(d_model), use hidden state type to match
# HF implementation
state_dict["embed.W_E"] = gemma.model.embed_tokens.weight * torch.tensor(
cfg.d_model**0.5, dtype=cfg.dtype
)

# Gemma has no biases anywhere
for l in range(cfg.n_layers):
# GemmaRMSNorm adds 1 to weights before multiplying by input
# GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32
state_dict[f"blocks.{l}.ln1.w"] = gemma.model.layers[
l
].input_layernorm.weight + torch.ones_like(
gemma.model.layers[l].input_layernorm.weight, dtype=cfg.dtype
].input_layernorm.weight.float() + torch.ones_like(
gemma.model.layers[l].input_layernorm.weight, dtype=torch.float32
)

W_Q = gemma.model.layers[l].self_attn.q_proj.weight
Expand All @@ -2631,11 +2634,11 @@ def convert_gemma_weights(gemma, cfg: HookedTransformerConfig):

state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)

# GemmaRMSNorm adds 1 to weights before multiplying by input
# GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32
state_dict[f"blocks.{l}.ln2.w"] = gemma.model.layers[
l
].post_attention_layernorm.weight + torch.ones_like(
gemma.model.norm.weight, dtype=cfg.dtype
].post_attention_layernorm.weight.float() + torch.ones_like(
gemma.model.norm.weight, dtype=torch.float32
)

state_dict[f"blocks.{l}.mlp.W_in"] = gemma.model.layers[l].mlp.up_proj.weight.T
Expand All @@ -2645,9 +2648,9 @@ def convert_gemma_weights(gemma, cfg: HookedTransformerConfig):
state_dict[f"blocks.{l}.mlp.W_out"] = gemma.model.layers[l].mlp.down_proj.weight.T
state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)

# GemmaRMSNorm adds 1 to weights before multiplying by input
state_dict["ln_final.w"] = gemma.model.norm.weight + torch.ones_like(
gemma.model.norm.weight, dtype=cfg.dtype
# GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32
state_dict["ln_final.w"] = gemma.model.norm.weight.float() + torch.ones_like(
gemma.model.norm.weight, dtype=torch.float32
)

state_dict["unembed.W_U"] = gemma.lm_head.weight.T
Expand Down
Loading