From 1777dbaef2009a6b2b55fcf713ee39f73c829095 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 27 Jan 2022 23:48:09 -0600 Subject: [PATCH] Fix len for empty GeoDataset --- setup.cfg | 4 ++-- tests/datasets/test_geo.py | 18 ++++++++++++------ torchgeo/datasets/geo.py | 2 +- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/setup.cfg b/setup.cfg index fc7001efbd1..d8644c98a67 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,8 +47,8 @@ install_requires = pytorch-lightning>=1.3 # rasterio 1.0.16+ required for CRS support rasterio>=1.0.16 - # rtree 0.5+ required for 3D index support - rtree>=0.5 + # rtree 0.9.4+ required for Index.get_size + rtree>=0.9.4 # scikit-learn 0.18+ required for sklearn.model_selection module scikit-learn>=0.18 # segmentation-models-pytorch 0.2+ required for smp.losses module diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index b71b08eca93..491b7294574 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -333,17 +333,20 @@ def test_vision_dataset(self) -> None: def test_different_crs(self) -> None: ds1 = CustomGeoDataset(crs=CRS.from_epsg(3005)) ds2 = CustomGeoDataset(crs=CRS.from_epsg(32616)) - IntersectionDataset(ds1, ds2) + ds = IntersectionDataset(ds1, ds2) + assert len(ds) == 0 def test_different_res(self) -> None: ds1 = CustomGeoDataset(res=1) ds2 = CustomGeoDataset(res=2) - IntersectionDataset(ds1, ds2) + ds = IntersectionDataset(ds1, ds2) + assert len(ds) == 1 def test_no_overlap(self) -> None: ds1 = CustomGeoDataset(BoundingBox(0, 1, 2, 3, 4, 5)) ds2 = CustomGeoDataset(BoundingBox(6, 7, 8, 9, 10, 11)) - IntersectionDataset(ds1, ds2) + ds = IntersectionDataset(ds1, ds2) + assert len(ds) == 0 def test_invalid_query(self, dataset: IntersectionDataset) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) @@ -382,17 +385,20 @@ def test_vision_dataset(self) -> None: def test_different_crs(self) -> None: ds1 = CustomGeoDataset(crs=CRS.from_epsg(3005)) ds2 = CustomGeoDataset(crs=CRS.from_epsg(32616)) - UnionDataset(ds1, ds2) + ds = UnionDataset(ds1, ds2) + assert len(ds) == 2 def test_different_res(self) -> None: ds1 = CustomGeoDataset(res=1) ds2 = CustomGeoDataset(res=2) - UnionDataset(ds1, ds2) + ds = UnionDataset(ds1, ds2) + assert len(ds) == 2 def test_no_overlap(self) -> None: ds1 = CustomGeoDataset(BoundingBox(0, 1, 2, 3, 4, 5)) ds2 = CustomGeoDataset(BoundingBox(6, 7, 8, 9, 10, 11)) - UnionDataset(ds1, ds2) + ds = UnionDataset(ds1, ds2) + assert len(ds) == 2 def test_invalid_query(self, dataset: UnionDataset) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 9386a5e3a61..a6c0d073199 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -161,7 +161,7 @@ def __len__(self) -> int: Returns: length of the dataset """ - count: int = self.index.count(self.index.bounds) + count: int = self.index.get_size() return count def __str__(self) -> str: