Skip to content

Commit

Permalink
[Mamba] from pretrained issue with self.embeddings (#29851)
Browse files Browse the repository at this point in the history
* nit

* update

* oups

* Update src/transformers/models/mamba/modeling_mamba.py

Co-authored-by: Lysandre Debut <[email protected]>

---------

Co-authored-by: Lysandre Debut <[email protected]>
  • Loading branch information
2 people authored and Ita Zaporozhets committed May 14, 2024
1 parent c37849c commit d6453e7
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,15 @@ def __init__(self, config):
self.gradient_checkpointing = False
self.norm_f = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
# Initialize weights and apply final processing
self._register_load_state_dict_pre_hook(self.load_hook)
self.post_init()

def load_hook(self, state_dict, prefix, *args):
for k in state_dict:
if "embedding." in k:
state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
break

def get_input_embeddings(self):
return self.embeddings

Expand Down

0 comments on commit d6453e7

Please sign in to comment.