Skip to content

Commit

Permalink
feat: Implement rail side order check
Browse files Browse the repository at this point in the history
  • Loading branch information
nalquas committed Nov 11, 2024
1 parent b509084 commit 23fc8c0
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import numpy as np
import raillabel

from raillabel_providerkit._util._filters import filter_sensor_uids_by_type
Expand Down Expand Up @@ -42,6 +43,9 @@ def validate_rail_side(scene: raillabel.Scene) -> list[str]:
# Count rails per track
counts_per_track = _count_rails_per_track_in_frame(frame)

# Find rail x limits per track
track_limits_per_track = _get_track_limits_in_frame(frame)

# Add errors if there is more than one left or right rail
for object_uid, (left_count, right_count) in counts_per_track.items():
if left_count > 1:
Expand All @@ -55,6 +59,18 @@ def validate_rail_side(scene: raillabel.Scene) -> list[str]:
f" object_uid {object_uid} has more than one ({right_count}) right rail."
)

# If left and right rails exist, check if the track has its rails swapped
if left_count >= 1 and right_count >= 1:
# Add errors if any track has its rails swapped
(max_x_of_left, min_x_of_right) = track_limits_per_track[object_uid]
if max_x_of_left > min_x_of_right:
errors.append(
f"In sensor {camera} frame {frame_uid}, the track with"
f" object_uid {object_uid} has its rails swapped."
f" The right-most left rail has x={max_x_of_left} while"
f" the left-most right rail has x={min_x_of_right}."
)

return errors


Expand All @@ -63,13 +79,11 @@ def _count_rails_per_track_in_frame(frame: raillabel.format.Frame) -> dict[str,
counts: dict[str, tuple[int, int]] = {}

# For each track, count the left and right rails
for object_uid, _annotations in frame.object_data.items():
for object_uid, unfiltered_annotations in frame.object_data.items():
# Ensure we work only on Poly2d annotations
poly2ds: list[raillabel.format.Poly2d] = [
annotation
for annotation in _annotations.values()
if isinstance(annotation, raillabel.format.Poly2d)
]
poly2ds: list[raillabel.format.Poly2d] = _filter_for_poly2ds(
list(unfiltered_annotations.values())
)

# Count left and right rails
left_count: int = 0
Expand All @@ -89,3 +103,46 @@ def _count_rails_per_track_in_frame(frame: raillabel.format.Frame) -> dict[str,

# Return results
return counts


def _get_track_limits_in_frame(frame: raillabel.format.Frame) -> dict[str, tuple[float, float]]:
# For each track, the largest x of any left rail and the smallest x of any right rail is stored
# as a tuple (max_x_of_left, min_x_of_right)
track_limits: dict[str, tuple[float, float]] = {}

for object_uid, unfiltered_annotations in frame.object_data.items():
# Ensure we work only on Poly2d annotations
poly2ds: list[raillabel.format.Poly2d] = _filter_for_poly2ds(
list(unfiltered_annotations.values())
)

# Get the largest x of any left rail and the smallest x of any right rail
max_x_of_left: float = float("-inf")
min_x_of_right: float = float("inf")
for poly2d in poly2ds:
rail_x_values: list[float] = [point.x for point in poly2d.points]
match poly2d.attributes["railSide"]:
case "leftRail":
max_x_of_rail_points: float = np.max(rail_x_values)
max_x_of_left = max(max_x_of_rail_points, max_x_of_left)
case "rightRail":
min_x_of_rail_points: float = np.min(rail_x_values)
min_x_of_right = min(min_x_of_rail_points, min_x_of_right)
case _:
# NOTE: This is ignored because it is covered by validate_onthology
continue

# Store the calculated limits of current track
track_limits[object_uid] = (max_x_of_left, min_x_of_right)

return track_limits


def _filter_for_poly2ds(
unfiltered_annotations: list[type[raillabel.format._ObjectAnnotation]],
) -> list[raillabel.format.Poly2d]:
return [
annotation
for annotation in unfiltered_annotations
if isinstance(annotation, raillabel.format.Poly2d)
]
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from raillabel_providerkit.validation.validate_rail_side.validate_rail_side import (
validate_rail_side,
_count_rails_per_track_in_frame,
_get_track_limits_in_frame,
_filter_for_poly2ds,
)


Expand Down Expand Up @@ -138,6 +140,51 @@ def test_count_rails_per_track_in_frame__many_rails_for_two_tracks(
assert results[object2.uid] == (LEFT_COUNT, RIGHT_COUNT)


def test_get_track_limits_in_frame__empty(empty_frame):
frame = empty_frame
results = _get_track_limits_in_frame(frame)
assert len(results) == 0


def test_get_track_limits_in_frame__one_track_two_rails(
empty_frame, example_camera_1, example_track_1
):
frame = empty_frame
sensor = example_camera_1
object = example_track_1

MAX_X_OF_LEFT = 42
MIN_X_OF_RIGHT = 73

frame.annotations["325b1f55-a2ef-475f-a780-13e1a9e823c3"] = raillabel.format.Poly2d(
uid="325b1f55-a2ef-475f-a780-13e1a9e823c3",
object=object,
sensor=sensor,
points=[
raillabel.format.Point2d(0, 0),
raillabel.format.Point2d(MAX_X_OF_LEFT, 1),
],
closed=False,
attributes={"railSide": "leftRail"},
)
frame.annotations["be7d136a-8364-4fbd-b098-6f4a21205d22"] = raillabel.format.Poly2d(
uid="be7d136a-8364-4fbd-b098-6f4a21205d22",
object=object,
sensor=sensor,
points=[
raillabel.format.Point2d(1000, 0),
raillabel.format.Point2d(MIN_X_OF_RIGHT, 1),
],
closed=False,
attributes={"railSide": "rightRail"},
)

results = _get_track_limits_in_frame(frame)
assert len(results) == 1
assert object.uid in results.keys()
assert results[object.uid] == (MAX_X_OF_LEFT, MIN_X_OF_RIGHT)


def test_validate_rail_side__no_errors(empty_scene, empty_frame, example_camera_1, example_track_1):
scene = empty_scene
object = example_track_1
Expand Down

0 comments on commit 23fc8c0

Please sign in to comment.