Skip to content

Commit

Permalink
feat: move uai format over
Browse files Browse the repository at this point in the history
  • Loading branch information
tklockau committed Oct 4, 2023
1 parent 081c391 commit 57a195c
Show file tree
Hide file tree
Showing 38 changed files with 5,067 additions and 0 deletions.
17 changes: 17 additions & 0 deletions raillabel_providerkit/format/understand_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright DB Netz AG and contributors
# SPDX-License-Identifier: Apache-2.0
"""Module containing all relevant understand.ai format classes."""

from .bounding_box_2d import BoundingBox2d
from .bounding_box_3d import BoundingBox3d
from .coordinate_system import CoordinateSystem
from .frame import Frame
from .metadata import Metadata
from .point_3d import Point3d
from .polygon_2d import Polygon2d
from .polyline_2d import Polyline2d
from .quaternion import Quaternion
from .scene import Scene
from .segmentation_3d import Segmentation3d
from .sensor_reference import SensorReference
from .size_3d import Size3d
73 changes: 73 additions & 0 deletions raillabel_providerkit/format/understand_ai/_annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright DB Netz AG and contributors
# SPDX-License-Identifier: Apache-2.0

import typing as t
from abc import ABC, abstractmethod, abstractproperty
from dataclasses import dataclass
from uuid import UUID

from ..._util._attribute_type import AttributeType
from ._translation import translate_class_id, translate_sensor_id
from .sensor_reference import SensorReference


@dataclass
class _Annotation(ABC):

id: UUID
object_id: UUID
class_name: str
attributes: dict
sensor: SensorReference

@property
@abstractproperty
def OPENLABEL_ID(self) -> t.List[str]:
raise NotImplementedError

@classmethod
@abstractmethod
def fromdict(cls, data_dict: t.Dict) -> t.Type["_Annotation"]:
raise NotImplementedError

def to_raillabel(self) -> t.Tuple[dict, str, str, dict]:
"""Convert to a raillabel compatible dict.
Returns
-------
annotation: dict
Dictionary valid for the raillabel schema.
object_id: str
Friendly identifier of the object this annotation belongs to.
class_name: str
Friendly identifier of the class the annotated object belongs to.
sensor_reference: dict
Dictionary of the sensor reference.
"""

return (
{
"name": str(self.id),
"val": self._val_to_raillabel(),
"coordinate_system": translate_sensor_id(self.sensor.type),
"attributes": self._attributes_to_raillabel(),
},
str(self.object_id),
translate_class_id(self.class_name),
self.sensor.to_raillabel()[1],
)

def _attributes_to_raillabel(self) -> dict:

attributes = {}

for attr_name, attr_value in self.attributes.items():

attr_type = AttributeType.from_value(type(attr_value)).value

if attr_type not in attributes:
attributes[attr_type] = []

attributes[attr_type].append({"name": attr_name, "val": attr_value})

return attributes
93 changes: 93 additions & 0 deletions raillabel_providerkit/format/understand_ai/_translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright DB Netz AG and contributors
# SPDX-License-Identifier: Apache-2.0

import json
from pathlib import Path


def translate_sensor_id(original_sensor_id: str) -> str:
"""Translate deprecated sensor ids to the correct ones.
Parameters
----------
original_sensor_id : str
Original id of the sensor.
Returns
-------
str
Translated id or original_sensor_id, if no translation could be found.
"""
return TRANSLATION["streams"].get(original_sensor_id, original_sensor_id)


def translate_class_id(original_class_id: str) -> str:
"""Translate deprecated class ids to the correct ones.
Parameters
----------
original_class_id : str
Original id of the class.
Returns
-------
str
Translated id or original_class_id, if no translation could be found.
"""
return TRANSLATION["classes"].get(original_class_id, original_class_id)


def fetch_sensor_type(sensor_id: str) -> str:
"""Fetch sensor type from translation file.
Parameters
----------
sensor_id : str
Id of the sensor.
Returns
-------
str
Sensor type or 'other' if sensor_id not found in translation.json.
"""
return TRANSLATION["stream_types"].get(sensor_id, "other")


def fetch_sensor_resolutions(sensor_id: str) -> dict:
"""Fetch sensor resolution from translation file.
Parameters
----------
sensor_id : str
Id of the sensor.
Returns
-------
dict
Dictionary containing the resolution information. Key 'x' contains the width in pixels,
key 'y' contains the height in pixels. If the sensor is a radar, 'resolution_px_per_m' is
also included.
"""
return TRANSLATION["stream_resolutions"].get(
sensor_id, {"x": None, "y": None, "resolution_px_per_m": None}
)


def _load_translation():
"""Load the translation file when the module is imported.
This prevents it from beeing loaded for every annotation.
"""

global TRANSLATION

translatiion_path = (
Path(__file__).parent.parent.parent / "load_" / "loader_classes" / "translation.json"
)
with translatiion_path.open() as translation_file:
TRANSLATION = json.load(translation_file)


