Skip to content

Commit

Permalink
fix: Update Gemma to reflect upstream HF changes (#596)
Browse files Browse the repository at this point in the history
* update activation function to tanh approximation

* keep RMSNorm calcs in float32 and match cfg dtype for embedding scaling

* formatting

* keep mypy happy

* formatting
  • Loading branch information
cmathw authored May 15, 2024
1 parent 3d6dbbb commit 0fd85b9
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"n_ctx": 8192,
"eps": 1e-06,
"d_vocab": 256000,
"act_fn": "gelu",
"act_fn": "gelu_new",
"initializer_range": 0.02,
"normalization_type": "RMS",
"rotary_base": 10000.0,
Expand All @@ -1130,7 +1130,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"n_ctx": 8192,
"eps": 1e-06,
"d_vocab": 256000,
"act_fn": "gelu",
"act_fn": "gelu_new",
"initializer_range": 0.02,
"normalization_type": "RMS",
"rotary_base": 10000.0,
Expand Down Expand Up @@ -2592,19 +2592,22 @@ def convert_phi_weights(phi, cfg: HookedTransformerConfig):
def convert_gemma_weights(gemma, cfg: HookedTransformerConfig):
state_dict = {}

assert cfg.n_key_value_heads is not None # mypy
assert cfg.d_mlp is not None # mypy
assert cfg.n_key_value_heads is not None # keep mypy happy
assert cfg.d_mlp is not None # keep mypy happy

# Gemma Models scale embeddings by multiplying by sqrt(d_model)
state_dict["embed.W_E"] = gemma.model.embed_tokens.weight * (cfg.d_model**0.5)
# Gemma Models scale embeddings by multiplying by sqrt(d_model), use hidden state type to match
# HF implementation
state_dict["embed.W_E"] = gemma.model.embed_tokens.weight * torch.tensor(
cfg.d_model**0.5, dtype=cfg.dtype
)

# Gemma has no biases anywhere
for l in range(cfg.n_layers):
# GemmaRMSNorm adds 1 to weights before multiplying by input
# GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32
state_dict[f"blocks.{l}.ln1.w"] = gemma.model.layers[
l
].input_layernorm.weight + torch.ones_like(
gemma.model.layers[l].input_layernorm.weight, dtype=cfg.dtype
].input_layernorm.weight.float() + torch.ones_like(
gemma.model.layers[l].input_layernorm.weight, dtype=torch.float32
)

W_Q = gemma.model.layers[l].self_attn.q_proj.weight
Expand All @@ -2631,11 +2634,11 @@ def convert_gemma_weights(gemma, cfg: HookedTransformerConfig):

state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)

# GemmaRMSNorm adds 1 to weights before multiplying by input
# GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32
state_dict[f"blocks.{l}.ln2.w"] = gemma.model.layers[
l
].post_attention_layernorm.weight + torch.ones_like(
gemma.model.norm.weight, dtype=cfg.dtype
].post_attention_layernorm.weight.float() + torch.ones_like(
gemma.model.norm.weight, dtype=torch.float32
)

state_dict[f"blocks.{l}.mlp.W_in"] = gemma.model.layers[l].mlp.up_proj.weight.T
Expand All @@ -2645,9 +2648,9 @@ def convert_gemma_weights(gemma, cfg: HookedTransformerConfig):
state_dict[f"blocks.{l}.mlp.W_out"] = gemma.model.layers[l].mlp.down_proj.weight.T
state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)

# GemmaRMSNorm adds 1 to weights before multiplying by input
state_dict["ln_final.w"] = gemma.model.norm.weight + torch.ones_like(
gemma.model.norm.weight, dtype=cfg.dtype
# GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32
state_dict["ln_final.w"] = gemma.model.norm.weight.float() + torch.ones_like(
gemma.model.norm.weight, dtype=torch.float32
)

state_dict["unembed.W_U"] = gemma.lm_head.weight.T
Expand Down

0 comments on commit 0fd85b9

Please sign in to comment.