Skip to content

Commit

Permalink
Clean up and prettify
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Apr 24, 2024
1 parent 3d19c62 commit d3e48d0
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)

is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
Expand Down
10 changes: 6 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union



# Integrations must be imported before ML frameworks:
# isort: off
from .integrations import (
Expand Down Expand Up @@ -70,7 +69,11 @@
MODEL_MAPPING_NAMES,
)
from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from .pytorch_utils import (
ALL_LAYERNORM_LAYERS,
is_torch_greater_or_equal_than_1_13,
is_torch_greater_or_equal_than_2_3,
)
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
Expand Down Expand Up @@ -160,7 +163,6 @@
is_torch_npu_available,
is_torch_xla_available,
logging,
get_torch_version,
strtobool,
)
from .utils.quantization_config import QuantizationMethod
Expand Down Expand Up @@ -622,7 +624,7 @@ def __init__(
if (args.fp16 or args.bf16) and args.half_precision_backend == "auto":
if args.device == torch.device("cpu"):
if args.fp16:
if version.parse(get_torch_version()) < version.parse("2.3.0"):
if is_torch_greater_or_equal_than_2_3:
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
else:
args.half_precision_backend = "cpu_amp"
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
requires_backends,
)
from .utils.generic import strtobool
from .utils.import_utils import is_optimum_neuron_available, get_torch_version
from .utils.import_utils import is_optimum_neuron_available


logger = logging.get_logger(__name__)
Expand All @@ -67,7 +67,7 @@
import torch
import torch.distributed as dist

from .pytorch_utils import is_torch_greater_or_equal_than_2_0
from .pytorch_utils import is_torch_greater_or_equal_than_2_0, is_torch_greater_or_equal_than_2_3

if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState
Expand Down Expand Up @@ -1618,7 +1618,7 @@ def __post_init__(self):
if (
self.framework == "pt"
and is_torch_available()
and (self.device.type == "cpu" and (version.parse(get_torch_version()) < version.parse("2.3.0")))
and (self.device.type == "cpu" and is_torch_greater_or_equal_than_2_3)
and (self.device.type != "cuda")
and (self.device.type != "mlu")
and (self.device.type != "npu")
Expand Down

0 comments on commit d3e48d0

Please sign in to comment.