diff --git a/docs/source/func.rst b/docs/source/func.rst new file mode 100644 index 000000000..ce38a0f03 --- /dev/null +++ b/docs/source/func.rst @@ -0,0 +1,10 @@ +Func +======= + +.. currentmodule:: pypose + +.. autosummary:: + :toctree: generated + :nosignatures: + + func.jacrev diff --git a/docs/source/index.rst b/docs/source/index.rst index ec5a1f7d8..3a482ce94 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -8,6 +8,7 @@ PyPose Documentation :caption: Contents: lietensor + func functions convert modules diff --git a/pypose/func/__init__.py b/pypose/func/__init__.py index 35608f939..cd6602cb2 100644 --- a/pypose/func/__init__.py +++ b/pypose/func/__init__.py @@ -1,24 +1 @@ -from typing import Callable, Union, Tuple, List, Any, Optional -import torch -from functools import partial, wraps -from torch._functorch.vmap import vmap, doesnt_support_saved_tensors_hooks, get_chunk_sizes -from torch._functorch.eager_transforms import error_if_complex, _slice_argnums, \ - _chunked_standard_basis_for_, _safe_zero_index, _vjp_with_argnums -from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map, tree_map_only - -from pypose.lietensor.lietensor import retain_ltype - - -def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False, - chunk_size: Optional[int] = None, - _preallocate_and_copy=False): - """ - This function provides the exact same functionality as `torch.func.jacrev`, except - that it allows LieTensor to be used as input. - """ - jac_func = torch.func.jacrev(func, argnums, has_aux=has_aux, chunk_size=chunk_size, - _preallocate_and_copy=_preallocate_and_copy) - @retain_ltype() - def wrapper_fn(*args, **kwargs): - return jac_func(*args, **kwargs) - return wrapper_fn +from .jac import jacrev diff --git a/pypose/func/jac.py b/pypose/func/jac.py new file mode 100644 index 000000000..4a5ca2ea4 --- /dev/null +++ b/pypose/func/jac.py @@ -0,0 +1,58 @@ +import torch +from typing import Callable, Union, Tuple, Optional +from pypose.lietensor.lietensor import retain_ltype + + +def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False, + chunk_size: Optional[int] = None, + _preallocate_and_copy=False): + r""" + This function provides the exact same functionality as `torch.func.jacrev() + `_, + except that it allows LieTensor to be used as input when calculating the jacobian. + + Args: + func (function): A Python function that takes one or more arguments, + one of which must be a Tensor, and returns one or more Tensors + argnums (int or Tuple[int]): Optional, integer or tuple of integers, + saying which arguments to get the Jacobian with respect to. + Default: 0. + has_aux (bool): Flag indicating that ``func`` returns a + ``(output, aux)`` tuple where the first element is the output of + the function to be differentiated and the second element is + auxiliary objects that will not be differentiated. + Default: False. + chunk_size (None or int): If None (default), use the maximum chunk size + (equivalent to doing a single vmap over vjp to compute the jacobian). + If 1, then compute the jacobian row-by-row with a for-loop. + If not None, then compute the jacobian :attr:`chunk_size` rows at a time + (equivalent to doing multiple vmap over vjp). If you run into memory issues computing + the jacobian, please try to specify a non-None chunk_size. + + Returns: + Returns a function that takes in the same inputs as ``func`` and + returns the Jacobian of ``func`` with respect to the arg(s) at + ``argnums``. If ``has_aux is True``, then the returned function + instead returns a ``(jacobian, aux)`` tuple where ``jacobian`` + is the Jacobian and ``aux`` is auxiliary objects returned by ``func``. + + A basic usage with our LieTensor type would be the transformation function. + + >>> import pypose as pp + >>> import torch + >>> def func(pose, points): + ... return pose @ points + >>> pose = pp.randn_SE3(1) + >>> points = torch.randn(1, 3) + >>> jacobian = pp.func.jacrev(func)(pose, points) + >>> jacobian + tensor([[[[ 1.0000, 0.0000, 0.0000, 0.0000, 1.5874, -0.2061, 0.0000]], + [[ 0.0000, 1.0000, 0.0000, -1.5874, 0.0000, -1.4273, 0.0000]], + [[ 0.0000, 0.0000, 1.0000, 0.2061, 1.4273, 0.0000, 0.0000]]]]) + """ + jac_func = torch.func.jacrev(func, argnums, has_aux=has_aux, chunk_size=chunk_size, + _preallocate_and_copy=_preallocate_and_copy) + @retain_ltype() + def wrapper_fn(*args, **kwargs): + return jac_func(*args, **kwargs) + return wrapper_fn