Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid creating distributions when inferring value domain #420

Merged
merged 10 commits into from
Jan 13, 2021
9 changes: 9 additions & 0 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,15 @@ def _eager_contract_tensors(reduced_vars, terms, backend):
if dim in symbols))
equation = ",".join(einsum_inputs) + "->" + einsum_output
data = opt_einsum.contract(equation, *operands, backend=backend)
# XXX: in jaxlib 0.1.58 + jax 0.2.8, opt_einsum.contract returns
# a jaxlib.xla_client.Buffer, which is a type not supported in Tensor.
# It is unclear whether this is an issue. The good thing is: under jit,
# data will be a ShapedArray, so we won't go to this branch to do
# the extra job `jax.numpy.array(data)`.
if "Buffer" in type(data).__name__:
import jax

data = jax.numpy.array(data)
data = data.reshape(batch_shape + event_shape)
return Tensor(data, inputs)

Expand Down
2 changes: 1 addition & 1 deletion funsor/memoize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def memoize(cache=None):

@interpreter.interpretation(interpreter._INTERPRETATION) # use base
def memoize_interpretation(cls, *args):
key = (cls,) + tuple(id(arg) if (type(arg).__name__ == "DeviceArray") or not isinstance(arg, Hashable)
key = (cls,) + tuple(id(arg) if ("DeviceArray" in type(arg).__name__) or not isinstance(arg, Hashable)
Copy link
Member

@fehiepsi fehiepsi Jan 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed in #405. But I didn't check carefully. It is probably that without changing this line, tests still pass.

else arg for arg in args)
if key not in cache:
cache[key] = cls(*args)
Expand Down
2 changes: 1 addition & 1 deletion funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def reflect(cls, *args, **kwargs):
_, args = args, new_args

# JAX DeviceArray has .__hash__ method but raise the unhashable error there.
cache_key = tuple(id(arg) if type(arg).__name__ == "DeviceArray" or not isinstance(arg, Hashable)
cache_key = tuple(id(arg) if ("DeviceArray" in type(arg).__name__) or not isinstance(arg, Hashable)
else arg for arg in args)
if cache_key in cls._cons_cache:
return cls._cons_cache[cache_key]
Expand Down