diff --git a/raillabel/filter/__init__.py b/raillabel/filter/__init__.py index e5a1aee..bc7e92c 100644 --- a/raillabel/filter/__init__.py +++ b/raillabel/filter/__init__.py @@ -7,6 +7,7 @@ from .exclude_frame_id_filter import ExcludeFrameIdFilter from .filter import filter_ from .include_annotation_id_filter import IncludeAnnotationIdFilter +from .include_annotation_type_filter import IncludeAnnotationTypeFilter from .include_frame_id_filter import IncludeFrameIdFilter from .start_time_filter import StartTimeFilter @@ -18,4 +19,5 @@ "EndTimeFilter", "IncludeAnnotationIdFilter", "ExcludeAnnotationIdFilter", + "IncludeAnnotationTypeFilter", ] diff --git a/raillabel/filter/filter.py b/raillabel/filter/filter.py index c8bce4a..7ad3edd 100644 --- a/raillabel/filter/filter.py +++ b/raillabel/filter/filter.py @@ -13,11 +13,10 @@ def filter_(scene: Scene, filters: list[_FilterAbc]) -> Scene: """Return a scene with filters applied to annotations, frame, sensors and objects.""" - filtered_scene = Scene( - metadata=deepcopy(scene.metadata), - sensors=deepcopy(scene.sensors), - objects=deepcopy(scene.objects), - ) + filtered_scene = Scene(metadata=deepcopy(scene.metadata)) + + used_sensor_ids = set() + used_object_ids = set() frame_filters, annotation_filters = _separate_filters(filters) @@ -32,11 +31,26 @@ def filter_(scene: Scene, filters: list[_FilterAbc]) -> Scene: ) for annotation_id, annotation in frame.annotations.items(): - if _annotation_passes_all_filters(annotation_id, annotation, annotation_filters): - filtered_frame.annotations[annotation_id] = deepcopy(annotation) + if not _annotation_passes_all_filters(annotation_id, annotation, annotation_filters): + continue + + filtered_frame.annotations[annotation_id] = deepcopy(annotation) + used_sensor_ids.add(annotation.sensor_id) + used_object_ids.add(annotation.object_id) filtered_scene.frames[frame_id] = filtered_frame + filtered_scene.sensors = { + sensor_id: deepcopy(sensor) + for sensor_id, sensor in scene.sensors.items() + if sensor_id in used_sensor_ids + } + filtered_scene.objects = { + object_id: deepcopy(object_) + for object_id, object_ in scene.objects.items() + if object_id in used_object_ids + } + return filtered_scene diff --git a/raillabel/filter/include_annotation_type_filter.py b/raillabel/filter/include_annotation_type_filter.py new file mode 100644 index 0000000..5e4908e --- /dev/null +++ b/raillabel/filter/include_annotation_type_filter.py @@ -0,0 +1,46 @@ +# Copyright DB InfraGO AG and contributors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal +from uuid import UUID + +from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d + +from ._filter_abc import _AnnotationLevelFilter + + +@dataclass +class IncludeAnnotationTypeFilter(_AnnotationLevelFilter): + """Filter out all annotations in the scene, that do NOT have the type (like bbox or cuboid).""" + + annotation_types: ( + set[Literal["bbox", "cuboid", "poly2d", "poly3d", "seg3d"]] + | list[Literal["bbox", "cuboid", "poly2d", "poly3d", "seg3d"]] + ) + + def passes_filter(self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool: + """Assess if an annotation passes this filter.""" + annotation_type_str = None + + if isinstance(annotation, Bbox): + annotation_type_str = "bbox" + + elif isinstance(annotation, Cuboid): + annotation_type_str = "cuboid" + + elif isinstance(annotation, Poly2d): + annotation_type_str = "poly2d" + + elif isinstance(annotation, Poly3d): + annotation_type_str = "poly3d" + + elif isinstance(annotation, Seg3d): + annotation_type_str = "seg3d" + + else: + raise TypeError + + return annotation_type_str in self.annotation_types diff --git a/tests/filter/test_filter.py b/tests/filter/test_filter.py index df8f386..8da2db2 100644 --- a/tests/filter/test_filter.py +++ b/tests/filter/test_filter.py @@ -92,5 +92,13 @@ def test_exclude_annotation_ids(): ) +def test_include_annotation_type(): + scene = SceneBuilder.empty().add_bbox().add_cuboid().result + filters = [raillabel.filter.IncludeAnnotationTypeFilter(["bbox"])] + + actual = raillabel.filter.filter_(scene, filters) + assert actual == SceneBuilder.empty().add_bbox().result + + if __name__ == "__main__": pytest.main([__file__, "-vv"])