Skip to content

Commit

Permalink
feat: implement Frame.to_json()
Browse files Browse the repository at this point in the history
  • Loading branch information
unexcellent committed Nov 14, 2024
1 parent d1bd3b9 commit 07428e1
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 17 deletions.
8 changes: 8 additions & 0 deletions raillabel/format/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,13 @@
from __future__ import annotations


def _empty_list_to_none(collection: list | None) -> list | None:
if collection is None:
return None
if len(collection) == 0:
return None
return collection


def _flatten_list(list_of_tuples: list[tuple]) -> list:
return [item for tup in list_of_tuples for item in tup]
87 changes: 85 additions & 2 deletions raillabel/format/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,19 @@
from decimal import Decimal
from uuid import UUID

from raillabel.json_format import JSONFrame, JSONFrameProperties, JSONObjectData

from raillabel.json_format import (
JSONAnnotations,
JSONFrame,
JSONFrameData,
JSONFrameProperties,
JSONObjectData,
)

from ._util import _empty_list_to_none
from .bbox import Bbox
from .cuboid import Cuboid
from .num import Num
from .object import Object
from .poly2d import Poly2d
from .poly3d import Poly3d
from .seg3d import Seg3d
Expand Down Expand Up @@ -45,6 +53,19 @@ def from_json(cls, json: JSONFrame) -> Frame:
annotations=_annotations_from_json(json.objects),
)

def to_json(self, objects: dict[UUID, Object]) -> JSONFrame:
"""Export this object into the RailLabel JSON format."""
return JSONFrame(
frame_properties=JSONFrameProperties(
timestamp=self.timestamp,
streams={
sensor_id: sensor_ref.to_json() for sensor_id, sensor_ref in self.sensors.items()
},
frame_data=JSONFrameData(num=[num.to_json() for num in self.frame_data.values()]),
),
objects=_objects_to_json(self.annotations, objects),
)


def _timestamp_from_dict(frame_properties: JSONFrameProperties | None) -> Decimal | None:
if frame_properties is None:
Expand Down Expand Up @@ -113,3 +134,65 @@ def _resolve_none_to_empty_list(optional_list: list | None) -> list:
if optional_list is None:
return []
return optional_list


def _objects_to_json(
annotations: dict[UUID, Bbox | Cuboid | Poly2d | Poly3d | Seg3d], objects: dict[UUID, Object]
) -> dict[str, JSONObjectData] | None:
if len(annotations) == 0:
return None

object_data = {}

for ann_id, annotation in annotations.items():
object_id = str(annotation.object_id)

if object_id not in object_data:
object_data[object_id] = JSONObjectData(
object_data=JSONAnnotations(
bbox=[],
cuboid=[],
poly2d=[],
poly3d=[],
vec=[],
)
)

json_annotation = annotation.to_json(ann_id, objects[UUID(object_id)].type)

if isinstance(annotation, Bbox):
object_data[object_id].object_data.bbox.append(json_annotation) # type: ignore

elif isinstance(annotation, Cuboid):
object_data[object_id].object_data.cuboid.append(json_annotation) # type: ignore

elif isinstance(annotation, Poly2d):
object_data[object_id].object_data.poly2d.append(json_annotation) # type: ignore

elif isinstance(annotation, Poly3d):
object_data[object_id].object_data.poly3d.append(json_annotation) # type: ignore

elif isinstance(annotation, Seg3d):
object_data[object_id].object_data.vec.append(json_annotation) # type: ignore

else:
raise TypeError

for object_id in object_data:
object_data[object_id].object_data.bbox = _empty_list_to_none(
object_data[object_id].object_data.bbox
)
object_data[object_id].object_data.cuboid = _empty_list_to_none(
object_data[object_id].object_data.cuboid
)
object_data[object_id].object_data.poly2d = _empty_list_to_none(
object_data[object_id].object_data.poly2d
)
object_data[object_id].object_data.poly3d = _empty_list_to_none(
object_data[object_id].object_data.poly3d
)
object_data[object_id].object_data.vec = _empty_list_to_none(
object_data[object_id].object_data.vec
)

return object_data
10 changes: 7 additions & 3 deletions raillabel/format/num.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ class Num:
val: float
"This is the value of the number object."

sensor_id: str | None
id: UUID | None = None
"The unique identifyer of the Num."

