From 994750c86f155e133bc02ef2acf2b394b2c14300 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 25 Mar 2024 21:15:35 +0900 Subject: [PATCH 1/4] nit --- src/transformers/models/mamba/modeling_mamba.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 0e233ae4304c80..54ce15e5656822 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -501,8 +501,13 @@ def __init__(self, config): self.gradient_checkpointing = False self.norm_f = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) # Initialize weights and apply final processing + self._register_load_state_dict_pre_hook(self.load_hook) self.post_init() + def load_hook(self, state_dict, prefix, *args): + if "embedding" in state_dict: + state_dict["embeddings"] = state_dict.pop("embedding", None) + def get_input_embeddings(self): return self.embeddings From e64ed0965fb8d0d10fb817b9c455d959ddc3d73e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 25 Mar 2024 21:18:56 +0900 Subject: [PATCH 2/4] update --- src/transformers/models/mamba/modeling_mamba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 54ce15e5656822..434cb2002d556f 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -505,8 +505,8 @@ def __init__(self, config): self.post_init() def load_hook(self, state_dict, prefix, *args): - if "embedding" in state_dict: - state_dict["embeddings"] = state_dict.pop("embedding", None) + if "backbone.embeddings.weight" in state_dict: + state_dict["backbone.embeddings.weight"] = state_dict.pop("backbone.embedding.weight", None) def get_input_embeddings(self): return self.embeddings From a33ba7e6d5ec040042e062189c36241ccea88c1c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 25 Mar 2024 21:22:26 +0900 Subject: [PATCH 3/4] oups --- src/transformers/models/mamba/modeling_mamba.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 434cb2002d556f..690094faf8986a 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -505,8 +505,10 @@ def __init__(self, config): self.post_init() def load_hook(self, state_dict, prefix, *args): - if "backbone.embeddings.weight" in state_dict: - state_dict["backbone.embeddings.weight"] = state_dict.pop("backbone.embedding.weight", None) + for k in state_dict: + if "embedding" in k: + state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) + break def get_input_embeddings(self): return self.embeddings From f7c106d340ac9baad70493099755626a29e0f147 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 28 Mar 2024 08:56:39 +0100 Subject: [PATCH 4/4] Update src/transformers/models/mamba/modeling_mamba.py Co-authored-by: Lysandre Debut --- src/transformers/models/mamba/modeling_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 690094faf8986a..72b2f39fa7413c 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -506,7 +506,7 @@ def __init__(self, config): def load_hook(self, state_dict, prefix, *args): for k in state_dict: - if "embedding" in k: + if "embedding." in k: state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) break