Skip to content

Commit

Permalink
refactor: filter_
Browse files Browse the repository at this point in the history
  • Loading branch information
unexcellent committed Nov 16, 2024
1 parent 02db99a commit f41ce29
Showing 1 changed file with 72 additions and 36 deletions.
108 changes: 72 additions & 36 deletions raillabel/filter/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,50 +6,33 @@
from copy import deepcopy
from uuid import UUID

from raillabel.format import Bbox, Cuboid, Frame, Poly2d, Poly3d, Scene, Seg3d
from raillabel.format import (
Bbox,
Camera,
Cuboid,
Frame,
GpsImu,
Lidar,
Object,
OtherSensor,
Poly2d,
Poly3d,
Radar,
Scene,
Seg3d,
)

from ._filter_abc import _AnnotationLevelFilter, _FilterAbc, _FrameLevelFilter


def filter_(scene: Scene, filters: list[_FilterAbc]) -> Scene:
"""Return a scene with filters applied to annotations, frame, sensors and objects."""
filtered_scene = Scene(metadata=deepcopy(scene.metadata))

used_sensor_ids = set()
used_object_ids = set()

frame_filters, annotation_filters = _separate_filters(filters)

for frame_id, frame in scene.frames.items():
if not _frame_passes_all_filters(frame_id, frame, frame_filters):
continue

filtered_frame = Frame(
timestamp=deepcopy(frame.timestamp),
sensors=deepcopy(frame.sensors),
frame_data=deepcopy(frame.frame_data),
)

for annotation_id, annotation in frame.annotations.items():
if not _annotation_passes_all_filters(annotation_id, annotation, annotation_filters):
continue

filtered_frame.annotations[annotation_id] = deepcopy(annotation)
used_sensor_ids.add(annotation.sensor_id)
used_object_ids.add(annotation.object_id)

filtered_scene.frames[frame_id] = filtered_frame

filtered_scene.sensors = {
sensor_id: deepcopy(sensor)
for sensor_id, sensor in scene.sensors.items()
if sensor_id in used_sensor_ids
}
filtered_scene.objects = {
object_id: deepcopy(object_)
for object_id, object_ in scene.objects.items()
if object_id in used_object_ids
}
filtered_scene = Scene(metadata=deepcopy(scene.metadata))
filtered_scene.frames = _filter_frames(scene.frames, frame_filters, annotation_filters)
filtered_scene.sensors = _get_used_sensors(scene, filtered_scene)
filtered_scene.objects = _get_used_objects(scene, filtered_scene)

return filtered_scene

Expand All @@ -69,6 +52,37 @@ def _separate_filters(
return frame_filters, annotation_filters


def _filter_frames(
frames: dict[int, Frame],
frame_filters: list[_FrameLevelFilter],
annotation_filters: list[_AnnotationLevelFilter],
) -> dict[int, Frame]:
filtered_frames = {}

for frame_id, frame in frames.items():
if _frame_passes_all_filters(frame_id, frame, frame_filters):
filtered_frames[frame_id] = Frame(
timestamp=deepcopy(frame.timestamp),
sensors=deepcopy(frame.sensors),
frame_data=deepcopy(frame.frame_data),
annotations=_filter_annotations(frame, annotation_filters),
)

return filtered_frames


def _filter_annotations(
frame: Frame, annotation_filters: list[_AnnotationLevelFilter]
) -> dict[UUID, Bbox | Cuboid | Poly2d | Poly3d | Seg3d]:
annotations = {}

for annotation_id, annotation in frame.annotations.items():
if _annotation_passes_all_filters(annotation_id, annotation, annotation_filters):
annotations[annotation_id] = deepcopy(annotation)

return annotations


def _frame_passes_all_filters(
frame_id: int, frame: Frame, frame_filters: list[_FrameLevelFilter]
) -> bool:
Expand All @@ -81,3 +95,25 @@ def _annotation_passes_all_filters(
annotation_filters: list[_AnnotationLevelFilter],
) -> bool:
return all(filter_.passes_filter(annotation_id, annotation) for filter_ in annotation_filters)


def _get_used_sensors(
scene: Scene, filtered_scene: Scene
) -> dict[str, Camera | Lidar | Radar | GpsImu | OtherSensor]:
used_sensors = {}
for frame in filtered_scene.frames.values():
for annotation in frame.annotations.values():
if annotation.sensor_id not in used_sensors:
used_sensors[annotation.sensor_id] = deepcopy(scene.sensors[annotation.sensor_id])

return used_sensors


def _get_used_objects(scene: Scene, filtered_scene: Scene) -> dict[UUID, Object]:
used_objects = {}
for frame in filtered_scene.frames.values():
for annotation in frame.annotations.values():
if annotation.object_id not in used_objects:
used_objects[annotation.object_id] = deepcopy(scene.objects[annotation.object_id])

return used_objects

0 comments on commit f41ce29

Please sign in to comment.