diff --git a/osm4gpd/parse.py b/osm4gpd/parse.py index f312359..bd6a965 100644 --- a/osm4gpd/parse.py +++ b/osm4gpd/parse.py @@ -65,6 +65,11 @@ class OSMFile: ways: list[WayGroup] = field(default_factory=list) relations: list[RelationGroup] = field(default_factory=list) + # protected property that is used to store the arguments to filter + # for later use during consolidation, since pre-consolidation filtering + # can leave non-matching geometries that are referenced by some way or relation + _filter: set[str] | None = None + @classmethod def from_file(cls, fp: Path | str) -> OSMFile: if isinstance(fp, str): @@ -97,17 +102,20 @@ def filter(self, *, tags: set[str]) -> OSMFile: self.ways, tags=tags, references=references ) self.nodes, _ = filter_groups(self.nodes, tags=tags, references=references) + + self._filter = tags return self - def consolidate(self) -> gpd.GeoDataFrame: + def _consolidate_nodes(self) -> gpd.GeoDataFrame: _node_parts = [ consolidate_nodes(nodes) for nodes in self.nodes if not nodes.is_empty() ] if len(_node_parts) > 0: - nodes = pd.concat(_node_parts) + return pd.concat(_node_parts) else: raise ValueError("Nothing to consolidate.") + def _consolidate_ways(self, *, nodes: gpd.GeoDataFrame) -> gpd.GeoDataFrame: _way_parts = [ consolidate_ways(ways, nodes=nodes) for ways in self.ways @@ -115,10 +123,13 @@ def consolidate(self) -> gpd.GeoDataFrame: ] if len(_way_parts) > 0: - ways = pd.concat(_way_parts) + return pd.concat(_way_parts) else: - ways = gpd.GeoDataFrame() + return gpd.GeoDataFrame() + def _consolidate_relations( + self, *, nodes: gpd.GeoDataFrame, ways: gpd.GeoDataFrame + ) -> gpd.GeoDataFrame: _relation_parts = [ consolidate_relations(relations, ways=ways, nodes=nodes) for relations in self.relations @@ -126,10 +137,25 @@ def consolidate(self) -> gpd.GeoDataFrame: ] if len(_relation_parts) > 0: - relations = pd.concat(_relation_parts) - return pd.concat([nodes, ways, relations]) + return pd.concat(_relation_parts) else: - return pd.concat([nodes, ways]) + return gpd.GeoDataFrame() + + def consolidate(self) -> gpd.GeoDataFrame: + nodes = self._consolidate_nodes() + ways = self._consolidate_ways(nodes=nodes) + relations = self._consolidate_relations(nodes=nodes, ways=ways) + + gdf = pd.concat([nodes, ways, relations]) + + if self._filter is not None: + # filter for rows that match a filter category + gdf = gdf[gdf[list(self._filter)].notna().any(axis=1)] + + # drop columns that became all NA from filtering + gdf = gdf[gdf.columns[~gdf.isna().all()]] + + return gdf # header_bbox = box( diff --git a/tests/test_consolidate.py b/tests/test_consolidate.py index 91deba2..f229ff1 100644 --- a/tests/test_consolidate.py +++ b/tests/test_consolidate.py @@ -6,13 +6,13 @@ @pytest.mark.parametrize( "filename,tags,expected_shape", [ - ("extract", {"amenity"}, (1184, 157)), - ("isle_of_man", {"name"}, (102488, 1109)), - ("malta", {"amenity"}, (19247, 354)), - ("malta", {"car_wash"}, (10, 8)), + ("extract", {"amenity"}, (314, 155)), + ("isle_of_man", {"name"}, (10827, 1062)), + ("malta", {"amenity"}, (5396, 342)), + ("malta", {"car_wash"}, (1, 8)), ], ) -def test_osm_file_can_be_consolidated( +def test_osm_file_can_be_consolidated_after_filtering( filename: str, tags: set[str], expected_shape: tuple[int, int], @@ -25,3 +25,40 @@ def test_osm_file_can_be_consolidated( ) assert gdf.shape == expected_shape + + +@pytest.mark.parametrize( + "filename,expected_shape", + [ + ("extract", (1184, 157)), + ], +) +def test_osm_file_can_be_consolidated_without_filtering( + filename: str, + expected_shape: tuple[int, int], + request: pytest.FixtureRequest, +) -> None: + gdf = OSMFile.from_file(request.getfixturevalue(filename)).consolidate() + + assert gdf.shape == expected_shape + + +@pytest.mark.parametrize( + "filename,tags", + [ + ("extract", {"amenity"}), + ("malta", {"car_wash"}), + ], +) +def test_consolidation_respects_filtering( + filename: str, + tags: set[str], + request: pytest.FixtureRequest, +) -> None: + gdf = ( + OSMFile.from_file(request.getfixturevalue(filename)) + .filter(tags=tags) + .consolidate() + ) + + assert gdf[list(tags)].isna().sum().sum() == 0