diff --git a/raillabel/filter/__init__.py b/raillabel/filter/__init__.py index 260c5a9..e5a1aee 100644 --- a/raillabel/filter/__init__.py +++ b/raillabel/filter/__init__.py @@ -3,6 +3,7 @@ """Package for the raillabel filter functionality.""" from .end_time_filter import EndTimeFilter +from .exclude_annotation_id_filter import ExcludeAnnotationIdFilter from .exclude_frame_id_filter import ExcludeFrameIdFilter from .filter import filter_ from .include_annotation_id_filter import IncludeAnnotationIdFilter @@ -16,4 +17,5 @@ "StartTimeFilter", "EndTimeFilter", "IncludeAnnotationIdFilter", + "ExcludeAnnotationIdFilter", ] diff --git a/raillabel/filter/exclude_annotation_id_filter.py b/raillabel/filter/exclude_annotation_id_filter.py new file mode 100644 index 0000000..7973e50 --- /dev/null +++ b/raillabel/filter/exclude_annotation_id_filter.py @@ -0,0 +1,22 @@ +# Copyright DB InfraGO AG and contributors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from uuid import UUID + +from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d + +from ._filter_abc import _AnnotationLevelFilter + + +@dataclass +class ExcludeAnnotationIdFilter(_AnnotationLevelFilter): + """Filter out all annotations in the scene, that do have disallowed ids.""" + + annotation_ids: set[UUID] | list[UUID] + + def passes_filter(self, annotation_id: UUID, _: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool: + """Assess if an annotation passes this filter.""" + return annotation_id not in self.annotation_ids diff --git a/tests/filter/test_filter.py b/tests/filter/test_filter.py index 4a4ccc3..df8f386 100644 --- a/tests/filter/test_filter.py +++ b/tests/filter/test_filter.py @@ -70,5 +70,27 @@ def test_include_annotation_ids(): ) +def test_exclude_annotation_ids(): + scene = ( + SceneBuilder.empty() + .add_bbox(uid="6c95543d-0000-4000-0000-000000000000") + .add_bbox(uid="6c95543d-0000-4000-0000-000000000001") + .add_bbox(uid="6c95543d-0000-4000-0000-000000000002") + .result + ) + filters = [ + raillabel.filter.ExcludeAnnotationIdFilter([UUID("6c95543d-0000-4000-0000-000000000001")]) + ] + + actual = raillabel.filter.filter_(scene, filters) + assert ( + actual + == SceneBuilder.empty() + .add_bbox(uid="6c95543d-0000-4000-0000-000000000000") + .add_bbox(uid="6c95543d-0000-4000-0000-000000000002") + .result + ) + + if __name__ == "__main__": pytest.main([__file__, "-vv"])