Skip to content

Commit

Permalink
Refactoring EmbeddingsWriter tests (#293)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig authored Mar 18, 2024
1 parent 5749186 commit 38ff85a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 47 deletions.
37 changes: 37 additions & 0 deletions tests/eva/core/callbacks/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Shared configuration and fixtures for callbacks unit tests."""

import pytest

from eva.core.data import dataloaders, datamodules, datasets


@pytest.fixture(scope="function")
def datamodule(
dataset: datasets.TorchDataset,
dataloader: dataloaders.DataLoader,
) -> datamodules.DataModule:
"""Returns a dummy datamodule fixture."""
return datamodules.DataModule(
datasets=datamodules.DatasetsSchema(
train=dataset,
val=dataset,
predict=dataset,
),
dataloaders=datamodules.DataloadersSchema(
train=dataloader,
val=dataloader,
predict=dataloader,
),
)


@pytest.fixture(scope="function")
def dataloader(batch_size: int) -> dataloaders.DataLoader:
"""Test dataloader fixture."""
return dataloaders.DataLoader(
batch_size=batch_size,
num_workers=0,
pin_memory=False,
persistent_workers=False,
prefetch_factor=None,
)
68 changes: 22 additions & 46 deletions tests/eva/core/callbacks/writers/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Tests the embeddings writer."""

import itertools
import os
import random
import tempfile
Expand All @@ -15,18 +14,28 @@
from typing_extensions import override

from eva.core.callbacks import writers
from eva.core.data import dataloaders, datamodules, datasets
from eva.core.data import datamodules, datasets
from eva.core.models import modules

SAMPLE_SHAPE = 32

@pytest.mark.parametrize("batch_size, n_samples", list(itertools.product([5, 8], [7, 16])))

@pytest.mark.parametrize(
"batch_size, n_samples",
[
(5, 7),
(8, 16),
],
)
def test_embeddings_writer(datamodule: datamodules.DataModule, model: modules.HeadModule) -> None:
"""Tests the embeddings writer callback."""
with tempfile.TemporaryDirectory() as output_dir:
trainer = pl.Trainer(
logger=False,
callbacks=writers.EmbeddingsWriter(
output_dir=output_dir, dataloader_idx_map={0: "train", 1: "val", 2: "test"}
output_dir=output_dir,
dataloader_idx_map={0: "train", 1: "val", 2: "test"},
backbone=nn.Flatten(),
),
)
all_predictions = trainer.predict(
Expand Down Expand Up @@ -66,64 +75,31 @@ def test_embeddings_writer(datamodule: datamodules.DataModule, model: modules.He


@pytest.fixture(scope="function")
def model(input_shape: int = 32, n_classes: int = 4) -> modules.HeadModule:
def model(n_classes: int = 4) -> modules.HeadModule:
"""Returns a HeadModule model fixture."""
return modules.HeadModule(
head=nn.Linear(input_shape, n_classes),
head=nn.Linear(SAMPLE_SHAPE, n_classes),
criterion=nn.CrossEntropyLoss(),
backbone=nn.Flatten(),
)


@pytest.fixture(scope="function")
def datamodule(
dataset: List[datasets.TorchDataset],
dataloader: dataloaders.DataLoader,
) -> datamodules.DataModule:
"""Returns a dummy classification datamodule fixture."""
return datamodules.DataModule(
datasets=datamodules.DatasetsSchema(
predict=dataset,
),
dataloaders=datamodules.DataloadersSchema(
predict=dataloader,
),
backbone=None,
)


@pytest.fixture(scope="function")
def dataset(
n_samples: int,
sample_shape: int = 32,
) -> List[datasets.TorchDataset]:
) -> List[datasets.Dataset]:
"""Fake dataset fixture."""
train_dataset = FakeDataset(split="train", length=n_samples, size=sample_shape)
val_dataset = FakeDataset(split="val", length=n_samples, size=sample_shape)
test_dataset = FakeDataset(split="test", length=n_samples, size=sample_shape)
return [train_dataset, val_dataset, test_dataset]

train_dataset = FakeDataset(split="train", length=n_samples, size=SAMPLE_SHAPE)
val_dataset = FakeDataset(split="val", length=n_samples, size=SAMPLE_SHAPE)
test_dataset = FakeDataset(split="test", length=n_samples, size=SAMPLE_SHAPE)

@pytest.fixture(scope="function")
def dataloader(batch_size: int) -> dataloaders.DataLoader:
"""Test dataloader fixture."""
return dataloaders.DataLoader(
batch_size=batch_size,
num_workers=0,
pin_memory=False,
persistent_workers=False,
prefetch_factor=None,
)
return [train_dataset, val_dataset, test_dataset]


class FakeDataset(boring_classes.RandomDataset, datasets.Dataset):
"""Fake prediction dataset."""

def __init__(
self,
split: Literal["train", "val", "test"],
size: int = 32,
length: int = 10,
) -> None:
def __init__(self, split: Literal["train", "val", "test"], size: int = 32, length: int = 10):
"""Initializes the dataset."""
super().__init__(size=size, length=length)
self._split = split
Expand Down
2 changes: 1 addition & 1 deletion tests/eva/core/models/modules/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def datamodule(
dataset_fixture: str,
dataloader: dataloaders.DataLoader,
) -> datamodules.DataModule:
"""Returns a dummy classification datamodule fixture."""
"""Returns a dummy datamodule fixture."""
dataset = request.getfixturevalue(dataset_fixture)
return datamodules.DataModule(
datasets=datamodules.DatasetsSchema(
Expand Down

0 comments on commit 38ff85a

Please sign in to comment.