Skip to content

Commit

Permalink
feat: implement IncludeAnnotationIdFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
unexcellent committed Nov 16, 2024
1 parent 82327b5 commit 0453c81
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 13 deletions.
2 changes: 2 additions & 0 deletions raillabel/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .end_time_filter import EndTimeFilter
from .exclude_frame_id_filter import ExcludeFrameIdFilter
from .filter import filter_
from .include_annotation_id_filter import IncludeAnnotationIdFilter
from .include_frame_id_filter import IncludeFrameIdFilter
from .start_time_filter import StartTimeFilter

Expand All @@ -14,4 +15,5 @@
"ExcludeFrameIdFilter",
"StartTimeFilter",
"EndTimeFilter",
"IncludeAnnotationIdFilter",
]
15 changes: 14 additions & 1 deletion raillabel/filter/_filter_abc.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
# Copyright DB InfraGO AG and contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from abc import ABC, abstractmethod
from uuid import UUID

from raillabel.format import Frame
from raillabel.format import Bbox, Cuboid, Frame, Poly2d, Poly3d, Seg3d


class _FilterAbc(ABC):
"""Base class of all filter classes regardless of level."""


class _AnnotationLevelFilter(_FilterAbc):
"""Base class of all filter classes applied to the annotations."""

@abstractmethod
def passes_filter(
self, annotation_id: UUID, annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d
) -> bool:
"""Assess if an annotation passes this filter."""
raise NotImplementedError


class _FrameLevelFilter(_FilterAbc):
"""Base class of all filter classes applied to the frames."""

Expand Down
57 changes: 48 additions & 9 deletions raillabel/filter/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,66 @@
from __future__ import annotations

from copy import deepcopy
from uuid import UUID

from raillabel.format import Scene
from raillabel.format import Bbox, Cuboid, Frame, Poly2d, Poly3d, Scene, Seg3d

from ._filter_abc import _FrameLevelFilter
from ._filter_abc import _AnnotationLevelFilter, _FilterAbc, _FrameLevelFilter


def filter_(scene: Scene, filters: list[_FrameLevelFilter]) -> Scene:
def filter_(scene: Scene, filters: list[_FilterAbc]) -> Scene:
"""Return a scene with filters applied to annotations, frame, sensors and objects."""
filtered_scene = Scene(
metadata=deepcopy(scene.metadata),
sensors=deepcopy(scene.sensors),
objects=deepcopy(scene.objects),
)

frame_filters, annotation_filters = _separate_filters(filters)

for frame_id, frame in scene.frames.items():
frame_passes_filters = True
for filter_ in filters:
if not filter_.passes_filter(frame_id, frame):
frame_passes_filters = False
if not _frame_passes_all_filters(frame_id, frame, frame_filters):
continue

filtered_frame = Frame(
timestamp=deepcopy(frame.timestamp),
sensors=deepcopy(frame.sensors),
frame_data=deepcopy(frame.frame_data),
)

for annotation_id, annotation in frame.annotations.items():
if _annotation_passes_all_filters(annotation_id, annotation, annotation_filters):
filtered_frame.annotations[annotation_id] = deepcopy(annotation)

if frame_passes_filters:
filtered_scene.frames[frame_id] = deepcopy(frame)
filtered_scene.frames[frame_id] = filtered_frame

return filtered_scene


def _separate_filters(
all_filters: list[_FilterAbc],
) -> tuple[list[_FrameLevelFilter], list[_AnnotationLevelFilter]]:
frame_filters = []
annotation_filters = []
for filter_ in all_filters:
if isinstance(filter_, _FrameLevelFilter):
frame_filters.append(filter_)

if isinstance(filter_, _AnnotationLevelFilter):
annotation_filters.append(filter_)

return frame_filters, annotation_filters


def _frame_passes_all_filters(
frame_id: int, frame: Frame, frame_filters: list[_FrameLevelFilter]
) -> bool:
return all(filter_.passes_filter(frame_id, frame) for filter_ in frame_filters)


def _annotation_passes_all_filters(
annotation_id: UUID,
annotation: Bbox | Cuboid | Poly2d | Poly3d | Seg3d,
annotation_filters: list[_AnnotationLevelFilter],
) -> bool:
return all(filter_.passes_filter(annotation_id, annotation) for filter_ in annotation_filters)
22 changes: 22 additions & 0 deletions raillabel/filter/include_annotation_id_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright DB InfraGO AG and contributors
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from dataclasses import dataclass
from uuid import UUID

from raillabel.format import Bbox, Cuboid, Poly2d, Poly3d, Seg3d

from ._filter_abc import _AnnotationLevelFilter


@dataclass
class IncludeAnnotationIdFilter(_AnnotationLevelFilter):
"""Filter out all annotations in the scene, that do NOT have the correct ids."""

annotation_ids: set[UUID] | list[UUID]

def passes_filter(self, annotation_id: UUID, _: Bbox | Cuboid | Poly2d | Poly3d | Seg3d) -> bool:
"""Assess if an annotation passes this filter."""
return annotation_id in self.annotation_ids
35 changes: 32 additions & 3 deletions tests/filter/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@

from __future__ import annotations

from uuid import UUID

import pytest

import raillabel
from raillabel.scene_builder import SceneBuilder


def test_include_frames():
def test_include_frame_ids():
scene = SceneBuilder.empty().add_frame(1).add_frame(2).add_frame(3).result
filters = [raillabel.filter.IncludeFrameIdFilter([1, 3])]

actual = raillabel.filter.filter_(scene, filters)
assert actual == SceneBuilder.empty().add_frame(1).add_frame(3).result


def test_exclude_frames():
def test_exclude_frame_ids():
scene = SceneBuilder.empty().add_frame(1).add_frame(2).add_frame(3).result
filters = [raillabel.filter.ExcludeFrameIdFilter([2])]

Expand All @@ -33,13 +35,40 @@ def test_start_time():
assert actual == SceneBuilder.empty().add_frame(2, 200).add_frame(3, 300).result


def test_start_time():
def test_end_time():
scene = SceneBuilder.empty().add_frame(1, 100).add_frame(2, 200).add_frame(3, 300).result
filters = [raillabel.filter.EndTimeFilter(250)]

actual = raillabel.filter.filter_(scene, filters)
assert actual == SceneBuilder.empty().add_frame(1, 100).add_frame(2, 200).result


def test_include_annotation_ids():
scene = (
SceneBuilder.empty()
.add_bbox(uid="6c95543d-0000-4000-0000-000000000000")
.add_bbox(uid="6c95543d-0000-4000-0000-000000000001")
.add_bbox(uid="6c95543d-0000-4000-0000-000000000002")
.result
)
filters = [
raillabel.filter.IncludeAnnotationIdFilter(
[
UUID("6c95543d-0000-4000-0000-000000000000"),
UUID("6c95543d-0000-4000-0000-000000000002"),
]
)
]

actual = raillabel.filter.filter_(scene, filters)
assert (
actual
== SceneBuilder.empty()
.add_bbox(uid="6c95543d-0000-4000-0000-000000000000")
.add_bbox(uid="6c95543d-0000-4000-0000-000000000002")
.result
)


if __name__ == "__main__":
pytest.main([__file__, "-vv"])

0 comments on commit 0453c81

Please sign in to comment.