Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix weights not properly initialized due to shape mismatch #28073

Closed
wants to merge 10 commits into from

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Dec 15, 2023

Currently, if there is some weight shape mismatched between the model and the checkpoint, and if ignore_mismatched_sizes=True, that/those weight(s) won't get initialized by the model's _init_weights method, and could get crazy values like 1e37.

This could make the training gets nan loss value from the beginning and won't have any progress.

One example is by running src/transformers/modeling_utils.py (add ignore_mismatched_sizes=True).

We usually set ignore_mismatched_sizes=True when we want to perform classification tasks using an existing model but to another task having different number of targets.

This PR aims to fix this issue.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator Author

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments

Comment on lines +3980 to +3982
_loaded_keys = [
checkpoint_key_to_model_key(k, remove_prefix_from_model, add_prefix_to_model) for k in loaded_keys
]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just simplified by the new method checkpoint_key_to_model_key above

Comment on lines +4024 to +4026
model_key = checkpoint_key_to_model_key(
checkpoint_key, remove_prefix_from_model, add_prefix_to_model
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Comment on lines +4173 to +4180
if _fast_init:
mismatched_model_keys = [
checkpoint_key_to_model_key(x[0], remove_prefix_from_model, add_prefix_to_model)
for x in mismatched_keys
]
set_initialized_submodules(model, mismatched_model_keys, loaded=False)
# This will only initialize submodules that are re-marked as `not loaded` above due to mismatched
model.apply(model._initialize_weights)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look at the mismatched_model_keys and initialize the weights for them

# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
def checkpoint_key_to_model_key(key, remove_prefix_from_model, add_prefix_to_model):
model_key = _fix_key(key)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #16920, original_loaded_keys is used when calling _find_mismatched_keys. Before that PR, it was loaded_keys being used (the one had went through _fix_key).

That PR forgot to recompute the model_key via _fix_key inside _find_mismatched_keys.

Here, we fix this issue too.

Comment on lines 3969 to +3976
if remove_prefix_from_model:
_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f"{prefix}.{key}"
elif add_prefix_to_model:
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
else:
_loaded_keys = loaded_keys
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = key[len(prefix) + 1 :]

return model_key
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do this conversion in several places. So a (inner) function to serve them.

Comment on lines +555 to +559
not_loaded_keys = [
k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")
]
if set(module.state_dict().keys()) == set(not_loaded_keys):
module._is_hf_initialized = False
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to be able to set some module(s) as not initialized in the case we have some checkpoint/model weight(s) having different shape(s) - that we can only find later.

In this case, currently on main, we have:

  • this block marks them as initialized (but not really initialize the weight)
  • later, when we want to set the model weight by the checkpoint weight, it is skipped as the shape doesn't match (if ignore_mismatched_sizes=True)
    • so that/those weights are not initialized by the model's _init_weights and could get crazy values like 1e37, and cause training issue.

Copy link
Collaborator Author

@ydshieh ydshieh Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition set(module.state_dict().keys()) == set(not_loaded_keys) is to make sure we only initialize that module when none of its weights is initialized (due to shape mismatched).

If some weights of a module are loaded from checkpoint but some of its weights are not loaded due to shape mismatch, we don't really have a way to proceed: as a model's _init_weight method can only operation on a module, not its individual weights.

@ydshieh ydshieh marked this pull request as ready for review December 15, 2023 16:23
@ydshieh
Copy link
Collaborator Author

ydshieh commented Dec 15, 2023

As this touches the core file, request 2 reviews 🙏

@ydshieh ydshieh changed the title fix not init Fix weights not properly initialized due to shape mismatch Dec 15, 2023
@ydshieh
Copy link
Collaborator Author

ydshieh commented Dec 15, 2023

Guess I have to add some tests for this case. Change to draft for now

@ydshieh ydshieh marked this pull request as draft December 15, 2023 17:29
@ydshieh ydshieh closed this Dec 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants