Skip to content

Commit

Permalink
add testing.assert_close
Browse files Browse the repository at this point in the history
* fix bugs in test_spline.py

Signed-off-by: Xiangyu Chen <[email protected]>

* add testing

Signed-off-by: Xiangyu Chen <[email protected]>

* add testing.rst

Signed-off-by: Xiangyu Chen <[email protected]>

* add testing

Signed-off-by: Xiangyu Chen <[email protected]>

* Update __init__.py

Signed-off-by: Xiangyu Chen <[email protected]>

* Update test_spline.py

Signed-off-by: Xiangyu Chen <[email protected]>

* Update test_spline.py

Signed-off-by: Xiangyu Chen <[email protected]>

* rename the assert_clase

* Update test_spline.py

---------

Signed-off-by: Xiangyu Chen <[email protected]>
Co-authored-by: Chen Wang <[email protected]>
Co-authored-by: Chen Wang <[email protected]>
  • Loading branch information
3 people authored Aug 19, 2023
1 parent 6421ae2 commit c71c6fb
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 36 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ PyPose Documentation
modules
optim
utils
testing


Indices and tables
Expand Down
11 changes: 11 additions & 0 deletions docs/source/testing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Testing
=========


.. currentmodule:: pypose
.. autosummary::
:toctree: generated
:nosignatures:

:template: autosummary/class-no-inherit.rst
testing.assert_close
1 change: 1 addition & 0 deletions pypose/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .basics import *
from . import module
from . import optim
from . import testing


