Skip to content

Commit

Permalink
Add a use_safetensors arg to TFPreTrainedModel.from_pretrained() (hug…
Browse files Browse the repository at this point in the history
…gingface#28511)

* Add a use_safetensors arg to TFPreTrainedModel.from_pretrained()

* One more catch!

* One more one more catch
  • Loading branch information
Rocketknight1 authored and wgifford committed Jan 21, 2024
1 parent 3c71365 commit 7827e41
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2508,6 +2508,7 @@ def from_pretrained(
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
):
r"""
Expand Down Expand Up @@ -2601,6 +2602,9 @@ def from_pretrained(
A function that is called to transform the names of weights during the PyTorch to TensorFlow
crossloading process. This is not necessary for most models, but is useful to allow composite models to
be crossloaded correctly.
use_safetensors (`bool`, *optional*, defaults to `None`):
Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
is not installed, it will be set to `False`.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
Expand Down Expand Up @@ -2673,6 +2677,9 @@ def from_pretrained(
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True

if use_safetensors is None and not is_safetensors_available():
use_safetensors = False

# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
Expand Down Expand Up @@ -2712,7 +2719,7 @@ def from_pretrained(
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
is_sharded = True
elif is_safetensors_available() and os.path.isfile(
elif use_safetensors is not False and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
):
# Load from a safetensors checkpoint
Expand All @@ -2724,14 +2731,20 @@ def from_pretrained(
# Load from a sharded TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
is_sharded = True
elif is_safetensors_available() and os.path.isfile(
elif use_safetensors is not False and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
):
# Load from a sharded safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
is_sharded = True
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
# At this stage we don't have a weight file so we will raise an error.
elif use_safetensors:
raise EnvironmentError(
f"Error no file named {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}. "
f"Please make sure that the model has been saved with `safe_serialization=True` or do not "
f"set `use_safetensors=True`."
)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile(
os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
):
Expand All @@ -2758,7 +2771,7 @@ def from_pretrained(
# set correct filename
if from_pt:
filename = WEIGHTS_NAME
elif is_safetensors_available():
elif use_safetensors is not False:
filename = SAFE_WEIGHTS_NAME
else:
filename = TF2_WEIGHTS_NAME
Expand Down

0 comments on commit 7827e41

Please sign in to comment.