From fedb34fd3e33d73d1d54363d321d6d421c950535 Mon Sep 17 00:00:00 2001 From: cmathw <108584265+cmathw@users.noreply.github.com> Date: Wed, 15 May 2024 21:52:22 +0100 Subject: [PATCH] fix: Update Gemma to reflect upstream HF changes (#596) * update activation function to tanh approximation * keep RMSNorm calcs in float32 and match cfg dtype for embedding scaling * formatting * keep mypy happy * formatting --- transformer_lens/loading_from_pretrained.py | 33 +++++++++++---------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 0a8d132cc..afb560d77 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1108,7 +1108,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "n_ctx": 8192, "eps": 1e-06, "d_vocab": 256000, - "act_fn": "gelu", + "act_fn": "gelu_new", "initializer_range": 0.02, "normalization_type": "RMS", "rotary_base": 10000.0, @@ -1130,7 +1130,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "n_ctx": 8192, "eps": 1e-06, "d_vocab": 256000, - "act_fn": "gelu", + "act_fn": "gelu_new", "initializer_range": 0.02, "normalization_type": "RMS", "rotary_base": 10000.0, @@ -2592,19 +2592,22 @@ def convert_phi_weights(phi, cfg: HookedTransformerConfig): def convert_gemma_weights(gemma, cfg: HookedTransformerConfig): state_dict = {} - assert cfg.n_key_value_heads is not None # mypy - assert cfg.d_mlp is not None # mypy + assert cfg.n_key_value_heads is not None # keep mypy happy + assert cfg.d_mlp is not None # keep mypy happy - # Gemma Models scale embeddings by multiplying by sqrt(d_model) - state_dict["embed.W_E"] = gemma.model.embed_tokens.weight * (cfg.d_model**0.5) + # Gemma Models scale embeddings by multiplying by sqrt(d_model), use hidden state type to match + # HF implementation + state_dict["embed.W_E"] = gemma.model.embed_tokens.weight * torch.tensor( + cfg.d_model**0.5, dtype=cfg.dtype + ) # Gemma has no biases anywhere for l in range(cfg.n_layers): - # GemmaRMSNorm adds 1 to weights before multiplying by input + # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32 state_dict[f"blocks.{l}.ln1.w"] = gemma.model.layers[ l - ].input_layernorm.weight + torch.ones_like( - gemma.model.layers[l].input_layernorm.weight, dtype=cfg.dtype + ].input_layernorm.weight.float() + torch.ones_like( + gemma.model.layers[l].input_layernorm.weight, dtype=torch.float32 ) W_Q = gemma.model.layers[l].self_attn.q_proj.weight @@ -2631,11 +2634,11 @@ def convert_gemma_weights(gemma, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) - # GemmaRMSNorm adds 1 to weights before multiplying by input + # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32 state_dict[f"blocks.{l}.ln2.w"] = gemma.model.layers[ l - ].post_attention_layernorm.weight + torch.ones_like( - gemma.model.norm.weight, dtype=cfg.dtype + ].post_attention_layernorm.weight.float() + torch.ones_like( + gemma.model.norm.weight, dtype=torch.float32 ) state_dict[f"blocks.{l}.mlp.W_in"] = gemma.model.layers[l].mlp.up_proj.weight.T @@ -2645,9 +2648,9 @@ def convert_gemma_weights(gemma, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.mlp.W_out"] = gemma.model.layers[l].mlp.down_proj.weight.T state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) - # GemmaRMSNorm adds 1 to weights before multiplying by input - state_dict["ln_final.w"] = gemma.model.norm.weight + torch.ones_like( - gemma.model.norm.weight, dtype=cfg.dtype + # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32 + state_dict["ln_final.w"] = gemma.model.norm.weight.float() + torch.ones_like( + gemma.model.norm.weight, dtype=torch.float32 ) state_dict["unembed.W_U"] = gemma.lm_head.weight.T