diff --git a/raillabel/format/_util.py b/raillabel/format/_util.py index dfb567e..02f5e57 100644 --- a/raillabel/format/_util.py +++ b/raillabel/format/_util.py @@ -4,5 +4,13 @@ from __future__ import annotations +def _empty_list_to_none(collection: list | None) -> list | None: + if collection is None: + return None + if len(collection) == 0: + return None + return collection + + def _flatten_list(list_of_tuples: list[tuple]) -> list: return [item for tup in list_of_tuples for item in tup] diff --git a/raillabel/format/frame.py b/raillabel/format/frame.py index 686715e..91fb7c3 100644 --- a/raillabel/format/frame.py +++ b/raillabel/format/frame.py @@ -7,11 +7,19 @@ from decimal import Decimal from uuid import UUID -from raillabel.json_format import JSONFrame, JSONFrameProperties, JSONObjectData - +from raillabel.json_format import ( + JSONAnnotations, + JSONFrame, + JSONFrameData, + JSONFrameProperties, + JSONObjectData, +) + +from ._util import _empty_list_to_none from .bbox import Bbox from .cuboid import Cuboid from .num import Num +from .object import Object from .poly2d import Poly2d from .poly3d import Poly3d from .seg3d import Seg3d @@ -45,6 +53,19 @@ def from_json(cls, json: JSONFrame) -> Frame: annotations=_annotations_from_json(json.objects), ) + def to_json(self, objects: dict[UUID, Object]) -> JSONFrame: + """Export this object into the RailLabel JSON format.""" + return JSONFrame( + frame_properties=JSONFrameProperties( + timestamp=self.timestamp, + streams={ + sensor_id: sensor_ref.to_json() for sensor_id, sensor_ref in self.sensors.items() + }, + frame_data=JSONFrameData(num=[num.to_json() for num in self.frame_data.values()]), + ), + objects=_objects_to_json(self.annotations, objects), + ) + def _timestamp_from_dict(frame_properties: JSONFrameProperties | None) -> Decimal | None: if frame_properties is None: @@ -113,3 +134,65 @@ def _resolve_none_to_empty_list(optional_list: list | None) -> list: if optional_list is None: return [] return optional_list + + +def _objects_to_json( + annotations: dict[UUID, Bbox | Cuboid | Poly2d | Poly3d | Seg3d], objects: dict[UUID, Object] +) -> dict[str, JSONObjectData] | None: + if len(annotations) == 0: + return None + + object_data = {} + + for ann_id, annotation in annotations.items(): + object_id = str(annotation.object_id) + + if object_id not in object_data: + object_data[object_id] = JSONObjectData( + object_data=JSONAnnotations( + bbox=[], + cuboid=[], + poly2d=[], + poly3d=[], + vec=[], + ) + ) + + json_annotation = annotation.to_json(ann_id, objects[UUID(object_id)].type) + + if isinstance(annotation, Bbox): + object_data[object_id].object_data.bbox.append(json_annotation) # type: ignore + + elif isinstance(annotation, Cuboid): + object_data[object_id].object_data.cuboid.append(json_annotation) # type: ignore + + elif isinstance(annotation, Poly2d): + object_data[object_id].object_data.poly2d.append(json_annotation) # type: ignore + + elif isinstance(annotation, Poly3d): + object_data[object_id].object_data.poly3d.append(json_annotation) # type: ignore + + elif isinstance(annotation, Seg3d): + object_data[object_id].object_data.vec.append(json_annotation) # type: ignore + + else: + raise TypeError + + for object_id in object_data: + object_data[object_id].object_data.bbox = _empty_list_to_none( + object_data[object_id].object_data.bbox + ) + object_data[object_id].object_data.cuboid = _empty_list_to_none( + object_data[object_id].object_data.cuboid + ) + object_data[object_id].object_data.poly2d = _empty_list_to_none( + object_data[object_id].object_data.poly2d + ) + object_data[object_id].object_data.poly3d = _empty_list_to_none( + object_data[object_id].object_data.poly3d + ) + object_data[object_id].object_data.vec = _empty_list_to_none( + object_data[object_id].object_data.vec + ) + + return object_data diff --git a/raillabel/format/num.py b/raillabel/format/num.py index acd2d62..e77d1c0 100644 --- a/raillabel/format/num.py +++ b/raillabel/format/num.py @@ -19,7 +19,10 @@ class Num: val: float "This is the value of the number object." - sensor_id: str | None + id: UUID | None = None + "The unique identifyer of the Num." + + sensor_id: str | None = None "A reference to the sensor, this value is represented in." @classmethod @@ -28,14 +31,15 @@ def from_json(cls, json: JSONNum) -> Num: return Num( name=json.name, val=json.val, + id=json.uid, sensor_id=json.coordinate_system, ) - def to_json(self, uid: UUID) -> JSONNum: + def to_json(self) -> JSONNum: """Export this object into the RailLabel JSON format.""" return JSONNum( name=self.name, val=self.val, coordinate_system=self.sensor_id, - uid=uid, + uid=self.id, ) diff --git a/tests/test_raillabel/format/conftest.py b/tests/test_raillabel/format/conftest.py index 8ba0856..4f75bde 100644 --- a/tests/test_raillabel/format/conftest.py +++ b/tests/test_raillabel/format/conftest.py @@ -12,6 +12,7 @@ from .test_metadata import metadata, metadata_json from .test_num import num, num_json, num_id from .test_object import ( + objects, object_person, object_person_json, object_person_id, diff --git a/tests/test_raillabel/format/test_frame.py b/tests/test_raillabel/format/test_frame.py index 07cd720..1b85ba6 100644 --- a/tests/test_raillabel/format/test_frame.py +++ b/tests/test_raillabel/format/test_frame.py @@ -98,5 +98,10 @@ def test_from_json(frame, frame_json): assert actual == frame +def test_to_json(frame, frame_json, objects): + actual = frame.to_json(objects) + assert actual == frame_json + + if __name__ == "__main__": - pytest.main([__file__, "-v"]) + pytest.main([__file__, "-vv"]) diff --git a/tests/test_raillabel/format/test_num.py b/tests/test_raillabel/format/test_num.py index 092c700..2779c86 100644 --- a/tests/test_raillabel/format/test_num.py +++ b/tests/test_raillabel/format/test_num.py @@ -29,11 +29,12 @@ def num_id() -> UUID: @pytest.fixture -def num() -> Num: +def num(num_id) -> Num: return Num( sensor_id="gps_imu", name="velocity", val=49.21321, + id=num_id, ) @@ -45,8 +46,8 @@ def test_from_json(num, num_json): assert actual == num -def test_to_json(num, num_json, num_id): - actual = num.to_json(num_id) +def test_to_json(num, num_json): + actual = num.to_json() assert actual == num_json diff --git a/tests/test_raillabel/format/test_object.py b/tests/test_raillabel/format/test_object.py index ae6b0da..9931826 100644 --- a/tests/test_raillabel/format/test_object.py +++ b/tests/test_raillabel/format/test_object.py @@ -13,6 +13,14 @@ # == Fixtures ========================= +@pytest.fixture +def objects(object_person, object_person_id, object_track, object_track_id) -> dict[UUID, Object]: + return { + object_person_id: object_person, + object_track_id: object_track, + } + + @pytest.fixture def object_person_json() -> JSONObject: return JSONObject( diff --git a/tests/test_raillabel/format/test_poly2d.py b/tests/test_raillabel/format/test_poly2d.py index 11da96a..c193197 100644 --- a/tests/test_raillabel/format/test_poly2d.py +++ b/tests/test_raillabel/format/test_poly2d.py @@ -21,7 +21,7 @@ def poly2d_json( ) -> JSONPoly2d: return JSONPoly2d( uid="013e7b34-62E5-435c-9412-87318c50f6d8", - name="rgb_middle__poly2d__person", + name="rgb_middle__poly2d__track", closed=True, mode="MODE_POLY2D_ABSOLUTE", val=point2d_json + another_point2d_json, @@ -60,12 +60,12 @@ def test_from_json(poly2d, poly2d_json, object_track_id): def test_name(poly2d): - actual = poly2d.name("person") - assert actual == "rgb_middle__poly2d__person" + actual = poly2d.name("track") + assert actual == "rgb_middle__poly2d__track" def test_to_json(poly2d, poly2d_json, poly2d_id): - actual = poly2d.to_json(poly2d_id, object_type="person") + actual = poly2d.to_json(poly2d_id, object_type="track") assert actual == poly2d_json diff --git a/tests/test_raillabel/format/test_poly3d.py b/tests/test_raillabel/format/test_poly3d.py index 66288f5..e62d77e 100644 --- a/tests/test_raillabel/format/test_poly3d.py +++ b/tests/test_raillabel/format/test_poly3d.py @@ -21,7 +21,7 @@ def poly3d_json( ) -> JSONPoly3d: return JSONPoly3d( uid="0da87210-46F1-40e5-b661-20ea1c392f50", - name="lidar__poly3d__person", + name="lidar__poly3d__track", closed=True, val=point3d_json + another_point3d_json, coordinate_system="lidar", @@ -59,12 +59,12 @@ def test_from_json(poly3d, poly3d_json, object_track_id): def test_name(poly3d): - actual = poly3d.name("person") - assert actual == "lidar__poly3d__person" + actual = poly3d.name("track") + assert actual == "lidar__poly3d__track" def test_to_json(poly3d, poly3d_json, poly3d_id): - actual = poly3d.to_json(poly3d_id, object_type="person") + actual = poly3d.to_json(poly3d_id, object_type="track") assert actual == poly3d_json