Skip to content

Commit

Permalink
Implement a detach() helper that works on distributions (#2599)
Browse files Browse the repository at this point in the history
* Implement a detach() helper

* Add low-level jit test
  • Loading branch information
fritzo authored Aug 14, 2020
1 parent 2b4a401 commit 4ed936a
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 13 deletions.
30 changes: 30 additions & 0 deletions pyro/distributions/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import copy
import ctypes
import functools
import numbers
import weakref
Expand Down Expand Up @@ -56,6 +58,7 @@ def decorator(destin_class):
return decorator


# TODO replace with weakref.WeakMethod?
def weakmethod(fn):
"""
Decorator to enforce weak binding of a method, so as to avoid reference
Expand Down Expand Up @@ -92,6 +95,33 @@ def weak_binder(self, new):
return weak_binder


# This helper intervenes in copy.deepcopy's use of a memo dict to
# use .detatch() to copy tensors.
class _DetachMemo(dict):
def get(self, key, default=None):
result = super().get(key, default)

if result is default:
# Assume key is the id of another object, and look up that object.
old = ctypes.cast(key, ctypes.py_object).value
if isinstance(old, torch.Tensor):
self[key] = result = old.detach()

return result


def detach(obj):
"""
Create a deep copy of an object, detaching all :class:`torch.Tensor` s in
the object. No tensor data is actually copied.
:param obj: Any python object.
:returns: A deep copy of ``obj`` with all :class:`torch.Tensor` s detached.
"""
memo = _DetachMemo()
return copy.deepcopy(obj, memo)


def is_identically_zero(x):
"""
Check if argument is exactly the number zero. True for the number zero;
Expand Down
6 changes: 3 additions & 3 deletions pyro/infer/tracegraph_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

import pyro
import pyro.ops.jit
from pyro.distributions.util import is_identically_zero
from pyro.distributions.util import detach, is_identically_zero
from pyro.infer import ELBO
from pyro.infer.enum import get_importance_trace
from pyro.infer.util import (MultiFrameTensor, detach_iterable, get_plate_stacks,
from pyro.infer.util import (MultiFrameTensor, get_plate_stacks,
is_validation_enabled, torch_backward, torch_item)
from pyro.util import check_if_enumerated, warn_if_nan

Expand Down Expand Up @@ -61,7 +61,7 @@ def _construct_baseline(node, guide_site, downstream_cost):
baseline += avg_downstream_cost_old
if use_nn_baseline:
# block nn_baseline_input gradients except in baseline loss
baseline += nn_baseline(detach_iterable(nn_baseline_input))
baseline += nn_baseline(detach(nn_baseline_input))
elif use_baseline_value:
# it's on the user to make sure baseline_value tape only points to baseline params
baseline += baseline_value
Expand Down
7 changes: 0 additions & 7 deletions pyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,6 @@ def torch_sum(tensor, dims):
return tensor.sum(dims) if dims else tensor


def detach_iterable(iterable):
if torch.is_tensor(iterable):
return iterable.detach()
else:
return [var.detach() for var in iterable]


def zero_grads(tensors):
"""
Sets gradients of list of Tensors to zero in place
Expand Down
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
[flake8]
max-line-length = 120
exclude = docs/src, build, dist, .ipynb_checkpoints
extend-ignore = E741
extend-ignore = E721,E741

[isort]
line_length = 120
not_skip = __init__.py
skip_glob = .ipynb_checkpoints
known_first_party = pyro, tests
known_third_party = opt_einsum, six, torch, torchvision
Expand Down
85 changes: 84 additions & 1 deletion tests/distributions/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import pytest
import torch

from pyro.distributions.util import broadcast_shape, sum_leftmost, sum_rightmost, weakmethod
import pyro.distributions as dist
from pyro.distributions.util import broadcast_shape, detach, sum_leftmost, sum_rightmost, weakmethod
from tests.common import assert_equal

INF = float('inf')

Expand Down Expand Up @@ -118,3 +120,84 @@ def _method(self, *args, **kwargs):
assert foo_ref() is foo
del foo
assert foo_ref() is None


@pytest.mark.parametrize("shape", [None, (), (4,), (3, 2)], ids=str)
def test_detach_normal(shape):
loc = torch.tensor(0., requires_grad=True)
scale = torch.tensor(1., requires_grad=True)
d1 = dist.Normal(loc, scale)
if shape is not None:
d1 = d1.expand(shape)

d2 = detach(d1)
assert type(d1) is type(d2)
assert_equal(d1.loc, d2.loc)
assert_equal(d1.scale, d2.scale)
assert not d2.loc.requires_grad
assert not d2.scale.requires_grad


@pytest.mark.parametrize("shape", [None, (), (4,), (3, 2)], ids=str)
def test_detach_beta(shape):
concentration1 = torch.tensor(0.5, requires_grad=True)
concentration0 = torch.tensor(2.0, requires_grad=True)
d1 = dist.Beta(concentration1, concentration0)
if shape is not None:
d1 = d1.expand(shape)

d2 = detach(d1)
assert type(d1) is type(d2)
assert d2.batch_shape == d1.batch_shape
assert_equal(d1.concentration1, d2.concentration1)
assert_equal(d1.concentration0, d2.concentration0)
assert not d2.concentration1.requires_grad
assert not d2.concentration0.requires_grad


@pytest.mark.parametrize("shape", [None, (), (4,), (3, 2)], ids=str)
def test_detach_transformed(shape):
loc = torch.tensor(0., requires_grad=True)
scale = torch.tensor(1., requires_grad=True)
a = torch.tensor(2., requires_grad=True)
b = torch.tensor(3., requires_grad=True)
d1 = dist.TransformedDistribution(dist.Normal(loc, scale),
dist.transforms.AffineTransform(a, b))
if shape is not None:
d1 = d1.expand(shape)

d2 = detach(d1)
assert type(d1) is type(d2)
assert d2.event_shape == d1.event_shape
assert d2.batch_shape == d1.batch_shape
assert type(d1.base_dist) is type(d2.base_dist)
assert len(d1.transforms) == len(d2.transforms)
assert_equal(d1.base_dist.loc, d2.base_dist.loc)
assert_equal(d1.base_dist.scale, d2.base_dist.scale)
assert_equal(d1.transforms[0].loc, d2.transforms[0].loc)
assert_equal(d1.transforms[0].scale, d2.transforms[0].scale)
assert not d2.base_dist.loc.requires_grad
assert not d2.base_dist.scale.requires_grad
assert not d2.transforms[0].loc.requires_grad
assert not d2.transforms[0].scale.requires_grad


@pytest.mark.parametrize("shape", [None, (), (4,), (3, 2)], ids=str)
def test_detach_jit(shape):
loc = torch.tensor(0., requires_grad=True)
scale = torch.tensor(1., requires_grad=True)
data = torch.randn(5, 1, 1)

def fn(loc, scale, data):
d = dist.Normal(loc, scale, validate_args=False)
if shape is not None:
d = d.expand(shape)
return detach(d).log_prob(data)

jit_fn = torch.jit.trace(fn, (loc, scale, data))

expected = fn(loc, scale, data)
actual = jit_fn(loc, scale, data)
assert not expected.requires_grad
assert not actual.requires_grad
assert_equal(actual, expected)

0 comments on commit 4ed936a

Please sign in to comment.