Skip to content

Commit

Permalink
fix: Updated the is_torch_mps_available() function to include `min_…
Browse files Browse the repository at this point in the history
…version` argument (#32545)

* Fixed wrong argument in is_torch_mps_available() function call.

* Fixed wrong argument in is_torch_mps_available() function call.

* sorted the import.

* Fixed wrong argument in is_torch_mps_available() function call.

* Fixed wrong argument in is_torch_mps_available() function call.

* Update src/transformers/utils/import_utils.py

Co-authored-by: Arthur <[email protected]>

* removed extra space.

* Added type hint for the min_version parameter.

* Added missing import.

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
Sai-Suraj-27 and ArthurZucker authored Aug 12, 2024
1 parent f1c8542 commit 2a5a6ad
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from functools import lru_cache
from itertools import chain
from types import ModuleType
from typing import Any, Tuple, Union
from typing import Any, Optional, Tuple, Union

from packaging import version

Expand Down Expand Up @@ -420,12 +420,16 @@ def is_mambapy_available():
return False


def is_torch_mps_available():
def is_torch_mps_available(min_version: Optional[str] = None):
if is_torch_available():
import torch

if hasattr(torch.backends, "mps"):
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
backend_available = torch.backends.mps.is_available() and torch.backends.mps.is_built()
if min_version is not None:
flag = version.parse(_torch_version) >= version.parse(min_version)
backend_available = backend_available and flag
return backend_available
return False


Expand Down

0 comments on commit 2a5a6ad

Please sign in to comment.