Skip to content

Commit

Permalink
feat: implement ExcludeAnnotationIdFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
unexcellent committed Nov 16, 2024
1 parent 0453c81 commit 2cc8bd3
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
2 changes: 2 additions & 0 deletions raillabel/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Package for the raillabel filter functionality."""

from .end_time_filter import EndTimeFilter
from .exclude_annotation_id_filter import ExcludeAnnotationIdFilter
from .exclude_frame_id_filter import ExcludeFrameIdFilter
from .filter import filter_
from .include_annotation_id_filter import IncludeAnnotationIdFilter
Expand All @@ -16,4 +17,5 @@
"StartTimeFilter",
"EndTimeFilter",
"IncludeAnnotationIdFilter",
"ExcludeAnnotationIdFilter",
]
22 changes: 22 additions & 0 deletions raillabel/filter/exclude_annotation_id_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright DB InfraGO AG and contributors
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from dataclasses import dataclass
from uuid import UUID

from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d

from ._filter_abc import _AnnotationLevelFilter


@dataclass
class ExcludeAnnotationIdFilter(_AnnotationLevelFilter):
"""Filter out all annotations in the scene, that do have disallowed ids."""

annotation_ids: set[UUID] | list[UUID]

def passes_filter(self, annotation_id: UUID, _: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool:
"""Assess if an annotation passes this filter."""
return annotation_id not in self.annotation_ids
22 changes: 22 additions & 0 deletions tests/filter/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,27 @@ def test_include_annotation_ids():
)


def test_exclude_annotation_ids():
scene = (
SceneBuilder.empty()
.add_bbox(uid="6c95543d-0000-4000-0000-000000000000")
.add_bbox(uid="6c95543d-0000-4000-0000-000000000001")
.add_bbox(uid="6c95543d-0000-4000-0000-000000000002")
.result
)
filters = [
raillabel.filter.ExcludeAnnotationIdFilter([UUID("6c95543d-0000-4000-0000-000000000001")])
]

actual = raillabel.filter.filter_(scene, filters)
assert (
actual
== SceneBuilder.empty()
.add_bbox(uid="6c95543d-0000-4000-0000-000000000000")
.add_bbox(uid="6c95543d-0000-4000-0000-000000000002")
.result
)


if __name__ == "__main__":
pytest.main([__file__, "-vv"])

0 comments on commit 2cc8bd3

Please sign in to comment.