diff --git a/raillabel/filter/__init__.py b/raillabel/filter/__init__.py index 4d90bb1..80c61f1 100644 --- a/raillabel/filter/__init__.py +++ b/raillabel/filter/__init__.py @@ -10,6 +10,7 @@ from .include_annotation_id_filter import IncludeAnnotationIdFilter from .include_annotation_type_filter import IncludeAnnotationTypeFilter from .include_frame_id_filter import IncludeFrameIdFilter +from .include_object_id_filter import IncludeObjectIdFilter from .start_time_filter import StartTimeFilter __all__ = [ @@ -22,4 +23,5 @@ "ExcludeAnnotationIdFilter", "IncludeAnnotationTypeFilter", "ExcludeAnnotationTypeFilter", + "IncludeObjectIdFilter", ] diff --git a/raillabel/filter/include_object_id_filter.py b/raillabel/filter/include_object_id_filter.py new file mode 100644 index 0000000..96aa7e5 --- /dev/null +++ b/raillabel/filter/include_object_id_filter.py @@ -0,0 +1,22 @@ +# Copyright DB InfraGO AG and contributors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from uuid import UUID + +from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d + +from ._filter_abc import _AnnotationLevelFilter + + +@dataclass +class IncludeObjectIdFilter(_AnnotationLevelFilter): + """Filter out all annotations in the scene, that do NOT have matching object ids.""" + + object_ids: list[UUID] + + def passes_filter(self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool: + """Assess if an annotation passes this filter.""" + return annotation.object_id in self.object_ids diff --git a/tests/filter/test_filter.py b/tests/filter/test_filter.py index 8671a42..1b6d5f0 100644 --- a/tests/filter/test_filter.py +++ b/tests/filter/test_filter.py @@ -108,5 +108,20 @@ def test_exclude_annotation_type(): assert actual == SceneBuilder.empty().add_bbox().result +def test_include_object_ids(): + scene = ( + SceneBuilder.empty() + .add_bbox(object_name="person_0001") + .add_cuboid(object_name="train_0001") + .result + ) + filters = [ + raillabel.filter.IncludeObjectIdFilter([UUID("5c59aad4-0000-4000-0000-000000000000")]) + ] + + actual = raillabel.filter.filter_(scene, filters) + assert actual == SceneBuilder.empty().add_bbox(object_name="person_0001").result + + if __name__ == "__main__": pytest.main([__file__, "-vv"])