Skip to content

Commit

Permalink
feat: implement IncludeObjectTypeFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
unexcellent committed Nov 16, 2024
1 parent c61ba1a commit a8b5667
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 21 deletions.
2 changes: 2 additions & 0 deletions raillabel/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .include_annotation_type_filter import IncludeAnnotationTypeFilter
from .include_frame_id_filter import IncludeFrameIdFilter
from .include_object_id_filter import IncludeObjectIdFilter
from .include_object_type_filter import IncludeObjectTypeFilter
from .start_time_filter import StartTimeFilter

__all__ = [
Expand All @@ -26,4 +27,5 @@
"ExcludeAnnotationTypeFilter",
"IncludeObjectIdFilter",
"ExcludeObjectIdFilter",
"IncludeObjectTypeFilter",
]
4 changes: 2 additions & 2 deletions raillabel/filter/_filter_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from uuid import UUID

from raillabel.format import Bbox, Cuboid, Frame, Poly2d, Poly3d, Seg3d
from raillabel.format import Bbox, Cuboid, Frame, Poly2d, Poly3d, Scene, Seg3d


class _FilterAbc(ABC):
Expand All @@ -17,7 +17,7 @@ class _AnnotationLevelFilter(_FilterAbc):

@abstractmethod
def passes_filter(
self, annotation_id: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d
self, annotation_id: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, scene: Scene
) -> bool:
"""Assess if an annotation passes this filter."""
raise NotImplementedError
Expand Down
6 changes: 4 additions & 2 deletions raillabel/filter/exclude_annotation_id_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from uuid import UUID

from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d
from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Scene, Seg3d

from ._filter_abc import _AnnotationLevelFilter

Expand All @@ -17,6 +17,8 @@ class ExcludeAnnotationIdFilter(_AnnotationLevelFilter):

annotation_ids: set[UUID] | list[UUID]

def passes_filter(self, annotation_id: UUID, _: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool:
def passes_filter(
self, annotation_id: UUID, _: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, __: Scene
) -> bool:
"""Assess if an annotation passes this filter."""
return annotation_id not in self.annotation_ids
6 changes: 4 additions & 2 deletions raillabel/filter/exclude_annotation_type_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Literal
from uuid import UUID

from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d
from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Scene, Seg3d

from ._filter_abc import _AnnotationLevelFilter

Expand All @@ -21,7 +21,9 @@ class ExcludeAnnotationTypeFilter(_AnnotationLevelFilter):
| list[Literal["bbox", "cuboid", "poly2d", "poly3d", "seg3d"]]
)

