forked from pypose/pypose
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d6663c2
commit faebd69
Showing
4 changed files
with
70 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
Func | ||
======= | ||
|
||
.. currentmodule:: pypose | ||
|
||
.. autosummary:: | ||
:toctree: generated | ||
:nosignatures: | ||
|
||
func.jacrev |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ PyPose Documentation | |
:caption: Contents: | ||
|
||
lietensor | ||
func | ||
functions | ||
convert | ||
modules | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1 @@ | ||
from typing import Callable, Union, Tuple, List, Any, Optional | ||
import torch | ||
from functools import partial, wraps | ||
from torch._functorch.vmap import vmap, doesnt_support_saved_tensors_hooks, get_chunk_sizes | ||
from torch._functorch.eager_transforms import error_if_complex, _slice_argnums, \ | ||
_chunked_standard_basis_for_, _safe_zero_index, _vjp_with_argnums | ||
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map, tree_map_only | ||
|
||
from pypose.lietensor.lietensor import retain_ltype | ||
|
||
|
||
def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False, | ||
chunk_size: Optional[int] = None, | ||
_preallocate_and_copy=False): | ||
""" | ||
This function provides the exact same functionality as `torch.func.jacrev`, except | ||
that it allows LieTensor to be used as input. | ||
""" | ||
jac_func = torch.func.jacrev(func, argnums, has_aux=has_aux, chunk_size=chunk_size, | ||
_preallocate_and_copy=_preallocate_and_copy) | ||
@retain_ltype() | ||
def wrapper_fn(*args, **kwargs): | ||
return jac_func(*args, **kwargs) | ||
return wrapper_fn | ||
from .jac import jacrev |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import torch | ||
from typing import Callable, Union, Tuple, Optional | ||
from pypose.lietensor.lietensor import retain_ltype | ||
|
||
|
||
def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False, | ||
chunk_size: Optional[int] = None, | ||
_preallocate_and_copy=False): | ||
r""" | ||
This function provides the exact same functionality as `torch.func.jacrev() | ||
<https://pytorch.org/docs/stable/generated/torch.func.jacrev.html#torch.func.jacrev>`_, | ||
except that it allows LieTensor to be used as input when calculating the jacobian. | ||
Args: | ||
func (function): A Python function that takes one or more arguments, | ||
one of which must be a Tensor, and returns one or more Tensors | ||
argnums (int or Tuple[int]): Optional, integer or tuple of integers, | ||
saying which arguments to get the Jacobian with respect to. | ||
Default: 0. | ||
has_aux (bool): Flag indicating that ``func`` returns a | ||
``(output, aux)`` tuple where the first element is the output of | ||
the function to be differentiated and the second element is | ||
auxiliary objects that will not be differentiated. | ||
Default: False. | ||
chunk_size (None or int): If None (default), use the maximum chunk size | ||
(equivalent to doing a single vmap over vjp to compute the jacobian). | ||
If 1, then compute the jacobian row-by-row with a for-loop. | ||
If not None, then compute the jacobian :attr:`chunk_size` rows at a time | ||
(equivalent to doing multiple vmap over vjp). If you run into memory issues computing | ||
the jacobian, please try to specify a non-None chunk_size. | ||
Returns: | ||
Returns a function that takes in the same inputs as ``func`` and | ||
returns the Jacobian of ``func`` with respect to the arg(s) at | ||
``argnums``. If ``has_aux is True``, then the returned function | ||
instead returns a ``(jacobian, aux)`` tuple where ``jacobian`` | ||
is the Jacobian and ``aux`` is auxiliary objects returned by ``func``. | ||
A basic usage with our LieTensor type would be the transformation function. | ||
>>> import pypose as pp | ||
>>> import torch | ||
>>> def func(pose, points): | ||
... return pose @ points | ||
>>> pose = pp.randn_SE3(1) | ||
>>> points = torch.randn(1, 3) | ||
>>> jacobian = pp.func.jacrev(func)(pose, points) | ||
>>> jacobian | ||
tensor([[[[ 1.0000, 0.0000, 0.0000, 0.0000, 1.5874, -0.2061, 0.0000]], | ||
[[ 0.0000, 1.0000, 0.0000, -1.5874, 0.0000, -1.4273, 0.0000]], | ||
[[ 0.0000, 0.0000, 1.0000, 0.2061, 1.4273, 0.0000, 0.0000]]]]) | ||
""" | ||
jac_func = torch.func.jacrev(func, argnums, has_aux=has_aux, chunk_size=chunk_size, | ||
_preallocate_and_copy=_preallocate_and_copy) | ||
@retain_ltype() | ||
def wrapper_fn(*args, **kwargs): | ||
return jac_func(*args, **kwargs) | ||
return wrapper_fn |