diff --git a/raillabel/format/quaternion.py b/raillabel/format/quaternion.py index ecff8d0..366747a 100644 --- a/raillabel/format/quaternion.py +++ b/raillabel/format/quaternion.py @@ -8,25 +8,24 @@ @dataclass class Quaternion: - """A quaternion. - - Parameters - ---------- - x: float or int - The x component of the quaternion. - y: float or int - The y component of the quaternion. - z: float or int - The z component of the quaternion. - w: float or int - The w component of the quaternion. - - """ + """A rotation represented by a quaternion.""" x: float + "The x component of the quaternion." + y: float + "The y component of the quaternion." + z: float + "The z component of the quaternion." + w: float + "The omega component of the quaternion." + + @classmethod + def from_json(cls, json: tuple[float, float, float, float]) -> Quaternion: + """Construct an instant of this class from RailLabel JSON data.""" + return Quaternion(x=json[0], y=json[1], z=json[2], w=json[3]) @classmethod def fromdict(cls, data_dict: dict) -> Quaternion: diff --git a/raillabel/format/transform.py b/raillabel/format/transform.py index 22c2cc3..c8cd55a 100644 --- a/raillabel/format/transform.py +++ b/raillabel/format/transform.py @@ -26,12 +26,7 @@ def from_json(cls, json: JSONTransformData) -> Transform: """Construct an instant of this class from RailLabel JSON data.""" return Transform( position=Point3d.from_json(json.translation), - quaternion=Quaternion( - x=json.quaternion[0], - y=json.quaternion[1], - z=json.quaternion[2], - w=json.quaternion[3], - ), + quaternion=Quaternion.from_json(json.quaternion), ) @classmethod diff --git a/tests/test_raillabel/format/test_quaternion.py b/tests/test_raillabel/format/test_quaternion.py index 8291881..d508e85 100644 --- a/tests/test_raillabel/format/test_quaternion.py +++ b/tests/test_raillabel/format/test_quaternion.py @@ -29,6 +29,11 @@ def quaternion() -> dict: # == Tests ============================ +def test_from_json(quaternion, quaternion_dict): + actual = Quaternion.from_json(quaternion_dict) + assert actual == quaternion + + def test_fromdict(): quaternion = Quaternion.fromdict([0.75318325, -0.10270147, 0.21430262, -0.61338551])