diff --git a/raillabel/filter/__init__.py b/raillabel/filter/__init__.py index 80c61f1..370f3f5 100644 --- a/raillabel/filter/__init__.py +++ b/raillabel/filter/__init__.py @@ -6,6 +6,7 @@ from .exclude_annotation_id_filter import ExcludeAnnotationIdFilter from .exclude_annotation_type_filter import ExcludeAnnotationTypeFilter from .exclude_frame_id_filter import ExcludeFrameIdFilter +from .exclude_object_id_filter import ExcludeObjectIdFilter from .filter import filter_ from .include_annotation_id_filter import IncludeAnnotationIdFilter from .include_annotation_type_filter import IncludeAnnotationTypeFilter @@ -24,4 +25,5 @@ "IncludeAnnotationTypeFilter", "ExcludeAnnotationTypeFilter", "IncludeObjectIdFilter", + "ExcludeObjectIdFilter", ] diff --git a/raillabel/filter/exclude_object_id_filter.py b/raillabel/filter/exclude_object_id_filter.py new file mode 100644 index 0000000..9ed7963 --- /dev/null +++ b/raillabel/filter/exclude_object_id_filter.py @@ -0,0 +1,22 @@ +# Copyright DB InfraGO AG and contributors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from uuid import UUID + +from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d + +from ._filter_abc import _AnnotationLevelFilter + + +@dataclass +class ExcludeObjectIdFilter(_AnnotationLevelFilter): + """Filter out all annotations in the scene, that do have matching object ids.""" + + object_ids: list[UUID] + + def passes_filter(self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool: + """Assess if an annotation passes this filter.""" + return annotation.object_id not in self.object_ids diff --git a/tests/filter/test_filter.py b/tests/filter/test_filter.py index 1b6d5f0..f196dcf 100644 --- a/tests/filter/test_filter.py +++ b/tests/filter/test_filter.py @@ -123,5 +123,20 @@ def test_include_object_ids(): assert actual == SceneBuilder.empty().add_bbox(object_name="person_0001").result +def test_exclude_object_ids(): + scene = ( + SceneBuilder.empty() + .add_bbox(object_name="person_0001") + .add_cuboid(object_name="train_0001") + .result + ) + filters = [ + raillabel.filter.ExcludeObjectIdFilter([UUID("5c59aad4-0000-4000-0000-000000000001")]) + ] + + actual = raillabel.filter.filter_(scene, filters) + assert actual == SceneBuilder.empty().add_bbox(object_name="person_0001").result + + if __name__ == "__main__": pytest.main([__file__, "-vv"])