-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added BalancedSampler for classification
- Loading branch information
Showing
8 changed files
with
128 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
"""Typing definitions for the datasets module.""" | ||
|
||
from typing import Any, Dict, NamedTuple | ||
|
||
import torch | ||
|
||
|
||
class DataSample(NamedTuple): | ||
"""The default input batch data scheme.""" | ||
|
||
data: torch.Tensor | ||
"""The data batch.""" | ||
|
||
targets: torch.Tensor | None = None | ||
"""The target batch.""" | ||
|
||
metadata: Dict[str, Any] | None = None | ||
"""The associated metadata.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
"""Data samplers API.""" | ||
|
||
from eva.core.data.samplers.classification.balanced import BalancedSampler | ||
from eva.core.data.samplers.random import RandomSampler | ||
from eva.core.data.samplers.sampler import Sampler, SamplerWithDataSource | ||
|
||
__all__ = ["Sampler", "SamplerWithDataSource", "RandomSampler"] | ||
__all__ = ["Sampler", "SamplerWithDataSource", "RandomSampler", "BalancedSampler"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""Classification data samplers API.""" | ||
|
||
from eva.core.data.samplers.classification.balanced import BalancedSampler | ||
|
||
__all__ = ["BalancedSampler"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
"""Random sampler for data loading.""" | ||
|
||
from collections import defaultdict | ||
from typing import Dict, Iterator, List | ||
|
||
import numpy as np | ||
from typing_extensions import override | ||
|
||
from eva.core.data import datasets | ||
from eva.core.data.datasets.typings import DataSample | ||
from eva.core.data.samplers.sampler import SamplerWithDataSource | ||
|
||
|
||
class BalancedSampler(SamplerWithDataSource[int]): | ||
"""Balanced class sampler for data loading. | ||
The sampler ensures that: | ||
1. Each class has the same number of samples | ||
2. Samples within each class are randomly selected | ||
3. Samples of different classes appear in random order | ||
""" | ||
|
||
def __init__(self, num_samples: int, replacement: bool = False, seed: int | None = None): | ||
"""Initializes the balanced sampler. | ||
Args: | ||
num_samples: The number of samples to draw per class. | ||
replacement: samples are drawn on-demand with replacement if ``True``, default=``False`` | ||
seed: Random seed for reproducibility. If None, sampling will be random | ||
but not reproducible between runs. | ||
""" | ||
self._num_samples = num_samples | ||
self._replacement = replacement | ||
self._class_indices: Dict[int, List[int]] = defaultdict(list) | ||
self._random_generator = np.random.default_rng(seed) | ||
|
||
def __len__(self) -> int: | ||
"""Returns the total number of samples.""" | ||
return self._num_samples * len(self._class_indices) | ||
|
||
def __iter__(self) -> Iterator[int]: | ||
"""Creates an iterator that yields indices in a class balanced way. | ||
Returns: | ||
Iterator yielding dataset indices. | ||
""" | ||
indices = [] | ||
|
||
for class_idx in self._class_indices: | ||
class_indices = self._class_indices[class_idx] | ||
sampled_indices = self._random_generator.choice( | ||
class_indices, size=self._num_samples, replace=self._replacement | ||
).tolist() | ||
indices.extend(sampled_indices) | ||
|
||
self._random_generator.shuffle(indices) | ||
|
||
return iter(indices) | ||
|
||
@override | ||
def set_dataset(self, data_source: datasets.MapDataset): | ||
"""Sets the dataset and builds class indices. | ||
Args: | ||
data_source: The dataset to sample from. | ||
Raises: | ||
ValueError: If the dataset doesn't have targets or if any class has | ||
fewer samples than `num_samples` and `replacement` is `False`. | ||
""" | ||
super().set_dataset(data_source) | ||
self._make_indices() | ||
|
||
def _make_indices(self): | ||
"""Builds indices for each class in the dataset.""" | ||
self._class_indices.clear() | ||
|
||
for idx in range(len(self.data_source)): | ||
_, target, _ = DataSample(*self.data_source[idx]) | ||
if target is None: | ||
raise ValueError("The dataset must return non-empty targets.") | ||
if target.numel() != 1: | ||
raise ValueError("The dataset must return a single & scalar target.") | ||
|
||
class_idx = int(target.item()) | ||
self._class_indices[class_idx].append(idx) | ||
|
||
if not self._replacement: | ||
for class_idx, indices in self._class_indices.items(): | ||
if len(indices) < self._num_samples: | ||
raise ValueError( | ||
f"Class {class_idx} has only {len(indices)} samples, " | ||
f"which is less than the required {self._num_samples} samples." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters