diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7e5d3e54e619e8..bc996dca369c9f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -539,15 +539,24 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): ) -def set_initialized_submodules(model, state_dict_keys): +def set_initialized_submodules(model, state_dict_keys, loaded=True): """ Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state dict. """ for module_name, module in model.named_modules(): - loaded_keys = [k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")] - if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0: - module._is_hf_initialized = True + if loaded: + loaded_keys = [ + k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.") + ] + if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0: + module._is_hf_initialized = loaded + else: + not_loaded_keys = [ + k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.") + ] + if set(module.state_dict().keys()) == set(not_loaded_keys): + module._is_hf_initialized = False def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): @@ -3955,14 +3964,22 @@ def _fix_key(key): model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype) ) - # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. - if _fast_init: + def checkpoint_key_to_model_key(key, remove_prefix_from_model, add_prefix_to_model): + model_key = _fix_key(key) if remove_prefix_from_model: - _loaded_keys = [f"{prefix}.{k}" for k in loaded_keys] + # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. + model_key = f"{prefix}.{key}" elif add_prefix_to_model: - _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] - else: - _loaded_keys = loaded_keys + # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. + model_key = key[len(prefix) + 1 :] + + return model_key + + # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. + if _fast_init: + _loaded_keys = [ + checkpoint_key_to_model_key(k, remove_prefix_from_model, add_prefix_to_model) for k in loaded_keys + ] set_initialized_submodules(model, _loaded_keys) # This will only initialize submodules that are not marked as initialized by the line above. model.apply(model._initialize_weights) @@ -4004,13 +4021,9 @@ def _find_mismatched_keys( # If the checkpoint is sharded, we may not have the key here. if checkpoint_key not in state_dict: continue - model_key = checkpoint_key - if remove_prefix_from_model: - # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. - model_key = f"{prefix}.{checkpoint_key}" - elif add_prefix_to_model: - # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. - model_key = ".".join(checkpoint_key.split(".")[1:]) + model_key = checkpoint_key_to_model_key( + checkpoint_key, remove_prefix_from_model, add_prefix_to_model + ) if ( model_key in model_state_dict @@ -4157,6 +4170,15 @@ def _find_mismatched_keys( load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder) shutil.rmtree(state_dict_folder) + if _fast_init: + mismatched_model_keys = [ + checkpoint_key_to_model_key(x[0], remove_prefix_from_model, add_prefix_to_model) + for x in mismatched_keys + ] + set_initialized_submodules(model, mismatched_model_keys, loaded=False) + # This will only initialize submodules that are re-marked as `not loaded` above due to mismatched + model.apply(model._initialize_weights) + if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) if "size mismatch" in error_msg: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 85e69300516164..f5b43819f161ed 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2889,6 +2889,42 @@ def test_load_with_mismatched_shapes(self): else: new_model_without_prefix(input_ids) + def test_mismatched_shapes_have_properly_initialized_weights(self): + if not self.test_mismatched_shapes: + return + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + + for model_class in self.all_model_classes: + if model_class.__name__ not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES): + continue + + with self.subTest(msg=f"Testing {model_class}"): + with tempfile.TemporaryDirectory() as tmp_dir: + model = model_class(configs_no_init) + model.save_pretrained(tmp_dir) + + # Fails when we don't set ignore_mismatched_sizes=True + with self.assertRaises(RuntimeError): + new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42) + + logger = logging.get_logger("transformers.modeling_utils") + + with CaptureLogger(logger) as cl: + new_model = AutoModelForSequenceClassification.from_pretrained( + tmp_dir, num_labels=42, ignore_mismatched_sizes=True + ) + self.assertIn("the shapes did not match", cl.out) + + for name, param in new_model.named_parameters(): + if param.requires_grad: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + def test_model_is_small(self): # Just a consistency check to make sure we are not running tests on 80M parameter models. config, _ = self.model_tester.prepare_config_and_inputs_for_common()