Skip to content

Commit

Permalink
fix: adapt first test to new raillabel version
Browse files Browse the repository at this point in the history
  • Loading branch information
tklockau committed Nov 18, 2024
1 parent 02d622f commit 8d112fb
Show file tree
Hide file tree
Showing 8 changed files with 490 additions and 1,293 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ classifiers = [
dependencies = [
"jsonschema>=4.4.0",
"fastjsonschema>=2.16.2",
"raillabel>=3.1.0, <4.0.0",
"raillabel>=4.0.0",
"pyyaml>=6.0.0",
"numpy>=1.24.4",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import raillabel
from raillabel.filter import IncludeObjectTypeFilter, IncludeSensorIdFilter, IncludeAnnotationTypeFilter, IncludeSensorTypeFilter, IncludeAttributesFilter

from raillabel_providerkit._util._filters import filter_sensor_uids_by_type

Expand All @@ -27,16 +28,16 @@ def validate_rail_side(scene: raillabel.Scene) -> list[str]:
errors: list[str] = []

# Get a list of camera uids
cameras = filter_sensor_uids_by_type(
list(scene.sensors.values()), raillabel.format.SensorType.CAMERA
)
cameras = list(scene.filter([IncludeSensorTypeFilter("camera")]).sensors.keys())

# Check per camera
for camera in cameras:
# Filter scene for track annotations in the selected camera sensor
filtered_scene = raillabel.filter(
scene, include_object_types=["track"], include_sensors=[camera]
)
filtered_scene = scene.filter([
IncludeObjectTypeFilter(["track"]),
IncludeSensorIdFilter([camera]),
IncludeAnnotationTypeFilter(["poly2d"]),
])

# Check per frame
for frame_uid, frame in filtered_scene.frames.items():
Expand All @@ -45,33 +46,34 @@ def validate_rail_side(scene: raillabel.Scene) -> list[str]:

# Add errors if there is more than one left or right rail
for object_uid, (left_count, right_count) in counts_per_track.items():
if left_count > 1:
errors.append(
f"In sensor {camera} frame {frame_uid}, the track with"
f" object_uid {object_uid} has more than one ({left_count}) left rail."
)
if right_count > 1:
errors.append(
f"In sensor {camera} frame {frame_uid}, the track with"
f" object_uid {object_uid} has more than one ({right_count}) right rail."
)
if left_count > 1 or right_count > 1:
if left_count > 1:
errors.append(
f"In sensor {camera} frame {frame_uid}, the track with"
f" object_uid {object_uid} has more than one ({left_count}) left rail."
)
if right_count > 1:
errors.append(
f"In sensor {camera} frame {frame_uid}, the track with"
f" object_uid {object_uid} has more than one ({right_count}) right rail."
)
continue

# If exactly one left and right rail exists, check if the track has its rails swapped
# or intersects with itself
if left_count == 1 == right_count:
# Get the two annotations in question
left_rail: raillabel.format.Poly2d | None = _get_track_from_frame(
frame, object_uid, "leftRail"
)
right_rail: raillabel.format.Poly2d | None = _get_track_from_frame(
frame, object_uid, "rightRail"
)
if left_rail is None or right_rail is None:
continue

swap_error: str | None = _check_rails_for_swap(left_rail, right_rail, frame_uid)
if swap_error is not None:
errors.append(swap_error)
# Get the two annotations in question
left_rail: raillabel.format.Poly2d | None = _get_track_from_frame(
frame, object_uid, "leftRail"
)
right_rail: raillabel.format.Poly2d | None = _get_track_from_frame(
frame, object_uid, "rightRail"
)
if left_rail is None or right_rail is None:
continue

swap_error: str | None = _check_rails_for_swap(left_rail, right_rail, frame_uid)
if swap_error is not None:
errors.append(swap_error)

return errors

Expand All @@ -82,7 +84,7 @@ def _check_rails_for_swap(
frame_uid: str | int = "unknown",
) -> str | None:
# Ensure the rails belong to the same track
if left_rail.object.uid != right_rail.object.uid:
if left_rail.object_id != right_rail.object_id:
return None

max_common_y = _find_max_common_y(left_rail, right_rail)
Expand All @@ -94,8 +96,8 @@ def _check_rails_for_swap(
if left_x is None or right_x is None:
return None

object_uid = left_rail.object.uid
sensor_uid = left_rail.sensor.uid if left_rail.sensor is not None else "unknown"
object_uid = left_rail.object_id
sensor_uid = left_rail.sensor_id if left_rail.sensor_id is not None else "unknown"

if left_x >= right_x:
return (
Expand All @@ -117,34 +119,31 @@ def _check_rails_for_swap(


def _count_rails_per_track_in_frame(frame: raillabel.format.Frame) -> dict[str, tuple[int, int]]:
# For each track, the left and right rail counts are stored as a tuple (left, right)
counts: dict[str, tuple[int, int]] = {}
# For each track, the left and right rail counts are stored as a list (left, right)
counts: dict[str, list[int, int]] = {}

# For each track, count the left and right rails
for object_uid, unfiltered_annotations in frame.object_data.items():
# Ensure we work only on Poly2d annotations
poly2ds: list[raillabel.format.Poly2d] = _filter_for_poly2ds(
list(unfiltered_annotations.values())
)

# Count left and right rails
left_count: int = 0
right_count: int = 0
for poly2d in poly2ds:
rail_side = poly2d.attributes["railSide"]
if rail_side == "leftRail":
left_count += 1
elif rail_side == "rightRail":
right_count += 1
else:
# NOTE: This is ignored because it is covered by validate_onthology
continue

# Store counts of current track
counts[object_uid] = (left_count, right_count)
unfiltered_annotations = list(frame.annotations.values())
# Ensure we work only on Poly2d annotations
poly2ds: list[raillabel.format.Poly2d] = _filter_for_poly2ds(unfiltered_annotations)

# Count left and right rails
for poly2d in poly2ds:
object_id = poly2d.object_id
if object_id not in counts:
counts[object_id] = [0, 0]

rail_side = poly2d.attributes["railSide"]
if rail_side == "leftRail":
counts[object_id][0] += 1
elif rail_side == "rightRail":
counts[object_id][1] += 1
else:
# NOTE: This is ignored because it is covered by validate_onthology
continue

# Return results
return counts
return {key: tuple(value) for key, value in counts.items()}


def _filter_for_poly2ds(
Expand Down Expand Up @@ -269,11 +268,8 @@ def _find_x_by_y(y: float, poly2d: raillabel.format.Poly2d) -> float | None:
def _get_track_from_frame(
frame: raillabel.format.Frame, object_uid: str, rail_side: str
) -> raillabel.format.Poly2d | None:
if object_uid not in frame.object_data:
return None

for annotation in frame.object_data[object_uid].values():
if not isinstance(annotation, raillabel.format.Poly2d):
for annotation in frame.annotations.values():
if annotation.object_id != object_uid:
continue

if "railSide" not in annotation.attributes:
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def default_frame(empty_annotation) -> raillabel.format.Frame:

@pytest.fixture
def empty_frame() -> raillabel.format.Frame:
return raillabel.format.Frame(uid=0, timestamp=None, sensors={}, frame_data={}, annotations={})
return raillabel.format.Frame(timestamp=None, sensors={}, frame_data={}, annotations={})


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion tests/test_raillabel_providerkit/_util/test_warning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
from raillabel._util._warning import _warning, _WarningsLogger
from raillabel_providerkit._util._warning import _warning, _WarningsLogger


def test_issue_warning():
Expand Down
9 changes: 1 addition & 8 deletions tests/test_raillabel_providerkit/validation/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,2 @@
# Copyright DB InfraGO AG and contributors
# SPDX-License-Identifier: Apache-2.0

from validate_onthology.test_validate_onthology import (
demo_onthology,
invalid_onthology_scene,
metadata,
valid_onthology_scene,
)
# SPDX-License-Identifier: Apache-2.0

This file was deleted.

Loading

0 comments on commit 8d112fb

Please sign in to comment.