diff --git a/tensorflow_probability/python/distributions/BUILD b/tensorflow_probability/python/distributions/BUILD index 90cb00d01a..c42e793ebb 100644 --- a/tensorflow_probability/python/distributions/BUILD +++ b/tensorflow_probability/python/distributions/BUILD @@ -4540,6 +4540,7 @@ multi_substrate_py_test( tags = ["colab-smoke"], deps = [ ":beta", + ":dirichlet", ":exponential", ":independent", ":joint_distribution_auto_batched", @@ -4552,6 +4553,7 @@ multi_substrate_py_test( ":normal", ":sample", ":transformed_distribution", + ":uniform", # numpy dep, # scipy dep, # tensorflow dep, diff --git a/tensorflow_probability/python/distributions/transformed_distribution.py b/tensorflow_probability/python/distributions/transformed_distribution.py index 9e77d68506..8d05d607bc 100644 --- a/tensorflow_probability/python/distributions/transformed_distribution.py +++ b/tensorflow_probability/python/distributions/transformed_distribution.py @@ -388,7 +388,7 @@ def _log_prob(self, y, **kwargs): return tf.reduce_logsumexp(tf.stack(lp_on_fibers), axis=0) def _prob(self, y, **kwargs): - if not hasattr(self.distribution, '_prob'): + if not hasattr(self.distribution, '_prob') or self.bijector._is_injective: # pylint: disable=protected-access return tf.exp(self._log_prob(y, **kwargs)) distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) @@ -400,9 +400,6 @@ def _prob(self, y, **kwargs): ) ildj = self.bijector.inverse_log_det_jacobian( y, event_ndims=event_ndims, **bijector_kwargs) - if self.bijector._is_injective: # pylint: disable=protected-access - base_prob = self.distribution.prob(x, **distribution_kwargs) - return base_prob * tf.exp(tf.cast(ildj, base_prob.dtype)) # Compute prob on each element of the inverse image. prob_on_fibers = [] diff --git a/tensorflow_probability/python/distributions/transformed_distribution_test.py b/tensorflow_probability/python/distributions/transformed_distribution_test.py index 53416a55ac..64e3294f3f 100644 --- a/tensorflow_probability/python/distributions/transformed_distribution_test.py +++ b/tensorflow_probability/python/distributions/transformed_distribution_test.py @@ -42,6 +42,7 @@ from tensorflow_probability.python.bijectors import split from tensorflow_probability.python.bijectors import tanh from tensorflow_probability.python.distributions import beta +from tensorflow_probability.python.distributions import dirichlet from tensorflow_probability.python.distributions import exponential from tensorflow_probability.python.distributions import independent from tensorflow_probability.python.distributions import joint_distribution_auto_batched as jdab @@ -54,6 +55,7 @@ from tensorflow_probability.python.distributions import normal as normal_lib from tensorflow_probability.python.distributions import sample as sample_lib from tensorflow_probability.python.distributions import transformed_distribution +from tensorflow_probability.python.distributions import uniform from tensorflow_probability.python.internal import hypothesis_testlib as tfp_hps from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensorshape_util @@ -650,6 +652,26 @@ def testLogProbRatio(self): # oracle_64, d0.log_prob(x0) - d1.log_prob(x1), # rtol=0., atol=0.007) + @test_util.numpy_disable_test_missing_functionality('b/306384754') + def testLogProbMatchesProbDirichlet(self): + # This was https://github.com/tensorflow/probability/issues/1761 + scaled_dir = transformed_distribution.TransformedDistribution( + distribution=dirichlet.Dirichlet([2.0, 3.0]), + bijector=scale_lib.Scale(2.0)) + x = np.array([0.2, 1.8], dtype=np.float32) + self.assertAllClose(scaled_dir.prob(x), + tf.exp(scaled_dir.log_prob(x))) + + @test_util.numpy_disable_test_missing_functionality('b/306384754') + def testLogProbMatchesProbUniform(self): + # Uniform does not define _log_prob + scaled_uniform = transformed_distribution.TransformedDistribution( + distribution=uniform.Uniform(), + bijector=scale_lib.Scale(2.0)) + x = np.array([0.2], dtype=np.float32) + self.assertAllClose(scaled_uniform.prob(x), + tf.exp(scaled_uniform.log_prob(x))) + @test_util.test_all_tf_execution_regimes class ScalarToMultiTest(test_util.TestCase):