diff --git a/lightly/utils/dependency.py b/lightly/utils/dependency.py index 34c533bb7..bdd36186c 100644 --- a/lightly/utils/dependency.py +++ b/lightly/utils/dependency.py @@ -4,8 +4,7 @@ @functools.lru_cache(maxsize=1) def torchvision_vit_available() -> bool: try: - # Requires torchvision >=0.12 - import torchvision.models.vision_transformer + import torchvision.models.vision_transformer # Requires torchvision >=0.12 except ( RuntimeError, # Different CUDA versions for torch and torchvision OSError, # Different CUDA versions for torch and torchvision (old) @@ -19,8 +18,8 @@ def torchvision_vit_available() -> bool: @functools.lru_cache(maxsize=1) def timm_vit_available() -> bool: try: - # Requires timm >= 0.9.9 - import timm.models.vision_transformer + import timm.models.vision_transformer # Requires timm >= 0.3.3 + from timm.layers import LayerType # Requires timm >= 0.9.9 except ImportError: return False else: