Skip to content

Commit

Permalink
Refactor Pytorch dataset (#202)
Browse files Browse the repository at this point in the history
Co-authored-by: Anderson Banihirwe <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Anderson Banihirwe <[email protected]>
  • Loading branch information
4 people authored Sep 13, 2024
1 parent f86f47b commit dcd09a0
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 25 deletions.
66 changes: 49 additions & 17 deletions xbatcher/loaders/torch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from __future__ import annotations

from collections.abc import Callable
from typing import Any
from types import ModuleType

import xarray as xr

from xbatcher import BatchGenerator

try:
import torch
Expand All @@ -9,6 +15,13 @@
'install PyTorch to proceed.'
) from exc

try:
import dask
except ImportError:
dask: ModuleType | None = None # type: ignore[no-redef]

T_DataArrayOrSet = xr.DataArray | xr.Dataset

# Notes:
# This module includes two PyTorch datasets.
# - The MapDataset provides an indexable interface
Expand All @@ -20,13 +33,22 @@
# - need to test with additional dataset parameters (e.g. transforms)


def to_tensor(xr_obj: T_DataArrayOrSet) -> torch.Tensor:
"""Convert this DataArray or Dataset to a torch.Tensor"""
if isinstance(xr_obj, xr.Dataset):
xr_obj = xr_obj.to_array().squeeze(dim='variable')
if isinstance(xr_obj, xr.DataArray):
xr_obj = xr_obj.data
return torch.tensor(xr_obj)


class MapDataset(torch.utils.data.Dataset):
def __init__(
self,
X_generator,
y_generator,
transform: Callable | None = None,
target_transform: Callable | None = None,
X_generator: BatchGenerator,
y_generator: BatchGenerator | None = None,
transform: Callable[[T_DataArrayOrSet], torch.Tensor] = to_tensor,
target_transform: Callable[[T_DataArrayOrSet], torch.Tensor] = to_tensor,
) -> None:
"""
PyTorch Dataset adapter for Xbatcher
Expand All @@ -35,10 +57,8 @@ def __init__(
----------
X_generator : xbatcher.BatchGenerator
y_generator : xbatcher.BatchGenerator
transform : callable, optional
A function/transform that takes in an array and returns a transformed version.
target_transform : callable, optional
A function/transform that takes in the target and transforms it.
transform, target_transform : callable, optional
A function/transform that takes in an Xarray object and returns a transformed version in the form of a torch.Tensor.
"""
self.X_generator = X_generator
self.y_generator = y_generator
Expand All @@ -48,7 +68,7 @@ def __init__(
def __len__(self) -> int:
return len(self.X_generator)

def __getitem__(self, idx) -> tuple[Any, Any]:
def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
if torch.is_tensor(idx):
idx = idx.tolist()
if len(idx) == 1:
Expand All @@ -58,15 +78,27 @@ def __getitem__(self, idx) -> tuple[Any, Any]:
f'{type(self).__name__}.__getitem__ currently requires a single integer key'
)

X_batch = self.X_generator[idx].torch.to_tensor()
y_batch = self.y_generator[idx].torch.to_tensor()
# generate batch (or batches)
if self.y_generator is not None:
X_batch, y_batch = self.X_generator[idx], self.y_generator[idx]
else:
X_batch, y_batch = self.X_generator[idx], None

# load batch (or batches) with dask if possible
if dask is not None:
X_batch, y_batch = dask.compute(X_batch, y_batch)

# apply transformation(s)
X_batch_tensor = self.transform(X_batch)
if y_batch is not None:
y_batch_tensor = self.target_transform(y_batch)

if self.transform:
X_batch = self.transform(X_batch)
assert isinstance(X_batch_tensor, torch.Tensor), self.transform

if self.target_transform:
y_batch = self.target_transform(y_batch)
return X_batch, y_batch
if y_batch is None:
return X_batch_tensor
assert isinstance(y_batch_tensor, torch.Tensor)
return X_batch_tensor, y_batch_tensor


class IterableDataset(torch.utils.data.IterableDataset):
Expand Down
87 changes: 79 additions & 8 deletions xbatcher/tests/test_torch_loaders.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,41 @@
from importlib import reload

import numpy as np
import pytest
import xarray as xr

