Skip to content

Commit

Permalink
Enable experimental/marginalize for the JAX and Numpy backends.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 579016812
  • Loading branch information
srvasude authored and tensorflower-gardener committed Nov 2, 2023
1 parent a446360 commit 6ccdb1e
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 29 deletions.
23 changes: 16 additions & 7 deletions tensorflow_probability/python/experimental/marginalize/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
# Description:
# Automatic marginalization of latent variables.

# Placeholder: py_library
# Placeholder: py_test
load(
"//tensorflow_probability/python:build_defs.bzl",
"multi_substrate_py_library",
"multi_substrate_py_test",
)

package(
# default_applicable_licenses
Expand All @@ -27,17 +30,18 @@ package(

licenses(["notice"])

py_library(
multi_substrate_py_library(
name = "logeinsumexp",
srcs = ["logeinsumexp.py"],
deps = [
# numpy dep,
# opt_einsum dep,
# tensorflow dep,
"//tensorflow_probability/python/internal:prefer_static",
],
)

py_test(
multi_substrate_py_test(
name = "logeinsumexp_test",
size = "medium",
srcs = [
Expand All @@ -53,7 +57,7 @@ py_test(
],
)

py_library(
multi_substrate_py_library(
name = "marginalize",
srcs = ["__init__.py"],
deps = [
Expand All @@ -62,7 +66,7 @@ py_library(
],
)

py_library(
multi_substrate_py_library(
name = "marginalizable",
srcs = ["marginalizable.py"],
deps = [
Expand All @@ -72,13 +76,17 @@ py_library(
"//tensorflow_probability/python/distributions:categorical",
"//tensorflow_probability/python/distributions:joint_distribution_coroutine",
"//tensorflow_probability/python/distributions:sample",
"//tensorflow_probability/python/internal:prefer_static",
"//tensorflow_probability/python/internal:samplers",
],
)

py_test(
multi_substrate_py_test(
name = "marginalizable_test",
size = "medium",
srcs = ["marginalizable_test.py"],
jax_tags = ["notap"],
numpy_tags = ["notap"],
deps = [
":marginalize",
# absl/testing:parameterized dep,
Expand All @@ -92,6 +100,7 @@ py_test(
"//tensorflow_probability/python/distributions:normal",
"//tensorflow_probability/python/distributions:poisson",
"//tensorflow_probability/python/distributions:sample",
"//tensorflow_probability/python/internal:prefer_static",
"//tensorflow_probability/python/internal:test_util",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"""Compute einsums in log space."""

import opt_einsum as oe
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.internal import prefer_static as ps


# pylint: disable=no-member
Expand Down Expand Up @@ -72,8 +73,8 @@ def rearrange(src, dst, t):
if i not in src:
new_indices += i
new_src = src + new_indices
new_t = tf.reshape(t, tf.concat(
[tf.shape(t), tf.ones(len(new_indices), dtype=tf.int32)], axis=0))
new_t = tf.reshape(t, ps.concat(
[ps.shape(t), ps.ones(len(new_indices), dtype=tf.int32)], axis=0))
formula = '{}->{}'.format(new_src, dst)
# It is safe to use ordinary `einsum` here as no summations
# are performed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from hypothesis.extra import numpy as hpnp
import hypothesis.strategies as hps
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.experimental.marginalize.logeinsumexp import _binary_einslogsumexp
from tensorflow_probability.python.experimental.marginalize.logeinsumexp import logeinsumexp
from tensorflow_probability.python.internal import test_util
Expand Down Expand Up @@ -179,7 +179,6 @@ def test_compare_einsum(self):
formula = 'abcdcfg,edfcbaa->bd'
u = tf.math.log(tf.einsum(formula, a, b))
v = logeinsumexp(formula, tf.math.log(a), tf.math.log(b))

self.assertAllClose(u, v)

def test_zero_zero_multiplication(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from tensorflow_probability.python.distributions import joint_distribution_coroutine as jdc_lib
from tensorflow_probability.python.distributions import sample as sample_lib
from tensorflow_probability.python.experimental.marginalize.logeinsumexp import logeinsumexp
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import samplers


__all__ = [
Expand Down Expand Up @@ -117,10 +119,9 @@ def _support(dist):
dist.sample_shape, 'expand_sample_shape')
p, rank = _support(dist.distribution)
product = _power(p, n)
new_shape = tf.concat([tf.shape(product)[:-1], sample_shape], axis=-1)
new_shape = ps.concat([ps.shape(product)[:-1], sample_shape], axis=-1)

new_rank = rank + tf.compat.v2.compat.dimension_value(
sample_shape.shape[0])
new_rank = rank + tf.compat.dimension_value(sample_shape.shape[0])
return tf.reshape(product, new_shape), new_rank
else:
raise ValueError('Unable to find support for distribution ' +
Expand All @@ -141,11 +142,11 @@ def _expand_right(a, n, pos):
Tensor with inserted dimensions.
"""

axis = tf.rank(a) + pos + 1
return tf.reshape(a, tf.concat([
tf.shape(a)[:axis],
tf.ones([n], dtype=tf.int32),
tf.shape(a)[axis:]], axis=0))
axis = ps.rank(a) + pos + 1
return tf.reshape(a, ps.concat([
ps.shape(a)[:axis],
ps.ones([n], dtype=tf.int32),
ps.shape(a)[axis:]], axis=0))


def _letter(i):
Expand Down Expand Up @@ -216,7 +217,9 @@ def marginalized_log_prob(self, values, name='marginalized_log_prob',

with tf.name_scope(name):
ds = self._call_execute_model(
sample_and_trace_fn=jd_lib.trace_distributions_only)
sample_and_trace_fn=jd_lib.trace_distributions_only,
# Only used for tracing so can be fixed.
seed=samplers.zeros_seed())

# Both 'marginalize' and 'tabulate' indicate that
# instead of using samples provided by the user, this method
Expand All @@ -229,7 +232,7 @@ def marginalized_log_prob(self, values, name='marginalized_log_prob',
for value, dist in zip(values, ds):
if value == 'marginalize':
supp, rank = _support(dist)
r = supp.shape.rank
r = ps.rank(supp)
num_new_variables = r - rank
# We can think of supp as being a tensor containing tensors,
# each of which is a draw from the distribution.
Expand All @@ -251,7 +254,7 @@ def marginalized_log_prob(self, values, name='marginalized_log_prob',
formula.append(indices)
elif value == 'tabulate':
supp, rank = _support(dist)
r = supp.shape.rank
r = ps.rank(supp)
if r is None:
raise ValueError('Need to be able to statically find rank of'
'support of random variable: {}'.format(str(dist)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from tensorflow_probability.python.distributions import poisson
from tensorflow_probability.python.distributions import sample as sample_dist_lib
import tensorflow_probability.python.experimental.marginalize as marginalize
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import test_util


Expand All @@ -48,10 +49,6 @@ def _conform(ts):
return [tf.broadcast_to(a, shape) for a in ts]


def _cat(*ts):
return tf.concat(ts, axis=0)


def _stack(*ts):
return tf.stack(_conform(ts), axis=-1)

Expand Down Expand Up @@ -209,7 +206,7 @@ def test_hmm(self):
n_steps = 4
infer_step = 2

observations = [-1.0, 0.0, 1.0, 2.0]
observations = np.array([-1.0, 0.0, 1.0, 2.0], np.float32)

initial_prob = tf.constant([0.6, 0.4], dtype=tf.float32)
transition_matrix = tf.constant([[0.6, 0.4],
Expand Down Expand Up @@ -309,7 +306,7 @@ def model():
0.4 * tf.roll(o, shift=[1, 0], axis=[-2, -1]))

# Reshape just last two dimensions.
p = tf.reshape(p, _cat(p.shape[:-2], [-1]))
p = tf.reshape(p, ps.concat([ps.shape(p)[:-2], [-1]], axis=0))
xy = yield categorical.Categorical(probs=p, dtype=tf.int32)
x[i] = xy // n
y[i] = xy % n
Expand Down Expand Up @@ -342,6 +339,7 @@ def model():
# order chosen by `tf.einsum` closer matches `_tree_example` above.
self.assertAllClose(p, q)

@test_util.numpy_disable_gradient_test
def test_marginalized_gradient(self):
n = 10

Expand Down

0 comments on commit 6ccdb1e

Please sign in to comment.