Skip to content

Commit

Permalink
add nnhead tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ioangatop committed Jan 22, 2024
1 parent bd8fe9c commit 7745674
Show file tree
Hide file tree
Showing 79 changed files with 174 additions and 7 deletions.
1 change: 1 addition & 0 deletions lightning_logs/version_0/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_1/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_10/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_11/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_12/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_13/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_14/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_15/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_16/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_17/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_18/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_19/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_2/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_20/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
2 changes: 2 additions & 0 deletions lightning_logs/version_20/metrics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
val/AverageLoss,epoch,step
1.3971446752548218,0,31
1 change: 1 addition & 0 deletions lightning_logs/version_21/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_22/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
2 changes: 2 additions & 0 deletions lightning_logs/version_22/metrics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
val/AverageLoss,epoch,step
1.6166414022445679,0,31
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_23/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
2 changes: 2 additions & 0 deletions lightning_logs/version_23/metrics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
val/AverageLoss,epoch,step
1.580564022064209,0,3
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_24/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
2 changes: 2 additions & 0 deletions lightning_logs/version_24/metrics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
val/AverageLoss,step,epoch
1.4413771629333496,3,0
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_25/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
2 changes: 2 additions & 0 deletions lightning_logs/version_25/metrics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
val/AverageLoss,step,epoch
1.6885499954223633,3,0
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_26/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
2 changes: 2 additions & 0 deletions lightning_logs/version_26/metrics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
val/AverageLoss,step,epoch
1.5723896026611328,3,0
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_27/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
2 changes: 2 additions & 0 deletions lightning_logs/version_27/metrics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
val/AverageLoss,epoch,step
1.7757924795150757,0,3
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_28/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
2 changes: 2 additions & 0 deletions lightning_logs/version_28/metrics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
val/AverageLoss,epoch,step
1.443694829940796,0,3
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_29/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
2 changes: 2 additions & 0 deletions lightning_logs/version_29/metrics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
step,val/AverageLoss,epoch
3,1.7681691646575928,0
1 change: 1 addition & 0 deletions lightning_logs/version_3/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_30/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
2 changes: 2 additions & 0 deletions lightning_logs/version_30/metrics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
val/AverageLoss,step,epoch
1.5971572399139404,3,0
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_31/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
2 changes: 2 additions & 0 deletions lightning_logs/version_31/metrics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
val/AverageLoss,epoch,step
1.7049283981323242,0,3
1 change: 1 addition & 0 deletions lightning_logs/version_32/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_33/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_34/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_35/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_36/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
2 changes: 2 additions & 0 deletions lightning_logs/version_36/metrics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
val/AverageLoss,step,epoch
1.0843629837036133,3,0
1 change: 1 addition & 0 deletions lightning_logs/version_37/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_38/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
2 changes: 2 additions & 0 deletions lightning_logs/version_38/metrics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
val/AverageLoss,step,epoch
1.4186309576034546,3,0
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_39/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
2 changes: 2 additions & 0 deletions lightning_logs/version_39/metrics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
step,epoch,val/AverageLoss
3,0,1.7707195281982422
1 change: 1 addition & 0 deletions lightning_logs/version_4/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_5/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_6/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_7/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_8/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_9/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ exclude_dirs = [".venv", "tests/**"]
[tool.pyright]
pythonVersion = "3.10"
reportInvalidStringEscapeSequence = false
reportIncompatibleMethodOverride = false
exclude = [
"__pypackages__",
".nox",
Expand Down
3 changes: 2 additions & 1 deletion src/eva/data/datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Datamodules API."""
from eva.data.datamodules.datamodule import DataModule
from eva.data.datamodules.schemas import DataloadersSchema, DatasetsSchema

__all__ = ["DataModule"]
__all__ = ["DataModule", "DataloadersSchema", "DatasetsSchema"]
3 changes: 2 additions & 1 deletion src/eva/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Models API."""
from eva.models.module import ModelModule
from eva.models.nnhead import NNHead

__all__ = ["ModelModule"]
__all__ = ["ModelModule", "NNHead"]
10 changes: 5 additions & 5 deletions src/eva/models/nnhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def configure_optimizers(self) -> Any:
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

@override
def forward(self, *args, tensor: torch.Tensor, **kwargs) -> torch.Tensor:
features = tensor if self.backbone is None else self.model.forward(tensor)
def forward(self, tensor: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
features = tensor if self.backbone is None else self.backbone(tensor)
return self.head(features.flatten(start_dim=1))

@override
Expand All @@ -74,15 +74,15 @@ def on_fit_start(self) -> None:
_utils.deactivate_requires_grad(self.backbone)

@override
def training_step(self, *args, batch: INPUT_BATCH, **kwargs) -> STEP_OUTPUT:
def training_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
return self._batch_step(batch)

@override
def validation_step(self, *args, batch: INPUT_BATCH, **kwargs) -> STEP_OUTPUT:
def validation_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
return self._batch_step(batch)

@override
def test_step(self, *args, batch: INPUT_BATCH, **kwargs) -> STEP_OUTPUT:
def test_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
return self._batch_step(batch)

@override
Expand Down
4 changes: 4 additions & 0 deletions src/eva/trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Trainers API."""
from eva.trainers.trainer import Trainer

__all__ = ["Trainer"]
5 changes: 5 additions & 0 deletions src/eva/trainers/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Core trainer module."""
from pytorch_lightning import trainer

Trainer = trainer.Trainer
"""Core trainer class."""
87 changes: 87 additions & 0 deletions tests/eva/models/test_nnhead.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Tests the NNHead module."""
import math
from typing import Tuple

import pytest
import torch
from torch import nn
from torch.utils import data as torch_data

from eva import metrics, models, trainers
from eva.data import dataloaders, datamodules, datasets


def test_nnhead_fit(
model: models.NNHead,
datamodule: datamodules.DataModule,
trainer: trainers.Trainer,
) -> None:
"""Tests the nnhead fit pipeline."""
initial_head_weights = model.head.weight.clone()
trainer.fit(model, datamodule=datamodule)
assert trainer.logged_metrics["train/AverageLoss"] > 0
assert trainer.logged_metrics["val/AverageLoss"] > 0
assert not torch.all(torch.eq(initial_head_weights, model.head.weight))


@pytest.fixture(scope="function")
def model(input_shape: Tuple[int, ...] = (3, 8, 8), n_classes: int = 4) -> models.NNHead:
"""Returns a NNHead model fixture."""
return models.NNHead(
head=nn.Linear(math.prod(input_shape), n_classes),
criterion=nn.CrossEntropyLoss(),
backbone=nn.Flatten(),
metrics=metrics.core.MetricsSchema(
common=metrics.AverageLoss(),
),
)


@pytest.fixture(scope="function")
def datamodule(
classification_dataset: datasets.Dataset,
dataloader: dataloaders.DataLoader,
) -> datamodules.DataModule:
"""Returns a dummy classification datamodule fixture."""
return datamodules.DataModule(
datasets=datamodules.DatasetsSchema(
train=classification_dataset,
val=classification_dataset,
),
dataloaders=datamodules.DataloadersSchema(
train=dataloader,
val=dataloader,
),
)


@pytest.fixture(scope="function")
def trainer(max_epochs: int = 1) -> trainers.Trainer:
"""Returns a model trainer fixture."""
return trainers.Trainer(max_epochs=max_epochs, accelerator="cpu")


@pytest.fixture(scope="function")
def classification_dataset(
n_samples: int = 4,
input_shape: Tuple[int, ...] = (3, 8, 8),
target_shape: Tuple[int, ...] = (),
n_classes: int = 4,
) -> datasets.Dataset:
"""Dummy classification dataset fixture."""
return torch_data.TensorDataset(
torch.randn((n_samples,) + input_shape),
torch.randint(n_classes, (n_samples,) + target_shape, dtype=torch.long),
)


@pytest.fixture(scope="function")
def dataloader(batch_size: int = 1) -> dataloaders.DataLoader:
"""Test dataloader fixture."""
return dataloaders.DataLoader(
batch_size=batch_size,
num_workers=0,
pin_memory=False,
persistent_workers=False,
prefetch_factor=None,
)

0 comments on commit 7745674

Please sign in to comment.