From 21ed2863571cdff0b492272733ce26247a5b6e49 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 11 Oct 2023 14:56:32 +0200 Subject: [PATCH] fix --- src/transformers/utils/import_utils.py | 28 +++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 837fb24af42a61..79684bd3b23bea 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -24,7 +24,7 @@ import sys import warnings from collections import OrderedDict -from functools import lru_cache +from functools import lru_cache, wraps from itertools import chain from types import ModuleType from typing import Any, Tuple, Union @@ -1222,6 +1222,32 @@ def __getattribute__(cls, key): requires_backends(cls, cls._backends) +def torch_required(func): + # Chose a different decorator name than in tests so it's clear they are not the same. + warnings.warn("The method `torch_required` is deprecated and will be removed in v4.36.", FutureWarning) + @wraps(func) + def wrapper(*args, **kwargs): + if is_torch_available(): + return func(*args, **kwargs) + else: + raise ImportError(f"Method `{func.__name__}` requires PyTorch.") + + return wrapper + + +def tf_required(func): + # Chose a different decorator name than in tests so it's clear they are not the same. + warnings.warn("The method `tf_required` is deprecated and will be removed in v4.36.", FutureWarning) + @wraps(func) + def wrapper(*args, **kwargs): + if is_tf_available(): + return func(*args, **kwargs) + else: + raise ImportError(f"Method `{func.__name__}` requires TF.") + + return wrapper + + def is_torch_fx_proxy(x): if is_torch_fx_available(): import torch.fx