Skip to content

Commit

Permalink
feat: implement IncludeAnnotationTypeFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
unexcellent committed Nov 16, 2024
1 parent 2cc8bd3 commit 02db99a
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 7 deletions.
2 changes: 2 additions & 0 deletions raillabel/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .exclude_frame_id_filter import ExcludeFrameIdFilter
from .filter import filter_
from .include_annotation_id_filter import IncludeAnnotationIdFilter
from .include_annotation_type_filter import IncludeAnnotationTypeFilter
from .include_frame_id_filter import IncludeFrameIdFilter
from .start_time_filter import StartTimeFilter

Expand All @@ -18,4 +19,5 @@
"EndTimeFilter",
"IncludeAnnotationIdFilter",
"ExcludeAnnotationIdFilter",
"IncludeAnnotationTypeFilter",
]
28 changes: 21 additions & 7 deletions raillabel/filter/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@

def filter_(scene: Scene, filters: list[_FilterAbc]) -> Scene:
"""Return a scene with filters applied to annotations, frame, sensors and objects."""
filtered_scene = Scene(
metadata=deepcopy(scene.metadata),
sensors=deepcopy(scene.sensors),
objects=deepcopy(scene.objects),
)
filtered_scene = Scene(metadata=deepcopy(scene.metadata))

used_sensor_ids = set()
used_object_ids = set()

frame_filters, annotation_filters = _separate_filters(filters)

Expand All @@ -32,11 +31,26 @@ def filter_(scene: Scene, filters: list[_FilterAbc]) -> Scene:
)

for annotation_id, annotation in frame.annotations.items():
if _annotation_passes_all_filters(annotation_id, annotation, annotation_filters):
filtered_frame.annotations[annotation_id] = deepcopy(annotation)
if not _annotation_passes_all_filters(annotation_id, annotation, annotation_filters):
continue

filtered_frame.annotations[annotation_id] = deepcopy(annotation)
used_sensor_ids.add(annotation.sensor_id)
used_object_ids.add(annotation.object_id)

filtered_scene.frames[frame_id] = filtered_frame

filtered_scene.sensors = {
sensor_id: deepcopy(sensor)
for sensor_id, sensor in scene.sensors.items()
if sensor_id in used_sensor_ids
}
filtered_scene.objects = {
object_id: deepcopy(object_)
for object_id, object_ in scene.objects.items()
if object_id in used_object_ids
}

return filtered_scene


Expand Down
46 changes: 46 additions & 0 deletions raillabel/filter/include_annotation_type_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright DB InfraGO AG and contributors
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from dataclasses import dataclass
from typing import Literal
from uuid import UUID

from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d

from ._filter_abc import _AnnotationLevelFilter


@dataclass
class IncludeAnnotationTypeFilter(_AnnotationLevelFilter):
"""Filter out all annotations in the scene, that do NOT have the type (like bbox or cuboid)."""

annotation_types: (
set[Literal["bbox", "cuboid", "poly2d", "poly3d", "seg3d"]]
| list[Literal["bbox", "cuboid", "poly2d", "poly3d", "seg3d"]]
)

def passes_filter(self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool:
"""Assess if an annotation passes this filter."""
annotation_type_str = None

if isinstance(annotation, Bbox):
annotation_type_str = "bbox"

elif isinstance(annotation, Cuboid):
annotation_type_str = "cuboid"

elif isinstance(annotation, Poly2d):
annotation_type_str = "poly2d"

elif isinstance(annotation, Poly3d):
annotation_type_str = "poly3d"

elif isinstance(annotation, Seg3d):
annotation_type_str = "seg3d"

else:
raise TypeError

return annotation_type_str in self.annotation_types
8 changes: 8 additions & 0 deletions tests/filter/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,13 @@ def test_exclude_annotation_ids():
)


def test_include_annotation_type():
scene = SceneBuilder.empty().add_bbox().add_cuboid().result
filters = [raillabel.filter.IncludeAnnotationTypeFilter(["bbox"])]

actual = raillabel.filter.filter_(scene, filters)
assert actual == SceneBuilder.empty().add_bbox().result


if __name__ == "__main__":
pytest.main([__file__, "-vv"])

0 comments on commit 02db99a

Please sign in to comment.