+# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
+# All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+
+"""Sub-module containing utilities for various math operations."""
+
+# needed to import for allowing type-hinting: torch.Tensor | np.ndarray
+from __future__ import annotations
+
+import math
+import numpy as np
+import torch
+import torch.nn.functional
+from typing import Literal
+
+"""
+General
+"""
+
+
+@torch.jit.script
+def scale_transform(x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:
+ """Normalizes a given input tensor to a range of [-1, 1].
+
+ .. note::
+ It uses pytorch broadcasting functionality to deal with batched input.
+
+ Args:
+ x: Input tensor of shape (N, dims).
+ lower: The minimum value of the tensor. Shape is (N, dims) or (dims,).
+ upper: The maximum value of the tensor. Shape is (N, dims) or (dims,).
+
+ Returns:
+ Normalized transform of the tensor. Shape is (N, dims).
+ """
+ # default value of center
+ offset = (lower + upper) * 0.5
+ # return normalized tensor
+ return 2 * (x - offset) / (upper - lower)
+
+
+@torch.jit.script
+def unscale_transform(x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:
+ """De-normalizes a given input tensor from range of [-1, 1] to (lower, upper).
+
+ .. note::
+ It uses pytorch broadcasting functionality to deal with batched input.
+
+ Args:
+ x: Input tensor of shape (N, dims).
+ lower: The minimum value of the tensor. Shape is (N, dims) or (dims,).
+ upper: The maximum value of the tensor. Shape is (N, dims) or (dims,).
+
+ Returns:
+ De-normalized transform of the tensor. Shape is (N, dims).
+ """
+ # default value of center
+ offset = (lower + upper) * 0.5
+ # return normalized tensor
+ return x * (upper - lower) * 0.5 + offset
+
+
+@torch.jit.script
+def saturate(x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:
+ """Clamps a given input tensor to (lower, upper).
+
+ It uses pytorch broadcasting functionality to deal with batched input.
+
+ Args:
+ x: Input tensor of shape (N, dims).
+ lower: The minimum value of the tensor. Shape is (N, dims) or (dims,).
+ upper: The maximum value of the tensor. Shape is (N, dims) or (dims,).
+
+ Returns:
+ Clamped transform of the tensor. Shape is (N, dims).
+ """
+ return torch.max(torch.min(x, upper), lower)
+
+
+@torch.jit.script
+def normalize(x: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
+ """Normalizes a given input tensor to unit length.
+
+ Args:
+ x: Input tensor of shape (N, dims).
+ eps: A small value to avoid division by zero. Defaults to 1e-9.
+
+ Returns:
+ Normalized tensor of shape (N, dims).
+ """
+ return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1)
+
+
+@torch.jit.script
+def wrap_to_pi(angles: torch.Tensor) -> torch.Tensor:
+ r"""Wraps input angles (in radians) to the range :math:`[-\pi, \pi]`.
+
+ This function wraps angles in radians to the range :math:`[-\pi, \pi]`, such that
+ :math:`\pi` maps to :math:`\pi`, and :math:`-\pi` maps to :math:`-\pi`. In general,
+ odd positive multiples of :math:`\pi` are mapped to :math:`\pi`, and odd negative
+ multiples of :math:`\pi` are mapped to :math:`-\pi`.
+
+ The function behaves similar to MATLAB's `wrapToPi <https://www.mathworks.com/help/map/ref/wraptopi.html>`_
+ function.
+
+ Args:
+ angles: Input angles of any shape.
+
+ Returns:
+ Angles in the range :math:`[-\pi, \pi]`.
+ """
+ # wrap to [0, 2*pi)
+ wrapped_angle = (angles + torch.pi) % (2 * torch.pi)
+ # map to [-pi, pi]
+ # we check for zero in wrapped angle to make it go to pi when input angle is odd multiple of pi
+ return torch.where((wrapped_angle == 0) & (angles > 0), torch.pi, wrapped_angle - torch.pi)
+
+
+@torch.jit.script
+def copysign(mag: float, other: torch.Tensor) -> torch.Tensor:
+ """Create a new floating-point tensor with the magnitude of input and the sign of other, element-wise.
+
+ Note:
+ The implementation follows from `torch.copysign`. The function allows a scalar magnitude.
+
+ Args:
+ mag: The magnitude scalar.
+ other: The tensor containing values whose signbits are applied to magnitude.
+
+ Returns:
+ The output tensor.
+ """
+ mag_torch = torch.tensor(mag, device=other.device, dtype=torch.float).repeat(other.shape[0])
+ return torch.abs(mag_torch) * torch.sign(other)
+
+
+"""
+Rotation
+"""
+
+
+@torch.jit.script
+def matrix_from_quat(quaternions: torch.Tensor) -> torch.Tensor:
+ """Convert rotations given as quaternions to rotation matrices.
+
+ Args:
+ quaternions: The quaternion orientation in (w, x, y, z). Shape is (..., 4).
+
+ Returns:
+ Rotation matrices. The shape is (..., 3, 3).
+
+ Reference:
+ https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L41-L70
+ """
+ r, i, j, k = torch.unbind(quaternions, -1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+[文档]def convert_quat(quat: torch.Tensor | np.ndarray, to: Literal["xyzw", "wxyz"] = "xyzw") -> torch.Tensor | np.ndarray:
+
"""Converts quaternion from one convention to another.
+
+
The convention to convert TO is specified as an optional argument. If to == 'xyzw',
+
then the input is in 'wxyz' format, and vice-versa.
+
+
Args:
+
quat: The quaternion of shape (..., 4).
+
to: Convention to convert quaternion to.. Defaults to "xyzw".
+
+
Returns:
+
The converted quaternion in specified convention.
+
+
Raises:
+
ValueError: Invalid input argument `to`, i.e. not "xyzw" or "wxyz".
+
ValueError: Invalid shape of input `quat`, i.e. not (..., 4,).
+
"""
+
# check input is correct
+
if quat.shape[-1] != 4:
+
msg = f"Expected input quaternion shape mismatch: {quat.shape} != (..., 4)."
+
raise ValueError(msg)
+
if to not in ["xyzw", "wxyz"]:
+
msg = f"Expected input argument `to` to be 'xyzw' or 'wxyz'. Received: {to}."
+
raise ValueError(msg)
+
# check if input is numpy array (we support this backend since some classes use numpy)
+
if isinstance(quat, np.ndarray):
+
# use numpy functions
+
if to == "xyzw":
+
# wxyz -> xyzw
+
return np.roll(quat, -1, axis=-1)
+
else:
+
# xyzw -> wxyz
+
return np.roll(quat, 1, axis=-1)
+
else:
+
# convert to torch (sanity check)
+
if not isinstance(quat, torch.Tensor):
+
quat = torch.tensor(quat, dtype=float)
+
# convert to specified quaternion type
+
if to == "xyzw":
+
# wxyz -> xyzw
+
return quat.roll(-1, dims=-1)
+
else:
+
# xyzw -> wxyz
+
return quat.roll(1, dims=-1)
+
+
+@torch.jit.script
+def quat_conjugate(q: torch.Tensor) -> torch.Tensor:
+ """Computes the conjugate of a quaternion.
+
+ Args:
+ q: The quaternion orientation in (w, x, y, z). Shape is (..., 4).
+
+ Returns:
+ The conjugate quaternion in (w, x, y, z). Shape is (..., 4).
+ """
+ shape = q.shape
+ q = q.reshape(-1, 4)
+ return torch.cat((q[:, 0:1], -q[:, 1:]), dim=-1).view(shape)
+
+
+@torch.jit.script
+def quat_inv(q: torch.Tensor) -> torch.Tensor:
+ """Compute the inverse of a quaternion.
+
+ Args:
+ q: The quaternion orientation in (w, x, y, z). Shape is (N, 4).
+
+ Returns:
+ The inverse quaternion in (w, x, y, z). Shape is (N, 4).
+ """
+ return normalize(quat_conjugate(q))
+
+
+@torch.jit.script
+def quat_from_euler_xyz(roll: torch.Tensor, pitch: torch.Tensor, yaw: torch.Tensor) -> torch.Tensor:
+ """Convert rotations given as Euler angles in radians to Quaternions.
+
+ Note:
+ The euler angles are assumed in XYZ convention.
+
+ Args:
+ roll: Rotation around x-axis (in radians). Shape is (N,).
+ pitch: Rotation around y-axis (in radians). Shape is (N,).
+ yaw: Rotation around z-axis (in radians). Shape is (N,).
+
+ Returns:
+ The quaternion in (w, x, y, z). Shape is (N, 4).
+ """
+ cy = torch.cos(yaw * 0.5)
+ sy = torch.sin(yaw * 0.5)
+ cr = torch.cos(roll * 0.5)
+ sr = torch.sin(roll * 0.5)
+ cp = torch.cos(pitch * 0.5)
+ sp = torch.sin(pitch * 0.5)
+ # compute quaternion
+ qw = cy * cr * cp + sy * sr * sp
+ qx = cy * sr * cp - sy * cr * sp
+ qy = cy * cr * sp + sy * sr * cp
+ qz = sy * cr * cp - cy * sr * sp
+
+ return torch.stack([qw, qx, qy, qz], dim=-1)
+
+
+@torch.jit.script
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """Returns torch.sqrt(torch.max(0, x)) but with a zero sub-gradient where x is 0.
+
+ Reference:
+ https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L91-L99
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ return ret
+
+
+@torch.jit.script
+def quat_from_matrix(matrix: torch.Tensor) -> torch.Tensor:
+ """Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: The rotation matrices. Shape is (..., 3, 3).
+
+ Returns:
+ The quaternion in (w, x, y, z). Shape is (..., 4).
+
+ Reference:
+ https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L102-L161
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+ return quat_candidates[torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(
+ batch_dim + (4,)
+ )
+
+
+def _axis_angle_rotation(axis: Literal["X", "Y", "Z"], angle: torch.Tensor) -> torch.Tensor:
+ """Return the rotation matrices for one of the rotations about an axis of which Euler angles describe,
+ for each value of the angle given.
+
+ Args:
+ axis: Axis label "X" or "Y or "Z".
+ angle: Euler angles in radians of any shape.
+
+ Returns:
+ Rotation matrices. Shape is (..., 3, 3).
+
+ Reference:
+ https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L164-L191
+ """
+ cos = torch.cos(angle)
+ sin = torch.sin(angle)
+ one = torch.ones_like(angle)
+ zero = torch.zeros_like(angle)
+
+ if axis == "X":
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
+ elif axis == "Y":
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
+ elif axis == "Z":
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
+ else:
+ raise ValueError("letter must be either X, Y or Z.")
+
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
+
+
+[文档]def matrix_from_euler(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
+
"""
+
Convert rotations given as Euler angles in radians to rotation matrices.
+
+
Args:
+
euler_angles: Euler angles in radians. Shape is (..., 3).
+
convention: Convention string of three uppercase letters from {"X", "Y", and "Z"}.
+
For example, "XYZ" means that the rotations should be applied first about x,
+
then y, then z.
+
+
Returns:
+
Rotation matrices. Shape is (..., 3, 3).
+
+
Reference:
+
https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L194-L220
+
"""
+
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
+
raise ValueError("Invalid input euler angles.")
+
if len(convention) != 3:
+
raise ValueError("Convention must have 3 letters.")
+
if convention[1] in (convention[0], convention[2]):
+
raise ValueError(f"Invalid convention {convention}.")
+
for letter in convention:
+
if letter not in ("X", "Y", "Z"):
+
raise ValueError(f"Invalid letter {letter} in convention string.")
+
matrices = [_axis_angle_rotation(c, e) for c, e in zip(convention, torch.unbind(euler_angles, -1))]
+
# return functools.reduce(torch.matmul, matrices)
+
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
+
+
+@torch.jit.script
+def euler_xyz_from_quat(quat: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Convert rotations given as quaternions to Euler angles in radians.
+
+ Note:
+ The euler angles are assumed in XYZ convention.
+
+ Args:
+ quat: The quaternion orientation in (w, x, y, z). Shape is (N, 4).
+
+ Returns:
+ A tuple containing roll-pitch-yaw. Each element is a tensor of shape (N,).
+
+ Reference:
+ https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles
+ """
+ q_w, q_x, q_y, q_z = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3]
+ # roll (x-axis rotation)
+ sin_roll = 2.0 * (q_w * q_x + q_y * q_z)
+ cos_roll = 1 - 2 * (q_x * q_x + q_y * q_y)
+ roll = torch.atan2(sin_roll, cos_roll)
+
+ # pitch (y-axis rotation)
+ sin_pitch = 2.0 * (q_w * q_y - q_z * q_x)
+ pitch = torch.where(torch.abs(sin_pitch) >= 1, copysign(torch.pi / 2.0, sin_pitch), torch.asin(sin_pitch))
+
+ # yaw (z-axis rotation)
+ sin_yaw = 2.0 * (q_w * q_z + q_x * q_y)
+ cos_yaw = 1 - 2 * (q_y * q_y + q_z * q_z)
+ yaw = torch.atan2(sin_yaw, cos_yaw)
+
+ return roll % (2 * torch.pi), pitch % (2 * torch.pi), yaw % (2 * torch.pi) # TODO: why not wrap_to_pi here ?
+
+
+@torch.jit.script
+def quat_unique(q: torch.Tensor) -> torch.Tensor:
+ """Convert a unit quaternion to a standard form where the real part is non-negative.
+
+ Quaternion representations have a singularity since ``q`` and ``-q`` represent the same
+ rotation. This function ensures the real part of the quaternion is non-negative.
+
+ Args:
+ q: The quaternion orientation in (w, x, y, z). Shape is (..., 4).
+
+ Returns:
+ Standardized quaternions. Shape is (..., 4).
+ """
+ return torch.where(q[..., 0:1] < 0, -q, q)
+
+
+@torch.jit.script
+def quat_mul(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
+ """Multiply two quaternions together.
+
+ Args:
+ q1: The first quaternion in (w, x, y, z). Shape is (..., 4).
+ q2: The second quaternion in (w, x, y, z). Shape is (..., 4).
+
+ Returns:
+ The product of the two quaternions in (w, x, y, z). Shape is (..., 4).
+
+ Raises:
+ ValueError: Input shapes of ``q1`` and ``q2`` are not matching.
+ """
+ # check input is correct
+ if q1.shape != q2.shape:
+ msg = f"Expected input quaternion shape mismatch: {q1.shape} != {q2.shape}."
+ raise ValueError(msg)
+ # reshape to (N, 4) for multiplication
+ shape = q1.shape
+ q1 = q1.reshape(-1, 4)
+ q2 = q2.reshape(-1, 4)
+ # extract components from quaternions
+ w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3]
+ w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3]
+ # perform multiplication
+ ww = (z1 + x1) * (x2 + y2)
+ yy = (w1 - y1) * (w2 + z2)
+ zz = (w1 + y1) * (w2 - z2)
+ xx = ww + yy + zz
+ qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
+ w = qq - ww + (z1 - y1) * (y2 - z2)
+ x = qq - xx + (x1 + w1) * (x2 + w2)
+ y = qq - yy + (w1 - x1) * (y2 + z2)
+ z = qq - zz + (z1 + y1) * (w2 - x2)
+
+ return torch.stack([w, x, y, z], dim=-1).view(shape)
+
+
+@torch.jit.script
+def quat_box_minus(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
+ """The box-minus operator (quaternion difference) between two quaternions.
+
+ Args:
+ q1: The first quaternion in (w, x, y, z). Shape is (N, 4).
+ q2: The second quaternion in (w, x, y, z). Shape is (N, 4).
+
+ Returns:
+ The difference between the two quaternions. Shape is (N, 3).
+ """
+ quat_diff = quat_mul(q1, quat_conjugate(q2)) # q1 * q2^-1
+ re = quat_diff[:, 0] # real part, q = [w, x, y, z] = [re, im]
+ im = quat_diff[:, 1:] # imaginary part
+ norm_im = torch.norm(im, dim=1)
+ scale = 2.0 * torch.where(norm_im > 1.0e-7, torch.atan2(norm_im, re) / norm_im, torch.sign(re))
+ return scale.unsqueeze(-1) * im
+
+
+@torch.jit.script
+def yaw_quat(quat: torch.Tensor) -> torch.Tensor:
+ """Extract the yaw component of a quaternion.
+
+ Args:
+ quat: The orientation in (w, x, y, z). Shape is (..., 4)
+
+ Returns:
+ A quaternion with only yaw component.
+ """
+ shape = quat.shape
+ quat_yaw = quat.clone().view(-1, 4)
+ qw = quat_yaw[:, 0]
+ qx = quat_yaw[:, 1]
+ qy = quat_yaw[:, 2]
+ qz = quat_yaw[:, 3]
+ yaw = torch.atan2(2 * (qw * qz + qx * qy), 1 - 2 * (qy * qy + qz * qz))
+ quat_yaw[:] = 0.0
+ quat_yaw[:, 3] = torch.sin(yaw / 2)
+ quat_yaw[:, 0] = torch.cos(yaw / 2)
+ quat_yaw = normalize(quat_yaw)
+ return quat_yaw.view(shape)
+
+
+@torch.jit.script
+def quat_apply(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
+ """Apply a quaternion rotation to a vector.
+
+ Args:
+ quat: The quaternion in (w, x, y, z). Shape is (..., 4).
+ vec: The vector in (x, y, z). Shape is (..., 3).
+
+ Returns:
+ The rotated vector in (x, y, z). Shape is (..., 3).
+ """
+ # store shape
+ shape = vec.shape
+ # reshape to (N, 3) for multiplication
+ quat = quat.reshape(-1, 4)
+ vec = vec.reshape(-1, 3)
+ # extract components from quaternions
+ xyz = quat[:, 1:]
+ t = xyz.cross(vec, dim=-1) * 2
+ return (vec + quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape)
+
+
+@torch.jit.script
+def quat_apply_yaw(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
+ """Rotate a vector only around the yaw-direction.
+
+ Args:
+ quat: The orientation in (w, x, y, z). Shape is (N, 4).
+ vec: The vector in (x, y, z). Shape is (N, 3).
+
+ Returns:
+ The rotated vector in (x, y, z). Shape is (N, 3).
+ """
+ quat_yaw = yaw_quat(quat)
+ return quat_apply(quat_yaw, vec)
+
+
+@torch.jit.script
+def quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
+ """Rotate a vector by a quaternion along the last dimension of q and v.
+
+ Args:
+ q: The quaternion in (w, x, y, z). Shape is (..., 4).
+ v: The vector in (x, y, z). Shape is (..., 3).
+
+ Returns:
+ The rotated vector in (x, y, z). Shape is (..., 3).
+ """
+ q_w = q[..., 0]
+ q_vec = q[..., 1:]
+ a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
+ b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
+ # for two-dimensional tensors, bmm is faster than einsum
+ if q_vec.dim() == 2:
+ c = q_vec * torch.bmm(q_vec.view(q.shape[0], 1, 3), v.view(q.shape[0], 3, 1)).squeeze(-1) * 2.0
+ else:
+ c = q_vec * torch.einsum("...i,...i->...", q_vec, v).unsqueeze(-1) * 2.0
+ return a + b + c
+
+
+@torch.jit.script
+def quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
+ """Rotate a vector by the inverse of a quaternion along the last dimension of q and v.
+
+ Args:
+ q: The quaternion in (w, x, y, z). Shape is (..., 4).
+ v: The vector in (x, y, z). Shape is (..., 3).
+
+ Returns:
+ The rotated vector in (x, y, z). Shape is (..., 3).
+ """
+ q_w = q[..., 0]
+ q_vec = q[..., 1:]
+ a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
+ b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
+ # for two-dimensional tensors, bmm is faster than einsum
+ if q_vec.dim() == 2:
+ c = q_vec * torch.bmm(q_vec.view(q.shape[0], 1, 3), v.view(q.shape[0], 3, 1)).squeeze(-1) * 2.0
+ else:
+ c = q_vec * torch.einsum("...i,...i->...", q_vec, v).unsqueeze(-1) * 2.0
+ return a - b + c
+
+
+@torch.jit.script
+def quat_from_angle_axis(angle: torch.Tensor, axis: torch.Tensor) -> torch.Tensor:
+ """Convert rotations given as angle-axis to quaternions.
+
+ Args:
+ angle: The angle turned anti-clockwise in radians around the vector's direction. Shape is (N,).
+ axis: The axis of rotation. Shape is (N, 3).
+
+ Returns:
+ The quaternion in (w, x, y, z). Shape is (N, 4).
+ """
+ theta = (angle / 2).unsqueeze(-1)
+ xyz = normalize(axis) * theta.sin()
+ w = theta.cos()
+ return normalize(torch.cat([w, xyz], dim=-1))
+
+
+@torch.jit.script
+def axis_angle_from_quat(quat: torch.Tensor, eps: float = 1.0e-6) -> torch.Tensor:
+ """Convert rotations given as quaternions to axis/angle.
+
+ Args:
+ quat: The quaternion orientation in (w, x, y, z). Shape is (..., 4).
+ eps: The tolerance for Taylor approximation. Defaults to 1.0e-6.
+
+ Returns:
+ Rotations given as a vector in axis angle form. Shape is (..., 3).
+ The vector's magnitude is the angle turned anti-clockwise in radians around the vector's direction.
+
+ Reference:
+ https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L526-L554
+ """
+ # Modified to take in quat as [q_w, q_x, q_y, q_z]
+ # Quaternion is [q_w, q_x, q_y, q_z] = [cos(theta/2), n_x * sin(theta/2), n_y * sin(theta/2), n_z * sin(theta/2)]
+ # Axis-angle is [a_x, a_y, a_z] = [theta * n_x, theta * n_y, theta * n_z]
+ # Thus, axis-angle is [q_x, q_y, q_z] / (sin(theta/2) / theta)
+ # When theta = 0, (sin(theta/2) / theta) is undefined
+ # However, as theta --> 0, we can use the Taylor approximation 1/2 - theta^2 / 48
+ quat = quat * (1.0 - 2.0 * (quat[..., 0:1] < 0.0))
+ mag = torch.linalg.norm(quat[..., 1:], dim=-1)
+ half_angle = torch.atan2(mag, quat[..., 0])
+ angle = 2.0 * half_angle
+ # check whether to apply Taylor approximation
+ sin_half_angles_over_angles = torch.where(
+ angle.abs() > eps, torch.sin(half_angle) / angle, 0.5 - angle * angle / 48
+ )
+ return quat[..., 1:4] / sin_half_angles_over_angles.unsqueeze(-1)
+
+
+@torch.jit.script
+def quat_error_magnitude(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
+ """Computes the rotation difference between two quaternions.
+
+ Args:
+ q1: The first quaternion in (w, x, y, z). Shape is (..., 4).
+ q2: The second quaternion in (w, x, y, z). Shape is (..., 4).
+
+ Returns:
+ Angular error between input quaternions in radians.
+ """
+ quat_diff = quat_mul(q1, quat_conjugate(q2))
+ return torch.norm(axis_angle_from_quat(quat_diff), dim=-1)
+
+
+@torch.jit.script
+def skew_symmetric_matrix(vec: torch.Tensor) -> torch.Tensor:
+ """Computes the skew-symmetric matrix of a vector.
+
+ Args:
+ vec: The input vector. Shape is (3,) or (N, 3).
+
+ Returns:
+ The skew-symmetric matrix. Shape is (1, 3, 3) or (N, 3, 3).
+
+ Raises:
+ ValueError: If input tensor is not of shape (..., 3).
+ """
+ # check input is correct
+ if vec.shape[-1] != 3:
+ raise ValueError(f"Expected input vector shape mismatch: {vec.shape} != (..., 3).")
+ # unsqueeze the last dimension
+ if vec.ndim == 1:
+ vec = vec.unsqueeze(0)
+ # create a skew-symmetric matrix
+ skew_sym_mat = torch.zeros(vec.shape[0], 3, 3, device=vec.device, dtype=vec.dtype)
+ skew_sym_mat[:, 0, 1] = -vec[:, 2]
+ skew_sym_mat[:, 0, 2] = vec[:, 1]
+ skew_sym_mat[:, 1, 2] = -vec[:, 0]
+ skew_sym_mat[:, 1, 0] = vec[:, 2]
+ skew_sym_mat[:, 2, 0] = -vec[:, 1]
+ skew_sym_mat[:, 2, 1] = vec[:, 0]
+
+ return skew_sym_mat
+
+
+"""
+Transformations
+"""
+
+
+[文档]def is_identity_pose(pos: torch.tensor, rot: torch.tensor) -> bool:
+
"""Checks if input poses are identity transforms.
+
+
The function checks if the input position and orientation are close to zero and
+
identity respectively using L2-norm. It does NOT check the error in the orientation.
+
+
Args:
+
pos: The cartesian position. Shape is (N, 3).
+
rot: The quaternion in (w, x, y, z). Shape is (N, 4).
+
+
Returns:
+
True if all the input poses result in identity transform. Otherwise, False.
+
"""
+
# create identity transformations
+
pos_identity = torch.zeros_like(pos)
+
rot_identity = torch.zeros_like(rot)
+
rot_identity[..., 0] = 1
+
# compare input to identity
+
return torch.allclose(pos, pos_identity) and torch.allclose(rot, rot_identity)
+
+
+# @torch.jit.script
+
+
+
+# @torch.jit.script
+
+
+
+# @torch.jit.script
+[文档]def compute_pose_error(
+
t01: torch.Tensor,
+
q01: torch.Tensor,
+
t02: torch.Tensor,
+
q02: torch.Tensor,
+
rot_error_type: Literal["quat", "axis_angle"] = "axis_angle",
+
) -> tuple[torch.Tensor, torch.Tensor]:
+
"""Compute the position and orientation error between source and target frames.
+
+
Args:
+
t01: Position of source frame. Shape is (N, 3).
+
q01: Quaternion orientation of source frame in (w, x, y, z). Shape is (N, 4).
+
t02: Position of target frame. Shape is (N, 3).
+
q02: Quaternion orientation of target frame in (w, x, y, z). Shape is (N, 4).
+
rot_error_type: The rotation error type to return: "quat", "axis_angle".
+
Defaults to "axis_angle".
+
+
Returns:
+
A tuple containing position and orientation error. Shape of position error is (N, 3).
+
Shape of orientation error depends on the value of :attr:`rot_error_type`:
+
+
- If :attr:`rot_error_type` is "quat", the orientation error is returned
+
as a quaternion. Shape is (N, 4).
+
- If :attr:`rot_error_type` is "axis_angle", the orientation error is
+
returned as an axis-angle vector. Shape is (N, 3).
+
+
Raises:
+
ValueError: Invalid rotation error type.
+
"""
+
# Compute quaternion error (i.e., difference quaternion)
+
# Reference: https://personal.utdallas.edu/~sxb027100/dock/quaternion.html
+
# q_current_norm = q_current * q_current_conj
+
source_quat_norm = quat_mul(q01, quat_conjugate(q01))[:, 0]
+
# q_current_inv = q_current_conj / q_current_norm
+
source_quat_inv = quat_conjugate(q01) / source_quat_norm.unsqueeze(-1)
+
# q_error = q_target * q_current_inv
+
quat_error = quat_mul(q02, source_quat_inv)
+
+
# Compute position error
+
pos_error = t02 - t01
+
+
# return error based on specified type
+
if rot_error_type == "quat":
+
return pos_error, quat_error
+
elif rot_error_type == "axis_angle":
+
# Convert to axis-angle error
+
axis_angle_error = axis_angle_from_quat(quat_error)
+
return pos_error, axis_angle_error
+
else:
+
raise ValueError(f"Unsupported orientation error type: {rot_error_type}. Valid: 'quat', 'axis_angle'.")
+
+
+@torch.jit.script
+def apply_delta_pose(
+ source_pos: torch.Tensor, source_rot: torch.Tensor, delta_pose: torch.Tensor, eps: float = 1.0e-6
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Applies delta pose transformation on source pose.
+
+ The first three elements of `delta_pose` are interpreted as cartesian position displacement.
+ The remaining three elements of `delta_pose` are interpreted as orientation displacement
+ in the angle-axis format.
+
+ Args:
+ source_pos: Position of source frame. Shape is (N, 3).
+ source_rot: Quaternion orientation of source frame in (w, x, y, z). Shape is (N, 4)..
+ delta_pose: Position and orientation displacements. Shape is (N, 6).
+ eps: The tolerance to consider orientation displacement as zero. Defaults to 1.0e-6.
+
+ Returns:
+ A tuple containing the displaced position and orientation frames.
+ Shape of the tensors are (N, 3) and (N, 4) respectively.
+ """
+ # number of poses given
+ num_poses = source_pos.shape[0]
+ device = source_pos.device
+
+ # interpret delta_pose[:, 0:3] as target position displacements
+ target_pos = source_pos + delta_pose[:, 0:3]
+ # interpret delta_pose[:, 3:6] as target rotation displacements
+ rot_actions = delta_pose[:, 3:6]
+ angle = torch.linalg.vector_norm(rot_actions, dim=1)
+ axis = rot_actions / angle.unsqueeze(-1)
+ # change from axis-angle to quat convention
+ identity_quat = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device).repeat(num_poses, 1)
+ rot_delta_quat = torch.where(
+ angle.unsqueeze(-1).repeat(1, 4) > eps, quat_from_angle_axis(angle, axis), identity_quat
+ )
+ # TODO: Check if this is the correct order for this multiplication.
+ target_rot = quat_mul(rot_delta_quat, source_rot)
+
+ return target_pos, target_rot
+
+
+# @torch.jit.script
+
+
+
+"""
+Projection operations.
+"""
+
+
+@torch.jit.script
+def orthogonalize_perspective_depth(depth: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
+ """Converts perspective depth image to orthogonal depth image.
+
+ Perspective depth images contain distances measured from the camera's optical center.
+ Meanwhile, orthogonal depth images provide the distance from the camera's image plane.
+ This method uses the camera geometry to convert perspective depth to orthogonal depth image.
+
+ The function assumes that the width and height are both greater than 1.
+
+ Args:
+ depth: The perspective depth images. Shape is (H, W) or or (H, W, 1) or (N, H, W) or (N, H, W, 1).
+ intrinsics: The camera's calibration matrix. If a single matrix is provided, the same
+ calibration matrix is used across all the depth images in the batch.
+ Shape is (3, 3) or (N, 3, 3).
+
+ Returns:
+ The orthogonal depth images. Shape matches the input shape of depth images.
+
+ Raises:
+ ValueError: When depth is not of shape (H, W) or (H, W, 1) or (N, H, W) or (N, H, W, 1).
+ ValueError: When intrinsics is not of shape (3, 3) or (N, 3, 3).
+ """
+ # Clone inputs to avoid in-place modifications
+ perspective_depth_batch = depth.clone()
+ intrinsics_batch = intrinsics.clone()
+
+ # Check if inputs are batched
+ is_batched = perspective_depth_batch.dim() == 4 or (
+ perspective_depth_batch.dim() == 3 and perspective_depth_batch.shape[-1] != 1
+ )
+
+ # Track whether the last dimension was singleton
+ add_last_dim = False
+ if perspective_depth_batch.dim() == 4 and perspective_depth_batch.shape[-1] == 1:
+ add_last_dim = True
+ perspective_depth_batch = perspective_depth_batch.squeeze(dim=3) # (N, H, W, 1) -> (N, H, W)
+ if perspective_depth_batch.dim() == 3 and perspective_depth_batch.shape[-1] == 1:
+ add_last_dim = True
+ perspective_depth_batch = perspective_depth_batch.squeeze(dim=2) # (H, W, 1) -> (H, W)
+
+ if perspective_depth_batch.dim() == 2:
+ perspective_depth_batch = perspective_depth_batch[None] # (H, W) -> (1, H, W)
+
+ if intrinsics_batch.dim() == 2:
+ intrinsics_batch = intrinsics_batch[None] # (3, 3) -> (1, 3, 3)
+
+ if is_batched and intrinsics_batch.shape[0] == 1:
+ intrinsics_batch = intrinsics_batch.expand(perspective_depth_batch.shape[0], -1, -1) # (1, 3, 3) -> (N, 3, 3)
+
+ # Validate input shapes
+ if perspective_depth_batch.dim() != 3:
+ raise ValueError(f"Expected depth images to have 2, 3, or 4 dimensions; got {depth.shape}.")
+ if intrinsics_batch.dim() != 3:
+ raise ValueError(f"Expected intrinsics to have shape (3, 3) or (N, 3, 3); got {intrinsics.shape}.")
+
+ # Image dimensions
+ im_height, im_width = perspective_depth_batch.shape[1:]
+
+ # Get the intrinsics parameters
+ fx = intrinsics_batch[:, 0, 0].view(-1, 1, 1)
+ fy = intrinsics_batch[:, 1, 1].view(-1, 1, 1)
+ cx = intrinsics_batch[:, 0, 2].view(-1, 1, 1)
+ cy = intrinsics_batch[:, 1, 2].view(-1, 1, 1)
+
+ # Create meshgrid of pixel coordinates
+ u_grid = torch.arange(im_width, device=depth.device, dtype=depth.dtype)
+ v_grid = torch.arange(im_height, device=depth.device, dtype=depth.dtype)
+ u_grid, v_grid = torch.meshgrid(u_grid, v_grid, indexing="xy")
+
+ # Expand the grids for batch processing
+ u_grid = u_grid.unsqueeze(0).expand(perspective_depth_batch.shape[0], -1, -1)
+ v_grid = v_grid.unsqueeze(0).expand(perspective_depth_batch.shape[0], -1, -1)
+
+ # Compute the squared terms for efficiency
+ x_term = ((u_grid - cx) / fx) ** 2
+ y_term = ((v_grid - cy) / fy) ** 2
+
+ # Calculate the orthogonal (normal) depth
+ orthogonal_depth = perspective_depth_batch / torch.sqrt(1 + x_term + y_term)
+
+ # Restore the last dimension if it was present in the input
+ if add_last_dim:
+ orthogonal_depth = orthogonal_depth.unsqueeze(-1)
+
+ # Return to original shape if input was not batched
+ if not is_batched:
+ orthogonal_depth = orthogonal_depth.squeeze(0)
+
+ return orthogonal_depth
+
+
+@torch.jit.script
+def unproject_depth(depth: torch.Tensor, intrinsics: torch.Tensor, is_ortho: bool = True) -> torch.Tensor:
+ r"""Un-project depth image into a pointcloud.
+
+ This function converts orthogonal or perspective depth images into points given the calibration matrix
+ of the camera. It uses the following transformation based on camera geometry:
+
+ .. math::
+ p_{3D} = K^{-1} \times [u, v, 1]^T \times d
+
+ where :math:`p_{3D}` is the 3D point, :math:`d` is the depth value (measured from the image plane),
+ :math:`u` and :math:`v` are the pixel coordinates and :math:`K` is the intrinsic matrix.
+
+ The function assumes that the width and height are both greater than 1. This makes the function
+ deal with many possible shapes of depth images and intrinsics matrices.
+
+ .. note::
+ If :attr:`is_ortho` is False, the input depth images are transformed to orthogonal depth images
+ by using the :meth:`orthogonalize_perspective_depth` method.
+
+ Args:
+ depth: The depth measurement. Shape is (H, W) or or (H, W, 1) or (N, H, W) or (N, H, W, 1).
+ intrinsics: The camera's calibration matrix. If a single matrix is provided, the same
+ calibration matrix is used across all the depth images in the batch.
+ Shape is (3, 3) or (N, 3, 3).
+ is_ortho: Whether the input depth image is orthogonal or perspective depth image. If True, the input
+ depth image is considered as the *orthogonal* type, where the measurements are from the camera's
+ image plane. If False, the depth image is considered as the *perspective* type, where the
+ measurements are from the camera's optical center. Defaults to True.
+
+ Returns:
+ The 3D coordinates of points. Shape is (P, 3) or (N, P, 3).
+
+ Raises:
+ ValueError: When depth is not of shape (H, W) or (H, W, 1) or (N, H, W) or (N, H, W, 1).
+ ValueError: When intrinsics is not of shape (3, 3) or (N, 3, 3).
+ """
+ # clone inputs to avoid in-place modifications
+ intrinsics_batch = intrinsics.clone()
+ # convert depth image to orthogonal if needed
+ if not is_ortho:
+ depth_batch = orthogonalize_perspective_depth(depth, intrinsics)
+ else:
+ depth_batch = depth.clone()
+
+ # check if inputs are batched
+ is_batched = depth_batch.dim() == 4 or (depth_batch.dim() == 3 and depth_batch.shape[-1] != 1)
+ # make sure inputs are batched
+ if depth_batch.dim() == 3 and depth_batch.shape[-1] == 1:
+ depth_batch = depth_batch.squeeze(dim=2) # (H, W, 1) -> (H, W)
+ if depth_batch.dim() == 2:
+ depth_batch = depth_batch[None] # (H, W) -> (1, H, W)
+ if depth_batch.dim() == 4 and depth_batch.shape[-1] == 1:
+ depth_batch = depth_batch.squeeze(dim=3) # (N, H, W, 1) -> (N, H, W)
+ if intrinsics_batch.dim() == 2:
+ intrinsics_batch = intrinsics_batch[None] # (3, 3) -> (1, 3, 3)
+ # check shape of inputs
+ if depth_batch.dim() != 3:
+ raise ValueError(f"Expected depth images to have dim = 2 or 3 or 4: got shape {depth.shape}")
+ if intrinsics_batch.dim() != 3:
+ raise ValueError(f"Expected intrinsics to have shape (3, 3) or (N, 3, 3): got shape {intrinsics.shape}")
+
+ # get image height and width
+ im_height, im_width = depth_batch.shape[1:]
+ # create image points in homogeneous coordinates (3, H x W)
+ indices_u = torch.arange(im_width, device=depth.device, dtype=depth.dtype)
+ indices_v = torch.arange(im_height, device=depth.device, dtype=depth.dtype)
+ img_indices = torch.stack(torch.meshgrid([indices_u, indices_v], indexing="ij"), dim=0).reshape(2, -1)
+ pixels = torch.nn.functional.pad(img_indices, (0, 0, 0, 1), mode="constant", value=1.0)
+ pixels = pixels.unsqueeze(0) # (3, H x W) -> (1, 3, H x W)
+
+ # unproject points into 3D space
+ points = torch.matmul(torch.inverse(intrinsics_batch), pixels) # (N, 3, H x W)
+ points = points / points[:, -1, :].unsqueeze(1) # normalize by last coordinate
+ # flatten depth image (N, H, W) -> (N, H x W)
+ depth_batch = depth_batch.transpose_(1, 2).reshape(depth_batch.shape[0], -1).unsqueeze(2)
+ depth_batch = depth_batch.expand(-1, -1, 3)
+ # scale points by depth
+ points_xyz = points.transpose_(1, 2) * depth_batch # (N, H x W, 3)
+
+ # return points in same shape as input
+ if not is_batched:
+ points_xyz = points_xyz.squeeze(0)
+
+ return points_xyz
+
+
+@torch.jit.script
+def project_points(points: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
+ r"""Projects 3D points into 2D image plane.
+
+ This project 3D points into a 2D image plane. The transformation is defined by the intrinsic
+ matrix of the camera.
+
+ .. math::
+
+ \begin{align}
+ p &= K \times p_{3D} = \\
+ p_{2D} &= \begin{pmatrix} u \\ v \\ d \end{pmatrix}
+ = \begin{pmatrix} p[0] / p[2] \\ p[1] / p[2] \\ Z \end{pmatrix}
+ \end{align}
+
+ where :math:`p_{2D} = (u, v, d)` is the projected 3D point, :math:`p_{3D} = (X, Y, Z)` is the
+ 3D point and :math:`K \in \mathbb{R}^{3 \times 3}` is the intrinsic matrix.
+
+ If `points` is a batch of 3D points and `intrinsics` is a single intrinsic matrix, the same
+ calibration matrix is applied to all points in the batch.
+
+ Args:
+ points: The 3D coordinates of points. Shape is (P, 3) or (N, P, 3).
+ intrinsics: Camera's calibration matrix. Shape is (3, 3) or (N, 3, 3).
+
+ Returns:
+ Projected 3D coordinates of points. Shape is (P, 3) or (N, P, 3).
+ """
+ # clone the inputs to avoid in-place operations modifying the original data
+ points_batch = points.clone()
+ intrinsics_batch = intrinsics.clone()
+
+ # check if inputs are batched
+ is_batched = points_batch.dim() == 2
+ # make sure inputs are batched
+ if points_batch.dim() == 2:
+ points_batch = points_batch[None] # (P, 3) -> (1, P, 3)
+ if intrinsics_batch.dim() == 2:
+ intrinsics_batch = intrinsics_batch[None] # (3, 3) -> (1, 3, 3)
+ # check shape of inputs
+ if points_batch.dim() != 3:
+ raise ValueError(f"Expected points to have dim = 3: got shape {points.shape}.")
+ if intrinsics_batch.dim() != 3:
+ raise ValueError(f"Expected intrinsics to have shape (3, 3) or (N, 3, 3): got shape {intrinsics.shape}.")
+
+ # project points into 2D image plane
+ points_2d = torch.matmul(intrinsics_batch, points_batch.transpose(1, 2))
+ points_2d = points_2d / points_2d[:, -1, :].unsqueeze(1) # normalize by last coordinate
+ points_2d = points_2d.transpose_(1, 2) # (N, 3, P) -> (N, P, 3)
+ # replace last coordinate with depth
+ points_2d[:, :, -1] = points_batch[:, :, -1]
+
+ # return points in same shape as input
+ if not is_batched:
+ points_2d = points_2d.squeeze(0) # (1, 3, P) -> (3, P)
+
+ return points_2d
+
+
+"""
+Sampling
+"""
+
+
+@torch.jit.script
+def default_orientation(num: int, device: str) -> torch.Tensor:
+ """Returns identity rotation transform.
+
+ Args:
+ num: The number of rotations to sample.
+ device: Device to create tensor on.
+
+ Returns:
+ Identity quaternion in (w, x, y, z). Shape is (num, 4).
+ """
+ quat = torch.zeros((num, 4), dtype=torch.float, device=device)
+ quat[..., 0] = 1.0
+
+ return quat
+
+
+@torch.jit.script
+def random_orientation(num: int, device: str) -> torch.Tensor:
+ """Returns sampled rotation in 3D as quaternion.
+
+ Args:
+ num: The number of rotations to sample.
+ device: Device to create tensor on.
+
+ Returns:
+ Sampled quaternion in (w, x, y, z). Shape is (num, 4).
+
+ Reference:
+ https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.random.html
+ """
+ # sample random orientation from normal distribution
+ quat = torch.randn((num, 4), dtype=torch.float, device=device)
+ # normalize the quaternion
+ return torch.nn.functional.normalize(quat, p=2.0, dim=-1, eps=1e-12)
+
+
+@torch.jit.script
+def random_yaw_orientation(num: int, device: str) -> torch.Tensor:
+ """Returns sampled rotation around z-axis.
+
+ Args:
+ num: The number of rotations to sample.
+ device: Device to create tensor on.
+
+ Returns:
+ Sampled quaternion in (w, x, y, z). Shape is (num, 4).
+ """
+ roll = torch.zeros(num, dtype=torch.float, device=device)
+ pitch = torch.zeros(num, dtype=torch.float, device=device)
+ yaw = 2 * torch.pi * torch.rand(num, dtype=torch.float, device=device)
+
+ return quat_from_euler_xyz(roll, pitch, yaw)
+
+
+[文档]def sample_triangle(lower: float, upper: float, size: int | tuple[int, ...], device: str) -> torch.Tensor:
+
"""Randomly samples tensor from a triangular distribution.
+
+
Args:
+
lower: The lower range of the sampled tensor.
+
upper: The upper range of the sampled tensor.
+
size: The shape of the tensor.
+
device: Device to create tensor on.
+
+
Returns:
+
Sampled tensor. Shape is based on :attr:`size`.
+
"""
+
# convert to tuple
+
if isinstance(size, int):
+
size = (size,)
+
# create random tensor in the range [-1, 1]
+
r = 2 * torch.rand(*size, device=device) - 1
+
# convert to triangular distribution
+
r = torch.where(r < 0.0, -torch.sqrt(-r), torch.sqrt(r))
+
# rescale back to [0, 1]
+
r = (r + 1.0) / 2.0
+
# rescale to range [lower, upper]
+
return (upper - lower) * r + lower
+
+
+
+
+
+
+
+
+[文档]def sample_gaussian(
+
mean: torch.Tensor | float, std: torch.Tensor | float, size: int | tuple[int, ...], device: str
+
) -> torch.Tensor:
+
"""Sample using gaussian distribution.
+
+
Args:
+
mean: Mean of the gaussian.
+
std: Std of the gaussian.
+
size: The shape of the tensor.
+
device: Device to create tensor on.
+
+
Returns:
+
Sampled tensor.
+
"""
+
if isinstance(mean, float):
+
if isinstance(size, int):
+
size = (size,)
+
return torch.normal(mean=mean, std=std, size=size).to(device=device)
+
else:
+
return torch.normal(mean=mean, std=std).to(device=device)
+
+
+[文档]def sample_cylinder(
+
radius: float, h_range: tuple[float, float], size: int | tuple[int, ...], device: str
+
) -> torch.Tensor:
+
"""Sample 3D points uniformly on a cylinder's surface.
+
+
The cylinder is centered at the origin and aligned with the z-axis. The height of the cylinder is
+
sampled uniformly from the range :obj:`h_range`, while the radius is fixed to :obj:`radius`.
+
+
The sampled points are returned as a tensor of shape :obj:`(*size, 3)`, i.e. the last dimension
+
contains the x, y, and z coordinates of the sampled points.
+
+
Args:
+
radius: The radius of the cylinder.
+
h_range: The minimum and maximum height of the cylinder.
+
size: The shape of the tensor.
+
device: Device to create tensor on.
+
+
Returns:
+
Sampled tensor. Shape is :obj:`(*size, 3)`.
+
"""
+
# sample angles
+
angles = (torch.rand(size, device=device) * 2 - 1) * torch.pi
+
h_min, h_max = h_range
+
# add shape
+
if isinstance(size, int):
+
size = (size, 3)
+
else:
+
size += (3,)
+
# allocate a tensor
+
xyz = torch.zeros(size, device=device)
+
xyz[..., 0] = radius * torch.cos(angles)
+
xyz[..., 1] = radius * torch.sin(angles)
+
xyz[..., 2].uniform_(h_min, h_max)
+
# return positions
+
return xyz
+
+
+"""
+Orientation Conversions
+"""
+
+
+[文档]def convert_camera_frame_orientation_convention(
+
orientation: torch.Tensor,
+
origin: Literal["opengl", "ros", "world"] = "opengl",
+
target: Literal["opengl", "ros", "world"] = "ros",
+
) -> torch.Tensor:
+
r"""Converts a quaternion representing a rotation from one convention to another.
+
+
In USD, the camera follows the ``"opengl"`` convention. Thus, it is always in **Y up** convention.
+
This means that the camera is looking down the -Z axis with the +Y axis pointing up , and +X axis pointing right.
+
However, in ROS, the camera is looking down the +Z axis with the +Y axis pointing down, and +X axis pointing right.
+
Thus, the camera needs to be rotated by :math:`180^{\circ}` around the X axis to follow the ROS convention.
+
+
.. math::
+
+
T_{ROS} = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & -1 & 0 & 0 \\ 0 & 0 & -1 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} T_{USD}
+
+
On the other hand, the typical world coordinate system is with +X pointing forward, +Y pointing left,
+
and +Z pointing up. The camera can also be set in this convention by rotating the camera by :math:`90^{\circ}`
+
around the X axis and :math:`-90^{\circ}` around the Y axis.
+
+
.. math::
+
+
T_{WORLD} = \begin{bmatrix} 0 & 0 & -1 & 0 \\ -1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} T_{USD}
+
+
Thus, based on their application, cameras follow different conventions for their orientation. This function
+
converts a quaternion from one convention to another.
+
+
Possible conventions are:
+
+
- :obj:`"opengl"` - forward axis: -Z - up axis +Y - Offset is applied in the OpenGL (Usd.Camera) convention
+
- :obj:`"ros"` - forward axis: +Z - up axis -Y - Offset is applied in the ROS convention
+
- :obj:`"world"` - forward axis: +X - up axis +Z - Offset is applied in the World Frame convention
+
+
Args:
+
orientation: Quaternion of form `(w, x, y, z)` with shape (..., 4) in source convention.
+
origin: Convention to convert from. Defaults to "opengl".
+
target: Convention to convert to. Defaults to "ros".
+
+
Returns:
+
Quaternion of form `(w, x, y, z)` with shape (..., 4) in target convention
+
"""
+
if target == origin:
+
return orientation.clone()
+
+
# -- unify input type
+
if origin == "ros":
+
# convert from ros to opengl convention
+
rotm = matrix_from_quat(orientation)
+
rotm[:, :, 2] = -rotm[:, :, 2]
+
rotm[:, :, 1] = -rotm[:, :, 1]
+
# convert to opengl convention
+
quat_gl = quat_from_matrix(rotm)
+
elif origin == "world":
+
# convert from world (x forward and z up) to opengl convention
+
rotm = matrix_from_quat(orientation)
+
rotm = torch.matmul(
+
rotm,
+
matrix_from_euler(torch.tensor([math.pi / 2, -math.pi / 2, 0], device=orientation.device), "XYZ"),
+
)
+
# convert to isaac-sim convention
+
quat_gl = quat_from_matrix(rotm)
+
else:
+
quat_gl = orientation
+
+
# -- convert to target convention
+
if target == "ros":
+
# convert from opengl to ros convention
+
rotm = matrix_from_quat(quat_gl)
+
rotm[:, :, 2] = -rotm[:, :, 2]
+
rotm[:, :, 1] = -rotm[:, :, 1]
+
return quat_from_matrix(rotm)
+
elif target == "world":
+
# convert from opengl to world (x forward and z up) convention
+
rotm = matrix_from_quat(quat_gl)
+
rotm = torch.matmul(
+
rotm,
+
matrix_from_euler(torch.tensor([math.pi / 2, -math.pi / 2, 0], device=orientation.device), "XYZ").T,
+
)
+
return quat_from_matrix(rotm)
+
else:
+
return quat_gl.clone()
+
+
+[文档]def create_rotation_matrix_from_view(
+
eyes: torch.Tensor,
+
targets: torch.Tensor,
+
up_axis: Literal["Y", "Z"] = "Z",
+
device: str = "cpu",
+
) -> torch.Tensor:
+
"""Compute the rotation matrix from world to view coordinates.
+
+
This function takes a vector ''eyes'' which specifies the location
+
of the camera in world coordinates and the vector ''targets'' which
+
indicate the position of the object.
+
The output is a rotation matrix representing the transformation
+
from world coordinates -> view coordinates.
+
+
The inputs eyes and targets can each be a
+
- 3 element tuple/list
+
- torch tensor of shape (1, 3)
+
- torch tensor of shape (N, 3)
+
+
Args:
+
eyes: Position of the camera in world coordinates.
+
targets: Position of the object in world coordinates.
+
up_axis: The up axis of the camera. Defaults to "Z".
+
device: The device to create torch tensors on. Defaults to "cpu".
+
+
The vectors are broadcast against each other so they all have shape (N, 3).
+
+
Returns:
+
R: (N, 3, 3) batched rotation matrices
+
+
Reference:
+
Based on PyTorch3D (https://github.com/facebookresearch/pytorch3d/blob/eaf0709d6af0025fe94d1ee7cec454bc3054826a/pytorch3d/renderer/cameras.py#L1635-L1685)
+
"""
+
if up_axis == "Y":
+
up_axis_vec = torch.tensor((0, 1, 0), device=device, dtype=torch.float32).repeat(eyes.shape[0], 1)
+
elif up_axis == "Z":
+
up_axis_vec = torch.tensor((0, 0, 1), device=device, dtype=torch.float32).repeat(eyes.shape[0], 1)
+
else:
+
raise ValueError(f"Invalid up axis: {up_axis}. Valid options are 'Y' and 'Z'.")
+
+
# get rotation matrix in opengl format (-Z forward, +Y up)
+
z_axis = -torch.nn.functional.normalize(targets - eyes, eps=1e-5)
+
x_axis = torch.nn.functional.normalize(torch.cross(up_axis_vec, z_axis, dim=1), eps=1e-5)
+
y_axis = torch.nn.functional.normalize(torch.cross(z_axis, x_axis, dim=1), eps=1e-5)
+
is_close = torch.isclose(x_axis, torch.tensor(0.0), atol=5e-3).all(dim=1, keepdim=True)
+
if is_close.any():
+
replacement = torch.nn.functional.normalize(torch.cross(y_axis, z_axis, dim=1), eps=1e-5)
+
x_axis = torch.where(is_close, replacement, x_axis)
+
R = torch.cat((x_axis[:, None, :], y_axis[:, None, :], z_axis[:, None, :]), dim=1)
+
return R.transpose(1, 2)
+
+
+