Skip to content

Commit

Permalink
add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
zitongzhan committed Aug 7, 2023
1 parent d6663c2 commit faebd69
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 24 deletions.
10 changes: 10 additions & 0 deletions docs/source/func.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Func
=======

.. currentmodule:: pypose

.. autosummary::
:toctree: generated
:nosignatures:

func.jacrev
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ PyPose Documentation
:caption: Contents:

lietensor
func
functions
convert
modules
Expand Down
25 changes: 1 addition & 24 deletions pypose/func/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1 @@
from typing import Callable, Union, Tuple, List, Any, Optional
import torch
from functools import partial, wraps
from torch._functorch.vmap import vmap, doesnt_support_saved_tensors_hooks, get_chunk_sizes
from torch._functorch.eager_transforms import error_if_complex, _slice_argnums, \
_chunked_standard_basis_for_, _safe_zero_index, _vjp_with_argnums
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map, tree_map_only

from pypose.lietensor.lietensor import retain_ltype


def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False,
chunk_size: Optional[int] = None,
_preallocate_and_copy=False):
"""
This function provides the exact same functionality as `torch.func.jacrev`, except
that it allows LieTensor to be used as input.
"""
jac_func = torch.func.jacrev(func, argnums, has_aux=has_aux, chunk_size=chunk_size,
_preallocate_and_copy=_preallocate_and_copy)
@retain_ltype()
def wrapper_fn(*args, **kwargs):
return jac_func(*args, **kwargs)
return wrapper_fn
from .jac import jacrev
58 changes: 58 additions & 0 deletions pypose/func/jac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
from typing import Callable, Union, Tuple, Optional
from pypose.lietensor.lietensor import retain_ltype


def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False,
chunk_size: Optional[int] = None,
_preallocate_and_copy=False):
r"""
This function provides the exact same functionality as `torch.func.jacrev()
<https://pytorch.org/docs/stable/generated/torch.func.jacrev.html#torch.func.jacrev>`_,
except that it allows LieTensor to be used as input when calculating the jacobian.
Args:
func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
argnums (int or Tuple[int]): Optional, integer or tuple of integers,
saying which arguments to get the Jacobian with respect to.
Default: 0.
has_aux (bool): Flag indicating that ``func`` returns a
``(output, aux)`` tuple where the first element is the output of
the function to be differentiated and the second element is
auxiliary objects that will not be differentiated.
Default: False.
chunk_size (None or int): If None (default), use the maximum chunk size
(equivalent to doing a single vmap over vjp to compute the jacobian).
If 1, then compute the jacobian row-by-row with a for-loop.
If not None, then compute the jacobian :attr:`chunk_size` rows at a time
(equivalent to doing multiple vmap over vjp). If you run into memory issues computing
the jacobian, please try to specify a non-None chunk_size.
Returns:
Returns a function that takes in the same inputs as ``func`` and
returns the Jacobian of ``func`` with respect to the arg(s) at
``argnums``. If ``has_aux is True``, then the returned function
instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
is the Jacobian and ``aux`` is auxiliary objects returned by ``func``.
A basic usage with our LieTensor type would be the transformation function.
>>> import pypose as pp
>>> import torch
>>> def func(pose, points):
... return pose @ points
>>> pose = pp.randn_SE3(1)
>>> points = torch.randn(1, 3)
>>> jacobian = pp.func.jacrev(func)(pose, points)
>>> jacobian
tensor([[[[ 1.0000, 0.0000, 0.0000, 0.0000, 1.5874, -0.2061, 0.0000]],
[[ 0.0000, 1.0000, 0.0000, -1.5874, 0.0000, -1.4273, 0.0000]],
[[ 0.0000, 0.0000, 1.0000, 0.2061, 1.4273, 0.0000, 0.0000]]]])
"""
jac_func = torch.func.jacrev(func, argnums, has_aux=has_aux, chunk_size=chunk_size,
_preallocate_and_copy=_preallocate_and_copy)
@retain_ltype()
def wrapper_fn(*args, **kwargs):
return jac_func(*args, **kwargs)
return wrapper_fn

0 comments on commit faebd69

Please sign in to comment.