Skip to content

Commit

Permalink
Merge pull request pypose#284 from pypose/lie_op_vmap
Browse files Browse the repository at this point in the history
Allow some LieTensor operation workable under `func.vmap`
  • Loading branch information
zitongzhan authored Oct 6, 2023
2 parents bcc523c + 6d0edf4 commit c2c5fbb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
6 changes: 2 additions & 4 deletions pypose/basics/ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math, torch


def pm(input, *, out=None):
def pm(input):
r'''
Returns plus or minus (:math:`\pm`) states for tensor.
Expand All @@ -21,9 +21,7 @@ def pm(input, *, out=None):
>>> pp.pm(torch.tensor([0.1, 0, -0.2], dtype=torch.float64))
tensor([ 1., 1., -1.], dtype=torch.float64)
'''
out = torch.sign(input, out=None)
out[out==0] = 1
return out
return torch.sign(torch.sign(input) * 2 + 1)


def cumops_(input, dim, ops):
Expand Down
10 changes: 5 additions & 5 deletions pypose/lietensor/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def SE3_Adj(X):


def SE3_Matrix(X):
T = torch.eye(4, device=X.device, dtype=X.dtype, requires_grad=False).repeat(X.shape[:-1]+(1, 1))
T[..., :3, :3] = SO3_Matrix(X[..., 3:])
T[..., :3, 3] = X[..., :3]
T = torch.cat([SO3_Matrix(X[..., 3:]), X[..., :3, None]], dim=-1)
E = torch.tensor([0, 0, 0, 1], dtype=T.dtype, device=T.device)
T = torch.cat([T, E.repeat(X.shape[:-1]+(1, 1))], dim=-2)
return T


Expand Down Expand Up @@ -470,7 +470,7 @@ def backward(ctx, grad_output):


class SO3_Act(torch.autograd.Function):

generate_vmap_rule = True
@staticmethod
def forward(X, p):
Xv, Xw = X[..., :3], X[..., 3:]
Expand Down Expand Up @@ -498,7 +498,7 @@ def backward(ctx, grad_output):


class SE3_Act(torch.autograd.Function):

generate_vmap_rule = True
@staticmethod
def forward(X, p):
out = X[..., :3] + SO3_Act.apply(X[..., 3:], p)
Expand Down

0 comments on commit c2c5fbb

Please sign in to comment.