diff --git a/pypose/__init__.py b/pypose/__init__.py index bb6957912..121f6dbbd 100644 --- a/pypose/__init__.py +++ b/pypose/__init__.py @@ -12,7 +12,8 @@ from .lietensor import Sim3_type, sim3_type, RxSO3_type, rxso3_type from .lietensor import tensor, translation, rotation, scale, matrix, euler from .lietensor import mat2SO3, mat2SE3, mat2Sim3, mat2RxSO3, from_matrix, matrix, euler2SO3, vec2skew -from .func import * +from .lietensor.lietensor import retain_ltype +from . import func from .function import * from .basics import * from . import module diff --git a/pypose/func/jac.py b/pypose/func/jac.py index 4a5ca2ea4..29e95312c 100644 --- a/pypose/func/jac.py +++ b/pypose/func/jac.py @@ -1,6 +1,6 @@ import torch +from .. import retain_ltype from typing import Callable, Union, Tuple, Optional -from pypose.lietensor.lietensor import retain_ltype def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False, @@ -36,7 +36,7 @@ def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False instead returns a ``(jacobian, aux)`` tuple where ``jacobian`` is the Jacobian and ``aux`` is auxiliary objects returned by ``func``. - A basic usage with our LieTensor type would be the transformation function. + A basic usage with our LieTensor type would be the transformation function. >>> import pypose as pp >>> import torch diff --git a/pypose/lietensor/lietensor.py b/pypose/lietensor/lietensor.py index 44bc02c64..42a818841 100644 --- a/pypose/lietensor/lietensor.py +++ b/pypose/lietensor/lietensor.py @@ -1,9 +1,9 @@ import torch from torch import nn -from contextlib import contextmanager from .basics import vec2skew -import collections, numbers, warnings, importlib +from contextlib import contextmanager from .operation import broadcast_inputs +import collections, numbers, warnings, importlib from torch.utils._pytree import tree_map, tree_flatten from .operation import SO3_Log, SE3_Log, RxSO3_Log, Sim3_Log from .operation import so3_Exp, se3_Exp, rxso3_Exp, sim3_Exp diff --git a/tests/optim/test_jacobian.py b/tests/optim/test_jacobian.py index e85852ca1..6fdd751c3 100644 --- a/tests/optim/test_jacobian.py +++ b/tests/optim/test_jacobian.py @@ -4,12 +4,11 @@ import pypose as pp from torch import nn from contextlib import contextmanager +from typing import Collection, Callable from torch.utils._pytree import tree_map from torch.autograd.functional import jacobian from torch.func import functional_call, jacfwd, jacrev -from typing import Collection, Callable -from pypose.lietensor.lietensor import retain_ltype device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -169,7 +168,7 @@ def assert_fn_equal(func1, func2): } with check_fn_equal(TO_BE_CHECKED): - with retain_ltype(): + with pp.retain_ltype(): jac_func = jacrev(func) jac = jac_func(pose, points) assert not pp.hasnan(jac)