diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4d59635c90e97b..7b5c9290dbc2cf 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -708,7 +708,11 @@ def load(module: nn.Module, state_dict, prefix=""): if child is not None: load(child, state_dict, prefix + name + ".") - load(model_to_load, state_dict, prefix=start_prefix) + # Adjust and remove our `start_prefix` as we don't need it anymore + state_dict = { + key[len(start_prefix) :] if key.startswith(start_prefix) else key: value for key, value in state_dict.items() + } + model_to_load.load_state_dict(state_dict, assign=True, strict=False) # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so # it's safe to delete it. del state_dict