Skip to content

Commit

Permalink
Don't default to other weights file when use_safetensors=True (huggin…
Browse files Browse the repository at this point in the history
…gface#31874)

* Don't default to other weights file when use_safetensors=True

* Add tests

* Update tests/utils/test_modeling_utils.py

* Add clarifying comments to tests

* Update tests/utils/test_modeling_utils.py

* Update tests/utils/test_modeling_utils.py
  • Loading branch information
amyeroberts authored and MHRDYN7 committed Jul 23, 2024
1 parent 30d8ae9 commit ff8db80
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 6 deletions.
15 changes: 9 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3395,14 +3395,14 @@ def from_pretrained(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
)
is_sharded = True
elif os.path.isfile(
elif not use_safetensors and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
):
# Load from a PyTorch checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
)
elif os.path.isfile(
elif not use_safetensors and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
):
# Load from a sharded PyTorch checkpoint
Expand All @@ -3411,15 +3411,18 @@ def from_pretrained(
)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)):
elif not use_safetensors and (
os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index"))
or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME))
):
raise EnvironmentError(
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use"
" `from_tf=True` to load this model from those weights."
)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
elif not use_safetensors and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
):
raise EnvironmentError(
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`"
Expand Down
66 changes: 66 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,72 @@ def test_checkpoint_variant_local_sharded_safe(self):
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))

def test_checkpoint_loading_only_safetensors_available(self):
# Test that the loading behaviour is as expected when only safetensor checkpoints are available
# - We can load the model with use_safetensors=True
# - We can load the model without specifying use_safetensors i.e. we search for the available checkpoint,
# preferring safetensors
# - We cannot load the model with use_safetensors=False
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, max_shard_size="50kB", safe_serialization=True)

weights_index_name = ".".join(SAFE_WEIGHTS_INDEX_NAME.split(".")[:-1] + ["json"])
weights_index_file = os.path.join(tmp_dir, weights_index_name)
self.assertTrue(os.path.isfile(weights_index_file))

for i in range(1, 5):
weights_name = f"model-0000{i}-of-00005" + ".safetensors"
weights_name_file = os.path.join(tmp_dir, weights_name)
self.assertTrue(os.path.isfile(weights_name_file))

# Setting use_safetensors=False should raise an error as the checkpoint was saved with safetensors=True
with self.assertRaises(OSError):
_ = BertModel.from_pretrained(tmp_dir, use_safetensors=False)

# We can load the model with use_safetensors=True
new_model = BertModel.from_pretrained(tmp_dir, use_safetensors=True)

# We can load the model without specifying use_safetensors
new_model = BertModel.from_pretrained(tmp_dir)

for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))

def test_checkpoint_loading_only_pytorch_bin_available(self):
# Test that the loading behaviour is as expected when only pytorch checkpoints are available
# - We can load the model with use_safetensors=False
# - We can load the model without specifying use_safetensors i.e. we search for the available checkpoint,
# preferring safetensors but falling back to pytorch
# - We cannot load the model with use_safetensors=True
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, max_shard_size="50kB", safe_serialization=False)

weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["json"])
weights_index_file = os.path.join(tmp_dir, weights_index_name)
self.assertTrue(os.path.isfile(weights_index_file))

for i in range(1, 5):
weights_name = WEIGHTS_NAME.split(".")[0].split("_")[0] + f"_model-0000{i}-of-00005" + ".bin"
weights_name_file = os.path.join(tmp_dir, weights_name)
self.assertTrue(os.path.isfile(weights_name_file))

# Setting use_safetensors=True should raise an error as the checkpoint was saved with safetensors=False
with self.assertRaises(OSError):
_ = BertModel.from_pretrained(tmp_dir, use_safetensors=True)

# We can load the model with use_safetensors=False
new_model = BertModel.from_pretrained(tmp_dir, use_safetensors=False)

# We can load the model without specifying use_safetensors
new_model = BertModel.from_pretrained(tmp_dir)

for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))

def test_checkpoint_variant_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(EnvironmentError):
Expand Down

0 comments on commit ff8db80

Please sign in to comment.