Skip to content

Commit

Permalink
Enable fp16 on CPU (#30459)
Browse files Browse the repository at this point in the history
* Check removing flag for torch

* LLM oops

* Getting there...

* More discoveries

* Change

* Clean up and prettify

* Logic check

* Not
  • Loading branch information
muellerzr authored Apr 24, 2024
1 parent d1d94d7 commit 5c57463
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)

is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@
MODEL_MAPPING_NAMES,
)
from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from .pytorch_utils import (
ALL_LAYERNORM_LAYERS,
is_torch_greater_or_equal_than_1_13,
is_torch_greater_or_equal_than_2_3,
)
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
Expand Down Expand Up @@ -620,7 +624,8 @@ def __init__(
if (args.fp16 or args.bf16) and args.half_precision_backend == "auto":
if args.device == torch.device("cpu"):
if args.fp16:
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
if not is_torch_greater_or_equal_than_2_3:
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
else:
args.half_precision_backend = "cpu_amp"
logger.info(f"Using {args.half_precision_backend} half precision backend")
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
import torch
import torch.distributed as dist

from .pytorch_utils import is_torch_greater_or_equal_than_2_0
from .pytorch_utils import is_torch_greater_or_equal_than_2_0, is_torch_greater_or_equal_than_2_3

if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState
Expand Down Expand Up @@ -1618,6 +1618,7 @@ def __post_init__(self):
if (
self.framework == "pt"
and is_torch_available()
and (self.device.type == "cpu" and not is_torch_greater_or_equal_than_2_3)
and (self.device.type != "cuda")
and (self.device.type != "mlu")
and (self.device.type != "npu")
Expand Down

0 comments on commit 5c57463

Please sign in to comment.