From 53a12a4e59bd5606c7021b9530489574423cc60f Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 28 Mar 2023 20:27:45 -0500 Subject: [PATCH] IntersectionDataset: better error message when no overlap (#1192) * IntersectionDataset: better error message when no overlap * Update split tests * Document error --- tests/datasets/test_geo.py | 21 +++++++--- tests/datasets/test_splits.py | 77 ++++++++++++++++++----------------- torchgeo/datasets/geo.py | 4 ++ 3 files changed, 59 insertions(+), 43 deletions(-) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 158117c9d88..ff0b35c3c3a 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -405,10 +405,20 @@ def test_nongeo_dataset(self) -> None: IntersectionDataset(ds1, ds2) # type: ignore[arg-type] def test_different_crs(self) -> None: - ds1 = CustomGeoDataset(crs=CRS.from_epsg(3005)) - ds2 = CustomGeoDataset(crs=CRS.from_epsg(32616)) + ds1 = CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 1), crs=CRS.from_epsg(3005)) + ds2 = CustomGeoDataset( + BoundingBox( + -3547229.913123814, + 6360089.518213182, + -3547229.913123814, + 6360089.518213182, + -3547229.913123814, + 6360089.518213182, + ), + crs=CRS.from_epsg(32616), + ) ds = IntersectionDataset(ds1, ds2) - assert len(ds) == 0 + assert len(ds) == 1 def test_different_res(self) -> None: ds1 = CustomGeoDataset(res=1) @@ -419,8 +429,9 @@ def test_different_res(self) -> None: def test_no_overlap(self) -> None: ds1 = CustomGeoDataset(BoundingBox(0, 1, 2, 3, 4, 5)) ds2 = CustomGeoDataset(BoundingBox(6, 7, 8, 9, 10, 11)) - ds = IntersectionDataset(ds1, ds2) - assert len(ds) == 0 + msg = "Datasets have no spatiotemporal intersection" + with pytest.raises(RuntimeError, match=msg): + IntersectionDataset(ds1, ds2) def test_invalid_query(self, dataset: IntersectionDataset) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index c03cfbd103e..6e7dfe31406 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -18,6 +18,23 @@ ) +def total_area(dataset: GeoDataset) -> float: + total_area = 0.0 + for hit in dataset.index.intersection(dataset.index.bounds, objects=True): + total_area += BoundingBox(*hit.bounds).area + + return total_area + + +def no_overlap(ds1: GeoDataset, ds2: GeoDataset) -> bool: + try: + ds = ds1 & ds2 + except RuntimeError: + return True + else: + return isclose(total_area(ds), 0) + + class CustomGeoDataset(GeoDataset): def __init__( self, @@ -66,11 +83,9 @@ def test_random_bbox_assignment( assert len(test_ds) == expected_lengths[2] # No overlap - assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0) - assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0) - assert len(test_ds & train_ds) == 0 or isclose( - _get_total_area(test_ds & train_ds), 0 - ) + assert no_overlap(train_ds, val_ds) + assert no_overlap(val_ds, test_ds) + assert no_overlap(test_ds, train_ds) # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds @@ -93,14 +108,6 @@ def test_random_bbox_assignment_invalid_inputs() -> None: random_bbox_assignment(CustomGeoDataset(), lengths=[1 / 2, 3 / 4, -1 / 4]) -def _get_total_area(dataset: GeoDataset) -> float: - total_area = 0.0 - for hit in dataset.index.intersection(dataset.index.bounds, objects=True): - total_area += BoundingBox(*hit.bounds).area - - return total_area - - def test_random_bbox_splitting() -> None: ds = CustomGeoDataset( [ @@ -111,14 +118,14 @@ def test_random_bbox_splitting() -> None: ] ) - ds_area = _get_total_area(ds) + ds_area = total_area(ds) train_ds, val_ds, test_ds = random_bbox_splitting( ds, fractions=[1 / 2, 1 / 4, 1 / 4] ) - train_ds_area = _get_total_area(train_ds) - val_ds_area = _get_total_area(val_ds) - test_ds_area = _get_total_area(test_ds) + train_ds_area = total_area(train_ds) + val_ds_area = total_area(val_ds) + test_ds_area = total_area(test_ds) # Check datasets areas assert train_ds_area == ds_area / 2 @@ -126,15 +133,13 @@ def test_random_bbox_splitting() -> None: assert test_ds_area == ds_area / 4 # No overlap - assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0) - assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0) - assert len(test_ds & train_ds) == 0 or isclose( - _get_total_area(test_ds & train_ds), 0 - ) + assert no_overlap(train_ds, val_ds) + assert no_overlap(val_ds, test_ds) + assert no_overlap(test_ds, train_ds) # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds - assert isclose(_get_total_area(train_ds | val_ds | test_ds), ds_area) + assert isclose(total_area(train_ds | val_ds | test_ds), ds_area) # Test __get_item__ x = train_ds[train_ds.bounds] @@ -168,15 +173,13 @@ def test_random_grid_cell_assignment() -> None: assert len(test_ds) == floor(1 / 4 * 2 * 5**2) # No overlap - assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0) - assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0) - assert len(test_ds & train_ds) == 0 or isclose( - _get_total_area(test_ds & train_ds), 0 - ) + assert no_overlap(train_ds, val_ds) + assert no_overlap(val_ds, test_ds) + assert no_overlap(test_ds, train_ds) # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds - assert isclose(_get_total_area(train_ds | val_ds | test_ds), _get_total_area(ds)) + assert isclose(total_area(train_ds | val_ds | test_ds), total_area(ds)) # Test __get_item__ x = train_ds[train_ds.bounds] @@ -219,15 +222,13 @@ def test_roi_split() -> None: assert len(test_ds) == 1 # No overlap - assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0) - assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0) - assert len(test_ds & train_ds) == 0 or isclose( - _get_total_area(test_ds & train_ds), 0 - ) + assert no_overlap(train_ds, val_ds) + assert no_overlap(val_ds, test_ds) + assert no_overlap(test_ds, train_ds) # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds - assert isclose(_get_total_area(train_ds | val_ds | test_ds), _get_total_area(ds)) + assert isclose(total_area(train_ds | val_ds | test_ds), total_area(ds)) # Test __get_item__ x = train_ds[train_ds.bounds] @@ -273,9 +274,9 @@ def test_time_series_split( assert len(test_ds) == expected_lengths[2] # No overlap - assert len(train_ds & val_ds) == 0 - assert len(val_ds & test_ds) == 0 - assert len(test_ds & train_ds) == 0 + assert no_overlap(train_ds, val_ds) + assert no_overlap(val_ds, test_ds) + assert no_overlap(test_ds, train_ds) # Union equals original assert (train_ds | val_ds | test_ds).bounds == ds.bounds diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 18b3cd9310c..faafaa2e81a 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -811,6 +811,7 @@ def __init__( entry and returns a transformed version Raises: + RuntimeError: if datasets have no spatiotemporal intersection ValueError: if either dataset is not a :class:`GeoDataset` .. versionadded:: 0.4 @@ -855,6 +856,9 @@ def _merge_dataset_indices(self) -> None: self.index.insert(i, tuple(box1 & box2)) i += 1 + if i == 0: + raise RuntimeError("Datasets have no spatiotemporal intersection") + def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: """Retrieve image and metadata indexed by query.