From 4ed936a82122e76fac087a8cb3d577de6149f0d7 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 14 Aug 2020 10:28:06 -0700 Subject: [PATCH] Implement a detach() helper that works on distributions (#2599) * Implement a detach() helper * Add low-level jit test --- pyro/distributions/util.py | 30 +++++++++++ pyro/infer/tracegraph_elbo.py | 6 +-- pyro/infer/util.py | 7 --- setup.cfg | 3 +- tests/distributions/test_util.py | 85 +++++++++++++++++++++++++++++++- 5 files changed, 118 insertions(+), 13 deletions(-) diff --git a/pyro/distributions/util.py b/pyro/distributions/util.py index 046b16c749..f035a30290 100644 --- a/pyro/distributions/util.py +++ b/pyro/distributions/util.py @@ -1,6 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import copy +import ctypes import functools import numbers import weakref @@ -56,6 +58,7 @@ def decorator(destin_class): return decorator +# TODO replace with weakref.WeakMethod? def weakmethod(fn): """ Decorator to enforce weak binding of a method, so as to avoid reference @@ -92,6 +95,33 @@ def weak_binder(self, new): return weak_binder +# This helper intervenes in copy.deepcopy's use of a memo dict to +# use .detatch() to copy tensors. +class _DetachMemo(dict): + def get(self, key, default=None): + result = super().get(key, default) + + if result is default: + # Assume key is the id of another object, and look up that object. + old = ctypes.cast(key, ctypes.py_object).value + if isinstance(old, torch.Tensor): + self[key] = result = old.detach() + + return result + + +def detach(obj): + """ + Create a deep copy of an object, detaching all :class:`torch.Tensor` s in + the object. No tensor data is actually copied. + + :param obj: Any python object. + :returns: A deep copy of ``obj`` with all :class:`torch.Tensor` s detached. + """ + memo = _DetachMemo() + return copy.deepcopy(obj, memo) + + def is_identically_zero(x): """ Check if argument is exactly the number zero. True for the number zero; diff --git a/pyro/infer/tracegraph_elbo.py b/pyro/infer/tracegraph_elbo.py index 9f77af7c3b..852ea6e658 100644 --- a/pyro/infer/tracegraph_elbo.py +++ b/pyro/infer/tracegraph_elbo.py @@ -8,10 +8,10 @@ import pyro import pyro.ops.jit -from pyro.distributions.util import is_identically_zero +from pyro.distributions.util import detach, is_identically_zero from pyro.infer import ELBO from pyro.infer.enum import get_importance_trace -from pyro.infer.util import (MultiFrameTensor, detach_iterable, get_plate_stacks, +from pyro.infer.util import (MultiFrameTensor, get_plate_stacks, is_validation_enabled, torch_backward, torch_item) from pyro.util import check_if_enumerated, warn_if_nan @@ -61,7 +61,7 @@ def _construct_baseline(node, guide_site, downstream_cost): baseline += avg_downstream_cost_old if use_nn_baseline: # block nn_baseline_input gradients except in baseline loss - baseline += nn_baseline(detach_iterable(nn_baseline_input)) + baseline += nn_baseline(detach(nn_baseline_input)) elif use_baseline_value: # it's on the user to make sure baseline_value tape only points to baseline params baseline += baseline_value diff --git a/pyro/infer/util.py b/pyro/infer/util.py index 635d3f8255..2d524beb13 100644 --- a/pyro/infer/util.py +++ b/pyro/infer/util.py @@ -76,13 +76,6 @@ def torch_sum(tensor, dims): return tensor.sum(dims) if dims else tensor -def detach_iterable(iterable): - if torch.is_tensor(iterable): - return iterable.detach() - else: - return [var.detach() for var in iterable] - - def zero_grads(tensors): """ Sets gradients of list of Tensors to zero in place diff --git a/setup.cfg b/setup.cfg index c5e5d829de..fad8ae19dd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,11 +1,10 @@ [flake8] max-line-length = 120 exclude = docs/src, build, dist, .ipynb_checkpoints -extend-ignore = E741 +extend-ignore = E721,E741 [isort] line_length = 120 -not_skip = __init__.py skip_glob = .ipynb_checkpoints known_first_party = pyro, tests known_third_party = opt_einsum, six, torch, torchvision diff --git a/tests/distributions/test_util.py b/tests/distributions/test_util.py index eefe2e462f..88ccb9bb4b 100644 --- a/tests/distributions/test_util.py +++ b/tests/distributions/test_util.py @@ -7,7 +7,9 @@ import pytest import torch -from pyro.distributions.util import broadcast_shape, sum_leftmost, sum_rightmost, weakmethod +import pyro.distributions as dist +from pyro.distributions.util import broadcast_shape, detach, sum_leftmost, sum_rightmost, weakmethod +from tests.common import assert_equal INF = float('inf') @@ -118,3 +120,84 @@ def _method(self, *args, **kwargs): assert foo_ref() is foo del foo assert foo_ref() is None + + +@pytest.mark.parametrize("shape", [None, (), (4,), (3, 2)], ids=str) +def test_detach_normal(shape): + loc = torch.tensor(0., requires_grad=True) + scale = torch.tensor(1., requires_grad=True) + d1 = dist.Normal(loc, scale) + if shape is not None: + d1 = d1.expand(shape) + + d2 = detach(d1) + assert type(d1) is type(d2) + assert_equal(d1.loc, d2.loc) + assert_equal(d1.scale, d2.scale) + assert not d2.loc.requires_grad + assert not d2.scale.requires_grad + + +@pytest.mark.parametrize("shape", [None, (), (4,), (3, 2)], ids=str) +def test_detach_beta(shape): + concentration1 = torch.tensor(0.5, requires_grad=True) + concentration0 = torch.tensor(2.0, requires_grad=True) + d1 = dist.Beta(concentration1, concentration0) + if shape is not None: + d1 = d1.expand(shape) + + d2 = detach(d1) + assert type(d1) is type(d2) + assert d2.batch_shape == d1.batch_shape + assert_equal(d1.concentration1, d2.concentration1) + assert_equal(d1.concentration0, d2.concentration0) + assert not d2.concentration1.requires_grad + assert not d2.concentration0.requires_grad + + +@pytest.mark.parametrize("shape", [None, (), (4,), (3, 2)], ids=str) +def test_detach_transformed(shape): + loc = torch.tensor(0., requires_grad=True) + scale = torch.tensor(1., requires_grad=True) + a = torch.tensor(2., requires_grad=True) + b = torch.tensor(3., requires_grad=True) + d1 = dist.TransformedDistribution(dist.Normal(loc, scale), + dist.transforms.AffineTransform(a, b)) + if shape is not None: + d1 = d1.expand(shape) + + d2 = detach(d1) + assert type(d1) is type(d2) + assert d2.event_shape == d1.event_shape + assert d2.batch_shape == d1.batch_shape + assert type(d1.base_dist) is type(d2.base_dist) + assert len(d1.transforms) == len(d2.transforms) + assert_equal(d1.base_dist.loc, d2.base_dist.loc) + assert_equal(d1.base_dist.scale, d2.base_dist.scale) + assert_equal(d1.transforms[0].loc, d2.transforms[0].loc) + assert_equal(d1.transforms[0].scale, d2.transforms[0].scale) + assert not d2.base_dist.loc.requires_grad + assert not d2.base_dist.scale.requires_grad + assert not d2.transforms[0].loc.requires_grad + assert not d2.transforms[0].scale.requires_grad + + +@pytest.mark.parametrize("shape", [None, (), (4,), (3, 2)], ids=str) +def test_detach_jit(shape): + loc = torch.tensor(0., requires_grad=True) + scale = torch.tensor(1., requires_grad=True) + data = torch.randn(5, 1, 1) + + def fn(loc, scale, data): + d = dist.Normal(loc, scale, validate_args=False) + if shape is not None: + d = d.expand(shape) + return detach(d).log_prob(data) + + jit_fn = torch.jit.trace(fn, (loc, scale, data)) + + expected = fn(loc, scale, data) + actual = jit_fn(loc, scale, data) + assert not expected.requires_grad + assert not actual.requires_grad + assert_equal(actual, expected)