diff --git a/raillabel/format/metadata.py b/raillabel/format/metadata.py index 24a23c7..b8ae018 100644 --- a/raillabel/format/metadata.py +++ b/raillabel/format/metadata.py @@ -55,3 +55,16 @@ def from_json(cls, json: JSONMetadata) -> Metadata: setattr(metadata, extra_field, extra_value) return metadata + + def to_json(self) -> JSONMetadata: + """Export this object into the RailLabel JSON format.""" + return JSONMetadata( + schema_version=self.schema_version, + name=self.name, + subschema_version=self.subschema_version, + exporter_version=self.exporter_version, + file_version=self.file_version, + tagged_file=self.tagged_file, + annotator=self.annotator, + comment=self.comment, + ) diff --git a/raillabel/format/scene.py b/raillabel/format/scene.py index c88f4a5..9220794 100644 --- a/raillabel/format/scene.py +++ b/raillabel/format/scene.py @@ -11,6 +11,7 @@ JSONFrame, JSONObject, JSONScene, + JSONSceneContent, JSONStreamCamera, JSONStreamOther, JSONStreamRadar, @@ -18,6 +19,7 @@ from .camera import Camera from .frame import Frame +from .frame_interval import FrameInterval from .gps_imu import GpsImu from .lidar import Lidar from .metadata import Metadata @@ -52,6 +54,27 @@ def from_json(cls, json: JSONScene) -> Scene: frames=_frames_from_json(json.openlabel.frames), ) + def to_json(self) -> JSONScene: + """Export this scene into the RailLabel JSON format.""" + return JSONScene( + openlabel=JSONSceneContent( + metadata=self.metadata.to_json(), + streams={ + sensor_id: sensor.to_json()[0] for sensor_id, sensor in self.sensors.items() + }, + coordinate_systems=_coordinate_systems_to_json(self.sensors), + objects={ + obj_id: obj.to_json(obj_id, self.frames) for obj_id, obj in self.objects.items() + }, + frames={ + frame_id: frame.to_json(self.objects) for frame_id, frame in self.frames.items() + }, + frame_intervals=[ + fi.to_json() for fi in FrameInterval.from_frame_ids(list(self.frames.keys())) + ], + ) + ) + def _sensors_from_json( json_streams: dict[str, JSONStreamCamera | JSONStreamOther | JSONStreamRadar] | None, @@ -95,3 +118,18 @@ def _frames_from_json(json_frames: dict[int, JSONFrame] | None) -> dict[int, Fra return {} return {frame_id: Frame.from_json(json_frame) for frame_id, json_frame in json_frames.items()} + + +def _coordinate_systems_to_json( + sensors: dict[str, Camera | Lidar | Radar | GpsImu | OtherSensor], +) -> dict[str, JSONCoordinateSystem]: + json_coordinate_systems = { + sensor_id: sensor.to_json()[1] for sensor_id, sensor in sensors.items() + } + json_coordinate_systems["base"] = JSONCoordinateSystem( + parent="", + type="local", + pose_wrt_parent=None, + children=list(json_coordinate_systems.keys()), + ) + return json_coordinate_systems diff --git a/tests/test_raillabel/format/test_metadata.py b/tests/test_raillabel/format/test_metadata.py index 6f67c0c..6ba5379 100644 --- a/tests/test_raillabel/format/test_metadata.py +++ b/tests/test_raillabel/format/test_metadata.py @@ -61,5 +61,10 @@ def test_from_json__extra_fields(): assert actual.ADDITIONAL_OBJECT == {"first_field": 2, "second_field": [1, 2, 3]} +def test_to_json(metadata, metadata_json): + actual = metadata.to_json() + assert actual == metadata_json + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/test_raillabel/format/test_object.py b/tests/test_raillabel/format/test_object.py index 98410b1..a7b69a1 100644 --- a/tests/test_raillabel/format/test_object.py +++ b/tests/test_raillabel/format/test_object.py @@ -83,6 +83,31 @@ def object_track_json() -> JSONObject: return JSONObject( name="track_0001", type="track", + frame_intervals=[JSONFrameInterval(frame_start=1, frame_end=1)], + object_data_pointers={ + "rgb_middle__poly2d__track": JSONElementDataPointer( + frame_intervals=[JSONFrameInterval(frame_start=1, frame_end=1)], + type="poly2d", + attribute_pointers={ + "has_red_hat": "boolean", + "has_green_hat": "boolean", + "number_of_red_clothing_items": "num", + "color_of_hat": "text", + "clothing_items": "vec", + }, + ), + "lidar__poly3d__track": JSONElementDataPointer( + frame_intervals=[JSONFrameInterval(frame_start=1, frame_end=1)], + type="poly3d", + attribute_pointers={ + "has_red_hat": "boolean", + "has_green_hat": "boolean", + "number_of_red_clothing_items": "num", + "color_of_hat": "text", + "clothing_items": "vec", + }, + ), + }, ) @@ -117,5 +142,10 @@ def test_to_json__person(object_person, object_person_json, object_person_id, fr assert actual == object_person_json +def test_to_json__track(object_track, object_track_json, object_track_id, frame): + actual = object_track.to_json(object_track_id, {1: frame}) + assert actual == object_track_json + + if __name__ == "__main__": pytest.main([__file__, "-vv"]) diff --git a/tests/test_raillabel/format/test_scene.py b/tests/test_raillabel/format/test_scene.py index c8a0f63..b7dbe24 100644 --- a/tests/test_raillabel/format/test_scene.py +++ b/tests/test_raillabel/format/test_scene.py @@ -6,7 +6,12 @@ import pytest from raillabel.format import Scene -from raillabel.json_format import JSONScene, JSONSceneContent, JSONCoordinateSystem +from raillabel.json_format import ( + JSONScene, + JSONSceneContent, + JSONCoordinateSystem, + JSONFrameInterval, +) # == Fixtures ========================= @@ -47,6 +52,7 @@ def scene_json( object_track_id: object_track_json, }, frames={1: frame_json}, + frame_intervals=[JSONFrameInterval(frame_start=1, frame_end=1)], ) ) @@ -86,5 +92,10 @@ def test_from_json(scene, scene_json): assert actual == scene +def test_to_json(scene, scene_json): + actual = scene.to_json() + assert actual == scene_json + + if __name__ == "__main__": pytest.main([__file__, "-v"])