From 65ce5ad6e886d5d0af2cb48696d90faf693aeec2 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Fri, 22 Nov 2024 14:18:55 +0100 Subject: [PATCH] Added BalancedSampler for classification --- .../pathology/online/classification/bach.yaml | 8 +- src/eva/core/data/datamodules/datamodule.py | 2 +- src/eva/core/data/datasets/__init__.py | 2 + src/eva/core/data/datasets/typings.py | 18 ++++ src/eva/core/data/samplers/__init__.py | 3 +- .../data/samplers/classification/__init__.py | 5 + .../data/samplers/classification/balanced.py | 94 +++++++++++++++++++ src/eva/core/data/samplers/random.py | 1 - 8 files changed, 128 insertions(+), 5 deletions(-) create mode 100644 src/eva/core/data/datasets/typings.py create mode 100644 src/eva/core/data/samplers/classification/__init__.py create mode 100644 src/eva/core/data/samplers/classification/balanced.py diff --git a/configs/vision/pathology/online/classification/bach.yaml b/configs/vision/pathology/online/classification/bach.yaml index 1719d821..d82df9e7 100644 --- a/configs/vision/pathology/online/classification/bach.yaml +++ b/configs/vision/pathology/online/classification/bach.yaml @@ -84,11 +84,15 @@ data: init_args: <<: *DATASET_ARGS split: val + samplers: + train: + class_path: eva.core.data.samplers.BalancedSampler + init_args: + num_samples: 10 dataloaders: train: - batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 1} num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} - shuffle: true val: batch_size: *BATCH_SIZE num_workers: *N_DATA_WORKERS diff --git a/src/eva/core/data/datamodules/datamodule.py b/src/eva/core/data/datamodules/datamodule.py index 589a9f0d..0b03711e 100644 --- a/src/eva/core/data/datamodules/datamodule.py +++ b/src/eva/core/data/datamodules/datamodule.py @@ -130,7 +130,7 @@ def _initialize_dataloaders( dataloaders = [] for dataset in datasets: - if sampler and isinstance(sampler, samplers_lib.SamplerWithDataSource): + if sampler is not None and isinstance(sampler, samplers_lib.SamplerWithDataSource): sampler.set_dataset(dataset) # type: ignore dataloaders.append(dataloader(dataset, sampler=sampler)) return dataloaders diff --git a/src/eva/core/data/datasets/__init__.py b/src/eva/core/data/datasets/__init__.py index 690e56e5..c5e36682 100644 --- a/src/eva/core/data/datasets/__init__.py +++ b/src/eva/core/data/datasets/__init__.py @@ -6,6 +6,7 @@ MultiEmbeddingsClassificationDataset, ) from eva.core.data.datasets.dataset import TorchDataset +from eva.core.data.datasets.typings import DataSample __all__ = [ "Dataset", @@ -13,4 +14,5 @@ "EmbeddingsClassificationDataset", "MultiEmbeddingsClassificationDataset", "TorchDataset", + "DataSample", ] diff --git a/src/eva/core/data/datasets/typings.py b/src/eva/core/data/datasets/typings.py new file mode 100644 index 00000000..465b23e2 --- /dev/null +++ b/src/eva/core/data/datasets/typings.py @@ -0,0 +1,18 @@ +"""Typing definitions for the datasets module.""" + +from typing import Any, Dict, NamedTuple + +import torch + + +class DataSample(NamedTuple): + """The default input batch data scheme.""" + + data: torch.Tensor + """The data batch.""" + + targets: torch.Tensor | None = None + """The target batch.""" + + metadata: Dict[str, Any] | None = None + """The associated metadata.""" diff --git a/src/eva/core/data/samplers/__init__.py b/src/eva/core/data/samplers/__init__.py index b7a8559b..7586d533 100644 --- a/src/eva/core/data/samplers/__init__.py +++ b/src/eva/core/data/samplers/__init__.py @@ -1,6 +1,7 @@ """Data samplers API.""" +from eva.core.data.samplers.classification.balanced import BalancedSampler from eva.core.data.samplers.random import RandomSampler from eva.core.data.samplers.sampler import Sampler, SamplerWithDataSource -__all__ = ["Sampler", "SamplerWithDataSource", "RandomSampler"] +__all__ = ["Sampler", "SamplerWithDataSource", "RandomSampler", "BalancedSampler"] diff --git a/src/eva/core/data/samplers/classification/__init__.py b/src/eva/core/data/samplers/classification/__init__.py new file mode 100644 index 00000000..c68235bc --- /dev/null +++ b/src/eva/core/data/samplers/classification/__init__.py @@ -0,0 +1,5 @@ +"""Classification data samplers API.""" + +from eva.core.data.samplers.classification.balanced import BalancedSampler + +__all__ = ["BalancedSampler"] diff --git a/src/eva/core/data/samplers/classification/balanced.py b/src/eva/core/data/samplers/classification/balanced.py new file mode 100644 index 00000000..dcf697b3 --- /dev/null +++ b/src/eva/core/data/samplers/classification/balanced.py @@ -0,0 +1,94 @@ +"""Random sampler for data loading.""" + +from collections import defaultdict +from typing import Dict, Iterator, List + +import numpy as np +from typing_extensions import override + +from eva.core.data import datasets +from eva.core.data.datasets.typings import DataSample +from eva.core.data.samplers.sampler import SamplerWithDataSource + + +class BalancedSampler(SamplerWithDataSource[int]): + """Balanced class sampler for data loading. + + The sampler ensures that: + 1. Each class has the same number of samples + 2. Samples within each class are randomly selected + 3. Samples of different classes appear in random order + """ + + def __init__(self, num_samples: int, replacement: bool = False, seed: int | None = None): + """Initializes the balanced sampler. + + Args: + num_samples: The number of samples to draw per class. + replacement: samples are drawn on-demand with replacement if ``True``, default=``False`` + seed: Random seed for reproducibility. If None, sampling will be random + but not reproducible between runs. + """ + self._num_samples = num_samples + self._replacement = replacement + self._class_indices: Dict[int, List[int]] = defaultdict(list) + self._random_generator = np.random.default_rng(seed) + + def __len__(self) -> int: + """Returns the total number of samples.""" + return self._num_samples * len(self._class_indices) + + def __iter__(self) -> Iterator[int]: + """Creates an iterator that yields indices in a class balanced way. + + Returns: + Iterator yielding dataset indices. + """ + indices = [] + + for class_idx in self._class_indices: + class_indices = self._class_indices[class_idx] + sampled_indices = self._random_generator.choice( + class_indices, size=self._num_samples, replace=self._replacement + ).tolist() + indices.extend(sampled_indices) + + self._random_generator.shuffle(indices) + + return iter(indices) + + @override + def set_dataset(self, data_source: datasets.MapDataset): + """Sets the dataset and builds class indices. + + Args: + data_source: The dataset to sample from. + + Raises: + ValueError: If the dataset doesn't have targets or if any class has + fewer samples than `num_samples` and `replacement` is `False`. + """ + super().set_dataset(data_source) + self._make_indices() + + def _make_indices(self): + """Builds indices for each class in the dataset.""" + self._class_indices.clear() + + for idx in range(len(self.data_source)): + _, target, _ = DataSample(*self.data_source[idx]) + if target is None: + raise ValueError("The dataset must return non-empty targets.") + if target.numel() != 1: + raise ValueError("The dataset must return a single & scalar target.") + + class_idx = int(target.item()) + self._class_indices[class_idx].append(idx) + + if not self._replacement: + for class_idx, indices in self._class_indices.items(): + if len(indices) < self._num_samples: + raise ValueError( + f"Class {class_idx} has only {len(indices)} samples, " + f"which is less than the required {self._num_samples} samples." + ) diff --git a/src/eva/core/data/samplers/random.py b/src/eva/core/data/samplers/random.py index ea3c6749..415b8ae3 100644 --- a/src/eva/core/data/samplers/random.py +++ b/src/eva/core/data/samplers/random.py @@ -13,7 +13,6 @@ class RandomSampler(data.RandomSampler, SamplerWithDataSource[int]): """Samples elements randomly.""" data_source: datasets.MapDataset # type: ignore - replacement: bool def __init__( self, replacement: bool = False, num_samples: Optional[int] = None, generator=None