def passes_filter(self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool:
def passes_filter(
self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, __: Scene
) -> bool:
"""Assess if an annotation passes this filter."""
annotation_type_str = None

Expand Down
6 changes: 4 additions & 2 deletions raillabel/filter/exclude_object_id_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from uuid import UUID

from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d
from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Scene, Seg3d

from ._filter_abc import _AnnotationLevelFilter

Expand All @@ -17,6 +17,8 @@ class ExcludeObjectIdFilter(_AnnotationLevelFilter):

object_ids: list[UUID]

def passes_filter(self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool:
def passes_filter(
self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, __: Scene
) -> bool:
"""Assess if an annotation passes this filter."""
return annotation.object_id not in self.object_ids
17 changes: 10 additions & 7 deletions raillabel/filter/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def filter_(scene: Scene, filters: list[_FilterAbc]) -> Scene:
frame_filters, annotation_filters = _separate_filters(filters)

filtered_scene = Scene(metadata=deepcopy(scene.metadata))
filtered_scene.frames = _filter_frames(scene.frames, frame_filters, annotation_filters)
filtered_scene.frames = _filter_frames(scene, frame_filters, annotation_filters)
filtered_scene.sensors = _get_used_sensors(scene, filtered_scene)
filtered_scene.objects = _get_used_objects(scene, filtered_scene)

Expand All @@ -53,31 +53,31 @@ def _separate_filters(


def _filter_frames(
frames: dict[int, Frame],
scene: Scene,
frame_filters: list[_FrameLevelFilter],
annotation_filters: list[_AnnotationLevelFilter],
) -> dict[int, Frame]:
filtered_frames = {}

for frame_id, frame in frames.items():
for frame_id, frame in scene.frames.items():
if _frame_passes_all_filters(frame_id, frame, frame_filters):
filtered_frames[frame_id] = Frame(
timestamp=deepcopy(frame.timestamp),
sensors=deepcopy(frame.sensors),
frame_data=deepcopy(frame.frame_data),
annotations=_filter_annotations(frame, annotation_filters),
annotations=_filter_annotations(frame, annotation_filters, scene),
)

return filtered_frames


def _filter_annotations(
frame: Frame, annotation_filters: list[_AnnotationLevelFilter]
frame: Frame, annotation_filters: list[_AnnotationLevelFilter], scene: Scene
) -> dict[UUID, Bbox | Cuboid | Poly2d | Poly3d | Seg3d]:
annotations = {}

for annotation_id, annotation in frame.annotations.items():
if _annotation_passes_all_filters(annotation_id, annotation, annotation_filters):
if _annotation_passes_all_filters(annotation_id, annotation, annotation_filters, scene):
annotations[annotation_id] = deepcopy(annotation)

return annotations
Expand All @@ -93,8 +93,11 @@ def _annotation_passes_all_filters(
annotation_id: UUID,
annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d,
annotation_filters: list[_AnnotationLevelFilter],
scene: Scene,
) -> bool:
return all(filter_.passes_filter(annotation_id, annotation) for filter_ in annotation_filters)
return all(
filter_.passes_filter(annotation_id, annotation, scene) for filter_ in annotation_filters
)


def _get_used_sensors(
Expand Down
6 changes: 4 additions & 2 deletions raillabel/filter/include_annotation_id_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from uuid import UUID

from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d
from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Scene, Seg3d

from ._filter_abc import _AnnotationLevelFilter

Expand All @@ -17,6 +17,8 @@ class IncludeAnnotationIdFilter(_AnnotationLevelFilter):

annotation_ids: set[UUID] | list[UUID]

def passes_filter(self, annotation_id: UUID, _: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool:
def passes_filter(
self, annotation_id: UUID, _: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, __: Scene
) -> bool:
"""Assess if an annotation passes this filter."""
return annotation_id in self.annotation_ids
6 changes: 4 additions & 2 deletions raillabel/filter/include_annotation_type_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Literal
from uuid import UUID

from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d
from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Scene, Seg3d

from ._filter_abc import _AnnotationLevelFilter

Expand All @@ -21,7 +21,9 @@ class IncludeAnnotationTypeFilter(_AnnotationLevelFilter):
| list[Literal["bbox", "cuboid", "poly2d", "poly3d", "seg3d"]]
)

def passes_filter(self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool:
def passes_filter(
self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, __: Scene
) -> bool:
"""Assess if an annotation passes this filter."""
annotation_type_str = None

Expand Down
6 changes: 4 additions & 2 deletions raillabel/filter/include_object_id_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from uuid import UUID

from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d
from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Scene, Seg3d

from ._filter_abc import _AnnotationLevelFilter

Expand All @@ -17,6 +17,8 @@ class IncludeObjectIdFilter(_AnnotationLevelFilter):

object_ids: list[UUID]

def passes_filter(self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool:
def passes_filter(
self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, __: Scene
) -> bool:
"""Assess if an annotation passes this filter."""
return annotation.object_id in self.object_ids
24 changes: 24 additions & 0 deletions raillabel/filter/include_object_type_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright DB InfraGO AG and contributors
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from dataclasses import dataclass
from uuid import UUID

from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Scene, Seg3d

from ._filter_abc import _AnnotationLevelFilter


@dataclass
class IncludeObjectTypeFilter(_AnnotationLevelFilter):
"""Filter out all annotations in the scene, that do NOT match the type (like 'person')."""

object_types: list[str]

def passes_filter(
self, _: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d, scene: Scene
) -> bool:
"""Assess if an annotation passes this filter."""
return scene.objects[annotation.object_id].type in self.object_types
13 changes: 13 additions & 0 deletions tests/filter/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,5 +138,18 @@ def test_exclude_object_ids():
assert actual == SceneBuilder.empty().add_bbox(object_name="person_0001").result


def test_include_object_types():
scene = (
SceneBuilder.empty()
.add_bbox(object_name="person_0001")
.add_cuboid(object_name="train_0001")
.result
)
filters = [raillabel.filter.IncludeObjectTypeFilter(["person"])]

actual = raillabel.filter.filter_(scene, filters)
assert actual == SceneBuilder.empty().add_bbox(object_name="person_0001").result


if __name__ == "__main__":
pytest.main([__file__, "-vv"])

0 comments on commit a8b5667

Please sign in to comment.