Skip to content

Commit

Permalink
make torch.load a bit safer (#27282)
Browse files Browse the repository at this point in the history
* make torch.load a bit safer

* Fixes

---------

Co-authored-by: Lysandre <[email protected]>
  • Loading branch information
julien-c and LysandreJik authored Dec 15, 2023
1 parent 74cae67 commit dec84b3
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/transformers/convert_pytorch_checkpoint_to_tf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def convert_pt_checkpoint_to_tf(
if compare_with_pt_model:
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network

state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu", weights_only=True)
pt_model = pt_model_class.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
for k in f.keys():
pt_state_dict[k] = f.get_tensor(k)
else:
pt_state_dict = torch.load(pt_path, map_location="cpu")
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")

flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
Expand Down Expand Up @@ -249,7 +249,7 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
flax_state_dict = {}
for shard_file in shard_filenames:
# load using msgpack utils
pt_state_dict = torch.load(shard_file)
pt_state_dict = torch.load(shard_file, weights_only=True)
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}

model_prefix = flax_model.base_model_prefix
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def load_pytorch_checkpoint_in_tf2_model(
if pt_path.endswith(".safetensors"):
state_dict = safe_load_file(pt_path)
else:
state_dict = torch.load(pt_path, map_location="cpu")
state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)

pt_state_dict.update(state_dict)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)

loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu")
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True)

This comment has been minimized.

Copy link
@hjenryin

hjenryin Jan 10, 2024

See #27282 (comment)

BTW when I revert this from source, everything seems fine.


for shard_file in shard_files:
state_dict = loader(os.path.join(folder, shard_file))
Expand Down Expand Up @@ -516,7 +516,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
else:
map_location = "cpu"

return torch.load(checkpoint_file, map_location=map_location)
return torch.load(checkpoint_file, map_location=map_location, weights_only=True)
except Exception as e:
try:
with open(checkpoint_file) as f:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,7 +1333,7 @@ def load_adapter(self, target_lang: str, force_load=True, **kwargs):
cache_dir=cache_dir,
)

state_dict = torch.load(weight_path, map_location="cpu")
state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)

except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2086,7 +2086,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
logger.warning(
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
)
state_dict = torch.load(weights_file, map_location="cpu")
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
Expand All @@ -2099,7 +2099,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
else:
state_dict = torch.load(weights_file, map_location="cpu")
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)

# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
Expand Down Expand Up @@ -2167,7 +2167,7 @@ def _load_best_model(self):
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else:
state_dict = torch.load(best_model_path, map_location="cpu")
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)

state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
Expand Down Expand Up @@ -2196,7 +2196,7 @@ def _load_best_model(self):
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else:
state_dict = torch.load(best_model_path, map_location="cpu")
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)

# If the model is on the GPU, it still works!
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
Expand Down

0 comments on commit dec84b3

Please sign in to comment.