Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix weights not properly initialized due to shape mismatch #28073

Closed
wants to merge 10 commits into from
Closed
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 39 additions & 17 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,15 +539,24 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
)


def set_initialized_submodules(model, state_dict_keys):
def set_initialized_submodules(model, state_dict_keys, loaded=True):
"""
Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
dict.
"""
for module_name, module in model.named_modules():
loaded_keys = [k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")]
if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0:
module._is_hf_initialized = True
if loaded:
loaded_keys = [
k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")
]
if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0:
module._is_hf_initialized = loaded
else:
not_loaded_keys = [
k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")
]
if set(module.state_dict().keys()) == set(not_loaded_keys):
module._is_hf_initialized = False
Comment on lines +555 to +559
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to be able to set some module(s) as not initialized in the case we have some checkpoint/model weight(s) having different shape(s) - that we can only find later.

In this case, currently on main, we have:

  • this block marks them as initialized (but not really initialize the weight)
  • later, when we want to set the model weight by the checkpoint weight, it is skipped as the shape doesn't match (if ignore_mismatched_sizes=True)
    • so that/those weights are not initialized by the model's _init_weights and could get crazy values like 1e37, and cause training issue.

Copy link
Collaborator Author

@ydshieh ydshieh Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition set(module.state_dict().keys()) == set(not_loaded_keys) is to make sure we only initialize that module when none of its weights is initialized (due to shape mismatched).

If some weights of a module are loaded from checkpoint but some of its weights are not loaded due to shape mismatch, we don't really have a way to proceed: as a model's _init_weight method can only operation on a module, not its individual weights.



def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
Expand Down Expand Up @@ -3955,14 +3964,22 @@ def _fix_key(key):
model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype)
)

# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
def checkpoint_key_to_model_key(key, remove_prefix_from_model, add_prefix_to_model):
model_key = _fix_key(key)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #16920, original_loaded_keys is used when calling _find_mismatched_keys. Before that PR, it was loaded_keys being used (the one had went through _fix_key).

That PR forgot to recompute the model_key via _fix_key inside _find_mismatched_keys.

Here, we fix this issue too.

if remove_prefix_from_model:
_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f"{prefix}.{key}"
elif add_prefix_to_model:
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
else:
_loaded_keys = loaded_keys
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = key[len(prefix) + 1 :]

return model_key
Comment on lines 3969 to +3976
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do this conversion in several places. So a (inner) function to serve them.


# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
_loaded_keys = [
checkpoint_key_to_model_key(k, remove_prefix_from_model, add_prefix_to_model) for k in loaded_keys
]
Comment on lines +3980 to +3982
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just simplified by the new method checkpoint_key_to_model_key above

set_initialized_submodules(model, _loaded_keys)
# This will only initialize submodules that are not marked as initialized by the line above.
model.apply(model._initialize_weights)
Expand Down Expand Up @@ -4004,13 +4021,9 @@ def _find_mismatched_keys(
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
model_key = checkpoint_key
if remove_prefix_from_model:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f"{prefix}.{checkpoint_key}"
elif add_prefix_to_model:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = ".".join(checkpoint_key.split(".")[1:])
model_key = checkpoint_key_to_model_key(
checkpoint_key, remove_prefix_from_model, add_prefix_to_model
)
Comment on lines +4024 to +4026
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same


if (
model_key in model_state_dict
Expand Down Expand Up @@ -4157,6 +4170,15 @@ def _find_mismatched_keys(
load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder)
shutil.rmtree(state_dict_folder)

if _fast_init:
mismatched_model_keys = [
checkpoint_key_to_model_key(x[0], remove_prefix_from_model, add_prefix_to_model)
for x in mismatched_keys
]
set_initialized_submodules(model, mismatched_model_keys, loaded=False)
# This will only initialize submodules that are re-marked as `not loaded` above due to mismatched
model.apply(model._initialize_weights)
Comment on lines +4173 to +4180
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look at the mismatched_model_keys and initialize the weights for them


if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
if "size mismatch" in error_msg:
Expand Down
Loading