From 38ff85a8f3867f2e0e94d8ca0dfbe01d8d510a5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20K=C3=A4nzig?= <36882833+nkaenzig@users.noreply.github.com> Date: Mon, 18 Mar 2024 17:02:25 +0100 Subject: [PATCH] Refactoring `EmbeddingsWriter` tests (#293) --- tests/eva/core/callbacks/conftest.py | 37 ++++++++++ .../core/callbacks/writers/test_embeddings.py | 68 ++++++------------- tests/eva/core/models/modules/conftest.py | 2 +- 3 files changed, 60 insertions(+), 47 deletions(-) create mode 100644 tests/eva/core/callbacks/conftest.py diff --git a/tests/eva/core/callbacks/conftest.py b/tests/eva/core/callbacks/conftest.py new file mode 100644 index 00000000..e2014405 --- /dev/null +++ b/tests/eva/core/callbacks/conftest.py @@ -0,0 +1,37 @@ +"""Shared configuration and fixtures for callbacks unit tests.""" + +import pytest + +from eva.core.data import dataloaders, datamodules, datasets + + +@pytest.fixture(scope="function") +def datamodule( + dataset: datasets.TorchDataset, + dataloader: dataloaders.DataLoader, +) -> datamodules.DataModule: + """Returns a dummy datamodule fixture.""" + return datamodules.DataModule( + datasets=datamodules.DatasetsSchema( + train=dataset, + val=dataset, + predict=dataset, + ), + dataloaders=datamodules.DataloadersSchema( + train=dataloader, + val=dataloader, + predict=dataloader, + ), + ) + + +@pytest.fixture(scope="function") +def dataloader(batch_size: int) -> dataloaders.DataLoader: + """Test dataloader fixture.""" + return dataloaders.DataLoader( + batch_size=batch_size, + num_workers=0, + pin_memory=False, + persistent_workers=False, + prefetch_factor=None, + ) diff --git a/tests/eva/core/callbacks/writers/test_embeddings.py b/tests/eva/core/callbacks/writers/test_embeddings.py index 3990f453..0256a2ba 100644 --- a/tests/eva/core/callbacks/writers/test_embeddings.py +++ b/tests/eva/core/callbacks/writers/test_embeddings.py @@ -1,6 +1,5 @@ """Tests the embeddings writer.""" -import itertools import os import random import tempfile @@ -15,18 +14,28 @@ from typing_extensions import override from eva.core.callbacks import writers -from eva.core.data import dataloaders, datamodules, datasets +from eva.core.data import datamodules, datasets from eva.core.models import modules +SAMPLE_SHAPE = 32 -@pytest.mark.parametrize("batch_size, n_samples", list(itertools.product([5, 8], [7, 16]))) + +@pytest.mark.parametrize( + "batch_size, n_samples", + [ + (5, 7), + (8, 16), + ], +) def test_embeddings_writer(datamodule: datamodules.DataModule, model: modules.HeadModule) -> None: """Tests the embeddings writer callback.""" with tempfile.TemporaryDirectory() as output_dir: trainer = pl.Trainer( logger=False, callbacks=writers.EmbeddingsWriter( - output_dir=output_dir, dataloader_idx_map={0: "train", 1: "val", 2: "test"} + output_dir=output_dir, + dataloader_idx_map={0: "train", 1: "val", 2: "test"}, + backbone=nn.Flatten(), ), ) all_predictions = trainer.predict( @@ -66,64 +75,31 @@ def test_embeddings_writer(datamodule: datamodules.DataModule, model: modules.He @pytest.fixture(scope="function") -def model(input_shape: int = 32, n_classes: int = 4) -> modules.HeadModule: +def model(n_classes: int = 4) -> modules.HeadModule: """Returns a HeadModule model fixture.""" return modules.HeadModule( - head=nn.Linear(input_shape, n_classes), + head=nn.Linear(SAMPLE_SHAPE, n_classes), criterion=nn.CrossEntropyLoss(), - backbone=nn.Flatten(), - ) - - -@pytest.fixture(scope="function") -def datamodule( - dataset: List[datasets.TorchDataset], - dataloader: dataloaders.DataLoader, -) -> datamodules.DataModule: - """Returns a dummy classification datamodule fixture.""" - return datamodules.DataModule( - datasets=datamodules.DatasetsSchema( - predict=dataset, - ), - dataloaders=datamodules.DataloadersSchema( - predict=dataloader, - ), + backbone=None, ) @pytest.fixture(scope="function") def dataset( n_samples: int, - sample_shape: int = 32, -) -> List[datasets.TorchDataset]: +) -> List[datasets.Dataset]: """Fake dataset fixture.""" - train_dataset = FakeDataset(split="train", length=n_samples, size=sample_shape) - val_dataset = FakeDataset(split="val", length=n_samples, size=sample_shape) - test_dataset = FakeDataset(split="test", length=n_samples, size=sample_shape) - return [train_dataset, val_dataset, test_dataset] - + train_dataset = FakeDataset(split="train", length=n_samples, size=SAMPLE_SHAPE) + val_dataset = FakeDataset(split="val", length=n_samples, size=SAMPLE_SHAPE) + test_dataset = FakeDataset(split="test", length=n_samples, size=SAMPLE_SHAPE) -@pytest.fixture(scope="function") -def dataloader(batch_size: int) -> dataloaders.DataLoader: - """Test dataloader fixture.""" - return dataloaders.DataLoader( - batch_size=batch_size, - num_workers=0, - pin_memory=False, - persistent_workers=False, - prefetch_factor=None, - ) + return [train_dataset, val_dataset, test_dataset] class FakeDataset(boring_classes.RandomDataset, datasets.Dataset): """Fake prediction dataset.""" - def __init__( - self, - split: Literal["train", "val", "test"], - size: int = 32, - length: int = 10, - ) -> None: + def __init__(self, split: Literal["train", "val", "test"], size: int = 32, length: int = 10): """Initializes the dataset.""" super().__init__(size=size, length=length) self._split = split diff --git a/tests/eva/core/models/modules/conftest.py b/tests/eva/core/models/modules/conftest.py index fbcff042..077f5551 100644 --- a/tests/eva/core/models/modules/conftest.py +++ b/tests/eva/core/models/modules/conftest.py @@ -16,7 +16,7 @@ def datamodule( dataset_fixture: str, dataloader: dataloaders.DataLoader, ) -> datamodules.DataModule: - """Returns a dummy classification datamodule fixture.""" + """Returns a dummy datamodule fixture.""" dataset = request.getfixturevalue(dataset_fixture) return datamodules.DataModule( datasets=datamodules.DatasetsSchema(