Skip to content

Commit

Permalink
1,100%!
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Jul 9, 2024
1 parent e3a7d9b commit 321ac8f
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,11 @@ def load(module: nn.Module, state_dict, prefix=""):
if child is not None:
load(child, state_dict, prefix + name + ".")

load(model_to_load, state_dict, prefix=start_prefix)
# Adjust and remove our `start_prefix` as we don't need it anymore
state_dict = {
key[len(start_prefix) :] if key.startswith(start_prefix) else key: value for key, value in state_dict.items()
}
model_to_load.load_state_dict(state_dict, assign=True, strict=False)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict
Expand Down

0 comments on commit 321ac8f

Please sign in to comment.