From ff8db808bb70daa9791cd676763787c804272251 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Mon, 22 Jul 2024 18:29:50 +0100 Subject: [PATCH] Don't default to other weights file when use_safetensors=True (#31874) * Don't default to other weights file when use_safetensors=True * Add tests * Update tests/utils/test_modeling_utils.py * Add clarifying comments to tests * Update tests/utils/test_modeling_utils.py * Update tests/utils/test_modeling_utils.py --- src/transformers/modeling_utils.py | 15 ++++--- tests/utils/test_modeling_utils.py | 66 ++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ce0086d1e3bcd8..a2cea6dcdc2483 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3395,14 +3395,14 @@ def from_pretrained( pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) ) is_sharded = True - elif os.path.isfile( + elif not use_safetensors and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) ): # Load from a PyTorch checkpoint archive_file = os.path.join( pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant) ) - elif os.path.isfile( + elif not use_safetensors and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) ): # Load from a sharded PyTorch checkpoint @@ -3411,15 +3411,18 @@ def from_pretrained( ) is_sharded = True # At this stage we don't have a weight file so we will raise an error. - elif os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") - ) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)): + elif not use_safetensors and ( + os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")) + or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)) + ): raise EnvironmentError( f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use" " `from_tf=True` to load this model from those weights." ) - elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)): + elif not use_safetensors and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + ): raise EnvironmentError( f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`" diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index ec39c1428b28c8..c47f26cffa2d83 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -815,6 +815,72 @@ def test_checkpoint_variant_local_sharded_safe(self): for p1, p2 in zip(model.parameters(), new_model.parameters()): self.assertTrue(torch.allclose(p1, p2)) + def test_checkpoint_loading_only_safetensors_available(self): + # Test that the loading behaviour is as expected when only safetensor checkpoints are available + # - We can load the model with use_safetensors=True + # - We can load the model without specifying use_safetensors i.e. we search for the available checkpoint, + # preferring safetensors + # - We cannot load the model with use_safetensors=False + model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, max_shard_size="50kB", safe_serialization=True) + + weights_index_name = ".".join(SAFE_WEIGHTS_INDEX_NAME.split(".")[:-1] + ["json"]) + weights_index_file = os.path.join(tmp_dir, weights_index_name) + self.assertTrue(os.path.isfile(weights_index_file)) + + for i in range(1, 5): + weights_name = f"model-0000{i}-of-00005" + ".safetensors" + weights_name_file = os.path.join(tmp_dir, weights_name) + self.assertTrue(os.path.isfile(weights_name_file)) + + # Setting use_safetensors=False should raise an error as the checkpoint was saved with safetensors=True + with self.assertRaises(OSError): + _ = BertModel.from_pretrained(tmp_dir, use_safetensors=False) + + # We can load the model with use_safetensors=True + new_model = BertModel.from_pretrained(tmp_dir, use_safetensors=True) + + # We can load the model without specifying use_safetensors + new_model = BertModel.from_pretrained(tmp_dir) + + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(torch.allclose(p1, p2)) + + def test_checkpoint_loading_only_pytorch_bin_available(self): + # Test that the loading behaviour is as expected when only pytorch checkpoints are available + # - We can load the model with use_safetensors=False + # - We can load the model without specifying use_safetensors i.e. we search for the available checkpoint, + # preferring safetensors but falling back to pytorch + # - We cannot load the model with use_safetensors=True + model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, max_shard_size="50kB", safe_serialization=False) + + weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["json"]) + weights_index_file = os.path.join(tmp_dir, weights_index_name) + self.assertTrue(os.path.isfile(weights_index_file)) + + for i in range(1, 5): + weights_name = WEIGHTS_NAME.split(".")[0].split("_")[0] + f"_model-0000{i}-of-00005" + ".bin" + weights_name_file = os.path.join(tmp_dir, weights_name) + self.assertTrue(os.path.isfile(weights_name_file)) + + # Setting use_safetensors=True should raise an error as the checkpoint was saved with safetensors=False + with self.assertRaises(OSError): + _ = BertModel.from_pretrained(tmp_dir, use_safetensors=True) + + # We can load the model with use_safetensors=False + new_model = BertModel.from_pretrained(tmp_dir, use_safetensors=False) + + # We can load the model without specifying use_safetensors + new_model = BertModel.from_pretrained(tmp_dir) + + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(torch.allclose(p1, p2)) + def test_checkpoint_variant_hub(self): with tempfile.TemporaryDirectory() as tmp_dir: with self.assertRaises(EnvironmentError):