Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixture of tuple-valued distributions for TFP on JAX raises a TypeError #1818

Open
pawel-czyz opened this issue Jun 28, 2024 · 0 comments
Open

Comments

@pawel-czyz
Copy link

pawel-czyz commented Jun 28, 2024

Short description

The Mixture distribution does not seem to be compatible with tuple-valued distributions, raising an error:

TypeError: Dimension value must be integer or None or have an __index__ method, got value 'TensorShape([])' with type '<class 'tensorflow_probability.python.internal.backend.jax.gen.tensor_shape.TensorShape'>'

Such distributions appear e.g.,:

I'm not sure where the issue lies, but under some guidance I'd be more than happy to work on a fix! 🙂

Code example

Consider three structured distributions. Each of them returns a tuple (int, float):

from tensorflow_probability.substrates import jax as tfp

import jax
import jax.numpy as jnp

tfd = tfp.distributions

dist1 = tfd.JointDistributionSequential([tfd.Bernoulli(probs=0.1), lambda x: tfd.Normal(0.0 + x, 0.2)])
dist2 = tfd.JointDistributionSequential([tfd.Bernoulli(probs=0.3), lambda x: tfd.Normal(1.0 + x, 0.2)])
dist3 = tfd.JointDistributionSequential([tfd.Bernoulli(probs=0.5), lambda x: tfd.Normal(2.0 + x, 0.2)])

probs = jnp.asarray([0.05, 0.1, 0.85])

mixture = tfd.Mixture(
    cat=tfd.Categorical(probs=probs),
    components=[dist1, dist2, dist3],
)

Version

I used Python 3.11 and TFP 0.20.1. I then reproduced the behavior using the development version, 0.25.0.dev20240628.

Full error message

TypeError                                 Traceback (most recent call last)
Cell In[5], line 14
     10 dist3 = tfd.JointDistributionSequential([tfd.Bernoulli(probs=0.5), lambda x: tfd.Normal(2.0 + x, 0.2)])
     12 probs = jnp.asarray([0.05, 0.1, 0.85])
---> 14 mixture = tfd.Mixture(
     15     cat=tfd.Categorical(probs=probs),
     16     components=[dist1, dist2, dist3],
     17 )

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/decorator.py:232](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/decorator.py#line=231), in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py#line=341), in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/mixture.py:155](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/mixture.py#line=154), in _Mixture.__init__(self, cat, components, validate_args, allow_nan_stats, name)
    153 static_batch_shape = cat.batch_shape
    154 for di, d in enumerate(components):
--> 155   if not tensorshape_util.is_compatible_with(static_batch_shape,
    156                                              d.batch_shape):
    157     raise ValueError(
    158         'components[{}] batch shape must be compatible with cat '
    159         'shape and other component batch shapes ({} vs {})'.format(
    160             di, static_batch_shape, d.batch_shape))
    161   if not tensorshape_util.is_compatible_with(static_event_shape,
    162                                              d.event_shape):

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/internal/tensorshape_util.py:211](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/internal/tensorshape_util.py#line=210), in is_compatible_with(x, other)
    199 def is_compatible_with(x, other):
    200   """Returns `True` iff `x` is compatible with `other`.
    201 
    202   For more details, see `help(tf.TensorShape.is_compatible_with)`.
   (...)
    209     is_compatible: `bool` indicating of the shapes are compatible.
    210   """
--> 211   return tf.TensorShape(x).is_compatible_with(other)

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:1437](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=1436), in TensorShape.is_compatible_with(self, other)
   1399 def is_compatible_with(self, other):
   1400   """Returns True iff `self` is compatible with `other`.
   1401 
   1402   Two possibly-partially-defined shapes are compatible if there
   (...)
   1435 
   1436   """
-> 1437   other = as_shape(other)
   1438   if self.dims is not None and other.dims is not None:
   1439     if self.rank != other.rank:

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:1624](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=1623), in as_shape(shape)
   1622   return shape
   1623 else:
-> 1624   return TensorShape(shape)

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:905](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=904), in TensorShape.__init__(self, dims)
    896 """Creates a new TensorShape with the given dimensions.
    897 
    898 Args:
   (...)
    902   TypeError: If dims cannot be converted to a list of dimensions.
    903 """
    904 if isinstance(dims, (tuple, list)):  # Most common case.
--> 905   self._dims = tuple(as_dimension(d).value for d in dims)
    906 elif dims is None:
    907   self._dims = None

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:905](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=904), in <genexpr>(.0)
    896 """Creates a new TensorShape with the given dimensions.
    897 
    898 Args:
   (...)
    902   TypeError: If dims cannot be converted to a list of dimensions.
    903 """
    904 if isinstance(dims, (tuple, list)):  # Most common case.
--> 905   self._dims = tuple(as_dimension(d).value for d in dims)
    906 elif dims is None:
    907   self._dims = None

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:819](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=818), in as_dimension(value)
    817   return value
    818 else:
--> 819   return Dimension(value)

File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:295](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=294), in Dimension.__init__(self, value)
    293   self._value = int(value.__index__())
    294 except AttributeError:
--> 295   raise TypeError(
    296       "Dimension value must be integer or None or have "
    297       "an __index__ method, got value '{0!r}' with type '{1!r}'".format(
    298           value, type(value))) from None
    299 if self._value < 0:
    300   raise ValueError("Dimension %d must be >= 0" % self._value)

TypeError: Dimension value must be integer or None or have an __index__ method, got value 'TensorShape([])' with type '<class 'tensorflow_probability.python.internal.backend.jax.gen.tensor_shape.TensorShape'>'
@pawel-czyz pawel-czyz changed the title Mixture for tuple-valued distributions for TFP on JAX Mixture of tuple-valued distributions for TFP on JAX raises a TypeError Jun 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant