diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index cab0b0d4aec72b..ae6c0627bb2677 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -28,6 +28,7 @@ parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) +is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3") is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2") is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1") is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0") diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e5411bb579f41a..a0551aea59f370 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -36,7 +36,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union - # Integrations must be imported before ML frameworks: # isort: off from .integrations import ( @@ -70,7 +69,11 @@ MODEL_MAPPING_NAMES, ) from .optimization import Adafactor, get_scheduler -from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 +from .pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + is_torch_greater_or_equal_than_1_13, + is_torch_greater_or_equal_than_2_3, +) from .tokenization_utils_base import PreTrainedTokenizerBase from .trainer_callback import ( CallbackHandler, @@ -160,7 +163,6 @@ is_torch_npu_available, is_torch_xla_available, logging, - get_torch_version, strtobool, ) from .utils.quantization_config import QuantizationMethod @@ -622,7 +624,7 @@ def __init__( if (args.fp16 or args.bf16) and args.half_precision_backend == "auto": if args.device == torch.device("cpu"): if args.fp16: - if version.parse(get_torch_version()) < version.parse("2.3.0"): + if is_torch_greater_or_equal_than_2_3: raise ValueError("Tried to use `fp16` but it is not supported on cpu") else: args.half_precision_backend = "cpu_amp" diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 944ee4434af89f..6e971bd9f2997d 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -56,7 +56,7 @@ requires_backends, ) from .utils.generic import strtobool -from .utils.import_utils import is_optimum_neuron_available, get_torch_version +from .utils.import_utils import is_optimum_neuron_available logger = logging.get_logger(__name__) @@ -67,7 +67,7 @@ import torch import torch.distributed as dist - from .pytorch_utils import is_torch_greater_or_equal_than_2_0 + from .pytorch_utils import is_torch_greater_or_equal_than_2_0, is_torch_greater_or_equal_than_2_3 if is_accelerate_available(): from accelerate.state import AcceleratorState, PartialState @@ -1618,7 +1618,7 @@ def __post_init__(self): if ( self.framework == "pt" and is_torch_available() - and (self.device.type == "cpu" and (version.parse(get_torch_version()) < version.parse("2.3.0"))) + and (self.device.type == "cpu" and is_torch_greater_or_equal_than_2_3) and (self.device.type != "cuda") and (self.device.type != "mlu") and (self.device.type != "npu")