-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix weights not properly initialized due to shape mismatch #28073
Changes from 6 commits
0c60daa
dec51b6
0668589
2f58425
67ba4c5
b79c54a
69e9841
16a4ad8
2ebe974
69af6d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -539,15 +539,24 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): | |
) | ||
|
||
|
||
def set_initialized_submodules(model, state_dict_keys): | ||
def set_initialized_submodules(model, state_dict_keys, loaded=True): | ||
""" | ||
Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state | ||
dict. | ||
""" | ||
for module_name, module in model.named_modules(): | ||
loaded_keys = [k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")] | ||
if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0: | ||
module._is_hf_initialized = True | ||
if loaded: | ||
loaded_keys = [ | ||
k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.") | ||
] | ||
if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0: | ||
module._is_hf_initialized = loaded | ||
else: | ||
not_loaded_keys = [ | ||
k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.") | ||
] | ||
if set(module.state_dict().keys()) == set(not_loaded_keys): | ||
module._is_hf_initialized = False | ||
|
||
|
||
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): | ||
|
@@ -3955,14 +3964,22 @@ def _fix_key(key): | |
model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype) | ||
) | ||
|
||
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. | ||
if _fast_init: | ||
def checkpoint_key_to_model_key(key, remove_prefix_from_model, add_prefix_to_model): | ||
model_key = _fix_key(key) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In #16920, That PR forgot to recompute the Here, we fix this issue too. |
||
if remove_prefix_from_model: | ||
_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys] | ||
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. | ||
model_key = f"{prefix}.{key}" | ||
elif add_prefix_to_model: | ||
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] | ||
else: | ||
_loaded_keys = loaded_keys | ||
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. | ||
model_key = key[len(prefix) + 1 :] | ||
|
||
return model_key | ||
Comment on lines
3969
to
+3976
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do this conversion in several places. So a (inner) function to serve them. |
||
|
||
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. | ||
if _fast_init: | ||
_loaded_keys = [ | ||
checkpoint_key_to_model_key(k, remove_prefix_from_model, add_prefix_to_model) for k in loaded_keys | ||
] | ||
Comment on lines
+3980
to
+3982
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just simplified by the new method |
||
set_initialized_submodules(model, _loaded_keys) | ||
# This will only initialize submodules that are not marked as initialized by the line above. | ||
model.apply(model._initialize_weights) | ||
|
@@ -4004,13 +4021,9 @@ def _find_mismatched_keys( | |
# If the checkpoint is sharded, we may not have the key here. | ||
if checkpoint_key not in state_dict: | ||
continue | ||
model_key = checkpoint_key | ||
if remove_prefix_from_model: | ||
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. | ||
model_key = f"{prefix}.{checkpoint_key}" | ||
elif add_prefix_to_model: | ||
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. | ||
model_key = ".".join(checkpoint_key.split(".")[1:]) | ||
model_key = checkpoint_key_to_model_key( | ||
checkpoint_key, remove_prefix_from_model, add_prefix_to_model | ||
) | ||
Comment on lines
+4024
to
+4026
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
|
||
if ( | ||
model_key in model_state_dict | ||
|
@@ -4157,6 +4170,15 @@ def _find_mismatched_keys( | |
load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder) | ||
shutil.rmtree(state_dict_folder) | ||
|
||
if _fast_init: | ||
mismatched_model_keys = [ | ||
checkpoint_key_to_model_key(x[0], remove_prefix_from_model, add_prefix_to_model) | ||
for x in mismatched_keys | ||
] | ||
set_initialized_submodules(model, mismatched_model_keys, loaded=False) | ||
# This will only initialize submodules that are re-marked as `not loaded` above due to mismatched | ||
model.apply(model._initialize_weights) | ||
Comment on lines
+4173
to
+4180
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Look at the |
||
|
||
if len(error_msgs) > 0: | ||
error_msg = "\n\t".join(error_msgs) | ||
if "size mismatch" in error_msg: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to be able to set some module(s) as not initialized in the case we have some checkpoint/model weight(s) having different shape(s) - that we can only find later.
In this case, currently on main, we have:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition
set(module.state_dict().keys()) == set(not_loaded_keys)
is to make sure we only initialize that module when none of its weights is initialized (due to shape mismatched).If some weights of a module are loaded from checkpoint but some of its weights are not loaded due to shape mismatch, we don't really have a way to proceed: as a model's
_init_weight
method can only operation on a module, not its individual weights.