Skip to content

Commit

Permalink
Fix len for empty GeoDataset (#374)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart authored Jan 29, 2022
1 parent 27f270b commit 60009b3
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ install_requires =
pytorch-lightning>=1.3
# rasterio 1.0.16+ required for CRS support
rasterio>=1.0.16
# rtree 0.5+ required for 3D index support
rtree>=0.5
# rtree 0.9.4+ required for Index.get_size
rtree>=0.9.4
# scikit-learn 0.18+ required for sklearn.model_selection module
scikit-learn>=0.18
# segmentation-models-pytorch 0.2+ required for smp.losses module
Expand Down
18 changes: 12 additions & 6 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,17 +333,20 @@ def test_vision_dataset(self) -> None:
def test_different_crs(self) -> None:
ds1 = CustomGeoDataset(crs=CRS.from_epsg(3005))
ds2 = CustomGeoDataset(crs=CRS.from_epsg(32616))
IntersectionDataset(ds1, ds2)
ds = IntersectionDataset(ds1, ds2)
assert len(ds) == 0

def test_different_res(self) -> None:
ds1 = CustomGeoDataset(res=1)
ds2 = CustomGeoDataset(res=2)
IntersectionDataset(ds1, ds2)
ds = IntersectionDataset(ds1, ds2)
assert len(ds) == 1

def test_no_overlap(self) -> None:
ds1 = CustomGeoDataset(BoundingBox(0, 1, 2, 3, 4, 5))
ds2 = CustomGeoDataset(BoundingBox(6, 7, 8, 9, 10, 11))
IntersectionDataset(ds1, ds2)
ds = IntersectionDataset(ds1, ds2)
assert len(ds) == 0

def test_invalid_query(self, dataset: IntersectionDataset) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
Expand Down Expand Up @@ -382,17 +385,20 @@ def test_vision_dataset(self) -> None:
def test_different_crs(self) -> None:
ds1 = CustomGeoDataset(crs=CRS.from_epsg(3005))
ds2 = CustomGeoDataset(crs=CRS.from_epsg(32616))
UnionDataset(ds1, ds2)
ds = UnionDataset(ds1, ds2)
assert len(ds) == 2

def test_different_res(self) -> None:
ds1 = CustomGeoDataset(res=1)
ds2 = CustomGeoDataset(res=2)
UnionDataset(ds1, ds2)
ds = UnionDataset(ds1, ds2)
assert len(ds) == 2

def test_no_overlap(self) -> None:
ds1 = CustomGeoDataset(BoundingBox(0, 1, 2, 3, 4, 5))
ds2 = CustomGeoDataset(BoundingBox(6, 7, 8, 9, 10, 11))
UnionDataset(ds1, ds2)
ds = UnionDataset(ds1, ds2)
assert len(ds) == 2

def test_invalid_query(self, dataset: UnionDataset) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __len__(self) -> int:
Returns:
length of the dataset
"""
count: int = self.index.count(self.index.bounds)
count: int = self.index.get_size()
return count

def __str__(self) -> str:
Expand Down

0 comments on commit 60009b3

Please sign in to comment.