-# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
-# All rights reserved.
-#
-# SPDX-License-Identifier: BSD-3-Clause
-
-"""Sub-module containing utilities for various math operations."""
-
-# needed to import for allowing type-hinting: torch.Tensor | np.ndarray
-from __future__ import annotations
-
-import numpy as np
-import torch
-import torch.nn.functional
-from typing import Literal
-
-"""
-General
-"""
-
-
-@torch.jit.script
-def scale_transform(x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:
- """Normalizes a given input tensor to a range of [-1, 1].
-
- .. note::
- It uses pytorch broadcasting functionality to deal with batched input.
-
- Args:
- x: Input tensor of shape (N, dims).
- lower: The minimum value of the tensor. Shape is (N, dims) or (dims,).
- upper: The maximum value of the tensor. Shape is (N, dims) or (dims,).
-
- Returns:
- Normalized transform of the tensor. Shape is (N, dims).
- """
- # default value of center
- offset = (lower + upper) * 0.5
- # return normalized tensor
- return 2 * (x - offset) / (upper - lower)
-
-
-@torch.jit.script
-def unscale_transform(x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:
- """De-normalizes a given input tensor from range of [-1, 1] to (lower, upper).
-
- .. note::
- It uses pytorch broadcasting functionality to deal with batched input.
-
- Args:
- x: Input tensor of shape (N, dims).
- lower: The minimum value of the tensor. Shape is (N, dims) or (dims,).
- upper: The maximum value of the tensor. Shape is (N, dims) or (dims,).
-
- Returns:
- De-normalized transform of the tensor. Shape is (N, dims).
- """
- # default value of center
- offset = (lower + upper) * 0.5
- # return normalized tensor
- return x * (upper - lower) * 0.5 + offset
-
-
-@torch.jit.script
-def saturate(x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:
- """Clamps a given input tensor to (lower, upper).
-
- It uses pytorch broadcasting functionality to deal with batched input.
-
- Args:
- x: Input tensor of shape (N, dims).
- lower: The minimum value of the tensor. Shape is (N, dims) or (dims,).
- upper: The maximum value of the tensor. Shape is (N, dims) or (dims,).
-
- Returns:
- Clamped transform of the tensor. Shape is (N, dims).
- """
- return torch.max(torch.min(x, upper), lower)
-
-
-@torch.jit.script
-def normalize(x: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
- """Normalizes a given input tensor to unit length.
-
- Args:
- x: Input tensor of shape (N, dims).
- eps: A small value to avoid division by zero. Defaults to 1e-9.
-
- Returns:
- Normalized tensor of shape (N, dims).
- """
- return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1)
-
-
-@torch.jit.script
-def wrap_to_pi(angles: torch.Tensor) -> torch.Tensor:
- """Wraps input angles (in radians) to the range [-pi, pi].
-
- Args:
- angles: Input angles of any shape.
-
- Returns:
- Angles in the range [-pi, pi].
- """
- angles = angles.clone()
- angles %= 2 * torch.pi
- angles -= 2 * torch.pi * (angles > torch.pi)
- return angles
-
-
-@torch.jit.script
-def copysign(mag: float, other: torch.Tensor) -> torch.Tensor:
- """Create a new floating-point tensor with the magnitude of input and the sign of other, element-wise.
-
- Note:
- The implementation follows from `torch.copysign`. The function allows a scalar magnitude.
-
- Args:
- mag: The magnitude scalar.
- other: The tensor containing values whose signbits are applied to magnitude.
-
- Returns:
- The output tensor.
- """
- mag = torch.tensor(mag, device=other.device, dtype=torch.float).repeat(other.shape[0])
- return torch.abs(mag) * torch.sign(other)
-
-
-"""
-Rotation
-"""
-
-
-@torch.jit.script
-def matrix_from_quat(quaternions: torch.Tensor) -> torch.Tensor:
- """Convert rotations given as quaternions to rotation matrices.
-
- Args:
- quaternions: The quaternion orientation in (w, x, y, z). Shape is (..., 4).
-
- Returns:
- Rotation matrices. The shape is (..., 3, 3).
-
- Reference:
- https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L41-L70
- """
- r, i, j, k = torch.unbind(quaternions, -1)
- # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
- two_s = 2.0 / (quaternions * quaternions).sum(-1)
-
- o = torch.stack(
- (
- 1 - two_s * (j * j + k * k),
- two_s * (i * j - k * r),
- two_s * (i * k + j * r),
- two_s * (i * j + k * r),
- 1 - two_s * (i * i + k * k),
- two_s * (j * k - i * r),
- two_s * (i * k - j * r),
- two_s * (j * k + i * r),
- 1 - two_s * (i * i + j * j),
- ),
- -1,
- )
- return o.reshape(quaternions.shape[:-1] + (3, 3))
-
-
-[文档]def convert_quat(quat: torch.Tensor | np.ndarray, to: Literal["xyzw", "wxyz"] = "xyzw") -> torch.Tensor | np.ndarray:
-
"""Converts quaternion from one convention to another.
-
-
The convention to convert TO is specified as an optional argument. If to == 'xyzw',
-
then the input is in 'wxyz' format, and vice-versa.
-
-
Args:
-
quat: The quaternion of shape (..., 4).
-
to: Convention to convert quaternion to.. Defaults to "xyzw".
-
-
Returns:
-
The converted quaternion in specified convention.
-
-
Raises:
-
ValueError: Invalid input argument `to`, i.e. not "xyzw" or "wxyz".
-
ValueError: Invalid shape of input `quat`, i.e. not (..., 4,).
-
"""
-
# check input is correct
-
if quat.shape[-1] != 4:
-
msg = f"Expected input quaternion shape mismatch: {quat.shape} != (..., 4)."
-
raise ValueError(msg)
-
if to not in ["xyzw", "wxyz"]:
-
msg = f"Expected input argument `to` to be 'xyzw' or 'wxyz'. Received: {to}."
-
raise ValueError(msg)
-
# check if input is numpy array (we support this backend since some classes use numpy)
-
if isinstance(quat, np.ndarray):
-
# use numpy functions
-
if to == "xyzw":
-
# wxyz -> xyzw
-
return np.roll(quat, -1, axis=-1)
-
else:
-
# xyzw -> wxyz
-
return np.roll(quat, 1, axis=-1)
-
else:
-
# convert to torch (sanity check)
-
if not isinstance(quat, torch.Tensor):
-
quat = torch.tensor(quat, dtype=float)
-
# convert to specified quaternion type
-
if to == "xyzw":
-
# wxyz -> xyzw
-
return quat.roll(-1, dims=-1)
-
else:
-
# xyzw -> wxyz
-
return quat.roll(1, dims=-1)
-
-
-@torch.jit.script
-def quat_conjugate(q: torch.Tensor) -> torch.Tensor:
- """Computes the conjugate of a quaternion.
-
- Args:
- q: The quaternion orientation in (w, x, y, z). Shape is (..., 4).
-
- Returns:
- The conjugate quaternion in (w, x, y, z). Shape is (..., 4).
- """
- shape = q.shape
- q = q.reshape(-1, 4)
- return torch.cat((q[:, 0:1], -q[:, 1:]), dim=-1).view(shape)
-
-
-@torch.jit.script
-def quat_inv(q: torch.Tensor) -> torch.Tensor:
- """Compute the inverse of a quaternion.
-
- Args:
- q: The quaternion orientation in (w, x, y, z). Shape is (N, 4).
-
- Returns:
- The inverse quaternion in (w, x, y, z). Shape is (N, 4).
- """
- return normalize(quat_conjugate(q))
-
-
-@torch.jit.script
-def quat_from_euler_xyz(roll: torch.Tensor, pitch: torch.Tensor, yaw: torch.Tensor) -> torch.Tensor:
- """Convert rotations given as Euler angles in radians to Quaternions.
-
- Note:
- The euler angles are assumed in XYZ convention.
-
- Args:
- roll: Rotation around x-axis (in radians). Shape is (N,).
- pitch: Rotation around y-axis (in radians). Shape is (N,).
- yaw: Rotation around z-axis (in radians). Shape is (N,).
-
- Returns:
- The quaternion in (w, x, y, z). Shape is (N, 4).
- """
- cy = torch.cos(yaw * 0.5)
- sy = torch.sin(yaw * 0.5)
- cr = torch.cos(roll * 0.5)
- sr = torch.sin(roll * 0.5)
- cp = torch.cos(pitch * 0.5)
- sp = torch.sin(pitch * 0.5)
- # compute quaternion
- qw = cy * cr * cp + sy * sr * sp
- qx = cy * sr * cp - sy * cr * sp
- qy = cy * cr * sp + sy * sr * cp
- qz = sy * cr * cp - cy * sr * sp
-
- return torch.stack([qw, qx, qy, qz], dim=-1)
-
-
-@torch.jit.script
-def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
- """Returns torch.sqrt(torch.max(0, x)) but with a zero sub-gradient where x is 0.
-
- Reference:
- https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L91-L99
- """
- ret = torch.zeros_like(x)
- positive_mask = x > 0
- ret[positive_mask] = torch.sqrt(x[positive_mask])
- return ret
-
-
-@torch.jit.script
-def quat_from_matrix(matrix: torch.Tensor) -> torch.Tensor:
- """Convert rotations given as rotation matrices to quaternions.
-
- Args:
- matrix: The rotation matrices. Shape is (..., 3, 3).
-
- Returns:
- The quaternion in (w, x, y, z). Shape is (..., 4).
-
- Reference:
- https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L102-L161
- """
- if matrix.size(-1) != 3 or matrix.size(-2) != 3:
- raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
-
- batch_dim = matrix.shape[:-2]
- m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
-
- q_abs = _sqrt_positive_part(
- torch.stack(
- [
- 1.0 + m00 + m11 + m22,
- 1.0 + m00 - m11 - m22,
- 1.0 - m00 + m11 - m22,
- 1.0 - m00 - m11 + m22,
- ],
- dim=-1,
- )
- )
-
- # we produce the desired quaternion multiplied by each of r, i, j, k
- quat_by_rijk = torch.stack(
- [
- # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
- torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
- # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
- torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
- # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
- torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
- # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
- torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
- ],
- dim=-2,
- )
-
- # We floor here at 0.1 but the exact level is not important; if q_abs is small,
- # the candidate won't be picked.
- flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
- quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
-
- # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
- # forall i; we pick the best-conditioned one (with the largest denominator)
- return quat_candidates[torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(
- batch_dim + (4,)
- )
-
-
-def _axis_angle_rotation(axis: Literal["X", "Y", "Z"], angle: torch.Tensor) -> torch.Tensor:
- """Return the rotation matrices for one of the rotations about an axis of which Euler angles describe,
- for each value of the angle given.
-
- Args:
- axis: Axis label "X" or "Y or "Z".
- angle: Euler angles in radians of any shape.
-
- Returns:
- Rotation matrices. Shape is (..., 3, 3).
-
- Reference:
- https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L164-L191
- """
- cos = torch.cos(angle)
- sin = torch.sin(angle)
- one = torch.ones_like(angle)
- zero = torch.zeros_like(angle)
-
- if axis == "X":
- R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
- elif axis == "Y":
- R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
- elif axis == "Z":
- R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
- else:
- raise ValueError("letter must be either X, Y or Z.")
-
- return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
-
-
-[文档]def matrix_from_euler(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
-
"""
-
Convert rotations given as Euler angles in radians to rotation matrices.
-
-
Args:
-
euler_angles: Euler angles in radians. Shape is (..., 3).
-
convention: Convention string of three uppercase letters from {"X", "Y", and "Z"}.
-
For example, "XYZ" means that the rotations should be applied first about x,
-
then y, then z.
-
-
Returns:
-
Rotation matrices. Shape is (..., 3, 3).
-
-
Reference:
-
https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L194-L220
-
"""
-
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
-
raise ValueError("Invalid input euler angles.")
-
if len(convention) != 3:
-
raise ValueError("Convention must have 3 letters.")
-
if convention[1] in (convention[0], convention[2]):
-
raise ValueError(f"Invalid convention {convention}.")
-
for letter in convention:
-
if letter not in ("X", "Y", "Z"):
-
raise ValueError(f"Invalid letter {letter} in convention string.")
-
matrices = [_axis_angle_rotation(c, e) for c, e in zip(convention, torch.unbind(euler_angles, -1))]
-
# return functools.reduce(torch.matmul, matrices)
-
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
-
-
-@torch.jit.script
-def euler_xyz_from_quat(quat: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Convert rotations given as quaternions to Euler angles in radians.
-
- Note:
- The euler angles are assumed in XYZ convention.
-
- Args:
- quat: The quaternion orientation in (w, x, y, z). Shape is (N, 4).
-
- Returns:
- A tuple containing roll-pitch-yaw. Each element is a tensor of shape (N,).
-
- Reference:
- https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles
- """
- q_w, q_x, q_y, q_z = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3]
- # roll (x-axis rotation)
- sin_roll = 2.0 * (q_w * q_x + q_y * q_z)
- cos_roll = 1 - 2 * (q_x * q_x + q_y * q_y)
- roll = torch.atan2(sin_roll, cos_roll)
-
- # pitch (y-axis rotation)
- sin_pitch = 2.0 * (q_w * q_y - q_z * q_x)
- pitch = torch.where(torch.abs(sin_pitch) >= 1, copysign(torch.pi / 2.0, sin_pitch), torch.asin(sin_pitch))
-
- # yaw (z-axis rotation)
- sin_yaw = 2.0 * (q_w * q_z + q_x * q_y)
- cos_yaw = 1 - 2 * (q_y * q_y + q_z * q_z)
- yaw = torch.atan2(sin_yaw, cos_yaw)
-
- return roll % (2 * torch.pi), pitch % (2 * torch.pi), yaw % (2 * torch.pi) # TODO: why not wrap_to_pi here ?
-
-
-@torch.jit.script
-def quat_unique(q: torch.Tensor) -> torch.Tensor:
- """Convert a unit quaternion to a standard form where the real part is non-negative.
-
- Quaternion representations have a singularity since ``q`` and ``-q`` represent the same
- rotation. This function ensures the real part of the quaternion is non-negative.
-
- Args:
- q: The quaternion orientation in (w, x, y, z). Shape is (..., 4).
-
- Returns:
- Standardized quaternions. Shape is (..., 4).
- """
- return torch.where(q[..., 0:1] < 0, -q, q)
-
-
-@torch.jit.script
-def quat_mul(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
- """Multiply two quaternions together.
-
- Args:
- q1: The first quaternion in (w, x, y, z). Shape is (..., 4).
- q2: The second quaternion in (w, x, y, z). Shape is (..., 4).
-
- Returns:
- The product of the two quaternions in (w, x, y, z). Shape is (..., 4).
-
- Raises:
- ValueError: Input shapes of ``q1`` and ``q2`` are not matching.
- """
- # check input is correct
- if q1.shape != q2.shape:
- msg = f"Expected input quaternion shape mismatch: {q1.shape} != {q2.shape}."
- raise ValueError(msg)
- # reshape to (N, 4) for multiplication
- shape = q1.shape
- q1 = q1.reshape(-1, 4)
- q2 = q2.reshape(-1, 4)
- # extract components from quaternions
- w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3]
- w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3]
- # perform multiplication
- ww = (z1 + x1) * (x2 + y2)
- yy = (w1 - y1) * (w2 + z2)
- zz = (w1 + y1) * (w2 - z2)
- xx = ww + yy + zz
- qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
- w = qq - ww + (z1 - y1) * (y2 - z2)
- x = qq - xx + (x1 + w1) * (x2 + w2)
- y = qq - yy + (w1 - x1) * (y2 + z2)
- z = qq - zz + (z1 + y1) * (w2 - x2)
-
- return torch.stack([w, x, y, z], dim=-1).view(shape)
-
-
-@torch.jit.script
-def quat_box_minus(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
- """The box-minus operator (quaternion difference) between two quaternions.
-
- Args:
- q1: The first quaternion in (w, x, y, z). Shape is (N, 4).
- q2: The second quaternion in (w, x, y, z). Shape is (N, 4).
-
- Returns:
- The difference between the two quaternions. Shape is (N, 3).
- """
- quat_diff = quat_mul(q1, quat_conjugate(q2)) # q1 * q2^-1
- re = quat_diff[:, 0] # real part, q = [w, x, y, z] = [re, im]
- im = quat_diff[:, 1:] # imaginary part
- norm_im = torch.norm(im, dim=1)
- scale = 2.0 * torch.where(norm_im > 1.0e-7, torch.atan2(norm_im, re) / norm_im, torch.sign(re))
- return scale.unsqueeze(-1) * im
-
-
-@torch.jit.script
-def yaw_quat(quat: torch.Tensor) -> torch.Tensor:
- """Extract the yaw component of a quaternion.
-
- Args:
- quat: The orientation in (w, x, y, z). Shape is (..., 4)
-
- Returns:
- A quaternion with only yaw component.
- """
- shape = quat.shape
- quat_yaw = quat.clone().view(-1, 4)
- qw = quat_yaw[:, 0]
- qx = quat_yaw[:, 1]
- qy = quat_yaw[:, 2]
- qz = quat_yaw[:, 3]
- yaw = torch.atan2(2 * (qw * qz + qx * qy), 1 - 2 * (qy * qy + qz * qz))
- quat_yaw[:] = 0.0
- quat_yaw[:, 3] = torch.sin(yaw / 2)
- quat_yaw[:, 0] = torch.cos(yaw / 2)
- quat_yaw = normalize(quat_yaw)
- return quat_yaw.view(shape)
-
-
-@torch.jit.script
-def quat_apply(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
- """Apply a quaternion rotation to a vector.
-
- Args:
- quat: The quaternion in (w, x, y, z). Shape is (..., 4).
- vec: The vector in (x, y, z). Shape is (..., 3).
-
- Returns:
- The rotated vector in (x, y, z). Shape is (..., 3).
- """
- # store shape
- shape = vec.shape
- # reshape to (N, 3) for multiplication
- quat = quat.reshape(-1, 4)
- vec = vec.reshape(-1, 3)
- # extract components from quaternions
- xyz = quat[:, 1:]
- t = xyz.cross(vec, dim=-1) * 2
- return (vec + quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape)
-
-
-@torch.jit.script
-def quat_apply_yaw(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
- """Rotate a vector only around the yaw-direction.
-
- Args:
- quat: The orientation in (w, x, y, z). Shape is (N, 4).
- vec: The vector in (x, y, z). Shape is (N, 3).
-
- Returns:
- The rotated vector in (x, y, z). Shape is (N, 3).
- """
- quat_yaw = yaw_quat(quat)
- return quat_apply(quat_yaw, vec)
-
-
-@torch.jit.script
-def quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
- """Rotate a vector by a quaternion.
-
- Args:
- q: The quaternion in (w, x, y, z). Shape is (N, 4).
- v: The vector in (x, y, z). Shape is (N, 3).
-
- Returns:
- The rotated vector in (x, y, z). Shape is (N, 3).
- """
- shape = q.shape
- q_w = q[:, 0]
- q_vec = q[:, 1:]
- a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
- b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
- c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0
- return a + b + c
-
-
-@torch.jit.script
-def quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
- """Rotate a vector by the inverse of a quaternion.
-
- Args:
- q: The quaternion in (w, x, y, z). Shape is (N, 4).
- v: The vector in (x, y, z). Shape is (N, 3).
-
- Returns:
- The rotated vector in (x, y, z). Shape is (N, 3).
- """
- shape = q.shape
- q_w = q[:, 0]
- q_vec = q[:, 1:]
- a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
- b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
- c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0
- return a - b + c
-
-
-@torch.jit.script
-def quat_from_angle_axis(angle: torch.Tensor, axis: torch.Tensor) -> torch.Tensor:
- """Convert rotations given as angle-axis to quaternions.
-
- Args:
- angle: The angle turned anti-clockwise in radians around the vector's direction. Shape is (N,).
- axis: The axis of rotation. Shape is (N, 3).
-
- Returns:
- The quaternion in (w, x, y, z). Shape is (N, 4).
- """
- theta = (angle / 2).unsqueeze(-1)
- xyz = normalize(axis) * theta.sin()
- w = theta.cos()
- return normalize(torch.cat([w, xyz], dim=-1))
-
-
-@torch.jit.script
-def axis_angle_from_quat(quat: torch.Tensor, eps: float = 1.0e-6) -> torch.Tensor:
- """Convert rotations given as quaternions to axis/angle.
-
- Args:
- quat: The quaternion orientation in (w, x, y, z). Shape is (..., 4).
- eps: The tolerance for Taylor approximation. Defaults to 1.0e-6.
-
- Returns:
- Rotations given as a vector in axis angle form. Shape is (..., 3).
- The vector's magnitude is the angle turned anti-clockwise in radians around the vector's direction.
-
- Reference:
- https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L526-L554
- """
- # Modified to take in quat as [q_w, q_x, q_y, q_z]
- # Quaternion is [q_w, q_x, q_y, q_z] = [cos(theta/2), n_x * sin(theta/2), n_y * sin(theta/2), n_z * sin(theta/2)]
- # Axis-angle is [a_x, a_y, a_z] = [theta * n_x, theta * n_y, theta * n_z]
- # Thus, axis-angle is [q_x, q_y, q_z] / (sin(theta/2) / theta)
- # When theta = 0, (sin(theta/2) / theta) is undefined
- # However, as theta --> 0, we can use the Taylor approximation 1/2 - theta^2 / 48
- quat = quat * (1.0 - 2.0 * (quat[..., 0:1] < 0.0))
- mag = torch.linalg.norm(quat[..., 1:], dim=-1)
- half_angle = torch.atan2(mag, quat[..., 0])
- angle = 2.0 * half_angle
- # check whether to apply Taylor approximation
- sin_half_angles_over_angles = torch.where(
- angle.abs() > eps, torch.sin(half_angle) / angle, 0.5 - angle * angle / 48
- )
- return quat[..., 1:4] / sin_half_angles_over_angles.unsqueeze(-1)
-
-
-@torch.jit.script
-def quat_error_magnitude(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
- """Computes the rotation difference between two quaternions.
-
- Args:
- q1: The first quaternion in (w, x, y, z). Shape is (..., 4).
- q2: The second quaternion in (w, x, y, z). Shape is (..., 4).
-
- Returns:
- Angular error between input quaternions in radians.
- """
- quat_diff = quat_mul(q1, quat_conjugate(q2))
- return torch.norm(axis_angle_from_quat(quat_diff), dim=-1)
-
-
-@torch.jit.script
-def skew_symmetric_matrix(vec: torch.Tensor) -> torch.Tensor:
- """Computes the skew-symmetric matrix of a vector.
-
- Args:
- vec: The input vector. Shape is (3,) or (N, 3).
-
- Returns:
- The skew-symmetric matrix. Shape is (1, 3, 3) or (N, 3, 3).
-
- Raises:
- ValueError: If input tensor is not of shape (..., 3).
- """
- # check input is correct
- if vec.shape[-1] != 3:
- raise ValueError(f"Expected input vector shape mismatch: {vec.shape} != (..., 3).")
- # unsqueeze the last dimension
- if vec.ndim == 1:
- vec = vec.unsqueeze(0)
- # create a skew-symmetric matrix
- skew_sym_mat = torch.zeros(vec.shape[0], 3, 3, device=vec.device, dtype=vec.dtype)
- skew_sym_mat[:, 0, 1] = -vec[:, 2]
- skew_sym_mat[:, 0, 2] = vec[:, 1]
- skew_sym_mat[:, 1, 2] = -vec[:, 0]
- skew_sym_mat[:, 1, 0] = vec[:, 2]
- skew_sym_mat[:, 2, 0] = -vec[:, 1]
- skew_sym_mat[:, 2, 1] = vec[:, 0]
-
- return skew_sym_mat
-
-
-"""
-Transformations
-"""
-
-
-[文档]def is_identity_pose(pos: torch.tensor, rot: torch.tensor) -> bool:
-
"""Checks if input poses are identity transforms.
-
-
The function checks if the input position and orientation are close to zero and
-
identity respectively using L2-norm. It does NOT check the error in the orientation.
-
-
Args:
-
pos: The cartesian position. Shape is (N, 3).
-
rot: The quaternion in (w, x, y, z). Shape is (N, 4).
-
-
Returns:
-
True if all the input poses result in identity transform. Otherwise, False.
-
"""
-
# create identity transformations
-
pos_identity = torch.zeros_like(pos)
-
rot_identity = torch.zeros_like(rot)
-
rot_identity[..., 0] = 1
-
# compare input to identity
-
return torch.allclose(pos, pos_identity) and torch.allclose(rot, rot_identity)
-
-
-# @torch.jit.script
-
-
-
-# @torch.jit.script
-
-
-
-# @torch.jit.script
-[文档]def compute_pose_error(
-
t01: torch.Tensor,
-
q01: torch.Tensor,
-
t02: torch.Tensor,
-
q02: torch.Tensor,
-
rot_error_type: Literal["quat", "axis_angle"] = "axis_angle",
-
) -> tuple[torch.Tensor, torch.Tensor]:
-
"""Compute the position and orientation error between source and target frames.
-
-
Args:
-
t01: Position of source frame. Shape is (N, 3).
-
q01: Quaternion orientation of source frame in (w, x, y, z). Shape is (N, 4).
-
t02: Position of target frame. Shape is (N, 3).
-
q02: Quaternion orientation of target frame in (w, x, y, z). Shape is (N, 4).
-
rot_error_type: The rotation error type to return: "quat", "axis_angle".
-
Defaults to "axis_angle".
-
-
Returns:
-
A tuple containing position and orientation error. Shape of position error is (N, 3).
-
Shape of orientation error depends on the value of :attr:`rot_error_type`:
-
-
- If :attr:`rot_error_type` is "quat", the orientation error is returned
-
as a quaternion. Shape is (N, 4).
-
- If :attr:`rot_error_type` is "axis_angle", the orientation error is
-
returned as an axis-angle vector. Shape is (N, 3).
-
-
Raises:
-
ValueError: Invalid rotation error type.
-
"""
-
# Compute quaternion error (i.e., difference quaternion)
-
# Reference: https://personal.utdallas.edu/~sxb027100/dock/quaternion.html
-
# q_current_norm = q_current * q_current_conj
-
source_quat_norm = quat_mul(q01, quat_conjugate(q01))[:, 0]
-
# q_current_inv = q_current_conj / q_current_norm
-
source_quat_inv = quat_conjugate(q01) / source_quat_norm.unsqueeze(-1)
-
# q_error = q_target * q_current_inv
-
quat_error = quat_mul(q02, source_quat_inv)
-
-
# Compute position error
-
pos_error = t02 - t01
-
-
# return error based on specified type
-
if rot_error_type == "quat":
-
return pos_error, quat_error
-
elif rot_error_type == "axis_angle":
-
# Convert to axis-angle error
-
axis_angle_error = axis_angle_from_quat(quat_error)
-
return pos_error, axis_angle_error
-
else:
-
raise ValueError(f"Unsupported orientation error type: {rot_error_type}. Valid: 'quat', 'axis_angle'.")
-
-
-@torch.jit.script
-def apply_delta_pose(
- source_pos: torch.Tensor, source_rot: torch.Tensor, delta_pose: torch.Tensor, eps: float = 1.0e-6
-) -> tuple[torch.Tensor, torch.Tensor]:
- """Applies delta pose transformation on source pose.
-
- The first three elements of `delta_pose` are interpreted as cartesian position displacement.
- The remaining three elements of `delta_pose` are interpreted as orientation displacement
- in the angle-axis format.
-
- Args:
- source_pos: Position of source frame. Shape is (N, 3).
- source_rot: Quaternion orientation of source frame in (w, x, y, z). Shape is (N, 4)..
- delta_pose: Position and orientation displacements. Shape is (N, 6).
- eps: The tolerance to consider orientation displacement as zero. Defaults to 1.0e-6.
-
- Returns:
- A tuple containing the displaced position and orientation frames.
- Shape of the tensors are (N, 3) and (N, 4) respectively.
- """
- # number of poses given
- num_poses = source_pos.shape[0]
- device = source_pos.device
-
- # interpret delta_pose[:, 0:3] as target position displacements
- target_pos = source_pos + delta_pose[:, 0:3]
- # interpret delta_pose[:, 3:6] as target rotation displacements
- rot_actions = delta_pose[:, 3:6]
- angle = torch.linalg.vector_norm(rot_actions, dim=1)
- axis = rot_actions / angle.unsqueeze(-1)
- # change from axis-angle to quat convention
- identity_quat = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device).repeat(num_poses, 1)
- rot_delta_quat = torch.where(
- angle.unsqueeze(-1).repeat(1, 4) > eps, quat_from_angle_axis(angle, axis), identity_quat
- )
- # TODO: Check if this is the correct order for this multiplication.
- target_rot = quat_mul(rot_delta_quat, source_rot)
-
- return target_pos, target_rot
-
-
-# @torch.jit.script
-
-
-
-"""
-Projection operations.
-"""
-
-
-@torch.jit.script
-def unproject_depth(depth: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
- r"""Unproject depth image into a pointcloud.
-
- This function converts depth images into points given the calibration matrix of the camera.
-
- .. math::
- p_{3D} = K^{-1} \times [u, v, 1]^T \times d
-
- where :math:`p_{3D}` is the 3D point, :math:`d` is the depth value, :math:`u` and :math:`v` are
- the pixel coordinates and :math:`K` is the intrinsic matrix.
-
- If `depth` is a batch of depth images and `intrinsics` is a single intrinsic matrix, the same
- calibration matrix is applied to all depth images in the batch.
-
- The function assumes that the width and height are both greater than 1. This makes the function
- deal with many possible shapes of depth images and intrinsics matrices.
-
- Args:
- depth: The depth measurement. Shape is (H, W) or or (H, W, 1) or (N, H, W) or (N, H, W, 1).
- intrinsics: A tensor providing camera's calibration matrix. Shape is (3, 3) or (N, 3, 3).
-
- Returns:
- The 3D coordinates of points. Shape is (P, 3) or (N, P, 3).
-
- Raises:
- ValueError: When depth is not of shape (H, W) or (H, W, 1) or (N, H, W) or (N, H, W, 1).
- ValueError: When intrinsics is not of shape (3, 3) or (N, 3, 3).
- """
- depth_batch = depth.clone()
- intrinsics_batch = intrinsics.clone()
- # check if inputs are batched
- is_batched = depth_batch.dim() == 4 or (depth_batch.dim() == 3 and depth_batch.shape[-1] != 1)
- # make sure inputs are batched
- if depth_batch.dim() == 3 and depth_batch.shape[-1] == 1:
- depth_batch = depth_batch.squeeze(dim=2) # (H, W, 1) -> (H, W)
- if depth_batch.dim() == 2:
- depth_batch = depth_batch[None] # (H, W) -> (1, H, W)
- if depth_batch.dim() == 4 and depth_batch.shape[-1] == 1:
- depth_batch = depth_batch.squeeze(dim=3) # (N, H, W, 1) -> (N, H, W)
- if intrinsics_batch.dim() == 2:
- intrinsics_batch = intrinsics_batch[None] # (3, 3) -> (1, 3, 3)
- # check shape of inputs
- if depth_batch.dim() != 3:
- raise ValueError(f"Expected depth images to have dim = 2 or 3 or 4: got shape {depth.shape}")
- if intrinsics_batch.dim() != 3:
- raise ValueError(f"Expected intrinsics to have shape (3, 3) or (N, 3, 3): got shape {intrinsics.shape}")
-
- # get image height and width
- im_height, im_width = depth_batch.shape[1:]
- # create image points in homogeneous coordinates (3, H x W)
- indices_u = torch.arange(im_width, device=depth.device, dtype=depth.dtype)
- indices_v = torch.arange(im_height, device=depth.device, dtype=depth.dtype)
- img_indices = torch.stack(torch.meshgrid([indices_u, indices_v], indexing="ij"), dim=0).reshape(2, -1)
- pixels = torch.nn.functional.pad(img_indices, (0, 0, 0, 1), mode="constant", value=1.0)
- pixels = pixels.unsqueeze(0) # (3, H x W) -> (1, 3, H x W)
-
- # unproject points into 3D space
- points = torch.matmul(torch.inverse(intrinsics_batch), pixels) # (N, 3, H x W)
- points = points / points[:, -1, :].unsqueeze(1) # normalize by last coordinate
- # flatten depth image (N, H, W) -> (N, H x W)
- depth_batch = depth_batch.transpose_(1, 2).reshape(depth_batch.shape[0], -1).unsqueeze(2)
- depth_batch = depth_batch.expand(-1, -1, 3)
- # scale points by depth
- points_xyz = points.transpose_(1, 2) * depth_batch # (N, H x W, 3)
-
- # return points in same shape as input
- if not is_batched:
- points_xyz = points_xyz.squeeze(0)
-
- return points_xyz
-
-
-@torch.jit.script
-def project_points(points: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
- r"""Projects 3D points into 2D image plane.
-
- This project 3D points into a 2D image plane. The transformation is defined by the intrinsic
- matrix of the camera.
-
- .. math::
-
- \begin{align}
- p &= K \times p_{3D} = \\
- p_{2D} &= \begin{pmatrix} u \\ v \\ d \end{pmatrix}
- = \begin{pmatrix} p[0] / p[2] \\ p[1] / p[2] \\ Z \end{pmatrix}
- \end{align}
-
- where :math:`p_{2D} = (u, v, d)` is the projected 3D point, :math:`p_{3D} = (X, Y, Z)` is the
- 3D point and :math:`K \in \mathbb{R}^{3 \times 3}` is the intrinsic matrix.
-
- If `points` is a batch of 3D points and `intrinsics` is a single intrinsic matrix, the same
- calibration matrix is applied to all points in the batch.
-
- Args:
- points: The 3D coordinates of points. Shape is (P, 3) or (N, P, 3).
- intrinsics: Camera's calibration matrix. Shape is (3, 3) or (N, 3, 3).
-
- Returns:
- Projected 3D coordinates of points. Shape is (P, 3) or (N, P, 3).
- """
- points_batch = points.clone()
- intrinsics_batch = intrinsics.clone()
- # check if inputs are batched
- is_batched = points_batch.dim() == 2
- # make sure inputs are batched
- if points_batch.dim() == 2:
- points_batch = points_batch[None] # (P, 3) -> (1, P, 3)
- if intrinsics_batch.dim() == 2:
- intrinsics_batch = intrinsics_batch[None] # (3, 3) -> (1, 3, 3)
- # check shape of inputs
- if points_batch.dim() != 3:
- raise ValueError(f"Expected points to have dim = 3: got shape {points.shape}.")
- if intrinsics_batch.dim() != 3:
- raise ValueError(f"Expected intrinsics to have shape (3, 3) or (N, 3, 3): got shape {intrinsics.shape}.")
- # project points into 2D image plane
- points_2d = torch.matmul(intrinsics_batch, points_batch.transpose(1, 2))
- points_2d = points_2d / points_2d[:, -1, :].unsqueeze(1) # normalize by last coordinate
- points_2d = points_2d.transpose_(1, 2) # (N, 3, P) -> (N, P, 3)
- # replace last coordinate with depth
- points_2d[:, :, -1] = points_batch[:, :, -1]
- # return points in same shape as input
- if not is_batched:
- points_2d = points_2d.squeeze(0) # (1, 3, P) -> (3, P)
-
- return points_2d
-
-
-"""
-Sampling
-"""
-
-
-@torch.jit.script
-def default_orientation(num: int, device: str) -> torch.Tensor:
- """Returns identity rotation transform.
-
- Args:
- num: The number of rotations to sample.
- device: Device to create tensor on.
-
- Returns:
- Identity quaternion in (w, x, y, z). Shape is (num, 4).
- """
- quat = torch.zeros((num, 4), dtype=torch.float, device=device)
- quat[..., 0] = 1.0
-
- return quat
-
-
-@torch.jit.script
-def random_orientation(num: int, device: str) -> torch.Tensor:
- """Returns sampled rotation in 3D as quaternion.
-
- Args:
- num: The number of rotations to sample.
- device: Device to create tensor on.
-
- Returns:
- Sampled quaternion in (w, x, y, z). Shape is (num, 4).
-
- Reference:
- https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.random.html
- """
- # sample random orientation from normal distribution
- quat = torch.randn((num, 4), dtype=torch.float, device=device)
- # normalize the quaternion
- return torch.nn.functional.normalize(quat, p=2.0, dim=-1, eps=1e-12)
-
-
-@torch.jit.script
-def random_yaw_orientation(num: int, device: str) -> torch.Tensor:
- """Returns sampled rotation around z-axis.
-
- Args:
- num: The number of rotations to sample.
- device: Device to create tensor on.
-
- Returns:
- Sampled quaternion in (w, x, y, z). Shape is (num, 4).
- """
- roll = torch.zeros(num, dtype=torch.float, device=device)
- pitch = torch.zeros(num, dtype=torch.float, device=device)
- yaw = 2 * torch.pi * torch.rand(num, dtype=torch.float, device=device)
-
- return quat_from_euler_xyz(roll, pitch, yaw)
-
-
-[文档]def sample_triangle(lower: float, upper: float, size: int | tuple[int, ...], device: str) -> torch.Tensor:
-
"""Randomly samples tensor from a triangular distribution.
-
-
Args:
-
lower: The lower range of the sampled tensor.
-
upper: The upper range of the sampled tensor.
-
size: The shape of the tensor.
-
device: Device to create tensor on.
-
-
Returns:
-
Sampled tensor. Shape is based on :attr:`size`.
-
"""
-
# convert to tuple
-
if isinstance(size, int):
-
size = (size,)
-
# create random tensor in the range [-1, 1]
-
r = 2 * torch.rand(*size, device=device) - 1
-
# convert to triangular distribution
-
r = torch.where(r < 0.0, -torch.sqrt(-r), torch.sqrt(r))
-
# rescale back to [0, 1]
-
r = (r + 1.0) / 2.0
-
# rescale to range [lower, upper]
-
return (upper - lower) * r + lower
-
-
-
-
-
-
-
-
-[文档]def sample_gaussian(
-
mean: torch.Tensor | float, std: torch.Tensor | float, size: int | tuple[int, ...], device: str
-
) -> torch.Tensor:
-
"""Sample using gaussian distribution.
-
-
Args:
-
mean: Mean of the gaussian.
-
std: Std of the gaussian.
-
size: The shape of the tensor.
-
device: Device to create tensor on.
-
-
Returns:
-
Sampled tensor.
-
"""
-
if isinstance(mean, float):
-
if isinstance(size, int):
-
size = (size,)
-
return torch.normal(mean=mean, std=std, size=size).to(device=device)
-
else:
-
return torch.normal(mean=mean, std=std).to(device=device)
-
-
-[文档]def sample_cylinder(
-
radius: float, h_range: tuple[float, float], size: int | tuple[int, ...], device: str
-
) -> torch.Tensor:
-
"""Sample 3D points uniformly on a cylinder's surface.
-
-
The cylinder is centered at the origin and aligned with the z-axis. The height of the cylinder is
-
sampled uniformly from the range :obj:`h_range`, while the radius is fixed to :obj:`radius`.
-
-
The sampled points are returned as a tensor of shape :obj:`(*size, 3)`, i.e. the last dimension
-
contains the x, y, and z coordinates of the sampled points.
-
-
Args:
-
radius: The radius of the cylinder.
-
h_range: The minimum and maximum height of the cylinder.
-
size: The shape of the tensor.
-
device: Device to create tensor on.
-
-
Returns:
-
Sampled tensor. Shape is :obj:`(*size, 3)`.
-
"""
-
# sample angles
-
angles = (torch.rand(size, device=device) * 2 - 1) * torch.pi
-
h_min, h_max = h_range
-
# add shape
-
if isinstance(size, int):
-
size = (size, 3)
-
else:
-
size += (3,)
-
# allocate a tensor
-
xyz = torch.zeros(size, device=device)
-
xyz[..., 0] = radius * torch.cos(angles)
-
xyz[..., 1] = radius * torch.sin(angles)
-
xyz[..., 2].uniform_(h_min, h_max)
-
# return positions
-
return xyz
-
-
-