Skip to content

Commit

Permalink
fix: make consolidation after filtering behave as intended
Browse files Browse the repository at this point in the history
  • Loading branch information
Christian committed Aug 18, 2023
1 parent d2fcd7a commit bdad5df
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 12 deletions.
40 changes: 33 additions & 7 deletions osm4gpd/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ class OSMFile:
ways: list[WayGroup] = field(default_factory=list)
relations: list[RelationGroup] = field(default_factory=list)

# protected property that is used to store the arguments to filter
# for later use during consolidation, since pre-consolidation filtering
# can leave non-matching geometries that are referenced by some way or relation
_filter: set[str] | None = None

@classmethod
def from_file(cls, fp: Path | str) -> OSMFile:
if isinstance(fp, str):
Expand Down Expand Up @@ -97,39 +102,60 @@ def filter(self, *, tags: set[str]) -> OSMFile:
self.ways, tags=tags, references=references
)
self.nodes, _ = filter_groups(self.nodes, tags=tags, references=references)

self._filter = tags
return self

def consolidate(self) -> gpd.GeoDataFrame:
def _consolidate_nodes(self) -> gpd.GeoDataFrame:
_node_parts = [
consolidate_nodes(nodes) for nodes in self.nodes if not nodes.is_empty()
]
if len(_node_parts) > 0:
nodes = pd.concat(_node_parts)
return pd.concat(_node_parts)
else:
raise ValueError("Nothing to consolidate.")

def _consolidate_ways(self, *, nodes: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
_way_parts = [
consolidate_ways(ways, nodes=nodes)
for ways in self.ways
if not ways.is_empty()
]

if len(_way_parts) > 0:
ways = pd.concat(_way_parts)
return pd.concat(_way_parts)
else:
ways = gpd.GeoDataFrame()
return gpd.GeoDataFrame()

def _consolidate_relations(
self, *, nodes: gpd.GeoDataFrame, ways: gpd.GeoDataFrame
) -> gpd.GeoDataFrame:
_relation_parts = [
consolidate_relations(relations, ways=ways, nodes=nodes)
for relations in self.relations
if not relations.is_empty()
]

if len(_relation_parts) > 0:
relations = pd.concat(_relation_parts)
return pd.concat([nodes, ways, relations])
return pd.concat(_relation_parts)
else:
return pd.concat([nodes, ways])
return gpd.GeoDataFrame()

def consolidate(self) -> gpd.GeoDataFrame:
nodes = self._consolidate_nodes()
ways = self._consolidate_ways(nodes=nodes)
relations = self._consolidate_relations(nodes=nodes, ways=ways)

gdf = pd.concat([nodes, ways, relations])

if self._filter is not None:
# filter for rows that match a filter category
gdf = gdf[gdf[list(self._filter)].notna().any(axis=1)]

# drop columns that became all NA from filtering
gdf = gdf[gdf.columns[~gdf.isna().all()]]

return gdf


# header_bbox = box(
Expand Down
47 changes: 42 additions & 5 deletions tests/test_consolidate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
@pytest.mark.parametrize(
"filename,tags,expected_shape",
[
("extract", {"amenity"}, (1184, 157)),
("isle_of_man", {"name"}, (102488, 1109)),
("malta", {"amenity"}, (19247, 354)),
("malta", {"car_wash"}, (10, 8)),
("extract", {"amenity"}, (314, 155)),
("isle_of_man", {"name"}, (10827, 1062)),
("malta", {"amenity"}, (5396, 342)),
("malta", {"car_wash"}, (1, 8)),
],
)
def test_osm_file_can_be_consolidated(
def test_osm_file_can_be_consolidated_after_filtering(
filename: str,
tags: set[str],
expected_shape: tuple[int, int],
Expand All @@ -25,3 +25,40 @@ def test_osm_file_can_be_consolidated(
)

assert gdf.shape == expected_shape


@pytest.mark.parametrize(
"filename,expected_shape",
[
("extract", (1184, 157)),
],
)
def test_osm_file_can_be_consolidated_without_filtering(
filename: str,
expected_shape: tuple[int, int],
request: pytest.FixtureRequest,
) -> None:
gdf = OSMFile.from_file(request.getfixturevalue(filename)).consolidate()

assert gdf.shape == expected_shape


@pytest.mark.parametrize(
"filename,tags",
[
("extract", {"amenity"}),
("malta", {"car_wash"}),
],
)
def test_consolidation_respects_filtering(
filename: str,
tags: set[str],
request: pytest.FixtureRequest,
) -> None:
gdf = (
OSMFile.from_file(request.getfixturevalue(filename))
.filter(tags=tags)
.consolidate()
)

assert gdf[list(tags)].isna().sum().sum() == 0

0 comments on commit bdad5df

Please sign in to comment.