Skip to content

Commit

Permalink
feat: implement ExcludeObjectIdFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
unexcellent committed Nov 16, 2024
1 parent c565dfc commit c61ba1a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 0 deletions.
2 changes: 2 additions & 0 deletions raillabel/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .exclude_annotation_id_filter import ExcludeAnnotationIdFilter
from .exclude_annotation_type_filter import ExcludeAnnotationTypeFilter
from .exclude_frame_id_filter import ExcludeFrameIdFilter
from .exclude_object_id_filter import ExcludeObjectIdFilter
from .filter import filter_
from .include_annotation_id_filter import IncludeAnnotationIdFilter
from .include_annotation_type_filter import IncludeAnnotationTypeFilter
Expand All @@ -24,4 +25,5 @@
"IncludeAnnotationTypeFilter",
"ExcludeAnnotationTypeFilter",
"IncludeObjectIdFilter",
"ExcludeObjectIdFilter",
]
22 changes: 22 additions & 0 deletions raillabel/filter/exclude_object_id_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright DB InfraGO AG and contributors
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from dataclasses import dataclass
from uuid import UUID

from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d

from ._filter_abc import _AnnotationLevelFilter


@dataclass
class ExcludeObjectIdFilter(_AnnotationLevelFilter):
"""Filter out all annotations in the scene, that do have matching object ids."""

object_ids: list[UUID]

def passes_filter(self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool:
"""Assess if an annotation passes this filter."""
return annotation.object_id not in self.object_ids
15 changes: 15 additions & 0 deletions tests/filter/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,20 @@ def test_include_object_ids():
assert actual == SceneBuilder.empty().add_bbox(object_name="person_0001").result


def test_exclude_object_ids():
scene = (
SceneBuilder.empty()
.add_bbox(object_name="person_0001")
.add_cuboid(object_name="train_0001")
.result
)
filters = [
raillabel.filter.ExcludeObjectIdFilter([UUID("5c59aad4-0000-4000-0000-000000000001")])
]

actual = raillabel.filter.filter_(scene, filters)
assert actual == SceneBuilder.empty().add_bbox(object_name="person_0001").result


if __name__ == "__main__":
pytest.main([__file__, "-vv"])

0 comments on commit c61ba1a

Please sign in to comment.