Skip to content

Commit

Permalink
feat: Transform.from_json()
Browse files Browse the repository at this point in the history
  • Loading branch information
tklockau committed Nov 4, 2024
1 parent 3beca64 commit f40ac8c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 15 deletions.
15 changes: 15 additions & 0 deletions raillabel/format/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from dataclasses import dataclass

from raillabel.json_format import JSONTransformData

from .point3d import Point3d
from .quaternion import Quaternion

Expand All @@ -25,6 +27,19 @@ class Transform:
pos: Point3d
quat: Quaternion

@classmethod
def from_json(cls, json: JSONTransformData) -> Transform:
"""Construct an instant of this class from RailLabel JSON data."""
return Transform(
pos=Point3d(x=json.translation[0], y=json.translation[1], z=json.translation[2]),
quat=Quaternion(
x=json.quaternion[0],
y=json.quaternion[1],
z=json.quaternion[2],
w=json.quaternion[3],
),
)

@classmethod
def fromdict(cls, data_dict: dict) -> Transform:
"""Generate a Transform object from a dict.
Expand Down
12 changes: 6 additions & 6 deletions tests/test_raillabel/format/test_point3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@

@pytest.fixture
def point3d_dict() -> dict:
return [420, 3.14, 0]
return [419, 3.14, 0]


@pytest.fixture
def point3d() -> dict:
return Point3d(420, 3.14, 0)
return Point3d(419, 3.14, 0)


@pytest.fixture
Expand All @@ -40,21 +40,21 @@ def point3d_another() -> dict:


def test_fromdict():
point3d = Point3d.fromdict([420, 3.14, 0])
point3d = Point3d.fromdict([419, 3.14, 0])

assert point3d.x == 420
assert point3d.x == 419
assert point3d.y == 3.14
assert point3d.z == 0


def test_asdict():
point3d = Point3d(
x=420,
x=419,
y=3.14,
z=0,
)

assert point3d.asdict() == [420, 3.14, 0]
assert point3d.asdict() == [419, 3.14, 0]


if __name__ == "__main__":
Expand Down
23 changes: 14 additions & 9 deletions tests/test_raillabel/format/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,11 @@

from __future__ import annotations

import os
import sys
from pathlib import Path

import pytest

sys.path.insert(1, str(Path(__file__).parent.parent.parent.parent.parent))

from raillabel.format import Transform
from raillabel.json_format import JSONTransformData


# == Fixtures =========================

Expand All @@ -22,13 +18,23 @@ def transform_dict(point3d_dict, quaternion_dict) -> dict:


@pytest.fixture
def transform(point3d, quaternion) -> dict:
def transform_json(point3d_dict, quaternion_dict) -> JSONTransformData:
return JSONTransformData(translation=point3d_dict, quaternion=quaternion_dict)


@pytest.fixture
def transform(point3d, quaternion) -> Transform:
return Transform(pos=point3d, quat=quaternion)


# == Tests ============================


def test_from_json(transform_json, transform):
actual = Transform.from_json(transform_json)
assert actual == transform


def test_fromdict(point3d, point3d_dict, quaternion, quaternion_dict):
transform = Transform.fromdict({"translation": point3d_dict, "quaternion": quaternion_dict})

Expand All @@ -43,5 +49,4 @@ def test_asdict(point3d, point3d_dict, quaternion, quaternion_dict):


if __name__ == "__main__":
os.system("clear")
pytest.main([__file__, "--disable-pytest-warnings", "--cache-clear", "-v"])
pytest.main([__file__, "-v"])

0 comments on commit f40ac8c

Please sign in to comment.