You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The Mixture distribution does not seem to be compatible with tuple-valued distributions, raising an error:
TypeError: Dimension value must be integer or None or have an __index__ method, got value 'TensorShape([])' with type '<class 'tensorflow_probability.python.internal.backend.jax.gen.tensor_shape.TensorShape'>'
I used Python 3.11 and TFP 0.20.1. I then reproduced the behavior using the development version, 0.25.0.dev20240628.
Full error message
TypeError Traceback (most recent call last)
Cell In[5], line 14
10 dist3 = tfd.JointDistributionSequential([tfd.Bernoulli(probs=0.5), lambda x: tfd.Normal(2.0 + x, 0.2)])
12 probs = jnp.asarray([0.05, 0.1, 0.85])
---> 14 mixture = tfd.Mixture(
15 cat=tfd.Categorical(probs=probs),
16 components=[dist1, dist2, dist3],
17 )
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/decorator.py:232](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/decorator.py#line=231), in decorate.<locals>.fun(*args, **kw)
230 if not kwsyntax:
231 args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py#line=341), in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
339 # Note: if we ever want to have things set in `self` before `__init__` is
340 # called, here is the place to do it.
341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
343 # Note: if we ever want to override things set in `self` by subclass
344 # `__init__`, here is the place to do it.
345 if self_._parameters is None:
346 # We prefer subclasses will set `parameters = dict(locals())` because
347 # this has nearly zero overhead. However, failing to do this, we will
348 # resolve the input arguments dynamically and only when needed.
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/mixture.py:155](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/mixture.py#line=154), in _Mixture.__init__(self, cat, components, validate_args, allow_nan_stats, name)
153 static_batch_shape = cat.batch_shape
154 for di, d in enumerate(components):
--> 155 if not tensorshape_util.is_compatible_with(static_batch_shape,
156 d.batch_shape):
157 raise ValueError(
158 'components[{}] batch shape must be compatible with cat '
159 'shape and other component batch shapes ({} vs {})'.format(
160 di, static_batch_shape, d.batch_shape))
161 if not tensorshape_util.is_compatible_with(static_event_shape,
162 d.event_shape):
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/internal/tensorshape_util.py:211](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/internal/tensorshape_util.py#line=210), in is_compatible_with(x, other)
199 def is_compatible_with(x, other):
200 """Returns `True` iff `x` is compatible with `other`.
201
202 For more details, see `help(tf.TensorShape.is_compatible_with)`.
(...)
209 is_compatible: `bool` indicating of the shapes are compatible.
210 """
--> 211 return tf.TensorShape(x).is_compatible_with(other)
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:1437](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=1436), in TensorShape.is_compatible_with(self, other)
1399 def is_compatible_with(self, other):
1400 """Returns True iff `self` is compatible with `other`.
1401
1402 Two possibly-partially-defined shapes are compatible if there
(...)
1435
1436 """
-> 1437 other = as_shape(other)
1438 if self.dims is not None and other.dims is not None:
1439 if self.rank != other.rank:
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:1624](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=1623), in as_shape(shape)
1622 return shape
1623 else:
-> 1624 return TensorShape(shape)
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:905](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=904), in TensorShape.__init__(self, dims)
896 """Creates a new TensorShape with the given dimensions.
897
898 Args:
(...)
902 TypeError: If dims cannot be converted to a list of dimensions.
903 """
904 if isinstance(dims, (tuple, list)): # Most common case.
--> 905 self._dims = tuple(as_dimension(d).value for d in dims)
906 elif dims is None:
907 self._dims = None
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:905](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=904), in <genexpr>(.0)
896 """Creates a new TensorShape with the given dimensions.
897
898 Args:
(...)
902 TypeError: If dims cannot be converted to a list of dimensions.
903 """
904 if isinstance(dims, (tuple, list)): # Most common case.
--> 905 self._dims = tuple(as_dimension(d).value for d in dims)
906 elif dims is None:
907 self._dims = None
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:819](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=818), in as_dimension(value)
817 return value
818 else:
--> 819 return Dimension(value)
File [~/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py:295](http://localhost:8888/home/pawel/micromamba/envs/bmi/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py#line=294), in Dimension.__init__(self, value)
293 self._value = int(value.__index__())
294 except AttributeError:
--> 295 raise TypeError(
296 "Dimension value must be integer or None or have "
297 "an __index__ method, got value '{0!r}' with type '{1!r}'".format(
298 value, type(value))) from None
299 if self._value < 0:
300 raise ValueError("Dimension %d must be >= 0" % self._value)
TypeError: Dimension value must be integer or None or have an __index__ method, got value 'TensorShape([])' with type '<class 'tensorflow_probability.python.internal.backend.jax.gen.tensor_shape.TensorShape'>'
The text was updated successfully, but these errors were encountered:
pawel-czyz
changed the title
Mixture for tuple-valued distributions for TFP on JAX
Mixture of tuple-valued distributions for TFP on JAX raises a TypeError
Jun 28, 2024
Short description
The
Mixture
distribution does not seem to be compatible with tuple-valued distributions, raising an error:Such distributions appear e.g.,:
JointDistributionSequential
.Split
bijector.I'm not sure where the issue lies, but under some guidance I'd be more than happy to work on a fix! 🙂
Code example
Consider three structured distributions. Each of them returns a tuple
(int, float)
:Version
I used Python 3.11 and TFP 0.20.1. I then reproduced the behavior using the development version, 0.25.0.dev20240628.
Full error message
The text was updated successfully, but these errors were encountered: