diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3fd067edfc5b06..46add00b018e3a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -272,6 +272,25 @@ def _get_fsdp_ckpt_kwargs(): return {} +def safe_globals(): + # Starting from version 2.4 PyTorch introduces a check for the objects loaded + # with torch.load(weights_only=True). Starting from 2.6 weights_only=True becomes + # a default and requires allowlisting of objects being loaded. + # See: https://github.com/pytorch/pytorch/pull/137602 + # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals + # See: https://github.com/huggingface/accelerate/pull/3036 + if version.parse(torch.__version__).release < version.parse("2.6").release: + return contextlib.nullcontext() + + np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core + allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype] + # numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for + # all versions of numpy + allowlist += [type(np.dtype(np.uint32))] + + return torch.serialization.safe_globals(allowlist) + + if TYPE_CHECKING: import optuna @@ -3055,7 +3074,8 @@ def _load_rng_state(self, checkpoint): ) return - checkpoint_rng_state = torch.load(rng_file) + with safe_globals(): + checkpoint_rng_state = torch.load(rng_file) random.setstate(checkpoint_rng_state["python"]) np.random.set_state(checkpoint_rng_state["numpy"]) torch.random.set_rng_state(checkpoint_rng_state["cpu"])