From dcd09a0cd254009f4f934381454ff9fdfbbff454 Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Fri, 13 Sep 2024 15:55:31 -0700 Subject: [PATCH] Refactor Pytorch dataset (#202) Co-authored-by: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Anderson Banihirwe --- xbatcher/loaders/torch.py | 66 +++++++++++++++------ xbatcher/tests/test_torch_loaders.py | 87 +++++++++++++++++++++++++--- 2 files changed, 128 insertions(+), 25 deletions(-) diff --git a/xbatcher/loaders/torch.py b/xbatcher/loaders/torch.py index 77ebcb8..a5ac85a 100644 --- a/xbatcher/loaders/torch.py +++ b/xbatcher/loaders/torch.py @@ -1,5 +1,11 @@ +from __future__ import annotations + from collections.abc import Callable -from typing import Any +from types import ModuleType + +import xarray as xr + +from xbatcher import BatchGenerator try: import torch @@ -9,6 +15,13 @@ 'install PyTorch to proceed.' ) from exc +try: + import dask +except ImportError: + dask: ModuleType | None = None # type: ignore[no-redef] + +T_DataArrayOrSet = xr.DataArray | xr.Dataset + # Notes: # This module includes two PyTorch datasets. # - The MapDataset provides an indexable interface @@ -20,13 +33,22 @@ # - need to test with additional dataset parameters (e.g. transforms) +def to_tensor(xr_obj: T_DataArrayOrSet) -> torch.Tensor: + """Convert this DataArray or Dataset to a torch.Tensor""" + if isinstance(xr_obj, xr.Dataset): + xr_obj = xr_obj.to_array().squeeze(dim='variable') + if isinstance(xr_obj, xr.DataArray): + xr_obj = xr_obj.data + return torch.tensor(xr_obj) + + class MapDataset(torch.utils.data.Dataset): def __init__( self, - X_generator, - y_generator, - transform: Callable | None = None, - target_transform: Callable | None = None, + X_generator: BatchGenerator, + y_generator: BatchGenerator | None = None, + transform: Callable[[T_DataArrayOrSet], torch.Tensor] = to_tensor, + target_transform: Callable[[T_DataArrayOrSet], torch.Tensor] = to_tensor, ) -> None: """ PyTorch Dataset adapter for Xbatcher @@ -35,10 +57,8 @@ def __init__( ---------- X_generator : xbatcher.BatchGenerator y_generator : xbatcher.BatchGenerator - transform : callable, optional - A function/transform that takes in an array and returns a transformed version. - target_transform : callable, optional - A function/transform that takes in the target and transforms it. + transform, target_transform : callable, optional + A function/transform that takes in an Xarray object and returns a transformed version in the form of a torch.Tensor. """ self.X_generator = X_generator self.y_generator = y_generator @@ -48,7 +68,7 @@ def __init__( def __len__(self) -> int: return len(self.X_generator) - def __getitem__(self, idx) -> tuple[Any, Any]: + def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor: if torch.is_tensor(idx): idx = idx.tolist() if len(idx) == 1: @@ -58,15 +78,27 @@ def __getitem__(self, idx) -> tuple[Any, Any]: f'{type(self).__name__}.__getitem__ currently requires a single integer key' ) - X_batch = self.X_generator[idx].torch.to_tensor() - y_batch = self.y_generator[idx].torch.to_tensor() + # generate batch (or batches) + if self.y_generator is not None: + X_batch, y_batch = self.X_generator[idx], self.y_generator[idx] + else: + X_batch, y_batch = self.X_generator[idx], None + + # load batch (or batches) with dask if possible + if dask is not None: + X_batch, y_batch = dask.compute(X_batch, y_batch) + + # apply transformation(s) + X_batch_tensor = self.transform(X_batch) + if y_batch is not None: + y_batch_tensor = self.target_transform(y_batch) - if self.transform: - X_batch = self.transform(X_batch) + assert isinstance(X_batch_tensor, torch.Tensor), self.transform - if self.target_transform: - y_batch = self.target_transform(y_batch) - return X_batch, y_batch + if y_batch is None: + return X_batch_tensor + assert isinstance(y_batch_tensor, torch.Tensor) + return X_batch_tensor, y_batch_tensor class IterableDataset(torch.utils.data.IterableDataset): diff --git a/xbatcher/tests/test_torch_loaders.py b/xbatcher/tests/test_torch_loaders.py index 7b44fa1..db2b968 100644 --- a/xbatcher/tests/test_torch_loaders.py +++ b/xbatcher/tests/test_torch_loaders.py @@ -1,15 +1,41 @@ +from importlib import reload + import numpy as np import pytest import xarray as xr from xbatcher import BatchGenerator -from xbatcher.loaders.torch import IterableDataset, MapDataset +from xbatcher.loaders.torch import IterableDataset, MapDataset, to_tensor torch = pytest.importorskip('torch') -@pytest.fixture(scope='module') -def ds_xy(): +def test_import_torch_failure(monkeypatch): + import sys + + import xbatcher.loaders + + monkeypatch.setitem(sys.modules, 'torch', None) + + with pytest.raises(ImportError) as excinfo: + reload(xbatcher.loaders.torch) + + assert 'install PyTorch to proceed' in str(excinfo.value) + + +def test_import_dask_failure(monkeypatch): + import sys + + import xbatcher.loaders + + monkeypatch.setitem(sys.modules, 'dask', None) + reload(xbatcher.loaders.torch) + + assert xbatcher.loaders.torch.dask is None + + +@pytest.fixture(scope='module', params=[True, False]) +def ds_xy(request): n_samples = 100 n_features = 5 ds = xr.Dataset( @@ -21,9 +47,54 @@ def ds_xy(): 'y': (['sample'], np.random.random(n_samples)), }, ) + + if request.param: + ds = ds.chunk({'sample': 10}) + return ds +@pytest.mark.parametrize('x_var', ['x', ['x']]) +def test_map_dataset_without_y(ds_xy, x_var) -> None: + x = ds_xy[x_var] + + x_gen = BatchGenerator(x, {'sample': 10}) + + dataset = MapDataset(x_gen) + + # test __getitem__ + x_batch = dataset[0] + assert x_batch.shape == (10, 5) # type: ignore[union-attr] + assert isinstance(x_batch, torch.Tensor) + + idx = torch.tensor([0]) + x_batch = dataset[idx] + assert x_batch.shape == (10, 5) + assert isinstance(x_batch, torch.Tensor) + + with pytest.raises(NotImplementedError): + idx = torch.tensor([0, 1]) + x_batch = dataset[idx] + + # test __len__ + assert len(dataset) == len(x_gen) + + # test integration with torch DataLoader + loader = torch.utils.data.DataLoader(dataset, batch_size=None) + + for x_batch in loader: + assert x_batch.shape == (10, 5) # type: ignore[union-attr] + assert isinstance(x_batch, torch.Tensor) + + # Check that array shape of last item in generator is same as the batch image + assert tuple(x_gen[-1].sizes.values()) == x_batch.shape # type: ignore[union-attr] + # Check that array values from last item in generator and batch are the same + gen_array = ( + x_gen[-1].to_array().squeeze() if hasattr(x_gen[-1], 'to_array') else x_gen[-1] + ) + np.testing.assert_array_equal(gen_array, x_batch) # type: ignore + + @pytest.mark.parametrize( ('x_var', 'y_var'), [ @@ -31,7 +102,7 @@ def ds_xy(): (['x'], ['y']), # xr.Dataset ], ) -def test_map_dataset(ds_xy, x_var, y_var): +def test_map_dataset(ds_xy, x_var, y_var) -> None: x = ds_xy[x_var] y = ds_xy[y_var] @@ -73,7 +144,7 @@ def test_map_dataset(ds_xy, x_var, y_var): gen_array = ( x_gen[-1].to_array().squeeze() if hasattr(x_gen[-1], 'to_array') else x_gen[-1] ) - np.testing.assert_array_equal(gen_array, x_batch) + np.testing.assert_array_equal(gen_array, x_batch) # type: ignore @pytest.mark.parametrize( @@ -83,7 +154,7 @@ def test_map_dataset(ds_xy, x_var, y_var): (['x'], ['y']), # xr.Dataset ], ) -def test_map_dataset_with_transform(ds_xy, x_var, y_var): +def test_map_dataset_with_transform(ds_xy, x_var, y_var) -> None: x = ds_xy[x_var] y = ds_xy[y_var] @@ -91,10 +162,10 @@ def test_map_dataset_with_transform(ds_xy, x_var, y_var): y_gen = BatchGenerator(y, {'sample': 10}) def x_transform(batch): - return batch * 0 + 1 + return to_tensor(batch * 0 + 1) def y_transform(batch): - return batch * 0 - 1 + return to_tensor(batch * 0 - 1) dataset = MapDataset( x_gen, y_gen, transform=x_transform, target_transform=y_transform