diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 00e51e50909cd3..4c3cfaa48d5175 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -501,8 +501,15 @@ def __init__(self, config): self.gradient_checkpointing = False self.norm_f = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) # Initialize weights and apply final processing + self._register_load_state_dict_pre_hook(self.load_hook) self.post_init() + def load_hook(self, state_dict, prefix, *args): + for k in state_dict: + if "embedding." in k: + state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) + break + def get_input_embeddings(self): return self.embeddings