From dec84b3211992e20daabe7bcd7e9534b2cc7cc01 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 15 Dec 2023 16:01:18 +0100 Subject: [PATCH] make torch.load a bit safer (#27282) * make torch.load a bit safer * Fixes --------- Co-authored-by: Lysandre --- src/transformers/convert_pytorch_checkpoint_to_tf2.py | 2 +- src/transformers/modeling_flax_pytorch_utils.py | 4 ++-- src/transformers/modeling_tf_pytorch_utils.py | 2 +- src/transformers/modeling_utils.py | 4 ++-- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 2 +- src/transformers/trainer.py | 8 ++++---- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/transformers/convert_pytorch_checkpoint_to_tf2.py b/src/transformers/convert_pytorch_checkpoint_to_tf2.py index f1358408a5cb57..f300b0bb92c661 100755 --- a/src/transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/src/transformers/convert_pytorch_checkpoint_to_tf2.py @@ -329,7 +329,7 @@ def convert_pt_checkpoint_to_tf( if compare_with_pt_model: tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network - state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu") + state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu", weights_only=True) pt_model = pt_model_class.from_pretrained( pretrained_model_name_or_path=None, config=config, state_dict=state_dict ) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index f78c4e78c78ba8..f6014d7c208ab6 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -68,7 +68,7 @@ def load_pytorch_checkpoint_in_flax_state_dict( for k in f.keys(): pt_state_dict[k] = f.get_tensor(k) else: - pt_state_dict = torch.load(pt_path, map_location="cpu") + pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=True) logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) @@ -249,7 +249,7 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): flax_state_dict = {} for shard_file in shard_filenames: # load using msgpack utils - pt_state_dict = torch.load(shard_file) + pt_state_dict = torch.load(shard_file, weights_only=True) pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} model_prefix = flax_model.base_model_prefix diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index c599b795bf1932..aca1b9e4d9dccf 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -186,7 +186,7 @@ def load_pytorch_checkpoint_in_tf2_model( if pt_path.endswith(".safetensors"): state_dict = safe_load_file(pt_path) else: - state_dict = torch.load(pt_path, map_location="cpu") + state_dict = torch.load(pt_path, map_location="cpu", weights_only=True) pt_state_dict.update(state_dict) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7e5d3e54e619e8..8be9709d072afe 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -480,7 +480,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): error_message += f"\nMissing key(s): {str_unexpected_keys}." raise RuntimeError(error_message) - loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu") + loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True) for shard_file in shard_files: state_dict = loader(os.path.join(folder, shard_file)) @@ -516,7 +516,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): else: map_location = "cpu" - return torch.load(checkpoint_file, map_location=map_location) + return torch.load(checkpoint_file, map_location=map_location, weights_only=True) except Exception as e: try: with open(checkpoint_file) as f: diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 3d97e7c73d3522..ddfa2e21263f0f 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1333,7 +1333,7 @@ def load_adapter(self, target_lang: str, force_load=True, **kwargs): cache_dir=cache_dir, ) - state_dict = torch.load(weight_path, map_location="cpu") + state_dict = torch.load(weight_path, map_location="cpu", weights_only=True) except EnvironmentError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3a4ff5528047ae..0b56488907fc17 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2086,7 +2086,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None): logger.warning( "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported." ) - state_dict = torch.load(weights_file, map_location="cpu") + state_dict = torch.load(weights_file, map_location="cpu", weights_only=True) # Required for smp to not auto-translate state_dict from hf to smp (is already smp). state_dict["_smp_is_partial"] = False load_result = model.load_state_dict(state_dict, strict=True) @@ -2099,7 +2099,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None): if self.args.save_safetensors and os.path.isfile(safe_weights_file): state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu") else: - state_dict = torch.load(weights_file, map_location="cpu") + state_dict = torch.load(weights_file, map_location="cpu", weights_only=True) # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # which takes *args instead of **kwargs @@ -2167,7 +2167,7 @@ def _load_best_model(self): if self.args.save_safetensors and os.path.isfile(best_safe_model_path): state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") else: - state_dict = torch.load(best_model_path, map_location="cpu") + state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True) state_dict["_smp_is_partial"] = False load_result = model.load_state_dict(state_dict, strict=True) @@ -2196,7 +2196,7 @@ def _load_best_model(self): if self.args.save_safetensors and os.path.isfile(best_safe_model_path): state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") else: - state_dict = torch.load(best_model_path, map_location="cpu") + state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True) # If the model is on the GPU, it still works! # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963