Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pyro.util.deep_to() to recursively call .to() on Python data structures #2918

Merged
merged 4 commits into from
Sep 7, 2021

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Aug 24, 2021

Addresses pytorch/pytorch#7795
Addresses pytorch/pytorch#17234

Many users of torch.distributions claim to want a .to() method on Distribution objects. What they really want is a general deep_to() function that operates on arbitrary Python data structures. This is advantageous to users because they can use it on any Python data structure. This is advantageous to developers and maintainers of PyTorch and Pyro because they can avoid interface bloat. See jax.device_put() for similar functionality in JAX.

This PR implements a deep_to() helper based on copy.deepcopy(), the same trick used in #2599.

Tested

  • added unit tests

@eb8680
Copy link
Member

eb8680 commented Aug 26, 2021

Will this work on funsor terms?

@fritzo
Copy link
Member Author

fritzo commented Aug 26, 2021

Will this work on funsor terms?

I doubt this will work on Funsors due to Funsor's use of hash consing. But I'm pretty sure don't need this in Funsor because Funsor instead provides op machinery to correctly map operations over leaves. E.g. whereas Pyro needs a hacky pyro.util.detach() function (using the same trick as in this PR), Funsor can provide a cleaner, backend-extensible ops.detach() function.

@fritzo fritzo requested a review from eb8680 September 1, 2021 16:42
Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be really useful!

if result is default:
# Assume key is the id of another object, and look up that object.
old = ctypes.cast(key, ctypes.py_object).value
if isinstance(old, (torch.Tensor, torch.nn.Module)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test exercising the nn.Module path? Also, will this work as expected on PyroModules with PyroSamples and PyroParams?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One common use case has been distributions inside of module objects where doing .cuda() doesn't work on the distributions objects. deep_to(nn, device='cuda') should work great, but will be nice to add a test that exercises this path by changing the dtype.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I believe distribution attributes of modules will be bypassed, since we assume Module objects handle themselves via torch.nn.Module.to(). I'll have to think more about this...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I've added tests for nn.Module, PyroModule, PyroParam, and PyroSample. The only way I could get this to work was to perform the behavior out-of-place, by deep-copying the nn.Module. IMO this is the best we can do, given PyTorch's early design choices about behavior of .to() on tensors versus modules 🤷

@fritzo fritzo requested a review from eb8680 September 7, 2021 15:38
Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with new tests

@eb8680 eb8680 merged commit beddc1f into dev Sep 7, 2021
@eb8680 eb8680 deleted the deep-to branch September 7, 2021 16:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants