Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid creating distributions when inferring value domain #420

Merged
merged 10 commits into from
Jan 13, 2021

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Jan 11, 2021

Addresses #412
pair coded with @eb8680

This refactors logic in Distribution._infer_value_domain() to avoid creating a temporary dummy distribution if backend distributions implement a .infer_shapes() method.

@fritzo fritzo requested a review from eb8680 January 11, 2021 23:25
@fritzo
Copy link
Member Author

fritzo commented Jan 12, 2021

Test failures appear unrelated:

...
from jaxlib import cusolver
ImportError: cannot import name 'cusolver'
Makefile:46: recipe for target 'test' failed

@fehiepsi any suggestions?

@eb8680
Copy link
Member

eb8680 commented Jan 12, 2021

Looks like a new JAX bug? jax-ml/jax#5374 edit: this was fixed upstream

@eb8680
Copy link
Member

eb8680 commented Jan 12, 2021

@fehiepsi it looks like a different new JAX error related to hashing of DeviceArrays is causing the build to fail, any idea what's going on?

@fehiepsi
Copy link
Member

fehiepsi commented Jan 12, 2021

@fritzo Could you allow me to rerun the build? Sometimes I would like to do it, for example to rerun this PR with the new jax release.

@fehiepsi
Copy link
Member

fehiepsi commented Jan 12, 2021

@eb8680 Opps, I missed your comment. Will take a look soon...

@fritzo
Copy link
Member Author

fritzo commented Jan 13, 2021

@fritzo Could you allow me to rerun the build?

@fehiepsi sure, feel free to rerun the builds. Do I need to authorize you on travis? Also feel free to push to this branch.

@@ -17,7 +17,7 @@ def memoize(cache=None):

@interpreter.interpretation(interpreter._INTERPRETATION) # use base
def memoize_interpretation(cls, *args):
key = (cls,) + tuple(id(arg) if (type(arg).__name__ == "DeviceArray") or not isinstance(arg, Hashable)
key = (cls,) + tuple(id(arg) if ("DeviceArray" in type(arg).__name__) or not isinstance(arg, Hashable)
Copy link
Member

@fehiepsi fehiepsi Jan 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed in #405. But I didn't check carefully. It is probably that without changing this line, tests still pass.

funsor/tensor.py Outdated
elif "Buffer" in type(data).__name__ and get_backend() == "jax":
import jax

data = jax.numpy.array(data)
Copy link
Member

@eb8680 eb8680 Jan 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we need to add some similar Buffer-handling logic in Gaussian, which handles raw backend tensors.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hawkinsp In jax-ml/jax#2839, you mentioned that xla_client.Buffer is going to be deleted. But in the latest jax+jaxlib release, we are seeing that Buffer starts appearing (it didn't appear in previous releases). I just wonder if it is will be supported in the long run?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this is a different Buffer. The one you are talking about was a Python class. The new one is a C++ class.

To improve dispatch times, we are moving a few parts of the JAX dispatch logic into C++, and in some cases (notably, uncomplicated jit calls), you will get a pure C++ object out (Buffer) without the additional Python wrapper class. It should duck type the same (e.g., have the same methods and act the same in most respects).

I believe both the Python and C++ variants report themselves as being instances of DeviceArray. Eventually we will only have the C++ version.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the very clear explanation, @hawkinsp! I understand what's going on now.

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi thanks for the help!

@eb8680 eb8680 merged commit 8fd6a32 into master Jan 13, 2021
@eb8680 eb8680 deleted the infer-value-dtype branch January 13, 2021 22:59
@fritzo
Copy link
Member Author

fritzo commented Jan 13, 2021

@fehiepsi thanks for the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants