-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compatibility with the Array API standard #7848
Comments
Probably the most important missing feature for Xarray in the Array API standard is support for indexing with integer arrays: data-apis/array-api#669 |
Quick update from the Array API standard, both |
Hi! I'm wondering if this is still a concerted effort? Def would love to see tight Xarray <-> Jax integration. |
I don't have any knowledge about There might be a few operations that are needlessly verbose, though, so someone would have to create a integration package ( |
Hey @allen-adastra, great to see you here. Wrapping JAX with Xarray is a great idea in need of a champion / beta-tester. As Justus says it should just work, but most of the Xarray devs concentrate on making the high-level wrappers compliant, rather than exposing all the edge cases the specific libraries. If you or one of your colleagues wants to just have a go and report back what happens that would be a great start! |
Gotcha, so it sounds like Xarray should have compliance with the Array API? Probably too much for me to take on, but going to talk around to see if we can drum up sufficient support somehow... |
at least as far as dispatching to wrapped arrays goes, yes. The full Array API is not a good fit for |
To clarify, xarray aims to be able to wrap any array-API compliant object. This should already work, and is already being used to wrap some array types successfully (e.g. cubed), but it's likely that trying this with JAX will turn up some corner cases. @keewis I think is referring to the separate idea of xarray objects (e.g. So all I'm suggesting that you guys do @allen-adastra is try wrapping a JAX array with xarray and see what happens, and report back any failures. I'm not suggesting you take on the task of making xarray work for any array-API compliant object (as we're already trying to do that for the general case). |
Hiya, since jax now supports the array-API I've been playing around with it a bit. Mostly things work well, but one thing that seems missing is a nice way to use ufuncs specific to the array API implementation of the wrapped array. I don't think this is necessarily specific to JAX, but I'll use JAX to illustrate it: import xarray
import jax.numpy as jnp
import numpy as np
v = xarray.Variable(('foo', 'bar'), jnp.ones((2,3))) All goes well so far, but what if we want to apply a ufunc like type(np.exp(v).data)
=> numpy.ndarray Could this be avoided? If xarray is intercepting these ufunc calls from numpy, perhaps it could dispatch them via Applying the jax equivalent jnp.exp(v)
=> TypeError: Argument '<xarray.Variable (foo: 2, bar: 3)> ... ' of type <class 'xarray.core.variable.Variable'> is not a valid JAX type. The following works: xarray.apply_ufunc(jnp.exp, v) but is more verbose and requires one to know the implementation of the underlying array in order to pick the right namespace (np vs jnp vs ...). We can work around the 'need to know the implementation' problem with something like: xarray.apply_ufunc(lambda v: v.__array_namespace__().exp(v), v) But this is crazy verbose, and doesn't work for numpy < 2. A helper could make it slighly less bad, e.g.: xarray.apply_api_ufunc(lambda np: np.exp, v)
or
xarray.apply_api_ufunc(lambda x, np: np.exp(x), v) although still not great. Perhaps a better alternative could be something like
? This avoids either the need for boilerplate, or any confusion that might arise from e.g. a |
xarray.ufuncs used to be a thing. We should probably bring it back!
…On Fri, Aug 30, 2024 at 4:23 AM Matthew ***@***.***> wrote:
Hiya, since jax now supports the array-API I've been playing around with
it a bit. Mostly things work well, but one thing that seems missing is a
nice way to use ufuncs specific to the array API implementation of the
wrapped array.
I don't think this is necessarily specific to JAX, but I'll use JAX to
illustrate it:
import xarray import jax.numpy as jnpimport numpy as npv = xarray.Variable(('foo', 'bar'), jnp.ones((2,3)))
All goes well so far, but what if we want to apply a ufunc like np.exp?
Applying np.exp directly will convert the underlying jax array to numpy:
type(np.exp(v).data)=> numpy.ndarray
whereas applying the jax equivalent jnp.exp to the xarray just doesn't
work:
type(jnp.exp(v).data)=> TypeError: Argument '<xarray.Variable (foo: 2, bar: 3)> ... ' of type <class 'xarray.core.variable.Variable'> is not a valid JAX type.
The following works:
xarray.apply_ufunc(jnp.exp, v)
but is more verbose and requires one to know the implementation of the
underlying array in order to pick the right namespace (np vs jnp vs ...).
We can work around the 'need to know the implementation' problem with
something like:
xarray.apply_ufunc(lambda v: v.__array_namespace__().exp(v), v)
But this is crazy verbose, and doesn't work for numpy < 2. A helper could
make it slighly less bad, e.g.:
xarray.apply_api_ufunc(lambda np: np.exp, v)orxarray.apply_api_ufunc(lambda x, np: np.exp(x), v)
although still not great. Perhaps a better alternative could be something
like
import xarray.ufuncs as xnp
xnp.exp(v)
?
—
Reply to this email directly, view it on GitHub
<#7848 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJJFVSKQSXETA4VWSS6BITZUBI33AVCNFSM6AAAAAAYG5XI4CVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGMRQHEYTEMJTGM>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
I think so. I might be ignoring history here but couldn't we make Not sure if that would be confusing though, to use |
Now I think of it: this is not really xarray-specific, but I would have expected there to be a canonical 'generic' array API namespace that dispatches based on array_namespace of the arguments, e.g.
Does anything like this exist, and if not, should it exist? and could xarray piggy-pack off it without implementing the array API for xarray datatypes directly? [EDIT: to answer my own question, I guess polymorphic dispatch for multiple-argument ufuncs isn't entirely trivial, although still feels like some reasonable heuristics could be implemented or worst-case errors raised if arguments are incompatible] |
Yeah I'd be happy with this solution too.
In that case, when dispatching you might have to special-case versions of
numpy that don't directly implement the array API (including
__array_namespace__).
…On Fri, 30 Aug 2024, 17:06 Tom Nicholas, ***@***.***> wrote:
I might be ignoring history here but couldn't we make np.exp(v) dispatch
to jnp.exp by changing v.__array__function__ to look in v.data.?
Not sure if that would be confusing though, to use np.exp to call jnp.exp
...
—
Reply to this email directly, view it on GitHub
<#7848 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAABDFTA3IX4GZV2LQAMP3LZUCKBRAVCNFSM6AAAAAAYG5XI4CVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGMRRG44TINRUG4>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
FYI @mjwillson you might also be interested in the discussion here about how xarray could call jax's vmap within apply_ufunc. |
@mjwillson Hello! I'm curious how you're avoiding conversion to a numpy array? Running this:
Gives |
(on the latest versions of xarray + jax btw)
|
Hi @allen-adastra ! I suspect that for some reason the JAX array is not making it through the This is not supposed to happen if JAX correctly implements EDIT: It looks like EDIT2: Looks like the |
Going back to @mjwillson 's question in #7848 (comment):
I just spoke to @lucascolley at the NumFOCUS summit about this, and we decided we agree with @shoyer that we should bring back import xarray.ufuncs as xu
xu.exp(v) The reasoning is:
The original If anyone wants to take a stab at adding this |
Ah, found the issue. |
For now, I implemented some small hacks in a fork, and am getting somewhere experimenting with a Xarray + JAX integration (experimental repo). Probably best to continue a more in depth convo in the existing thread, but @mjwillson I'd be curious to hear your thoughts on the "dummy pytrees" issue I mention here |
In case anyone is interested in using JAX<0.4.32 for some reason, it would be possible to modify the xarray logic to use (or just copy from) array-api-compat's |
The problem is that Xarray objects shouldn't implement |
I would have a go at this, although it looks like it would be easier for someone with more context about the old ufuncs.py. The old version contains some multi-argument dispatch logic which I'm not sure if you'd want to preserve or not, for example. And it goes via _unary_op and _binary_op, so not clear to me if the dispatch based on underlying array type should live there, or in ufuncs.py. If it lives in _unary_op/_binary_op, could that be disentangled from the dispatch logic in ufuncs.py, etc. |
Alternatively would ufuncs.py now be better implemented just as thin wrappers around apply_ufunc? |
Yes, I would lean towards this. If I recall correctly, |
Small change that adds the `__array_namespace__` method from the [Array API standard](https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.array.__array_namespace__.html) to autograd's `ArrayBox`. This makes autograd (partially) compatible with libraries that check for array API compatibility, such as [xarray](https://github.com/pydata/xarray), because xarray will coerce types that do not implement the array API into regular numpy arrays, see relevant discussion [here](pydata/xarray#7848). The following code snippet: ```python import xarray as xr import autograd.numpy as np from autograd import grad def f(x): data = xr.DataArray(x, dims=range(x.ndim)).data print(data) return np.sum(data) x = np.random.uniform(-1, 1, 10) grad(f)(x) ``` prints ``` [<autograd.numpy.numpy_boxes.ArrayBox object at 0x7f9465347640> <autograd.numpy.numpy_boxes.ArrayBox object at 0x7f94652f7000> <autograd.numpy.numpy_boxes.ArrayBox object at 0x7f94652f7c40>] ``` currently, because the traced `ArrayBox` is being coerced into a regular numpy array of `dtype=object` with `ArrayBox` elements. Differentiating through this works (with some limitations), but is very inefficient. With the addition suggested in this PR, the snippet instead prints: ``` Autograd ArrayBox with value [1. 1. 1.] ``` i.e. `data` is now a "proper" `ArrayBox`. I might also be interested in working on a full xarray wrapper in the future, but for now just this change would already help a lot. --------- Co-authored-by: Fabian Joswig <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
I ran into the ufuncs issue trying to do basic ops on objects containing |
might be good to reopen this for further discussion |
Opened #9798 as another follow up, curious to get feedback. |
What is your issue?
Meta-issue to track all the smaller issues around making xarray and the array API standard compatible with each other.
We've already had
and there will likely be many others.
I suspect this might require changes to the standard as well as to xarray - in particular see this list of common numpy functions which are not currently in the array API standard. Of these xarray currently uses (FYI @ralfgommers ):
np.clip
np.diff
np.pad
np.repeat
np.take
np.tile
The text was updated successfully, but these errors were encountered: