Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a detach() helper that works on distributions #2599

Merged
merged 2 commits into from
Aug 14, 2020
Merged

Implement a detach() helper that works on distributions #2599

merged 2 commits into from
Aug 14, 2020

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Aug 13, 2020

Addresses #2598, pytorch/pytorch#25783

This implements a function detach(-) that creates a deep copy of an arbitrary Python object but with all tensors detached (and without copying underlying tensor data). Then main use case is in inference algorithms that require mixed detachment of gradients, e.g. d = MyDist(param) followed by d.rsample() but detach(d).log_prob(x).

This works by overriding deepcopy's behavior to use a modified memoizer dict. This seems like the cleanest way to apply arbitrary transforms on leaves of python objects (similar to JAX's PyTree). We could in principle use this trick for other leaf transforms.

Tested

  • added unit tests
  • refactored TraceGraph_ELBO to use the new helper; should be covered by existing tests.

cc @iffsid


[isort]
line_length = 120
not_skip = __init__.py
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixes a new error emitted by isort 5.0

@@ -1,11 +1,10 @@
[flake8]
max-line-length = 120
exclude = docs/src, build, dist, .ipynb_checkpoints
extend-ignore = E741
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignore errors about test assertions assert type(x) is type(y)

@fritzo fritzo requested a review from eb8680 August 13, 2020 20:44
Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat! Is this exercised by any of the existing JitTraceGraph_ELBO tests?

@fritzo
Copy link
Member Author

fritzo commented Aug 14, 2020

Is this exercised by any of the existing JitTraceGraph_ELBO tests?

No, the JitTraceGraph_ELBO tests seem low value relative to planned follow-up tests that use detach(-) in RenyiELBO and a new IWAE_ELBO #2598

@fritzo
Copy link
Member Author

fritzo commented Aug 14, 2020

OK, I've added a low-level jit test to check that detach(-) can be called in jitted code.

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants