diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d7fbb9c278d12d..bc996dca369c9f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -555,7 +555,7 @@ def set_initialized_submodules(model, state_dict_keys, loaded=True): not_loaded_keys = [ k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.") ] - if len(set(module.state_dict().keys()).intersection(not_loaded_keys)) > 0: + if set(module.state_dict().keys()) == set(not_loaded_keys): module._is_hf_initialized = False