Skip to content

Commit

Permalink
feat: implement Object.to_json()
Browse files Browse the repository at this point in the history
  • Loading branch information
unexcellent committed Nov 14, 2024
1 parent 43d91f9 commit 69e08a8
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 3 deletions.
78 changes: 77 additions & 1 deletion raillabel/format/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,16 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING
from uuid import UUID

from raillabel.json_format import JSONObject
from raillabel.json_format import JSONElementDataPointer, JSONFrameInterval, JSONObject

from ._attributes import _attributes_to_json
from .frame_interval import FrameInterval

if TYPE_CHECKING:
from .frame import Frame


@dataclass
Expand All @@ -26,3 +34,71 @@ def from_json(cls, json: JSONObject) -> Object:
name=json.name,
type=json.type,
)

def to_json(self, object_id: UUID, frames: dict[int, Frame]) -> JSONObject:
"""Export this object into the RailLabel JSON format."""
return JSONObject(
name=self.name,
type=self.type,
frame_intervals=_frame_intervals_to_json(object_id, frames),
object_data_pointers=_object_data_pointers_to_json(object_id, self.type, frames),
)


def _frame_intervals_to_json(object_id: UUID, frames: dict[int, Frame]) -> list[JSONFrameInterval]:
frames_with_this_object = set()

for frame_id, frame in frames.items():
for annotation in frame.annotations.values():
if annotation.object_id == object_id:
frames_with_this_object.add(frame_id)
continue

return [fi.to_json() for fi in FrameInterval.from_frame_ids(list(frames_with_this_object))]


def _object_data_pointers_to_json(
object_id: UUID, object_type: str, frames: dict[int, Frame]
) -> dict[str, JSONElementDataPointer]:
pointers_raw = {}

for frame_id, frame in frames.items():
for annotation in [ann for ann in frame.annotations.values() if ann.object_id == object_id]:
annotation_name = annotation.name(object_type)
if annotation_name not in pointers_raw:
pointers_raw[annotation_name] = {
"frame_intervals": set(),
"type": annotation_name.split("__")[1],
"attribute_pointers": {},
}

pointers_raw[annotation_name]["frame_intervals"].add(frame_id) # type: ignore
json_attributes = _attributes_to_json(annotation.attributes)

if json_attributes is None:
continue

for attribute in json_attributes.boolean: # type: ignore
pointers_raw[annotation_name]["attribute_pointers"][attribute.name] = "boolean" # type: ignore

for attribute in json_attributes.num: # type: ignore
pointers_raw[annotation_name]["attribute_pointers"][attribute.name] = "num" # type: ignore

for attribute in json_attributes.text: # type: ignore
pointers_raw[annotation_name]["attribute_pointers"][attribute.name] = "text" # type: ignore

for attribute in json_attributes.vec: # type: ignore
pointers_raw[annotation_name]["attribute_pointers"][attribute.name] = "vec" # type: ignore

object_data_pointers = {}
for annotation_name, object_data_pointer in pointers_raw.items():
object_data_pointers[annotation_name] = JSONElementDataPointer(
type=object_data_pointer["type"],
frame_intervals=[
fi.to_json()
for fi in FrameInterval.from_frame_ids(list(object_data_pointer["frame_intervals"])) # type: ignore
],
attribute_pointers=object_data_pointer["attribute_pointers"],
)

return object_data_pointers
45 changes: 43 additions & 2 deletions tests/test_raillabel/format/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest

from raillabel.json_format import JSONObject
from raillabel.json_format import JSONObject, JSONFrameInterval, JSONElementDataPointer
from raillabel.format import Object

# == Fixtures =========================
Expand All @@ -26,6 +26,42 @@ def object_person_json() -> JSONObject:
return JSONObject(
name="person_0032",
type="person",
frame_intervals=[JSONFrameInterval(frame_start=1, frame_end=1)],
object_data_pointers={
"rgb_middle__bbox__person": JSONElementDataPointer(
frame_intervals=[JSONFrameInterval(frame_start=1, frame_end=1)],
type="bbox",
attribute_pointers={
"has_red_hat": "boolean",
"has_green_hat": "boolean",
"number_of_red_clothing_items": "num",
"color_of_hat": "text",
"clothing_items": "vec",
},
),
"lidar__cuboid__person": JSONElementDataPointer(
frame_intervals=[JSONFrameInterval(frame_start=1, frame_end=1)],
type="cuboid",
attribute_pointers={
"has_red_hat": "boolean",
"has_green_hat": "boolean",
"number_of_red_clothing_items": "num",
"color_of_hat": "text",
"clothing_items": "vec",
},
),
"lidar__vec__person": JSONElementDataPointer(
frame_intervals=[JSONFrameInterval(frame_start=1, frame_end=1)],
type="vec",
attribute_pointers={
"has_red_hat": "boolean",
"has_green_hat": "boolean",
"number_of_red_clothing_items": "num",
"color_of_hat": "text",
"clothing_items": "vec",
},
),
},
)


Expand Down Expand Up @@ -76,5 +112,10 @@ def test_from_json__track(object_track, object_track_json):
assert actual == object_track


def test_to_json__person(object_person, object_person_json, object_person_id, frame):
actual = object_person.to_json(object_person_id, {1: frame})
assert actual == object_person_json


if __name__ == "__main__":
pytest.main([__file__, "-v"])
pytest.main([__file__, "-vv"])

0 comments on commit 69e08a8

Please sign in to comment.