Skip to content

Commit

Permalink
feat: implement ExcludeObjectTypeFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
unexcellent committed Nov 16, 2024
1 parent a8b5667 commit ad8bb58
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 0 deletions.
2 changes: 2 additions & 0 deletions raillabel/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .exclude_annotation_type_filter import ExcludeAnnotationTypeFilter
from .exclude_frame_id_filter import ExcludeFrameIdFilter
from .exclude_object_id_filter import ExcludeObjectIdFilter
from .exclude_object_type_filter import ExcludeObjectTypeFilter
from .filter import filter_
from .include_annotation_id_filter import IncludeAnnotationIdFilter
from .include_annotation_type_filter import IncludeAnnotationTypeFilter
Expand All @@ -28,4 +29,5 @@
"IncludeObjectIdFilter",
"ExcludeObjectIdFilter",
"IncludeObjectTypeFilter",
"ExcludeObjectTypeFilter",
]
24 changes: 24 additions & 0 deletions raillabel/filter/exclude_object_type_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright DB InfraGO AG and contributors
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from dataclasses import dataclass
from uuid import UUID

from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Scene, Seg3d

from ._filter_abc import _AnnotationLevelFilter


@dataclass
class ExcludeObjectTypeFilter(_AnnotationLevelFilter):
"""Filter out all annotations in the scene, that do match the type (like 'person')."""

object_types: list[str]

def passes_filter(
self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, scene: Scene
) -> bool:
"""Assess if an annotation passes this filter."""
return scene.objects[annotation.object_id].type not in self.object_types
13 changes: 13 additions & 0 deletions tests/filter/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,5 +151,18 @@ def test_include_object_types():
assert actual == SceneBuilder.empty().add_bbox(object_name="person_0001").result


def test_exclude_object_types():
scene = (
SceneBuilder.empty()
.add_bbox(object_name="person_0001")
.add_cuboid(object_name="train_0001")
.result
)
filters = [raillabel.filter.ExcludeObjectTypeFilter(["train"])]

actual = raillabel.filter.filter_(scene, filters)
assert actual == SceneBuilder.empty().add_bbox(object_name="person_0001").result


if __name__ == "__main__":
pytest.main([__file__, "-vv"])

0 comments on commit ad8bb58

Please sign in to comment.