From 0c60daa79f0d5e33bd892c0701a9c00c658f1a25 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Fri, 15 Dec 2023 14:48:47 +0100 Subject: [PATCH 01/10] fix --- src/transformers/modeling_utils.py | 42 +++++++++++++++++++----------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7e5d3e54e619e8..282d8c1295455e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -539,7 +539,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): ) -def set_initialized_submodules(model, state_dict_keys): +def set_initialized_submodules(model, state_dict_keys, loaded=True): """ Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state dict. @@ -547,7 +547,7 @@ def set_initialized_submodules(model, state_dict_keys): for module_name, module in model.named_modules(): loaded_keys = [k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")] if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0: - module._is_hf_initialized = True + module._is_hf_initialized = loaded def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): @@ -3955,14 +3955,22 @@ def _fix_key(key): model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype) ) - # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. - if _fast_init: + def checkpoint_key_to_model_key(key, remove_prefix_from_model, add_prefix_to_model): + model_key = _fix_key(key) if remove_prefix_from_model: - _loaded_keys = [f"{prefix}.{k}" for k in loaded_keys] + # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. + model_key = f"{prefix}.{key}" elif add_prefix_to_model: - _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] - else: - _loaded_keys = loaded_keys + # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. + model_key = key[len(prefix) + 1 :] + + return model_key + + # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. + if _fast_init: + _loaded_keys = [ + checkpoint_key_to_model_key(k, remove_prefix_from_model, add_prefix_to_model) for k in loaded_keys + ] set_initialized_submodules(model, _loaded_keys) # This will only initialize submodules that are not marked as initialized by the line above. model.apply(model._initialize_weights) @@ -4004,13 +4012,9 @@ def _find_mismatched_keys( # If the checkpoint is sharded, we may not have the key here. if checkpoint_key not in state_dict: continue - model_key = checkpoint_key - if remove_prefix_from_model: - # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. - model_key = f"{prefix}.{checkpoint_key}" - elif add_prefix_to_model: - # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. - model_key = ".".join(checkpoint_key.split(".")[1:]) + model_key = checkpoint_key_to_model_key( + checkpoint_key, remove_prefix_from_model, add_prefix_to_model + ) if ( model_key in model_state_dict @@ -4157,6 +4161,14 @@ def _find_mismatched_keys( load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder) shutil.rmtree(state_dict_folder) + if _fast_init: + mismatched_model_keys = [ + checkpoint_key_to_model_key(k, remove_prefix_from_model, add_prefix_to_model) for k in mismatched_keys + ] + set_initialized_submodules(model, mismatched_model_keys, loaded=False) + # This will only initialize submodules that are re-marked as `not loaded` above due to mismatched + model.apply(model._initialize_weights) + if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) if "size mismatch" in error_msg: From dec51b6c11c5d13dcbd816ed43bf3eb31cae31f1 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Fri, 15 Dec 2023 14:53:12 +0100 Subject: [PATCH 02/10] fix --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 282d8c1295455e..0378697547ad69 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4163,7 +4163,7 @@ def _find_mismatched_keys( if _fast_init: mismatched_model_keys = [ - checkpoint_key_to_model_key(k, remove_prefix_from_model, add_prefix_to_model) for k in mismatched_keys + checkpoint_key_to_model_key(x[0], remove_prefix_from_model, add_prefix_to_model) for x in mismatched_keys ] set_initialized_submodules(model, mismatched_model_keys, loaded=False) # This will only initialize submodules that are re-marked as `not loaded` above due to mismatched From 0668589af07d848aa7c01a6dd3836441c5275eb5 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Fri, 15 Dec 2023 15:28:20 +0100 Subject: [PATCH 03/10] fix --- src/transformers/modeling_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0378697547ad69..52647abd132c5d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4163,7 +4163,8 @@ def _find_mismatched_keys( if _fast_init: mismatched_model_keys = [ - checkpoint_key_to_model_key(x[0], remove_prefix_from_model, add_prefix_to_model) for x in mismatched_keys + checkpoint_key_to_model_key(x[0], remove_prefix_from_model, add_prefix_to_model) + for x in mismatched_keys ] set_initialized_submodules(model, mismatched_model_keys, loaded=False) # This will only initialize submodules that are re-marked as `not loaded` above due to mismatched From 2f58425101817f2c6c162a2710cf45a991e79a87 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Fri, 15 Dec 2023 16:03:51 +0100 Subject: [PATCH 04/10] fix --- src/transformers/modeling_utils.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 52647abd132c5d..57bb1b6f8ff9f0 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -545,9 +545,18 @@ def set_initialized_submodules(model, state_dict_keys, loaded=True): dict. """ for module_name, module in model.named_modules(): - loaded_keys = [k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")] - if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0: - module._is_hf_initialized = loaded + if loaded: + loaded_keys = [ + k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.") + ] + if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0: + module._is_hf_initialized = loaded + else: + not_loaded_keys = [ + k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.") + ] + if len(set(module.state_dict().keys()).intersection(not_loaded_keys)) == 0: + module._is_hf_initialized = False def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): From 67ba4c56cf5512a5d9452c5e2f55127bec6ce937 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Fri, 15 Dec 2023 16:10:31 +0100 Subject: [PATCH 05/10] fix --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 57bb1b6f8ff9f0..d7fbb9c278d12d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -555,7 +555,7 @@ def set_initialized_submodules(model, state_dict_keys, loaded=True): not_loaded_keys = [ k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.") ] - if len(set(module.state_dict().keys()).intersection(not_loaded_keys)) == 0: + if len(set(module.state_dict().keys()).intersection(not_loaded_keys)) > 0: module._is_hf_initialized = False From b79c54a0a3674708401e0f7c531860e117a95d33 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Fri, 15 Dec 2023 17:10:49 +0100 Subject: [PATCH 06/10] fix --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d7fbb9c278d12d..bc996dca369c9f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -555,7 +555,7 @@ def set_initialized_submodules(model, state_dict_keys, loaded=True): not_loaded_keys = [ k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.") ] - if len(set(module.state_dict().keys()).intersection(not_loaded_keys)) > 0: + if set(module.state_dict().keys()) == set(not_loaded_keys): module._is_hf_initialized = False From 69e984170abb8bd86e3ddda0af0bd7ec62e303d2 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Mon, 18 Dec 2023 14:17:01 +0100 Subject: [PATCH 07/10] add test --- tests/test_modeling_common.py | 42 +++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 85e69300516164..bac52f0e14530a 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2889,6 +2889,48 @@ def test_load_with_mismatched_shapes(self): else: new_model_without_prefix(input_ids) + def test_mismatched_shapes_have_properly_initialized_weights(self): + if not self.test_mismatched_shapes: + return + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if model_class.__name__ not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES): + continue + + with self.subTest(msg=f"Testing {model_class}"): + with tempfile.TemporaryDirectory() as tmp_dir: + model = model_class(config) + model.save_pretrained(tmp_dir) + + # Fails when we don't set ignore_mismatched_sizes=True + with self.assertRaises(RuntimeError): + new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42) + + logger = logging.get_logger("transformers.modeling_utils") + + with CaptureLogger(logger) as cl: + new_model = AutoModelForSequenceClassification.from_pretrained( + tmp_dir, num_labels=42, ignore_mismatched_sizes=True + ) + self.assertIn("the shapes did not match", cl.out) + + with CaptureLogger(logger) as cl: + new_model_2 = AutoModelForSequenceClassification.from_pretrained( + tmp_dir, + num_labels=42, + ignore_mismatched_sizes=True, + ) + self.assertIn("the shapes did not match", cl.out) + + # The classifier heads of `new_model` and `new_model_2` should contain different weight values. + diff_found = False + for key in new_model.state_dict(): + if not torch.allclose(new_model.state_dict()[key], new_model_2.state_dict()[key], atol=1e-9): + diff_found = True + break + self.assertTrue(diff_found) + def test_model_is_small(self): # Just a consistency check to make sure we are not running tests on 80M parameter models. config, _ = self.model_tester.prepare_config_and_inputs_for_common() From 16a4ad8c7f185c61e34c18173d3a1ac636732e51 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Mon, 18 Dec 2023 15:06:24 +0100 Subject: [PATCH 08/10] fix --- tests/test_modeling_common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index bac52f0e14530a..0c25737bff7e82 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2910,12 +2910,14 @@ def test_mismatched_shapes_have_properly_initialized_weights(self): logger = logging.get_logger("transformers.modeling_utils") with CaptureLogger(logger) as cl: + torch.manual_seed(0) new_model = AutoModelForSequenceClassification.from_pretrained( tmp_dir, num_labels=42, ignore_mismatched_sizes=True ) self.assertIn("the shapes did not match", cl.out) with CaptureLogger(logger) as cl: + torch.manual_seed(0) new_model_2 = AutoModelForSequenceClassification.from_pretrained( tmp_dir, num_labels=42, From 2ebe97444a6077d93c2c2522c5815db4003b1a84 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Mon, 18 Dec 2023 15:12:00 +0100 Subject: [PATCH 09/10] fix --- tests/test_modeling_common.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 0c25737bff7e82..a01f648ad640b4 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2930,8 +2930,7 @@ def test_mismatched_shapes_have_properly_initialized_weights(self): for key in new_model.state_dict(): if not torch.allclose(new_model.state_dict()[key], new_model_2.state_dict()[key], atol=1e-9): diff_found = True - break - self.assertTrue(diff_found) + self.assertFalse(diff_found) def test_model_is_small(self): # Just a consistency check to make sure we are not running tests on 80M parameter models. From 69af6d10cfe7f286965fd683d97ca7707536439e Mon Sep 17 00:00:00 2001 From: ydshieh Date: Mon, 18 Dec 2023 15:26:23 +0100 Subject: [PATCH 10/10] fix --- tests/test_modeling_common.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a01f648ad640b4..f5b43819f161ed 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2894,13 +2894,15 @@ def test_mismatched_shapes_have_properly_initialized_weights(self): return config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: if model_class.__name__ not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES): continue with self.subTest(msg=f"Testing {model_class}"): with tempfile.TemporaryDirectory() as tmp_dir: - model = model_class(config) + model = model_class(configs_no_init) model.save_pretrained(tmp_dir) # Fails when we don't set ignore_mismatched_sizes=True @@ -2910,27 +2912,18 @@ def test_mismatched_shapes_have_properly_initialized_weights(self): logger = logging.get_logger("transformers.modeling_utils") with CaptureLogger(logger) as cl: - torch.manual_seed(0) new_model = AutoModelForSequenceClassification.from_pretrained( tmp_dir, num_labels=42, ignore_mismatched_sizes=True ) self.assertIn("the shapes did not match", cl.out) - with CaptureLogger(logger) as cl: - torch.manual_seed(0) - new_model_2 = AutoModelForSequenceClassification.from_pretrained( - tmp_dir, - num_labels=42, - ignore_mismatched_sizes=True, - ) - self.assertIn("the shapes did not match", cl.out) - - # The classifier heads of `new_model` and `new_model_2` should contain different weight values. - diff_found = False - for key in new_model.state_dict(): - if not torch.allclose(new_model.state_dict()[key], new_model_2.state_dict()[key], atol=1e-9): - diff_found = True - self.assertFalse(diff_found) + for name, param in new_model.named_parameters(): + if param.requires_grad: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) def test_model_is_small(self): # Just a consistency check to make sure we are not running tests on 80M parameter models.