sensor_id: str | None = None
"A reference to the sensor, this value is represented in."

@classmethod
Expand All @@ -28,14 +31,15 @@ def from_json(cls, json: JSONNum) -> Num:
return Num(
name=json.name,
val=json.val,
id=json.uid,
sensor_id=json.coordinate_system,
)

def to_json(self, uid: UUID) -> JSONNum:
def to_json(self) -> JSONNum:
"""Export this object into the RailLabel JSON format."""
return JSONNum(
name=self.name,
val=self.val,
coordinate_system=self.sensor_id,
uid=uid,
uid=self.id,
)
1 change: 1 addition & 0 deletions tests/test_raillabel/format/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .test_metadata import metadata, metadata_json
from .test_num import num, num_json, num_id
from .test_object import (
objects,
object_person,
object_person_json,
object_person_id,
Expand Down
7 changes: 6 additions & 1 deletion tests/test_raillabel/format/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,10 @@ def test_from_json(frame, frame_json):
assert actual == frame


def test_to_json(frame, frame_json, objects):
actual = frame.to_json(objects)
assert actual == frame_json


if __name__ == "__main__":
pytest.main([__file__, "-v"])
pytest.main([__file__, "-vv"])
7 changes: 4 additions & 3 deletions tests/test_raillabel/format/test_num.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ def num_id() -> UUID:


@pytest.fixture
def num() -> Num:
def num(num_id) -> Num:
return Num(
sensor_id="gps_imu",
name="velocity",
val=49.21321,
id=num_id,
)


Expand All @@ -45,8 +46,8 @@ def test_from_json(num, num_json):
assert actual == num


def test_to_json(num, num_json, num_id):
actual = num.to_json(num_id)
def test_to_json(num, num_json):
actual = num.to_json()
assert actual == num_json


Expand Down
8 changes: 8 additions & 0 deletions tests/test_raillabel/format/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
# == Fixtures =========================


@pytest.fixture
def objects(object_person, object_person_id, object_track, object_track_id) -> dict[UUID, Object]:
return {
object_person_id: object_person,
object_track_id: object_track,
}


@pytest.fixture
def object_person_json() -> JSONObject:
return JSONObject(
Expand Down
8 changes: 4 additions & 4 deletions tests/test_raillabel/format/test_poly2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def poly2d_json(
) -> JSONPoly2d:
return JSONPoly2d(
uid="013e7b34-62E5-435c-9412-87318c50f6d8",
name="rgb_middle__poly2d__person",
name="rgb_middle__poly2d__track",
closed=True,
mode="MODE_POLY2D_ABSOLUTE",
val=point2d_json + another_point2d_json,
Expand Down Expand Up @@ -60,12 +60,12 @@ def test_from_json(poly2d, poly2d_json, object_track_id):


def test_name(poly2d):
actual = poly2d.name("person")
assert actual == "rgb_middle__poly2d__person"
actual = poly2d.name("track")
assert actual == "rgb_middle__poly2d__track"


def test_to_json(poly2d, poly2d_json, poly2d_id):
actual = poly2d.to_json(poly2d_id, object_type="person")
actual = poly2d.to_json(poly2d_id, object_type="track")
assert actual == poly2d_json


Expand Down
8 changes: 4 additions & 4 deletions tests/test_raillabel/format/test_poly3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def poly3d_json(
) -> JSONPoly3d:
return JSONPoly3d(
uid="0da87210-46F1-40e5-b661-20ea1c392f50",
name="lidar__poly3d__person",
name="lidar__poly3d__track",
closed=True,
val=point3d_json + another_point3d_json,
coordinate_system="lidar",
Expand Down Expand Up @@ -59,12 +59,12 @@ def test_from_json(poly3d, poly3d_json, object_track_id):


def test_name(poly3d):
actual = poly3d.name("person")
assert actual == "lidar__poly3d__person"
actual = poly3d.name("track")
assert actual == "lidar__poly3d__track"


def test_to_json(poly3d, poly3d_json, poly3d_id):
actual = poly3d.to_json(poly3d_id, object_type="person")
actual = poly3d.to_json(poly3d_id, object_type="track")
assert actual == poly3d_json


Expand Down

0 comments on commit 07428e1

Please sign in to comment.