diff --git a/custom/threestudio-dreammesh4d/utils/dual_quaternions.py b/custom/threestudio-dreammesh4d/utils/dual_quaternions.py index 27dc810..edd4d01 100644 --- a/custom/threestudio-dreammesh4d/utils/dual_quaternions.py +++ b/custom/threestudio-dreammesh4d/utils/dual_quaternions.py @@ -6,8 +6,10 @@ import json import torch import pypose as pp -from pypose.lietensor.lietensor import SO3Type as Quaternion +# from pypose.lietensor.lietensor import SO3Type as Quaternion from pypose.lietensor.lietensor import LieType, SO3Type +from pyquaternion import Quaternion + def quat_norm(quat): norm = quat.norm(dim=-1, keepdim=True) @@ -16,6 +18,9 @@ def quat_norm(quat): def quat_conjugate(quat: Quaternion): return quat.Inv() +def quat_add(quat_1: Quaternion, quat_2: Quaternion): + return pp.SO3(quat_1.tensor() + quat_2.tensor()) + class DualQuaternion(object): def __init__(self, q_r: Quaternion, q_d: Quaternion, normalize=False): @@ -32,7 +37,7 @@ def __init__(self, q_r: Quaternion, q_d: Quaternion, normalize=False): def __mul__(self, other): q_r_prod = self.q_r * other.q_r - q_d_prod = self.q_r * other.q_d + self.q_d * other.q_r + q_d_prod = quat_add(self.q_r * other.q_d, self.q_d * other.q_r) return DualQuaternion(q_r_prod, q_d_prod) def __imul__(self, other): @@ -51,7 +56,8 @@ def __truediv__(self, other): return DualQuaternion(prod_r, prod_d) def __add__(self, other): - return DualQuaternion(self.q_r + other.q_r, self.q_d + other.q_d) + # return DualQuaternion(self.q_r + other.q_r, self.q_d + other.q_d) + return DualQuaternion(quat_add(self.q_r, other.q_r), quat_add(self.q_d, other.q_d)) def __eq__(self, other): return (self.q_r == other.q_r or self.q_r == -other.q_r) \ @@ -78,8 +84,12 @@ def transform_point(self, point_xyz): ) dq_point = DualQuaternion.from_dq_array(dq_point_array) res_dq = self * dq_point * self.quaternion_conjugate() - - return res_dq.translation + p = res_dq.q_d.tensor()[..., :3] + + # following https://github.com/neka-nat/dq3d/blob/master/dq3d/DualQuaternion.h#L269 + # add translation to the transformed point + p += self.translation + return p def transform_point_simple(self, point_xyz): """ @@ -139,6 +149,10 @@ def identity(cls, dq_size): dual_part = pp.SO3(torch.zeros_like(real_part)) return cls(real_part, dual_part) + @property + def conjugate(self): + return self.quaternion_conjugate() + def quaternion_conjugate(self): return DualQuaternion(quat_conjugate(self.q_r), quat_conjugate(self.q_d)) @@ -154,7 +168,8 @@ def inverse(self): def is_normalized(self): return self.q_r.norm(dim=-1).allclose(torch.as_tensor(1.)) and \ - (self.q_r * quat_conjugate(self.q_d) + self.q_d * quat_conjugate(self.q_r)).allclose(torch.as_tensor(0.)) + quat_add(self.q_r * quat_conjugate(self.q_d), self.q_d * quat_conjugate(self.q_r)).allclose(torch.as_tensor(0.)) + # ((self.q_r * quat_conjugate(self.q_d)).tensor() + (self.q_d * quat_conjugate(self.q_r)).tensor()).allclose(torch.as_tensor(0.)) def normalize(self): """ @@ -242,13 +257,13 @@ def homogeneous_matrix(self): h_mat = self.q_r.matrix() # 3x3 h_mat = torch.cat([h_mat, self.translation.unsqueeze(-1)], dim=-1) h_mat = torch.cat([h_mat, torch.zeros_like(h_mat[..., :1, :])], dim=-2) - h_mat[..., -1] = 1 + h_mat[..., -1, -1] = 1 return h_mat def quat_pose_array(self): return torch.cat([self.q_r.tensor(), self.translation], dim=-1) def dq_array(self): - return torch.cat([self.q_r, self.q_d], dim=-1) + return torch.cat([self.q_r.tensor(), self.q_d.tensor()], dim=-1)