-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid creating distributions when inferring value domain #420
Conversation
Test failures appear unrelated:
@fehiepsi any suggestions? |
|
@fehiepsi it looks like a different new JAX error related to hashing of |
@fritzo Could you allow me to rerun the build? Sometimes I would like to do it, for example to rerun this PR with the new jax release. |
@eb8680 Opps, I missed your comment. Will take a look soon... |
funsor/memoize.py
Outdated
@@ -17,7 +17,7 @@ def memoize(cache=None): | |||
|
|||
@interpreter.interpretation(interpreter._INTERPRETATION) # use base | |||
def memoize_interpretation(cls, *args): | |||
key = (cls,) + tuple(id(arg) if (type(arg).__name__ == "DeviceArray") or not isinstance(arg, Hashable) | |||
key = (cls,) + tuple(id(arg) if ("DeviceArray" in type(arg).__name__) or not isinstance(arg, Hashable) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be fixed in #405. But I didn't check carefully. It is probably that without changing this line, tests still pass.
funsor/tensor.py
Outdated
elif "Buffer" in type(data).__name__ and get_backend() == "jax": | ||
import jax | ||
|
||
data = jax.numpy.array(data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like we need to add some similar Buffer
-handling logic in Gaussian
, which handles raw backend tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hawkinsp In jax-ml/jax#2839, you mentioned that xla_client.Buffer
is going to be deleted. But in the latest jax+jaxlib release, we are seeing that Buffer
starts appearing (it didn't appear in previous releases). I just wonder if it is will be supported in the long run?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, this is a different Buffer
. The one you are talking about was a Python class. The new one is a C++ class.
To improve dispatch times, we are moving a few parts of the JAX dispatch logic into C++, and in some cases (notably, uncomplicated jit
calls), you will get a pure C++ object out (Buffer
) without the additional Python wrapper class. It should duck type the same (e.g., have the same methods and act the same in most respects).
I believe both the Python and C++ variants report themselves as being instances of DeviceArray
. Eventually we will only have the C++ version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the very clear explanation, @hawkinsp! I understand what's going on now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fehiepsi thanks for the help!
@fehiepsi thanks for the help! |
Addresses #412
pair coded with @eb8680
This refactors logic in
Distribution._infer_value_domain()
to avoid creating a temporary dummy distribution if backend distributions implement a.infer_shapes()
method.