Skip to content

Commit

Permalink
Use tangent space in TransformedDistribution.prob.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 576156013
  • Loading branch information
ColCarroll authored and tensorflower-gardener committed Oct 24, 2023
1 parent 346cb6f commit c3586af
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
2 changes: 2 additions & 0 deletions tensorflow_probability/python/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4540,6 +4540,7 @@ multi_substrate_py_test(
tags = ["colab-smoke"],
deps = [
":beta",
":dirichlet",
":exponential",
":independent",
":joint_distribution_auto_batched",
Expand All @@ -4552,6 +4553,7 @@ multi_substrate_py_test(
":normal",
":sample",
":transformed_distribution",
":uniform",
# numpy dep,
# scipy dep,
# tensorflow dep,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def _log_prob(self, y, **kwargs):
return tf.reduce_logsumexp(tf.stack(lp_on_fibers), axis=0)

def _prob(self, y, **kwargs):
if not hasattr(self.distribution, '_prob'):
if not hasattr(self.distribution, '_prob') or self.bijector._is_injective: # pylint: disable=protected-access
return tf.exp(self._log_prob(y, **kwargs))
distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)

Expand All @@ -400,9 +400,6 @@ def _prob(self, y, **kwargs):
)
ildj = self.bijector.inverse_log_det_jacobian(
y, event_ndims=event_ndims, **bijector_kwargs)
if self.bijector._is_injective: # pylint: disable=protected-access
base_prob = self.distribution.prob(x, **distribution_kwargs)
return base_prob * tf.exp(tf.cast(ildj, base_prob.dtype))

# Compute prob on each element of the inverse image.
prob_on_fibers = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from tensorflow_probability.python.bijectors import split
from tensorflow_probability.python.bijectors import tanh
from tensorflow_probability.python.distributions import beta
from tensorflow_probability.python.distributions import dirichlet
from tensorflow_probability.python.distributions import exponential
from tensorflow_probability.python.distributions import independent
from tensorflow_probability.python.distributions import joint_distribution_auto_batched as jdab
Expand All @@ -54,6 +55,7 @@
from tensorflow_probability.python.distributions import normal as normal_lib
from tensorflow_probability.python.distributions import sample as sample_lib
from tensorflow_probability.python.distributions import transformed_distribution
from tensorflow_probability.python.distributions import uniform
from tensorflow_probability.python.internal import hypothesis_testlib as tfp_hps
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import tensorshape_util
Expand Down Expand Up @@ -650,6 +652,26 @@ def testLogProbRatio(self):
# oracle_64, d0.log_prob(x0) - d1.log_prob(x1),
# rtol=0., atol=0.007)

@test_util.numpy_disable_test_missing_functionality('b/306384754')
def testLogProbMatchesProbDirichlet(self):
# This was https://github.com/tensorflow/probability/issues/1761
scaled_dir = transformed_distribution.TransformedDistribution(
distribution=dirichlet.Dirichlet([2.0, 3.0]),
bijector=scale_lib.Scale(2.0))
x = np.array([0.2, 1.8], dtype=np.float32)
self.assertAllClose(scaled_dir.prob(x),
tf.exp(scaled_dir.log_prob(x)))

@test_util.numpy_disable_test_missing_functionality('b/306384754')
def testLogProbMatchesProbUniform(self):
# Uniform does not define _log_prob
scaled_uniform = transformed_distribution.TransformedDistribution(
distribution=uniform.Uniform(),
bijector=scale_lib.Scale(2.0))
x = np.array([0.2], dtype=np.float32)
self.assertAllClose(scaled_uniform.prob(x),
tf.exp(scaled_uniform.log_prob(x)))


@test_util.test_all_tf_execution_regimes
class ScalarToMultiTest(test_util.TestCase):
Expand Down

0 comments on commit c3586af

Please sign in to comment.