TRANSLATION = {}

_load_translation()
79 changes: 79 additions & 0 deletions raillabel_providerkit/format/understand_ai/bounding_box_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright DB Netz AG and contributors
# SPDX-License-Identifier: Apache-2.0

import typing as t
from dataclasses import dataclass
from uuid import UUID

from ._annotation import _Annotation
from .sensor_reference import SensorReference


@dataclass
class BoundingBox2d(_Annotation):
"""A 2d bounding box.
Parameters
----------
id: uuid.UUID
Unique identifier of the annotation.
object_id: uuid.UUID
Unique identifier of the object this annotation refers to. Used for tracking.
class_name: str
Name of the class this annotation belongs to.
attributes: dict[str, str or list]
Key value pairs of attributes with the keys beeing the friendly identifier of the
attribute and the value beeing the attribute value.
sensor: raillabel.format.understand_ai.SensorReference
Information about the sensor this annotation is labeled in.
x_min: float
Left corner of the bounding box in pixels.
y_min: float
Top corner of the bounding box in pixels.
x_max: float
Right corner of the bounding box in pixels.
y_max: float
Bottom corner of the bounding box in pixels.
"""

x_min: float
y_min: float
x_max: float
y_max: float

OPENLABEL_ID = "bbox"

@classmethod
def fromdict(cls, data_dict: t.Dict) -> "BoundingBox2d":
"""Generate a BoundingBox2d from a dictionary in the UAI format.
Parameters
----------
data_dict: dict
Understand.AI T4 format dictionary containing the data_dict.
Returns
-------
BoundingBox2d
Converted 2d bounding box.
"""

return BoundingBox2d(
id=UUID(data_dict["id"]),
object_id=UUID(data_dict["objectId"]),
class_name=data_dict["className"],
x_min=data_dict["geometry"]["xMin"],
y_min=data_dict["geometry"]["yMin"],
x_max=data_dict["geometry"]["xMax"],
y_max=data_dict["geometry"]["yMax"],
attributes=data_dict["attributes"],
sensor=SensorReference.fromdict(data_dict["sensor"]),
)

def _val_to_raillabel(self) -> list:
return [
(self.x_max + self.x_min) / 2,
(self.y_max + self.y_min) / 2,
abs(self.x_max - self.x_min),
abs(self.y_max - self.y_min),
]
84 changes: 84 additions & 0 deletions raillabel_providerkit/format/understand_ai/bounding_box_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright DB Netz AG and contributors
# SPDX-License-Identifier: Apache-2.0

import typing as t
from dataclasses import dataclass
from uuid import UUID

from ._annotation import _Annotation
from .point_3d import Point3d
from .quaternion import Quaternion
from .sensor_reference import SensorReference
from .size_3d import Size3d


@dataclass
class BoundingBox3d(_Annotation):
"""A 3d bounding box.
Parameters
----------
id: uuid.UUID
Unique identifier of the annotation.
object_id: uuid.UUID
Unique identifier of the object this annotation refers to. Used for tracking.
class_name: str
Name of the class this annotation belongs to.
attributes: dict[str, str or list]
Key value pairs of attributes with the keys beeing the friendly identifier of the
attribute and the value beeing the attribute value.
sensor: raillabel.format.understand_ai.SensorReference
Information about the sensor this annotation is labeled in.
center: raillabel.format.understand_ai.Point3d
Center position of the bounding box.
size: raillabel.format.understand_ai.Size3d
3d size of the bounding box.
quaternion: raillabel.format.understand_ai.Quaternion
Rotation quaternion of the bounding box.
"""

center: Point3d
size: Size3d
quaternion: Quaternion

OPENLABEL_ID = "cuboid"

@classmethod
def fromdict(cls, data_dict: t.Dict) -> "BoundingBox3d":
"""Generate a BoundingBox3d from a dictionary in the UAI format.
Parameters
----------
data_dict: dict
Understand.AI T4 format dictionary containing the data_dict.
Returns
-------
BoundingBox3d
Converted 3d bounding box.
"""

return BoundingBox3d(
id=UUID(data_dict["id"]),
object_id=UUID(data_dict["objectId"]),
class_name=data_dict["className"],
center=Point3d.fromdict(data_dict["geometry"]["center"]),
size=Size3d.fromdict(data_dict["geometry"]["size"]),
quaternion=Quaternion.fromdict(data_dict["geometry"]["quaternion"]),
attributes=data_dict["attributes"],
sensor=SensorReference.fromdict(data_dict["sensor"]),
)

def _val_to_raillabel(self) -> list:
return [
float(self.center.x),
float(self.center.y),
float(self.center.z),
float(self.quaternion.x),
float(self.quaternion.y),
float(self.quaternion.z),
float(self.quaternion.w),
float(self.size.width),
float(self.size.length),
float(self.size.height),
]
Loading

0 comments on commit 57a195c

Please sign in to comment.