Skip to content

Commit

Permalink
Added BalancedSampler for classification
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig committed Nov 22, 2024
1 parent c8545a7 commit 65ce5ad
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 5 deletions.
8 changes: 6 additions & 2 deletions configs/vision/pathology/online/classification/bach.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,15 @@ data:
init_args:
<<: *DATASET_ARGS
split: val
samplers:
train:
class_path: eva.core.data.samplers.BalancedSampler
init_args:
num_samples: 10
dataloaders:
train:
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256}
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 1}
num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4}
shuffle: true
val:
batch_size: *BATCH_SIZE
num_workers: *N_DATA_WORKERS
2 changes: 1 addition & 1 deletion src/eva/core/data/datamodules/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _initialize_dataloaders(

dataloaders = []
for dataset in datasets:
if sampler and isinstance(sampler, samplers_lib.SamplerWithDataSource):
if sampler is not None and isinstance(sampler, samplers_lib.SamplerWithDataSource):
sampler.set_dataset(dataset) # type: ignore
dataloaders.append(dataloader(dataset, sampler=sampler))
return dataloaders
2 changes: 2 additions & 0 deletions src/eva/core/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
MultiEmbeddingsClassificationDataset,
)
from eva.core.data.datasets.dataset import TorchDataset
from eva.core.data.datasets.typings import DataSample

__all__ = [
"Dataset",
"MapDataset",
"EmbeddingsClassificationDataset",
"MultiEmbeddingsClassificationDataset",
"TorchDataset",
"DataSample",
]
18 changes: 18 additions & 0 deletions src/eva/core/data/datasets/typings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Typing definitions for the datasets module."""

from typing import Any, Dict, NamedTuple

import torch


class DataSample(NamedTuple):
"""The default input batch data scheme."""

data: torch.Tensor
"""The data batch."""

targets: torch.Tensor | None = None
"""The target batch."""

metadata: Dict[str, Any] | None = None
"""The associated metadata."""
3 changes: 2 additions & 1 deletion src/eva/core/data/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Data samplers API."""

from eva.core.data.samplers.classification.balanced import BalancedSampler
from eva.core.data.samplers.random import RandomSampler
from eva.core.data.samplers.sampler import Sampler, SamplerWithDataSource

__all__ = ["Sampler", "SamplerWithDataSource", "RandomSampler"]
__all__ = ["Sampler", "SamplerWithDataSource", "RandomSampler", "BalancedSampler"]
5 changes: 5 additions & 0 deletions src/eva/core/data/samplers/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Classification data samplers API."""

from eva.core.data.samplers.classification.balanced import BalancedSampler

__all__ = ["BalancedSampler"]
94 changes: 94 additions & 0 deletions src/eva/core/data/samplers/classification/balanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Random sampler for data loading."""

from collections import defaultdict
from typing import Dict, Iterator, List

import numpy as np
from typing_extensions import override

from eva.core.data import datasets
from eva.core.data.datasets.typings import DataSample
from eva.core.data.samplers.sampler import SamplerWithDataSource


class BalancedSampler(SamplerWithDataSource[int]):
"""Balanced class sampler for data loading.
The sampler ensures that:
1. Each class has the same number of samples
2. Samples within each class are randomly selected
3. Samples of different classes appear in random order
"""

def __init__(self, num_samples: int, replacement: bool = False, seed: int | None = None):
"""Initializes the balanced sampler.
Args:
num_samples: The number of samples to draw per class.
replacement: samples are drawn on-demand with replacement if ``True``, default=``False``
seed: Random seed for reproducibility. If None, sampling will be random
but not reproducible between runs.
"""
self._num_samples = num_samples
self._replacement = replacement
self._class_indices: Dict[int, List[int]] = defaultdict(list)
self._random_generator = np.random.default_rng(seed)

def __len__(self) -> int:
"""Returns the total number of samples."""
return self._num_samples * len(self._class_indices)

def __iter__(self) -> Iterator[int]:
"""Creates an iterator that yields indices in a class balanced way.
Returns:
Iterator yielding dataset indices.
"""
indices = []

for class_idx in self._class_indices:
class_indices = self._class_indices[class_idx]
sampled_indices = self._random_generator.choice(
class_indices, size=self._num_samples, replace=self._replacement
).tolist()
indices.extend(sampled_indices)

self._random_generator.shuffle(indices)

return iter(indices)

@override
def set_dataset(self, data_source: datasets.MapDataset):
"""Sets the dataset and builds class indices.
Args:
data_source: The dataset to sample from.
Raises:
ValueError: If the dataset doesn't have targets or if any class has
fewer samples than `num_samples` and `replacement` is `False`.
"""
super().set_dataset(data_source)
self._make_indices()

def _make_indices(self):
"""Builds indices for each class in the dataset."""
self._class_indices.clear()

for idx in range(len(self.data_source)):
_, target, _ = DataSample(*self.data_source[idx])
if target is None:
raise ValueError("The dataset must return non-empty targets.")
if target.numel() != 1:
raise ValueError("The dataset must return a single & scalar target.")

class_idx = int(target.item())
self._class_indices[class_idx].append(idx)

if not self._replacement:
for class_idx, indices in self._class_indices.items():
if len(indices) < self._num_samples:
raise ValueError(
f"Class {class_idx} has only {len(indices)} samples, "
f"which is less than the required {self._num_samples} samples."
)
1 change: 0 additions & 1 deletion src/eva/core/data/samplers/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ class RandomSampler(data.RandomSampler, SamplerWithDataSource[int]):
"""Samples elements randomly."""

data_source: datasets.MapDataset # type: ignore
replacement: bool

def __init__(
self, replacement: bool = False, num_samples: Optional[int] = None, generator=None
Expand Down

0 comments on commit 65ce5ad

Please sign in to comment.