diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index c5b80c617c85de..08f9fc82643ac8 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -186,7 +186,9 @@ is_training_run_on_sagemaker, is_vision_available, requires_backends, + tf_required, torch_only_method, + torch_required, ) from .peft_utils import ( ADAPTER_CONFIG_NAME, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 837fb24af42a61..e2920b291a3462 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -24,7 +24,7 @@ import sys import warnings from collections import OrderedDict -from functools import lru_cache +from functools import lru_cache, wraps from itertools import chain from types import ModuleType from typing import Any, Tuple, Union @@ -1222,6 +1222,40 @@ def __getattribute__(cls, key): requires_backends(cls, cls._backends) +def torch_required(func): + warnings.warn( + "The method `torch_required` is deprecated and will be removed in v4.36. Use `requires_backends` instead.", + FutureWarning, + ) + + # Chose a different decorator name than in tests so it's clear they are not the same. + @wraps(func) + def wrapper(*args, **kwargs): + if is_torch_available(): + return func(*args, **kwargs) + else: + raise ImportError(f"Method `{func.__name__}` requires PyTorch.") + + return wrapper + + +def tf_required(func): + warnings.warn( + "The method `tf_required` is deprecated and will be removed in v4.36. Use `requires_backends` instead.", + FutureWarning, + ) + + # Chose a different decorator name than in tests so it's clear they are not the same. + @wraps(func) + def wrapper(*args, **kwargs): + if is_tf_available(): + return func(*args, **kwargs) + else: + raise ImportError(f"Method `{func.__name__}` requires TF.") + + return wrapper + + def is_torch_fx_proxy(x): if is_torch_fx_available(): import torch.fx