-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix weights not properly initialized due to shape mismatch #28073
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments
_loaded_keys = [ | ||
checkpoint_key_to_model_key(k, remove_prefix_from_model, add_prefix_to_model) for k in loaded_keys | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just simplified by the new method checkpoint_key_to_model_key
above
model_key = checkpoint_key_to_model_key( | ||
checkpoint_key, remove_prefix_from_model, add_prefix_to_model | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
if _fast_init: | ||
mismatched_model_keys = [ | ||
checkpoint_key_to_model_key(x[0], remove_prefix_from_model, add_prefix_to_model) | ||
for x in mismatched_keys | ||
] | ||
set_initialized_submodules(model, mismatched_model_keys, loaded=False) | ||
# This will only initialize submodules that are re-marked as `not loaded` above due to mismatched | ||
model.apply(model._initialize_weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look at the mismatched_model_keys
and initialize the weights for them
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. | ||
if _fast_init: | ||
def checkpoint_key_to_model_key(key, remove_prefix_from_model, add_prefix_to_model): | ||
model_key = _fix_key(key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In #16920, original_loaded_keys
is used when calling _find_mismatched_keys
. Before that PR, it was loaded_keys
being used (the one had went through _fix_key
).
That PR forgot to recompute the model_key
via _fix_key
inside _find_mismatched_keys
.
Here, we fix this issue too.
if remove_prefix_from_model: | ||
_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys] | ||
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. | ||
model_key = f"{prefix}.{key}" | ||
elif add_prefix_to_model: | ||
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] | ||
else: | ||
_loaded_keys = loaded_keys | ||
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. | ||
model_key = key[len(prefix) + 1 :] | ||
|
||
return model_key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do this conversion in several places. So a (inner) function to serve them.
not_loaded_keys = [ | ||
k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.") | ||
] | ||
if set(module.state_dict().keys()) == set(not_loaded_keys): | ||
module._is_hf_initialized = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to be able to set some module(s) as not initialized in the case we have some checkpoint/model weight(s) having different shape(s) - that we can only find later.
In this case, currently on main, we have:
- this block marks them as initialized (but not really initialize the weight)
- later, when we want to set the model weight by the checkpoint weight, it is skipped as the shape doesn't match (if ignore_mismatched_sizes=True)
- so that/those weights are not initialized by the model's _init_weights and could get crazy values like 1e37, and cause training issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition set(module.state_dict().keys()) == set(not_loaded_keys)
is to make sure we only initialize that module when none of its weights is initialized (due to shape mismatched).
If some weights of a module are loaded from checkpoint but some of its weights are not loaded due to shape mismatch, we don't really have a way to proceed: as a model's _init_weight
method can only operation on a module, not its individual weights.
As this touches the core file, request 2 reviews 🙏 |
Guess I have to add some tests for this case. Change to draft for now |
Currently, if there is some weight shape mismatched between the model and the checkpoint, and if
ignore_mismatched_sizes=True
, that/those weight(s) won't get initialized by the model's_init_weights
method, and could get crazy values like1e37
.This could make the training gets nan loss value from the beginning and won't have any progress.
One example is by running
src/transformers/modeling_utils.py
(addignore_mismatched_sizes=True
).We usually set
ignore_mismatched_sizes=True
when we want to perform classification tasks using an existing model but to another task having different number of targets.This PR aims to fix this issue.