Skip to content

Commit

Permalink
feat: implement Scene.to_json()
Browse files Browse the repository at this point in the history
  • Loading branch information
unexcellent committed Nov 14, 2024
1 parent 69e08a8 commit e404739
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 1 deletion.
13 changes: 13 additions & 0 deletions raillabel/format/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,16 @@ def from_json(cls, json: JSONMetadata) -> Metadata:
setattr(metadata, extra_field, extra_value)

return metadata

def to_json(self) -> JSONMetadata:
"""Export this object into the RailLabel JSON format."""
return JSONMetadata(
schema_version=self.schema_version,
name=self.name,
subschema_version=self.subschema_version,
exporter_version=self.exporter_version,
file_version=self.file_version,
tagged_file=self.tagged_file,
annotator=self.annotator,
comment=self.comment,
)
38 changes: 38 additions & 0 deletions raillabel/format/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
JSONFrame,
JSONObject,
JSONScene,
JSONSceneContent,
JSONStreamCamera,
JSONStreamOther,
JSONStreamRadar,
)

from .camera import Camera
from .frame import Frame
from .frame_interval import FrameInterval
from .gps_imu import GpsImu
from .lidar import Lidar
from .metadata import Metadata
Expand Down Expand Up @@ -52,6 +54,27 @@ def from_json(cls, json: JSONScene) -> Scene:
frames=_frames_from_json(json.openlabel.frames),
)

def to_json(self) -> JSONScene:
"""Export this scene into the RailLabel JSON format."""
return JSONScene(
openlabel=JSONSceneContent(
metadata=self.metadata.to_json(),
streams={
sensor_id: sensor.to_json()[0] for sensor_id, sensor in self.sensors.items()
},
coordinate_systems=_coordinate_systems_to_json(self.sensors),
objects={
obj_id: obj.to_json(obj_id, self.frames) for obj_id, obj in self.objects.items()
},
frames={
frame_id: frame.to_json(self.objects) for frame_id, frame in self.frames.items()
},
frame_intervals=[
fi.to_json() for fi in FrameInterval.from_frame_ids(list(self.frames.keys()))
],
)
)


def _sensors_from_json(
json_streams: dict[str, JSONStreamCamera | JSONStreamOther | JSONStreamRadar] | None,
Expand Down Expand Up @@ -95,3 +118,18 @@ def _frames_from_json(json_frames: dict[int, JSONFrame] | None) -> dict[int, Fra
return {}

return {frame_id: Frame.from_json(json_frame) for frame_id, json_frame in json_frames.items()}


def _coordinate_systems_to_json(
sensors: dict[str, Camera | Lidar | Radar | GpsImu | OtherSensor],
) -> dict[str, JSONCoordinateSystem]:
json_coordinate_systems = {
sensor_id: sensor.to_json()[1] for sensor_id, sensor in sensors.items()
}
json_coordinate_systems["base"] = JSONCoordinateSystem(
parent="",
type="local",
pose_wrt_parent=None,
children=list(json_coordinate_systems.keys()),
)
return json_coordinate_systems
5 changes: 5 additions & 0 deletions tests/test_raillabel/format/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,10 @@ def test_from_json__extra_fields():
assert actual.ADDITIONAL_OBJECT == {"first_field": 2, "second_field": [1, 2, 3]}


def test_to_json(metadata, metadata_json):
actual = metadata.to_json()
assert actual == metadata_json


if __name__ == "__main__":
pytest.main([__file__, "-v"])
30 changes: 30 additions & 0 deletions tests/test_raillabel/format/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,31 @@ def object_track_json() -> JSONObject:
return JSONObject(
name="track_0001",
type="track",
frame_intervals=[JSONFrameInterval(frame_start=1, frame_end=1)],
object_data_pointers={
"rgb_middle__poly2d__track": JSONElementDataPointer(
frame_intervals=[JSONFrameInterval(frame_start=1, frame_end=1)],
type="poly2d",
attribute_pointers={
"has_red_hat": "boolean",
"has_green_hat": "boolean",
"number_of_red_clothing_items": "num",
"color_of_hat": "text",
"clothing_items": "vec",
},
),
"lidar__poly3d__track": JSONElementDataPointer(
frame_intervals=[JSONFrameInterval(frame_start=1, frame_end=1)],
type="poly3d",
attribute_pointers={
"has_red_hat": "boolean",
"has_green_hat": "boolean",
"number_of_red_clothing_items": "num",
"color_of_hat": "text",
"clothing_items": "vec",
},
),
},
)


Expand Down Expand Up @@ -117,5 +142,10 @@ def test_to_json__person(object_person, object_person_json, object_person_id, fr
assert actual == object_person_json


def test_to_json__track(object_track, object_track_json, object_track_id, frame):
actual = object_track.to_json(object_track_id, {1: frame})
assert actual == object_track_json


if __name__ == "__main__":
pytest.main([__file__, "-vv"])
13 changes: 12 additions & 1 deletion tests/test_raillabel/format/test_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import pytest

from raillabel.format import Scene
from raillabel.json_format import JSONScene, JSONSceneContent, JSONCoordinateSystem
from raillabel.json_format import (
JSONScene,
JSONSceneContent,
JSONCoordinateSystem,
JSONFrameInterval,
)

# == Fixtures =========================

Expand Down Expand Up @@ -47,6 +52,7 @@ def scene_json(
object_track_id: object_track_json,
},
frames={1: frame_json},
frame_intervals=[JSONFrameInterval(frame_start=1, frame_end=1)],
)
)

Expand Down Expand Up @@ -86,5 +92,10 @@ def test_from_json(scene, scene_json):
assert actual == scene


def test_to_json(scene, scene_json):
actual = scene.to_json()
assert actual == scene_json


if __name__ == "__main__":
pytest.main([__file__, "-v"])

0 comments on commit e404739

Please sign in to comment.