from xbatcher import BatchGenerator
from xbatcher.loaders.torch import IterableDataset, MapDataset
from xbatcher.loaders.torch import IterableDataset, MapDataset, to_tensor

torch = pytest.importorskip('torch')


@pytest.fixture(scope='module')
def ds_xy():
def test_import_torch_failure(monkeypatch):
import sys

import xbatcher.loaders

monkeypatch.setitem(sys.modules, 'torch', None)

with pytest.raises(ImportError) as excinfo:
reload(xbatcher.loaders.torch)

assert 'install PyTorch to proceed' in str(excinfo.value)


def test_import_dask_failure(monkeypatch):
import sys

import xbatcher.loaders

monkeypatch.setitem(sys.modules, 'dask', None)
reload(xbatcher.loaders.torch)

assert xbatcher.loaders.torch.dask is None


@pytest.fixture(scope='module', params=[True, False])
def ds_xy(request):
n_samples = 100
n_features = 5
ds = xr.Dataset(
Expand All @@ -21,17 +47,62 @@ def ds_xy():
'y': (['sample'], np.random.random(n_samples)),
},
)

if request.param:
ds = ds.chunk({'sample': 10})

return ds


@pytest.mark.parametrize('x_var', ['x', ['x']])
def test_map_dataset_without_y(ds_xy, x_var) -> None:
x = ds_xy[x_var]

x_gen = BatchGenerator(x, {'sample': 10})

dataset = MapDataset(x_gen)

# test __getitem__
x_batch = dataset[0]
assert x_batch.shape == (10, 5) # type: ignore[union-attr]
assert isinstance(x_batch, torch.Tensor)

idx = torch.tensor([0])
x_batch = dataset[idx]
assert x_batch.shape == (10, 5)
assert isinstance(x_batch, torch.Tensor)

with pytest.raises(NotImplementedError):
idx = torch.tensor([0, 1])
x_batch = dataset[idx]

# test __len__
assert len(dataset) == len(x_gen)

# test integration with torch DataLoader
loader = torch.utils.data.DataLoader(dataset, batch_size=None)

for x_batch in loader:
assert x_batch.shape == (10, 5) # type: ignore[union-attr]
assert isinstance(x_batch, torch.Tensor)

# Check that array shape of last item in generator is same as the batch image
assert tuple(x_gen[-1].sizes.values()) == x_batch.shape # type: ignore[union-attr]
# Check that array values from last item in generator and batch are the same
gen_array = (
x_gen[-1].to_array().squeeze() if hasattr(x_gen[-1], 'to_array') else x_gen[-1]
)
np.testing.assert_array_equal(gen_array, x_batch) # type: ignore


@pytest.mark.parametrize(
('x_var', 'y_var'),
[
('x', 'y'), # xr.DataArray
(['x'], ['y']), # xr.Dataset
],
)
def test_map_dataset(ds_xy, x_var, y_var):
def test_map_dataset(ds_xy, x_var, y_var) -> None:
x = ds_xy[x_var]
y = ds_xy[y_var]

Expand Down Expand Up @@ -73,7 +144,7 @@ def test_map_dataset(ds_xy, x_var, y_var):
gen_array = (
x_gen[-1].to_array().squeeze() if hasattr(x_gen[-1], 'to_array') else x_gen[-1]
)
np.testing.assert_array_equal(gen_array, x_batch)
np.testing.assert_array_equal(gen_array, x_batch) # type: ignore


@pytest.mark.parametrize(
Expand All @@ -83,18 +154,18 @@ def test_map_dataset(ds_xy, x_var, y_var):
(['x'], ['y']), # xr.Dataset
],
)
def test_map_dataset_with_transform(ds_xy, x_var, y_var):
def test_map_dataset_with_transform(ds_xy, x_var, y_var) -> None:
x = ds_xy[x_var]
y = ds_xy[y_var]

x_gen = BatchGenerator(x, {'sample': 10})
y_gen = BatchGenerator(y, {'sample': 10})

def x_transform(batch):
return batch * 0 + 1
return to_tensor(batch * 0 + 1)

def y_transform(batch):
return batch * 0 - 1
return to_tensor(batch * 0 - 1)

dataset = MapDataset(
x_gen, y_gen, transform=x_transform, target_transform=y_transform
Expand Down

0 comments on commit dcd09a0

Please sign in to comment.