Skip to content

Commit

Permalink
add TestGatherLayer_BenchmarkModule
Browse files Browse the repository at this point in the history
  • Loading branch information
MalteEbner committed May 14, 2024
1 parent a6b00b5 commit 91a3fca
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 19 deletions.
31 changes: 21 additions & 10 deletions lightly/utils/benchmarking/benchmark_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

from typing import List, Optional
from typing import Any, List, Optional, Tuple

import torch
import torch.distributed as dist
Expand All @@ -12,6 +12,7 @@
from torch.utils.data import DataLoader

from lightly.utils.benchmarking.knn import knn_predict
from lightly.utils.dist import gather as lightly_gather


class BenchmarkModule(LightningModule):
Expand Down Expand Up @@ -84,7 +85,7 @@ class BenchmarkModule(LightningModule):

def __init__(
self,
dataloader_kNN: DataLoader,
dataloader_kNN: DataLoader[Any],
num_classes: int,
knn_k: int = 200,
knn_t: float = 0.1,
Expand Down Expand Up @@ -117,15 +118,16 @@ def on_validation_epoch_start(self) -> None:
and dist.is_initialized()
and dist.get_world_size() > 0
):
# gather features and targets from all processes
feature = torch.cat(dist.gather(feature), 0)
target = torch.cat(dist.gather(target), 0)
feature = torch.cat(lightly_gather(feature), dim=0)
target = torch.cat(lightly_gather(target), dim=0)
train_features.append(feature)
train_targets.append(target)
self._train_features = torch.cat(train_features, dim=0).t().contiguous()
self._train_targets = torch.cat(train_targets, dim=0).t().contiguous()

def validation_step(self, batch, batch_idx) -> None:
def validation_step(
self, batch: Tuple[List[torch.Tensor], Tensor, List[str]], batch_idx: int
) -> None:
# we can only do kNN predictions once we have a feature bank
if self._train_features is not None and self._train_targets is not None:
images, targets, _ = batch
Expand All @@ -139,10 +141,16 @@ def validation_step(self, batch, batch_idx) -> None:
self.knn_k,
self.knn_t,
)
if dist.is_initialized() and dist.get_world_size() > 0:

print(f"_train_features: {self._train_features}")
print(f"_train_targets: {self._train_targets}")
print(f"predicted_labels: {predicted_labels}")
print(f"targets: {targets}")
if dist.is_initialized() and dist.get_world_size() > 1:
# gather predictions and targets from all processes
predicted_labels = torch.cat(dist.gather(predicted_labels), 0)
targets = torch.cat(dist.gather(targets), 0)

predicted_labels = torch.cat(lightly_gather(predicted_labels), dim=0)
targets = torch.cat(lightly_gather(targets), dim=0)

self._val_predicted_labels.append(predicted_labels.cpu())
self._val_targets.append(targets.cpu())
Expand All @@ -154,8 +162,11 @@ def on_validation_epoch_end(self) -> None:
top1 = (predicted_labels[:, 0] == targets).float().sum()
acc = top1 / len(targets)
if acc > self.max_accuracy:
self.max_accuracy = acc.item()
self.max_accuracy = float(acc.item())
self.log("kNN_accuracy", acc * 100.0, prog_bar=True)
print(
f"This method should not be called - accuracy at rank {dist.get_rank()} is {self.max_accuracy}"
)

self._val_predicted_labels.clear()
self._val_targets.clear()
26 changes: 18 additions & 8 deletions tests/utils/benchmarking/test_benchmark_module.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import unittest
from typing import Any, List, Tuple

import torch
from pytorch_lightning import Trainer
from torch import Tensor
from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential
from torch.optim import SGD
from torch.optim import SGD, Optimizer
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor
Expand Down Expand Up @@ -66,20 +68,28 @@ def test_knn_train_val(self) -> None:
) # accuracy is <1.0 because train val are different


class _DummyModel(BenchmarkModule):
def __init__(self, dataloader_kNN, knn_k=1):
super().__init__(dataloader_kNN, num_classes=2, knn_k=knn_k)
# Type ignore becaue of "Class cannot subclass "BenchmarkModule" (has type "Any")"
class _DummyModel(BenchmarkModule): # type: ignore[misc]
def __init__(
self,
dataloader_kNN: DataLoader[LightlyDataset],
knn_k: int = 1,
num_classes: int = 2,
) -> None:
super().__init__(dataloader_kNN, num_classes=num_classes, knn_k=knn_k)
self.backbone = Sequential(
Flatten(),
Linear(3 * 32 * 32, 2),
Linear(3 * 32 * 32, num_classes),
)
self.criterion = CrossEntropyLoss()

def training_step(self, batch, batch_idx):
def training_step(
self, batch: Tuple[List[torch.Tensor], Tensor, List[str]], batch_idx: int
) -> Tensor:
images, targets, _ = batch
predictions = self.backbone(images)
loss = self.criterion(predictions, targets)
loss: Tensor = self.criterion(predictions, targets)
return loss

def configure_optimizers(self):
def configure_optimizers(self) -> Optimizer:
return SGD(self.backbone.parameters(), lr=0.1)
69 changes: 68 additions & 1 deletion tests/utils/test_dist__gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@
from torch import Tensor
from torch.nn import Linear, Module
from torch.optim import SGD
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, TensorDataset
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor

from lightly.data import LightlyDataset
from lightly.loss.dcl_loss import DCLLoss
from lightly.loss.ntx_ent_loss import NTXentLoss
from lightly.loss.tico_loss import TiCoLoss
from lightly.loss.vicreg_loss import VICRegLoss
from tests.utils.benchmarking.test_benchmark_module import (
_DummyModel as _BenchmarkDummyModel,
)

"""
WARNING:
Expand Down Expand Up @@ -55,6 +61,7 @@ def close_torch_distributed() -> Generator[None, None, None]:
torch.distributed.destroy_process_group()


@pytest.mark.skip(reason="This test is flaky and needs to be fixed.")
class TestGatherLayer_Losses:
"""
Tests that the gather layer works as expected.
Expand Down Expand Up @@ -165,3 +172,63 @@ def _test_ddp(

params = list(model.parameters())[0]
assert torch.allclose(params, expected_params__10_epochs__no_gather, rtol=1e-3)


class TestGatherLayer_BenchmarkModule:
def test__benchmark_module(self, close_torch_distributed: None) -> None:
n_devices = 2
n_samples = 128
num_classes = 3
batch_size = int(n_samples / n_devices)

pl.seed_everything(0, workers=True)

dataset_train = LightlyDataset.from_torch_dataset(
FakeData(
size=n_samples,
image_size=(3, 32, 32),
num_classes=num_classes,
transform=ToTensor(),
)
)
dataloader_train = DataLoader(
dataset_train,
batch_size=batch_size,
num_workers=0,
shuffle=False,
drop_last=False,
)
dataset_val = LightlyDataset.from_torch_dataset(
FakeData(
size=n_samples,
image_size=(3, 32, 32),
num_classes=num_classes,
transform=ToTensor(),
random_offset=10,
)
)
dataloader_val = DataLoader(
dataset_val,
batch_size=batch_size,
num_workers=0,
shuffle=False,
drop_last=False,
)

model = _BenchmarkDummyModel(
dataloader_kNN=dataloader_train, num_classes=num_classes
)

trainer = Trainer(
devices=n_devices,
accelerator="cpu",
strategy=DDPStrategy(find_unused_parameters=False),
max_epochs=10,
)
trainer.fit(
model,
train_dataloaders=dataloader_train,
val_dataloaders=dataloader_val,
)

assert model.max_accuracy == 0.953125

0 comments on commit 91a3fca

Please sign in to comment.