From 6ccdb1e2e2c8373f817e6a240af376f831eb8cd1 Mon Sep 17 00:00:00 2001 From: Srinivas Vasudevan Date: Thu, 2 Nov 2023 16:43:32 -0700 Subject: [PATCH] Enable experimental/marginalize for the JAX and Numpy backends. PiperOrigin-RevId: 579016812 --- .../python/experimental/marginalize/BUILD | 23 +++++++++++------ .../experimental/marginalize/logeinsumexp.py | 7 +++--- .../marginalize/logeinsumexp_test.py | 3 +-- .../marginalize/marginalizable.py | 25 +++++++++++-------- .../marginalize/marginalizable_test.py | 10 +++----- 5 files changed, 39 insertions(+), 29 deletions(-) diff --git a/tensorflow_probability/python/experimental/marginalize/BUILD b/tensorflow_probability/python/experimental/marginalize/BUILD index e92a16f96c..d2d9c1865e 100644 --- a/tensorflow_probability/python/experimental/marginalize/BUILD +++ b/tensorflow_probability/python/experimental/marginalize/BUILD @@ -15,8 +15,11 @@ # Description: # Automatic marginalization of latent variables. -# Placeholder: py_library -# Placeholder: py_test +load( + "//tensorflow_probability/python:build_defs.bzl", + "multi_substrate_py_library", + "multi_substrate_py_test", +) package( # default_applicable_licenses @@ -27,17 +30,18 @@ package( licenses(["notice"]) -py_library( +multi_substrate_py_library( name = "logeinsumexp", srcs = ["logeinsumexp.py"], deps = [ # numpy dep, # opt_einsum dep, # tensorflow dep, + "//tensorflow_probability/python/internal:prefer_static", ], ) -py_test( +multi_substrate_py_test( name = "logeinsumexp_test", size = "medium", srcs = [ @@ -53,7 +57,7 @@ py_test( ], ) -py_library( +multi_substrate_py_library( name = "marginalize", srcs = ["__init__.py"], deps = [ @@ -62,7 +66,7 @@ py_library( ], ) -py_library( +multi_substrate_py_library( name = "marginalizable", srcs = ["marginalizable.py"], deps = [ @@ -72,13 +76,17 @@ py_library( "//tensorflow_probability/python/distributions:categorical", "//tensorflow_probability/python/distributions:joint_distribution_coroutine", "//tensorflow_probability/python/distributions:sample", + "//tensorflow_probability/python/internal:prefer_static", + "//tensorflow_probability/python/internal:samplers", ], ) -py_test( +multi_substrate_py_test( name = "marginalizable_test", size = "medium", srcs = ["marginalizable_test.py"], + jax_tags = ["notap"], + numpy_tags = ["notap"], deps = [ ":marginalize", # absl/testing:parameterized dep, @@ -92,6 +100,7 @@ py_test( "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/distributions:poisson", "//tensorflow_probability/python/distributions:sample", + "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:test_util", ], ) diff --git a/tensorflow_probability/python/experimental/marginalize/logeinsumexp.py b/tensorflow_probability/python/experimental/marginalize/logeinsumexp.py index 7d8794c0b5..1923f79a27 100644 --- a/tensorflow_probability/python/experimental/marginalize/logeinsumexp.py +++ b/tensorflow_probability/python/experimental/marginalize/logeinsumexp.py @@ -15,7 +15,8 @@ """Compute einsums in log space.""" import opt_einsum as oe -import tensorflow.compat.v1 as tf +import tensorflow.compat.v2 as tf +from tensorflow_probability.python.internal import prefer_static as ps # pylint: disable=no-member @@ -72,8 +73,8 @@ def rearrange(src, dst, t): if i not in src: new_indices += i new_src = src + new_indices - new_t = tf.reshape(t, tf.concat( - [tf.shape(t), tf.ones(len(new_indices), dtype=tf.int32)], axis=0)) + new_t = tf.reshape(t, ps.concat( + [ps.shape(t), ps.ones(len(new_indices), dtype=tf.int32)], axis=0)) formula = '{}->{}'.format(new_src, dst) # It is safe to use ordinary `einsum` here as no summations # are performed. diff --git a/tensorflow_probability/python/experimental/marginalize/logeinsumexp_test.py b/tensorflow_probability/python/experimental/marginalize/logeinsumexp_test.py index 182a6d42dd..016284f24f 100644 --- a/tensorflow_probability/python/experimental/marginalize/logeinsumexp_test.py +++ b/tensorflow_probability/python/experimental/marginalize/logeinsumexp_test.py @@ -18,7 +18,7 @@ from hypothesis.extra import numpy as hpnp import hypothesis.strategies as hps import numpy as np -import tensorflow.compat.v1 as tf +import tensorflow.compat.v2 as tf from tensorflow_probability.python.experimental.marginalize.logeinsumexp import _binary_einslogsumexp from tensorflow_probability.python.experimental.marginalize.logeinsumexp import logeinsumexp from tensorflow_probability.python.internal import test_util @@ -179,7 +179,6 @@ def test_compare_einsum(self): formula = 'abcdcfg,edfcbaa->bd' u = tf.math.log(tf.einsum(formula, a, b)) v = logeinsumexp(formula, tf.math.log(a), tf.math.log(b)) - self.assertAllClose(u, v) def test_zero_zero_multiplication(self): diff --git a/tensorflow_probability/python/experimental/marginalize/marginalizable.py b/tensorflow_probability/python/experimental/marginalize/marginalizable.py index d9f327f720..e3ae6fb97c 100644 --- a/tensorflow_probability/python/experimental/marginalize/marginalizable.py +++ b/tensorflow_probability/python/experimental/marginalize/marginalizable.py @@ -24,6 +24,8 @@ from tensorflow_probability.python.distributions import joint_distribution_coroutine as jdc_lib from tensorflow_probability.python.distributions import sample as sample_lib from tensorflow_probability.python.experimental.marginalize.logeinsumexp import logeinsumexp +from tensorflow_probability.python.internal import prefer_static as ps +from tensorflow_probability.python.internal import samplers __all__ = [ @@ -117,10 +119,9 @@ def _support(dist): dist.sample_shape, 'expand_sample_shape') p, rank = _support(dist.distribution) product = _power(p, n) - new_shape = tf.concat([tf.shape(product)[:-1], sample_shape], axis=-1) + new_shape = ps.concat([ps.shape(product)[:-1], sample_shape], axis=-1) - new_rank = rank + tf.compat.v2.compat.dimension_value( - sample_shape.shape[0]) + new_rank = rank + tf.compat.dimension_value(sample_shape.shape[0]) return tf.reshape(product, new_shape), new_rank else: raise ValueError('Unable to find support for distribution ' + @@ -141,11 +142,11 @@ def _expand_right(a, n, pos): Tensor with inserted dimensions. """ - axis = tf.rank(a) + pos + 1 - return tf.reshape(a, tf.concat([ - tf.shape(a)[:axis], - tf.ones([n], dtype=tf.int32), - tf.shape(a)[axis:]], axis=0)) + axis = ps.rank(a) + pos + 1 + return tf.reshape(a, ps.concat([ + ps.shape(a)[:axis], + ps.ones([n], dtype=tf.int32), + ps.shape(a)[axis:]], axis=0)) def _letter(i): @@ -216,7 +217,9 @@ def marginalized_log_prob(self, values, name='marginalized_log_prob', with tf.name_scope(name): ds = self._call_execute_model( - sample_and_trace_fn=jd_lib.trace_distributions_only) + sample_and_trace_fn=jd_lib.trace_distributions_only, + # Only used for tracing so can be fixed. + seed=samplers.zeros_seed()) # Both 'marginalize' and 'tabulate' indicate that # instead of using samples provided by the user, this method @@ -229,7 +232,7 @@ def marginalized_log_prob(self, values, name='marginalized_log_prob', for value, dist in zip(values, ds): if value == 'marginalize': supp, rank = _support(dist) - r = supp.shape.rank + r = ps.rank(supp) num_new_variables = r - rank # We can think of supp as being a tensor containing tensors, # each of which is a draw from the distribution. @@ -251,7 +254,7 @@ def marginalized_log_prob(self, values, name='marginalized_log_prob', formula.append(indices) elif value == 'tabulate': supp, rank = _support(dist) - r = supp.shape.rank + r = ps.rank(supp) if r is None: raise ValueError('Need to be able to statically find rank of' 'support of random variable: {}'.format(str(dist))) diff --git a/tensorflow_probability/python/experimental/marginalize/marginalizable_test.py b/tensorflow_probability/python/experimental/marginalize/marginalizable_test.py index 1211246c5e..b0da46d476 100644 --- a/tensorflow_probability/python/experimental/marginalize/marginalizable_test.py +++ b/tensorflow_probability/python/experimental/marginalize/marginalizable_test.py @@ -34,6 +34,7 @@ from tensorflow_probability.python.distributions import poisson from tensorflow_probability.python.distributions import sample as sample_dist_lib import tensorflow_probability.python.experimental.marginalize as marginalize +from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import test_util @@ -48,10 +49,6 @@ def _conform(ts): return [tf.broadcast_to(a, shape) for a in ts] -def _cat(*ts): - return tf.concat(ts, axis=0) - - def _stack(*ts): return tf.stack(_conform(ts), axis=-1) @@ -209,7 +206,7 @@ def test_hmm(self): n_steps = 4 infer_step = 2 - observations = [-1.0, 0.0, 1.0, 2.0] + observations = np.array([-1.0, 0.0, 1.0, 2.0], np.float32) initial_prob = tf.constant([0.6, 0.4], dtype=tf.float32) transition_matrix = tf.constant([[0.6, 0.4], @@ -309,7 +306,7 @@ def model(): 0.4 * tf.roll(o, shift=[1, 0], axis=[-2, -1])) # Reshape just last two dimensions. - p = tf.reshape(p, _cat(p.shape[:-2], [-1])) + p = tf.reshape(p, ps.concat([ps.shape(p)[:-2], [-1]], axis=0)) xy = yield categorical.Categorical(probs=p, dtype=tf.int32) x[i] = xy // n y[i] = xy % n @@ -342,6 +339,7 @@ def model(): # order chosen by `tf.einsum` closer matches `_tree_example` above. self.assertAllClose(p, q) + @test_util.numpy_disable_gradient_test def test_marginalized_gradient(self): n = 10