min_torch = '2.0'
Expand Down
10 changes: 5 additions & 5 deletions pypose/basics/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def pm(input, *, out=None):
Args:
input (:obj:`Tensor`): the input tensor.
Return:
:obj:`Tensor`: the output tensor contains only :math:`-1` or :math:`+1`.
Expand Down Expand Up @@ -103,7 +103,7 @@ def cummul(input, dim, left = True):
.. math::
y_i = x_1 * x_2 * \cdots * x_i,
where :math:`x_i,~y_i` are the :math:`i`-th LieType item along the :obj:`dim`
dimension of input and output, respectively.
Expand All @@ -122,7 +122,7 @@ def cummul(input, dim, left = True):
:math:`N` is the LieTensor size along the :obj:`dim` dimension.
Example:
* Left multiplication with :math:`\text{input} \in` :obj:`SE3`
>>> input = pp.randn_SE3(2)
Expand All @@ -141,7 +141,7 @@ def cummul(input, dim, left = True):
"""
if left:
return cumops(input, dim, lambda a, b : a * b)
else:
else:
return cumops(input, dim, lambda a, b : b * a)


Expand All @@ -152,7 +152,7 @@ def cumprod(input, dim, left = True):
.. math::
y_i = x_i ~\times~ x_{i-1} ~\times~ \cdots ~\times~ x_1,
* Right product:
.. math::
Expand Down
1 change: 1 addition & 0 deletions pypose/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .comparison import *
42 changes: 42 additions & 0 deletions pypose/testing/comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
from ..function.checking import is_lietensor


def assert_close(actual, expected, *args, **kwargs):
'''
Asserts that ``actual`` and ``expected`` are close. This function is exactly the same
with `torch.testing.assert_close <https://tinyurl.com/3bm33ps7>`_ except for that it
also accepts ``pypose.LieTensor``.
Args:
actual(:obj:`Tensor` or :obj:`LieTensor`): Actual input.
expected(:obj:`Tensor` or :obj:`LieTensor`): Expected input.
rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be
specified. If omitted, default values based on the :attr:`~torch.Tensor.dtype`
are selected with the below table.
atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be
specified. If omitted, default values based on the :attr:`~torch.Tensor.dtype`
are selected with the below table.
If :math:`T_e` and :math:`T_a` are Lietensor, they are considered close if
:math:`T_e*T_a^{-1} = \mathbf{O}`, where :math:`\mathbf{O}` is close to zero tensor in
the sense of ``torch.testing.assert_close`` is ``True.``
Warning:
The prerequisites for the other arguments align precisely with
`torch.testing.assert_close <https://tinyurl.com/3bm33ps7>`_. Kindly consult it
for further details.
Examples:
>>> import pypose as pp
>>> actual = pp.randn_SE3(3)
>>> expected = actual.Log().Exp()
>>> pp.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5)
'''
if is_lietensor(actual) and is_lietensor(expected):
source = (actual.Inv() @ expected).Log().tensor()
target = torch.zeros_like(source)
else:
source, target = actual, expected
torch.testing.assert_close(source, target, *args, **kwargs)
59 changes: 28 additions & 31 deletions tests/function/test_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,51 +6,45 @@ class TestSpline:

def test_bsplilne(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# test two poses
data = pp.randn_SE3(2, device=device)
data = pp.randn_SE3(2,device=device)
poses = pp.bspline(data, 0.10, True)
torch.testing.assert_close(poses[...,[0,-1],:].translation(),
data[...,[0,-1],:].translation())
torch.testing.assert_close(poses[...,[0,-1],:].rotation(),
data[...,[0,-1],:].rotation())
pp.testing.assert_close(poses[...,[0,-1],:], data[...,[0,-1],:])

# test for multi batch
data = pp.randn_SE3(2,5, device=device)
poses = pp.bspline(data, 0.5)
assert poses.lshape[-1] == 2 * (data.lshape[-1]-3)+1
poses = pp.bspline(data, 0.5, True)
torch.testing.assert_close(poses[...,[0,-1],:].translation(),
data[...,[0,-1],:].translation())
torch.testing.assert_close(poses[...,[0,-1],:].rotation(),
data[...,[0,-1],:].rotation())
pp.testing.assert_close(poses[...,[0,-1],:], data[...,[0,-1],:])

# test for high dimension
data = pp.randn_SE3(2, 3, 4, device=device)
poses = pp.bspline(data, 0.20)
data = pp.randn_SE3(2,3,4, device=device)
poses = pp.bspline(data, 0.2)
assert poses.lshape[-1] == 5 * (data.lshape[-1]-3)+1
poses = pp.bspline(data, 0.5, True)
torch.testing.assert_close(poses[...,[0,-1],:].translation(),
data[...,[0,-1],:].translation())
torch.testing.assert_close(poses[...,[0,-1],:].rotation(),
data[...,[0,-1],:].rotation())
poses = pp.bspline(data, 0.2, True)
pp.testing.assert_close(poses[...,[0,-1],:], data[...,[0,-1],:])

data = pp.randn_SE3(2, 3, 4, device=device)
data = pp.randn_SE3(2,3,4, device=device)
poses = pp.bspline(data, 0.3)
assert poses.lshape[-1] == 4 * (data.lshape[-1]-3)+1
poses = pp.bspline(data, 0.5, True)
torch.testing.assert_close(poses[...,[0,-1],:].translation(),
data[...,[0,-1],:].translation())
torch.testing.assert_close(poses[...,[0,-1],:].rotation(),
data[...,[0,-1],:].rotation())
poses = pp.bspline(data, 0.3, True)
pp.testing.assert_close(poses[...,[0,-1],:], data[...,[0,-1],:])

def test_chspline(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# test for different point dimension
points = torch.randn(1, 2, 3, device=device)
points = torch.randn(1,2,3, device=device)
interval = .2
interpoints = pp.chspline(points, interval=interval)
num = points.shape[-2]
k = math.ceil(1.0 / interval)
index = k*torch.arange(0, num, device=device, dtype=torch.int64)
po = interpoints.index_select(-2, index)
assert (points - po).sum() < 1e-5
pp.testing.assert_close(points, po)

# test multi points
points = torch.randn(20, 3, device=device)
interval = 0.5
Expand All @@ -59,33 +53,36 @@ def test_chspline(self):
k = math.ceil(1.0 / interval)
index = k*torch.arange(0, num, device=device, dtype=torch.int64)
po = interpoints.index_select(-2, index)
assert (points - po).sum() < 1e-5
pp.testing.assert_close(points, po)

# test multi batches
points = torch.randn(3, 20, 3, device=device)
points = torch.randn(3,20,3, device=device)
interval = 0.4
interpoints = pp.chspline(points, interval=interval)
num = points.shape[-2]
k = math.ceil(1.0/interval)
index = k*torch.arange(0, num, device=device, dtype=torch.int64)
po = interpoints.index_select(-2, index)
assert (points - po).sum() < 1e-5
pp.testing.assert_close(points, po)

# test multi dim of points
points = torch.randn(2, 3, 50, 4, device=device)
points = torch.randn(2,3,50,4, device=device)
interval = 0.1
interpoints = pp.chspline(points, interval=interval)
num = points.shape[-2]
k = math.ceil(1.0 / interval)
index = k*torch.arange(0, num, device=device, dtype=torch.int64)
po = interpoints.index_select(-2, index)
assert (points - po).sum() < 1e-5
points = torch.randn(10, 2, 3, 50, 4, device=device)
pp.testing.assert_close(points, po)

points = torch.randn(10,2,3,50,4, device=device)
interval = 0.1
interpoints = pp.chspline(points, interval=interval)
num = points.shape[-2]
k = math.ceil(1.0 / interval)
index = k*torch.arange(0, num, device=device, dtype=torch.int64)
po = interpoints.index_select(-2, index)
assert (points - po).sum() < 1e-5
pp.testing.assert_close(points, po)


if __name__=="__main__":
Expand Down

0 comments on commit c71c6fb

Please sign in to comment.