-
-
Notifications
You must be signed in to change notification settings - Fork 987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement a detach() helper that works on distributions #2599
Conversation
|
||
[isort] | ||
line_length = 120 | ||
not_skip = __init__.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixes a new error emitted by isort 5.0
@@ -1,11 +1,10 @@ | |||
[flake8] | |||
max-line-length = 120 | |||
exclude = docs/src, build, dist, .ipynb_checkpoints | |||
extend-ignore = E741 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ignore errors about test assertions assert type(x) is type(y)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat! Is this exercised by any of the existing JitTraceGraph_ELBO
tests?
No, the |
OK, I've added a low-level jit test to check that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Addresses #2598, pytorch/pytorch#25783
This implements a function
detach(-)
that creates a deep copy of an arbitrary Python object but with all tensors detached (and without copying underlying tensor data). Then main use case is in inference algorithms that require mixed detachment of gradients, e.g.d = MyDist(param)
followed byd.rsample()
butdetach(d).log_prob(x)
.This works by overriding
deepcopy
's behavior to use a modified memoizer dict. This seems like the cleanest way to apply arbitrary transforms on leaves of python objects (similar to JAX's PyTree). We could in principle use this trick for other leaf transforms.Tested
TraceGraph_ELBO
to use the new helper; should be covered by existing tests.cc @iffsid