-
-
Notifications
You must be signed in to change notification settings - Fork 987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pyro.util.deep_to() to recursively call .to() on Python data structures #2918
Conversation
Will this work on funsor terms? |
I doubt this will work on Funsors due to Funsor's use of hash consing. But I'm pretty sure don't need this in Funsor because Funsor instead provides |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be really useful!
pyro/distributions/util.py
Outdated
if result is default: | ||
# Assume key is the id of another object, and look up that object. | ||
old = ctypes.cast(key, ctypes.py_object).value | ||
if isinstance(old, (torch.Tensor, torch.nn.Module)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test exercising the nn.Module
path? Also, will this work as expected on PyroModule
s with PyroSample
s and PyroParam
s?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One common use case has been distributions inside of module objects where doing .cuda()
doesn't work on the distributions objects. deep_to(nn, device='cuda')
should work great, but will be nice to add a test that exercises this path by changing the dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I believe distribution attributes of modules will be bypassed, since we assume Module
objects handle themselves via torch.nn.Module.to()
. I'll have to think more about this...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I've added tests for nn.Module
, PyroModule
, PyroParam
, and PyroSample
. The only way I could get this to work was to perform the behavior out-of-place, by deep-copying the nn.Module
. IMO this is the best we can do, given PyTorch's early design choices about behavior of .to()
on tensors versus modules 🤷
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with new tests
Addresses pytorch/pytorch#7795
Addresses pytorch/pytorch#17234
Many users of torch.distributions claim to want a
.to()
method onDistribution
objects. What they really want is a generaldeep_to()
function that operates on arbitrary Python data structures. This is advantageous to users because they can use it on any Python data structure. This is advantageous to developers and maintainers of PyTorch and Pyro because they can avoid interface bloat. See jax.device_put() for similar functionality in JAX.This PR implements a
deep_to()
helper based oncopy.deepcopy()
, the same trick used in #2599.Tested