From fd3cc1f99de11276b980e87f55a1ead10765518d Mon Sep 17 00:00:00 2001 From: unexcellent <> Date: Sat, 16 Nov 2024 17:36:34 +0100 Subject: [PATCH] feat: implement ExcludeAnnotationTypeFilter --- raillabel/filter/__init__.py | 2 + .../filter/exclude_annotation_type_filter.py | 46 +++++++++++++++++++ tests/filter/test_filter.py | 8 ++++ 3 files changed, 56 insertions(+) create mode 100644 raillabel/filter/exclude_annotation_type_filter.py diff --git a/raillabel/filter/__init__.py b/raillabel/filter/__init__.py index bc7e92c..4d90bb1 100644 --- a/raillabel/filter/__init__.py +++ b/raillabel/filter/__init__.py @@ -4,6 +4,7 @@ from .end_time_filter import EndTimeFilter from .exclude_annotation_id_filter import ExcludeAnnotationIdFilter +from .exclude_annotation_type_filter import ExcludeAnnotationTypeFilter from .exclude_frame_id_filter import ExcludeFrameIdFilter from .filter import filter_ from .include_annotation_id_filter import IncludeAnnotationIdFilter @@ -20,4 +21,5 @@ "IncludeAnnotationIdFilter", "ExcludeAnnotationIdFilter", "IncludeAnnotationTypeFilter", + "ExcludeAnnotationTypeFilter", ] diff --git a/raillabel/filter/exclude_annotation_type_filter.py b/raillabel/filter/exclude_annotation_type_filter.py new file mode 100644 index 0000000..89e2357 --- /dev/null +++ b/raillabel/filter/exclude_annotation_type_filter.py @@ -0,0 +1,46 @@ +# Copyright DB InfraGO AG and contributors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal +from uuid import UUID + +from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d + +from ._filter_abc import _AnnotationLevelFilter + + +@dataclass +class ExcludeAnnotationTypeFilter(_AnnotationLevelFilter): + """Filter out all annotations in the scene, that do have the type (like bbox or cuboid).""" + + annotation_types: ( + set[Literal["bbox", "cuboid", "poly2d", "poly3d", "seg3d"]] + | list[Literal["bbox", "cuboid", "poly2d", "poly3d", "seg3d"]] + ) + + def passes_filter(self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool: + """Assess if an annotation passes this filter.""" + annotation_type_str = None + + if isinstance(annotation, Bbox): + annotation_type_str = "bbox" + + elif isinstance(annotation, Cuboid): + annotation_type_str = "cuboid" + + elif isinstance(annotation, Poly2d): + annotation_type_str = "poly2d" + + elif isinstance(annotation, Poly3d): + annotation_type_str = "poly3d" + + elif isinstance(annotation, Seg3d): + annotation_type_str = "seg3d" + + else: + raise TypeError + + return annotation_type_str not in self.annotation_types diff --git a/tests/filter/test_filter.py b/tests/filter/test_filter.py index 8da2db2..8671a42 100644 --- a/tests/filter/test_filter.py +++ b/tests/filter/test_filter.py @@ -100,5 +100,13 @@ def test_include_annotation_type(): assert actual == SceneBuilder.empty().add_bbox().result +def test_exclude_annotation_type(): + scene = SceneBuilder.empty().add_bbox().add_cuboid().result + filters = [raillabel.filter.ExcludeAnnotationTypeFilter(["cuboid"])] + + actual = raillabel.filter.filter_(scene, filters) + assert actual == SceneBuilder.empty().add_bbox().result + + if __name__ == "__main__": pytest.main([__file__, "-vv"])