diff --git a/pyproject.toml b/pyproject.toml index 52c0c7c..a405f5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ classifiers = [ dependencies = [ "jsonschema>=4.4.0", "fastjsonschema>=2.16.2", - "raillabel==4.0.0", + "raillabel==4.1.0", "pyyaml>=6.0.0", "numpy>=1.24.4", "pydantic<3.0.0", diff --git a/raillabel_providerkit/validation/__init__.py b/raillabel_providerkit/validation/__init__.py index 28410b5..255559b 100644 --- a/raillabel_providerkit/validation/__init__.py +++ b/raillabel_providerkit/validation/__init__.py @@ -5,6 +5,7 @@ from .issue import Issue, IssueIdentifiers, IssueType from .validate_empty_frames.validate_empty_frames import validate_empty_frames from .validate_onthology.validate_onthology import validate_onthology +from .validate_rail_side.validate_rail_side import validate_rail_side from .validate_schema import validate_schema __all__ = [ @@ -13,5 +14,6 @@ "IssueType", "validate_empty_frames", "validate_onthology", + "validate_rail_side", "validate_schema", ] diff --git a/raillabel_providerkit/validation/validate.py b/raillabel_providerkit/validation/validate.py index 9c2b48c..7488744 100644 --- a/raillabel_providerkit/validation/validate.py +++ b/raillabel_providerkit/validation/validate.py @@ -8,15 +8,22 @@ from raillabel_providerkit.validation import Issue -from . import validate_empty_frames, validate_schema +from . import validate_empty_frames, validate_rail_side, validate_schema -def validate(scene_dict: dict, validate_for_empty_frames: bool = True) -> list[Issue]: +def validate( + scene_dict: dict, + validate_for_empty_frames: bool = True, + validate_for_rail_side_order: bool = True, +) -> list[Issue]: """Validate a scene based on the Deutsche Bahn Requirements. Args: scene_dict: The scene as a dictionary directly from `json.load()` in the raillabel format. - validate_for_empty_frames (optional): If True, the scene is validated for empty frames. + validate_for_empty_frames (optional): If True, issues are returned if the scene contains + frames without annotations. Default is True. + validate_for_rail_side_order: If True, issues are returned if the scene contains track with + a mismatching rail side order. Default is True. Returns: List of all requirement errors in the scene. If an empty list is returned, then there are no @@ -32,4 +39,7 @@ def validate(scene_dict: dict, validate_for_empty_frames: bool = True) -> list[I if validate_for_empty_frames: errors.extend(validate_empty_frames(scene)) + if validate_for_rail_side_order: + errors.extend(validate_rail_side(scene)) + return errors diff --git a/tests/validation/test_validate.py b/tests/validation/test_validate.py index b321978..893a8a4 100644 --- a/tests/validation/test_validate.py +++ b/tests/validation/test_validate.py @@ -5,28 +5,69 @@ import pytest from raillabel.scene_builder import SceneBuilder +from raillabel.format import Poly2d, Point2d from raillabel_providerkit import validate -def test_no_errors_in_empty_scene(): +def test_no_issues_in_empty_scene(): scene_dict = {"openlabel": {"metadata": {"schema_version": "1.0.0"}}} actual = validate(scene_dict) assert len(actual) == 0 -def test_schema_errors(): +def test_schema_issues(): scene_dict = {"openlabel": {}} actual = validate(scene_dict) assert len(actual) == 1 -def test_empty_frame_errors(): +def test_empty_frame_issues(): scene_dict = json.loads(SceneBuilder.empty().add_frame().result.to_json().model_dump_json()) actual = validate(scene_dict) assert len(actual) == 1 +def test_rail_side_issues(ignore_uuid): + SENSOR_ID = "rgb_center" + scene = ( + SceneBuilder.empty() + .add_annotation( + annotation=Poly2d( + points=[ + Point2d(0, 0), + Point2d(0, 1), + ], + closed=False, + attributes={"railSide": "rightRail"}, + object_id=ignore_uuid, + sensor_id="IGNORE_THIS", + ), + object_name="track_0001", + sensor_id=SENSOR_ID, + ) + .add_annotation( + annotation=Poly2d( + points=[ + Point2d(1, 0), + Point2d(1, 1), + ], + closed=False, + attributes={"railSide": "leftRail"}, + object_id=ignore_uuid, + sensor_id="IGNORE_THIS", + ), + object_name="track_0001", + sensor_id=SENSOR_ID, + ) + .result + ) + scene_dict = json.loads(scene.to_json().model_dump_json()) + + actual = validate(scene_dict) + assert len(actual) == 1 + + if __name__ == "__main__": pytest.main([__file__, "--disable-pytest-warnings", "--cache-clear", "-v"])