Skip to content

Commit

Permalink
refactor: add typings to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vectorvp committed Nov 27, 2024
1 parent c71dc71 commit c385598
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 83 deletions.
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,6 @@ exclude = '''(?x)(
tests/utils/benchmarking/test_linear_classifier.py |
tests/utils/benchmarking/test_metric_callback.py |
tests/utils/test_dist.py |
tests/models/modules/test_masked_autoencoder.py |
tests/models/test_ModelUtils.py |
tests/models/test_ProjectionHeads.py |
tests/conftest.py |
tests/api_workflow/test_api_workflow_selection.py |
tests/api_workflow/test_api_workflow_datasets.py |
Expand Down
36 changes: 21 additions & 15 deletions tests/models/modules/test_masked_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
dependency.torchvision_vit_available(), "Torchvision ViT not available"
)
class TestMAEEncoder(unittest.TestCase):
def _vit(self):
def _vit(self) -> torchvision.models.vision_transformer.VisionTransformer:
return torchvision.models.vision_transformer.vit_b_32(progress=False)

def test_from_vit(self):
def test_from_vit(self) -> None:
MAEEncoder.from_vit_encoder(self._vit().encoder)

def _test_forward(self, device, batch_size=8, seed=0):
def _test_forward(
self, device: torch.device, batch_size: int = 8, seed: int = 0
) -> None:
torch.manual_seed(seed)
vit = self._vit()
encoder = MAEEncoder.from_vit_encoder(vit.encoder).to(device)
Expand All @@ -42,25 +44,27 @@ def _test_forward(self, device, batch_size=8, seed=0):
# output must have reasonable numbers
self.assertTrue(torch.all(torch.not_equal(out, torch.inf)))

def test_forward(self):
self._test_forward(torch.device("cpu"))
def test_forward(self) -> None:
self._test_forward(device=torch.device("cpu"))

@unittest.skipUnless(torch.cuda.is_available(), "Cuda not available.")
def test_forward_cuda(self):
def test_forward_cuda(self) -> None:
self._test_forward(torch.device("cuda"))


@unittest.skipUnless(
dependency.torchvision_vit_available(), "Torchvision ViT not available"
)
class TestMAEBackbone(unittest.TestCase):
def _vit(self):
def _vit(self) -> torchvision.models.vision_transformer.VisionTransformer:
return torchvision.models.vision_transformer.vit_b_32(progress=False)

def test_from_vit(self):
def test_from_vit(self) -> None:
MAEBackbone.from_vit(self._vit())

def _test_forward(self, device, batch_size=8, seed=0):
def _test_forward(
self, device: torch.device, batch_size: int = 8, seed: int = 0
) -> None:
torch.manual_seed(seed)
vit = self._vit()
backbone = MAEBackbone.from_vit(vit).to(device)
Expand All @@ -80,11 +84,11 @@ def _test_forward(self, device, batch_size=8, seed=0):
# output must have reasonable numbers
self.assertTrue(torch.all(torch.not_equal(class_tokens, torch.inf)))

def test_forward(self):
def test_forward(self) -> None:
self._test_forward(torch.device("cpu"))

@unittest.skipUnless(torch.cuda.is_available(), "Cuda not available.")
def test_forward_cuda(self):
def test_forward_cuda(self) -> None:
self._test_forward(torch.device("cuda"))

def test_images_to_tokens(self) -> None:
Expand All @@ -102,7 +106,7 @@ def test_images_to_tokens(self) -> None:
dependency.torchvision_vit_available(), "Torchvision ViT not available"
)
class TestMAEDecoder(unittest.TestCase):
def test_init(self):
def test_init(self) -> None:
MAEDecoder(
seq_length=50,
num_layers=2,
Expand All @@ -113,7 +117,9 @@ def test_init(self):
out_dim=3 * 32**2,
)

def _test_forward(self, device, batch_size=8, seed=0):
def _test_forward(
self, device: torch.device, batch_size: int = 8, seed: int = 0
) -> None:
torch.manual_seed(seed)
seq_length = 50
embed_input_dim = 128
Expand All @@ -137,9 +143,9 @@ def _test_forward(self, device, batch_size=8, seed=0):
# output must have reasonable numbers
self.assertTrue(torch.all(torch.not_equal(predictions, torch.inf)))

def test_forward(self):
def test_forward(self) -> None:
self._test_forward(torch.device("cpu"))

@unittest.skipUnless(torch.cuda.is_available(), "Cuda not available.")
def test_forward_cuda(self):
def test_forward_cuda(self) -> None:
self._test_forward(torch.device("cuda"))
Loading

0 comments on commit c385598

Please sign in to comment.