Skip to content

Commit

Permalink
Add is_torchvision_available
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Jan 5, 2023
1 parent dbb96ae commit 113cc26
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 45 deletions.
49 changes: 35 additions & 14 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
is_timm_available,
is_tokenizers_available,
is_torch_available,
is_torchvision_available,
is_vision_available,
logging,
)
Expand Down Expand Up @@ -574,6 +575,7 @@
"is_tokenizers_available",
"is_torch_available",
"is_torch_tpu_available",
"is_torchvision_available",
"is_vision_available",
"logging",
],
Expand Down Expand Up @@ -852,6 +854,25 @@
]
)

# Torchvision-backed objects
try:
if not is_torchvision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torchvision_objects

_import_structure["utils.dummmy_torchvision_objects"] = [
name for name in dir(dummmy_torchvision_objects) if not name.startswith("_")
]
else:
_import_structure["models.deta"].extend(
[
"DETA_PRETRAINED_MODEL_ARCHIVE_LIST",
"DetaForObjectDetection",
"DetaModel",
"DetaPreTrainedModel",
]
)

# PyTorch-backed objects
try:
Expand Down Expand Up @@ -1315,14 +1336,6 @@
"DeiTPreTrainedModel",
]
)
_import_structure["models.deta"].extend(
[
"DETA_PRETRAINED_MODEL_ARCHIVE_LIST",
"DetaForObjectDetection",
"DetaModel",
"DetaPreTrainedModel",
]
)
_import_structure["models.dinat"].extend(
[
"DINAT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -3932,6 +3945,7 @@
is_tokenizers_available,
is_torch_available,
is_torch_tpu_available,
is_torchvision_available,
is_vision_available,
logging,
)
Expand Down Expand Up @@ -4150,6 +4164,19 @@
TableTransformerPreTrainedModel,
)

try:
if not is_torchvision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torchvision_objects import *
else:
from .models.deta import (
DETA_PRETRAINED_MODEL_ARCHIVE_LIST,
DetaForObjectDetection,
DetaModel,
DetaPreTrainedModel,
)

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -4542,12 +4569,6 @@
DeiTModel,
DeiTPreTrainedModel,
)
from .models.deta import (
DETA_PRETRAINED_MODEL_ARCHIVE_LIST,
DetaForObjectDetection,
DetaModel,
DetaPreTrainedModel,
)
from .models.dinat import (
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST,
DinatBackbone,
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/deta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torchvision_available, is_vision_available


_import_structure = {
Expand All @@ -34,7 +34,7 @@
_import_structure["image_processing_deta"] = ["DetaImageProcessor"]

try:
if not is_torch_available():
if not is_torchvision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
Expand All @@ -59,7 +59,7 @@
from .image_processing_deta import DetaImageProcessor

try:
if not is_torch_available():
if not is_torchvision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
is_torch_tpu_available,
is_torchaudio_available,
is_torchdynamo_available,
is_torchvision_available,
is_vision_available,
)

Expand Down Expand Up @@ -305,6 +306,16 @@ def require_torch(test_case):
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)


def require_torchvision(test_case):
"""
Decorator marking a test that requires Torchvision.
These tests are skipped when Torchvision isn't installed.
"""
return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case)


def require_torch_or_tf(test_case):
"""
Decorator marking a test that requires PyTorch or TensorFlow.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@
is_torchaudio_available,
is_torchdistx_available,
is_torchdynamo_available,
is_torchvision_available,
is_training_run_on_sagemaker,
is_vision_available,
requires_backends,
Expand Down
24 changes: 0 additions & 24 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1953,30 +1953,6 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


DETA_PRETRAINED_MODEL_ARCHIVE_LIST = None


class DetaForObjectDetection(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class DetaModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class DetaPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


DINAT_PRETRAINED_MODEL_ARCHIVE_LIST = None


Expand Down
27 changes: 27 additions & 0 deletions src/transformers/utils/dummy_torchvision_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends


DETA_PRETRAINED_MODEL_ARCHIVE_LIST = None


class DetaForObjectDetection(metaclass=DummyObject):
_backends = ["torchvision"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])


class DetaModel(metaclass=DummyObject):
_backends = ["torchvision"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])


class DetaPreTrainedModel(metaclass=DummyObject):
_backends = ["torchvision"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])
13 changes: 13 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,10 @@ def is_torch_available():
return _torch_available


def is_torchvision_available():
return importlib.util.find_spec("torchvision") is not None


def is_pyctcdecode_available():
return _pyctcdecode_available

Expand Down Expand Up @@ -792,6 +796,14 @@ def is_jumanpp_available():
Please note that you may need to restart your runtime after installation.
"""


# docstyle-ignore
TORCHVISION_IMPORT_ERROR = """
{0} requires the Torchvision library but it was not found in your environment. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
Please note that you may need to restart your runtime after installation.
"""

# docstyle-ignore
PYTORCH_IMPORT_ERROR_WITH_TF = """
{0} requires the PyTorch library but it was not found in your environment.
Expand Down Expand Up @@ -998,6 +1010,7 @@ def is_jumanpp_available():
("natten", (is_natten_available, NATTEN_IMPORT_ERROR)),
("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)),
("vision", (is_vision_available, VISION_IMPORT_ERROR)),
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
Expand Down
9 changes: 5 additions & 4 deletions tests/models/deta/test_modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import unittest
from typing import Dict, List, Tuple

from transformers import DetaConfig, is_torch_available, is_vision_available
from transformers import DetaConfig, is_torch_available, is_torchvision_available, is_vision_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.testing_utils import require_torchvision, require_vision, slow, torch_device

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
Expand All @@ -32,6 +32,7 @@
if is_torch_available():
import torch

if is_torchvision_available():
from transformers import DetaForObjectDetection, DetaModel


Expand Down Expand Up @@ -165,7 +166,7 @@ def create_and_check_deta_object_detection_head_model(self, config, pixel_values
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))


@require_torch
@require_torchvision
class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (DetaModel, DetaForObjectDetection) if is_torch_available() else ()
is_encoder_decoder = True
Expand Down Expand Up @@ -506,7 +507,7 @@ def prepare_img():
return image


@require_torch
@require_torchvision
@require_vision
@slow
class DetaModelIntegrationTests(unittest.TestCase):
Expand Down

0 comments on commit 113cc26

Please sign in to comment.