Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't default to other weights file when use_safetensors=True #31874

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3395,14 +3395,14 @@ def from_pretrained(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
)
is_sharded = True
elif os.path.isfile(
elif not use_safetensors and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
):
# Load from a PyTorch checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
)
elif os.path.isfile(
elif not use_safetensors and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
):
# Load from a sharded PyTorch checkpoint
Expand All @@ -3411,15 +3411,18 @@ def from_pretrained(
)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)):
elif not use_safetensors and (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use if from_tf and ... here, to correspond to (several lines above)

            if is_local:
                if from_tf and os.path.isfile(
                    os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
                ):

os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index"))
or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME))
):
raise EnvironmentError(
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use"
" `from_tf=True` to load this model from those weights."
)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
elif not use_safetensors and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
):
Comment on lines +3423 to +3425
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

raise EnvironmentError(
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`"
Expand Down
66 changes: 66 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,72 @@ def test_checkpoint_variant_local_sharded_safe(self):
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))

def test_checkpoint_loading_only_safetensors_available(self):
# Test that the loading behaviour is as expected when only safetensor checkpoints are available
# - We can load the model with use_safetensors=True
# - We can load the model without specifying use_safetensors i.e. we search for the available checkpoint,
# preferring safetensors
# - We cannot load the model with use_safetensors=False
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, max_shard_size="50kB", safe_serialization=True)

weights_index_name = ".".join(SAFE_WEIGHTS_INDEX_NAME.split(".")[:-1] + ["json"])
weights_index_file = os.path.join(tmp_dir, weights_index_name)
self.assertTrue(os.path.isfile(weights_index_file))

for i in range(1, 5):
weights_name = f"model-0000{i}-of-00005" + ".safetensors"
weights_name_file = os.path.join(tmp_dir, weights_name)
self.assertTrue(os.path.isfile(weights_name_file))

# Setting use_safetensors=False should raise an error as the checkpoint was saved with safetensors=True
with self.assertRaises(OSError):
_ = BertModel.from_pretrained(tmp_dir, use_safetensors=False)

# We can load the model with use_safetensors=True
new_model = BertModel.from_pretrained(tmp_dir, use_safetensors=True)

# We can load the model without specifying use_safetensors
new_model = BertModel.from_pretrained(tmp_dir)

for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))

def test_checkpoint_loading_only_pytorch_bin_available(self):
# Test that the loading behaviour is as expected when only pytorch checkpoints are available
# - We can load the model with use_safetensors=False
# - We can load the model without specifying use_safetensors i.e. we search for the available checkpoint,
# preferring safetensors but falling back to pytorch
# - We cannot load the model with use_safetensors=True
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, max_shard_size="50kB", safe_serialization=False)

weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["json"])
weights_index_file = os.path.join(tmp_dir, weights_index_name)
self.assertTrue(os.path.isfile(weights_index_file))

for i in range(1, 5):
weights_name = WEIGHTS_NAME.split(".")[0].split("_")[0] + f"_model-0000{i}-of-00005" + ".bin"
weights_name_file = os.path.join(tmp_dir, weights_name)
self.assertTrue(os.path.isfile(weights_name_file))

# Setting use_safetensors=True should raise an error as the checkpoint was saved with safetensors=False
with self.assertRaises(OSError):
_ = BertModel.from_pretrained(tmp_dir, use_safetensors=True)

# We can load the model with use_safetensors=False
new_model = BertModel.from_pretrained(tmp_dir, use_safetensors=False)

# We can load the model without specifying use_safetensors
new_model = BertModel.from_pretrained(tmp_dir)

for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))

def test_checkpoint_variant_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(EnvironmentError):
Expand Down
Loading