diff --git a/raillabel/filter/__init__.py b/raillabel/filter/__init__.py index 8fb800a..1d2d5a4 100644 --- a/raillabel/filter/__init__.py +++ b/raillabel/filter/__init__.py @@ -7,6 +7,7 @@ from .exclude_annotation_type_filter import ExcludeAnnotationTypeFilter from .exclude_frame_id_filter import ExcludeFrameIdFilter from .exclude_object_id_filter import ExcludeObjectIdFilter +from .exclude_object_type_filter import ExcludeObjectTypeFilter from .filter import filter_ from .include_annotation_id_filter import IncludeAnnotationIdFilter from .include_annotation_type_filter import IncludeAnnotationTypeFilter @@ -28,4 +29,5 @@ "IncludeObjectIdFilter", "ExcludeObjectIdFilter", "IncludeObjectTypeFilter", + "ExcludeObjectTypeFilter", ] diff --git a/raillabel/filter/exclude_object_type_filter.py b/raillabel/filter/exclude_object_type_filter.py new file mode 100644 index 0000000..bdd26d9 --- /dev/null +++ b/raillabel/filter/exclude_object_type_filter.py @@ -0,0 +1,24 @@ +# Copyright DB InfraGO AG and contributors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from uuid import UUID + +from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Scene, Seg3d + +from ._filter_abc import _AnnotationLevelFilter + + +@dataclass +class ExcludeObjectTypeFilter(_AnnotationLevelFilter): + """Filter out all annotations in the scene, that do match the type (like 'person').""" + + object_types: list[str] + + def passes_filter( + self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, scene: Scene + ) -> bool: + """Assess if an annotation passes this filter.""" + return scene.objects[annotation.object_id].type not in self.object_types diff --git a/tests/filter/test_filter.py b/tests/filter/test_filter.py index 5a20f43..3e73045 100644 --- a/tests/filter/test_filter.py +++ b/tests/filter/test_filter.py @@ -151,5 +151,18 @@ def test_include_object_types(): assert actual == SceneBuilder.empty().add_bbox(object_name="person_0001").result +def test_exclude_object_types(): + scene = ( + SceneBuilder.empty() + .add_bbox(object_name="person_0001") + .add_cuboid(object_name="train_0001") + .result + ) + filters = [raillabel.filter.ExcludeObjectTypeFilter(["train"])] + + actual = raillabel.filter.filter_(scene, filters) + assert actual == SceneBuilder.empty().add_bbox(object_name="person_0001").result + + if __name__ == "__main__": pytest.main([__file__, "-vv"])