Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with the Array API standard #7848

Open
TomNicholas opened this issue May 18, 2023 · 30 comments · Fixed by #9776
Open

Compatibility with the Array API standard #7848

TomNicholas opened this issue May 18, 2023 · 30 comments · Fixed by #9776
Labels
API design array API standard Support for the Python array API standard topic-arrays related to flexible array support topic-chunked-arrays Managing different chunked backends, e.g. dask upstream issue

Comments

@TomNicholas
Copy link
Member

TomNicholas commented May 18, 2023

What is your issue?

Meta-issue to track all the smaller issues around making xarray and the array API standard compatible with each other.

We've already had

and there will likely be many others.


I suspect this might require changes to the standard as well as to xarray - in particular see this list of common numpy functions which are not currently in the array API standard. Of these xarray currently uses (FYI @ralfgommers ):

  • np.clip
  • np.diff
  • np.pad
  • np.repeat
  • np.take
  • np.tile
@TomNicholas TomNicholas added API design upstream issue topic-arrays related to flexible array support topic-chunked-arrays Managing different chunked backends, e.g. dask labels May 18, 2023
@TomNicholas
Copy link
Member Author

TomNicholas commented May 18, 2023

np.pad is an interesting example in the context of chunked arrays (xref #6807) - dask implements it (parallelized using various approaches, including map_blocks internally), but cubed currently doesn't implement it because it's not part of the array API standard. (cc @tomwhite)

@shoyer
Copy link
Member

shoyer commented Aug 14, 2023

Probably the most important missing feature for Xarray in the Array API standard is support for indexing with integer arrays: data-apis/array-api#669

@kgryte
Copy link

kgryte commented Oct 19, 2023

Quick update from the Array API standard, both take and tile have been merged in the standard. take is available as of the 2022.12 revision and tile will be included in the 2023 revision.

@TomNicholas
Copy link
Member Author

Thanks @kgryte! We also just bumped our minimum-supported numpy to 1.22 (#8346), which means we could potentially start using numpy.array_api internally for various operations...

@TomNicholas TomNicholas added the array API standard Support for the Python array API standard label Jan 25, 2024
@allen-adastra
Copy link

Hi! I'm wondering if this is still a concerted effort? Def would love to see tight Xarray <-> Jax integration.

@keewis
Copy link
Collaborator

keewis commented May 19, 2024

I don't have any knowledge about jax specifically, but packages like cubed have been wrapped successfully by xarray using the existing Array API support. So I'd say that as long as jax fully supports the Array API, you should be able to do this already. If not, could you please open a dedicated issue?

There might be a few operations that are needlessly verbose, though, so someone would have to create a integration package (jax-xarray or xarray-jax, not sure).

@TomNicholas
Copy link
Member Author

Hey @allen-adastra, great to see you here.

Wrapping JAX with Xarray is a great idea in need of a champion / beta-tester. As Justus says it should just work, but most of the Xarray devs concentrate on making the high-level wrappers compliant, rather than exposing all the edge cases the specific libraries.

If you or one of your colleagues wants to just have a go and report back what happens that would be a great start!

@allen-adastra
Copy link

Gotcha, so it sounds like Xarray should have compliance with the Array API?

Probably too much for me to take on, but going to talk around to see if we can drum up sufficient support somehow...

@keewis
Copy link
Collaborator

keewis commented May 19, 2024

so it sounds like Xarray should have compliance with the Array API?

at least as far as dispatching to wrapped arrays goes, yes. The full Array API is not a good fit for xarray's objects and adopting just a subset is not allowed, so we don't implement it for those.

@TomNicholas
Copy link
Member Author

TomNicholas commented May 20, 2024

To clarify, xarray aims to be able to wrap any array-API compliant object. This should already work, and is already being used to wrap some array types successfully (e.g. cubed), but it's likely that trying this with JAX will turn up some corner cases.

@keewis I think is referring to the separate idea of xarray objects (e.g. xr.Variable) following the array API, (as wrappable objects rather than wrappers), which we decided against.

So all I'm suggesting that you guys do @allen-adastra is try wrapping a JAX array with xarray and see what happens, and report back any failures. I'm not suggesting you take on the task of making xarray work for any array-API compliant object (as we're already trying to do that for the general case).

@mjwillson
Copy link

mjwillson commented Aug 30, 2024

Hiya, since jax now supports the array-API I've been playing around with it a bit. Mostly things work well, but one thing that seems missing is a nice way to use ufuncs specific to the array API implementation of the wrapped array.

I don't think this is necessarily specific to JAX, but I'll use JAX to illustrate it:

import xarray 
import jax.numpy as jnp
import numpy as np
v = xarray.Variable(('foo', 'bar'), jnp.ones((2,3)))

All goes well so far, but what if we want to apply a ufunc like np.exp? Applying np.exp directly will convert the underlying jax array to numpy:

type(np.exp(v).data)
=> numpy.ndarray

Could this be avoided? If xarray is intercepting these ufunc calls from numpy, perhaps it could dispatch them via data.__array_namespace__() ?

Applying the jax equivalent jnp.exp to the xarray just doesn't work:

jnp.exp(v)
=> TypeError: Argument '<xarray.Variable (foo: 2, bar: 3)> ... ' of type <class 'xarray.core.variable.Variable'> is not a valid JAX type.

The following works:

xarray.apply_ufunc(jnp.exp, v)

but is more verbose and requires one to know the implementation of the underlying array in order to pick the right namespace (np vs jnp vs ...). We can work around the 'need to know the implementation' problem with something like:

xarray.apply_ufunc(lambda v: v.__array_namespace__().exp(v), v)

But this is crazy verbose, and doesn't work for numpy < 2. A helper could make it slighly less bad, e.g.:

xarray.apply_api_ufunc(lambda np: np.exp, v)
or
xarray.apply_api_ufunc(lambda x, np: np.exp(x), v)

although still not great. Perhaps a better alternative could be something like

import xarray.ufuncs as xnp
xnp.exp(v)

? This avoids either the need for boilerplate, or any confusion that might arise from e.g. a np.exp(data_array) call being dispatched to a different array API implementation under the hood. But I'm not sure how much work involved.

@shoyer
Copy link
Member

shoyer commented Aug 30, 2024 via email

@TomNicholas
Copy link
Member Author

TomNicholas commented Aug 30, 2024

If xarray is intercepting these ufunc calls from numpy, perhaps it could dispatch them via data.array_namespace() ?

I think so. I might be ignoring history here but couldn't we make np.exp(v) dispatch to jnp.exp just by changing v.__array_function__ to look in v.data.__array_namespace__?

Not sure if that would be confusing though, to use np.exp to call jnp.exp...

@mjwillson
Copy link

mjwillson commented Aug 30, 2024

Now I think of it: this is not really xarray-specific, but I would have expected there to be a canonical 'generic' array API namespace that dispatches based on array_namespace of the arguments, e.g.

import array_api.generic_ufuncs as xnp

xnp.exp(jax_array) -> jax_array
xnp.exp(numpy_array) -> numpy_array

Does anything like this exist, and if not, should it exist? and could xarray piggy-pack off it without implementing the array API for xarray datatypes directly?

[EDIT: to answer my own question, I guess polymorphic dispatch for multiple-argument ufuncs isn't entirely trivial, although still feels like some reasonable heuristics could be implemented or worst-case errors raised if arguments are incompatible]

@mjwillson
Copy link

mjwillson commented Aug 30, 2024 via email

@TomNicholas
Copy link
Member Author

TomNicholas commented Aug 30, 2024

FYI @mjwillson you might also be interested in the discussion here about how xarray could call jax's vmap within apply_ufunc.

#9286 (comment)

@allen-adastra
Copy link

@mjwillson Hello! I'm curious how you're avoiding conversion to a numpy array? Running this:

import xarray 
import jax.numpy as jnp
import numpy as np
v = xarray.Variable(('foo', 'bar'), jnp.ones((2,3)))
print(type(v.data))

Gives <class 'numpy.ndarray'>

@allen-adastra
Copy link

(on the latest versions of xarray + jax btw)

pip list | grep -E "xarray|jax"
jax               0.4.31
jaxlib            0.4.31
xarray            2024.7.0

@TomNicholas
Copy link
Member Author

TomNicholas commented Sep 7, 2024

Hi @allen-adastra ! I suspect that for some reason the JAX array is not making it through the as_compatible_data function that gets called by Variable.__init__, and is being coerced to np.ndarray by a np.asarray call (maybe this one).

This is not supposed to happen if JAX correctly implements __array_namespace__ (see this check), so this might indicate some non-compliance on JAX's part...

EDIT: It looks like __array_namespace__ is defined on the JAX Array class (see https://github.com/google/jax/blob/02b7a767683dd92c7cd3503fbb0d60b2a7440bb9/jax/_src/numpy/array_methods.py#L956), so maybe there is something else going on? But regardless I bet that coercion is happening inside as_compatible_data.

EDIT2: Looks like the Variable wrapping was working okay for @mjwillson above?

@TomNicholas
Copy link
Member Author

TomNicholas commented Sep 7, 2024

Going back to @mjwillson 's question in #7848 (comment):

a canonical 'generic' array API namespace that dispatches based on __array_namespace__ of the arguments,

Does anything like this exist, and if not, should it exist? and could xarray piggy-pack off it without implementing the array API for xarray datatypes directly?

xarray.ufuncs used to be a thing. We should probably bring it back!

I just spoke to @lucascolley at the NumFOCUS summit about this, and we decided we agree with @shoyer that we should bring back xarray.ufuncs. We think the syntax convention should be:

import xarray.ufuncs as xu
xu.exp(v)

The reasoning is:

  • Using np for this would imply a breaking change to the existing behaviour of np.exp(v). It's possible some users are relying on this returning a np.ndarray, so it's be better to have a separate namespace for this.
  • Using xp for this would get confusing with the array API standard encouraging using xp to mean the entire array API namespace, including creation functions and so on. (It's a little unfortunate that both xarray and the array API standard like to use x-prefixes).
  • Using np or xnp (or anything with an n) for this would be confusing because numpy is not actually involved at all! We are using an agnostic interface to call a jax function on a jax array.
  • Using the actual jax.numpy namespace jnp would technically be a violation of the array API standard, which doesn't allow supporting some operations (i.e. ufuncs) but not others on the xarray classes (see also Should NamedArray be interchangeable with other array types? or Should we support the axis kwarg? #8333 and the linked discussion partial adoption of Array API for Xarray's NamedArray data-apis/array-api#698).
  • Literally using the actual array API namespace convention xp would imply having user code that does import jax.numpy as xp, which we don't want because it's confusing to users. It also would be violating the standard too.
  • xu is nice because it doesn't have any of those problems, and it still indicates both that its an xarray thing and that it's specifically for ufuncs.

The original xarray.ufuncs module from yesteryear actually had an explicit set of numpy-dispatching functions that were defined, but due to the array API standard we can now just point to that list in our documentation instead. EDIT: Actually there is no list of "ufuncs" in the array API standard, so maybe we do just have to explicitly list which functions we do and don't support.

If anyone wants to take a stab at adding this xarray.ufuncs module it would just have to be defined at the top-level and then examine the wrapped array type .data to decide which namespace to use on it for the particular operation.

@allen-adastra
Copy link

allen-adastra commented Sep 7, 2024

Hi @allen-adastra ! I suspect that for some reason the JAX array is not making it through the as_compatible_data function that gets called by Variable.__init__, and is being coerced to np.ndarray by a np.asarray call (maybe this one).

This is not supposed to happen if JAX correctly implements __array_namespace__ (see this check), so this might indicate some non-compliance on JAX's part...

EDIT: It looks like __array_namespace__ is defined on the JAX Array class (see https://github.com/google/jax/blob/02b7a767683dd92c7cd3503fbb0d60b2a7440bb9/jax/_src/numpy/array_methods.py#L956), so maybe there is something else going on? But regardless I bet that coercion is happening inside as_compatible_data.

EDIT2: Looks like the Variable wrapping was working okay for @mjwillson above?

Ah, found the issue. __array_namespace__ is introduced in 0.4.32 https://jax.readthedocs.io/en/latest/jax.numpy.html#python-array-api-standard which is not yet in PyPi, so this will work intended once 0.4.32 is out https://github.com/google/jax/releases

@allen-adastra
Copy link

allen-adastra commented Sep 7, 2024

For now, I implemented some small hacks in a fork, and am getting somewhere experimenting with a Xarray + JAX integration (experimental repo).

Probably best to continue a more in depth convo in the existing thread, but @mjwillson I'd be curious to hear your thoughts on the "dummy pytrees" issue I mention here

@lucascolley
Copy link

lucascolley commented Sep 7, 2024

In case anyone is interested in using JAX<0.4.32 for some reason, it would be possible to modify the xarray logic to use (or just copy from) array-api-compat's array_namespace: https://github.com/data-apis/array-api-compat/blob/main/array_api_compat/common/_helpers.py#L529-L543, which basically amounts to using jax.experimental.array_api instead.

@shoyer
Copy link
Member

shoyer commented Sep 10, 2024

Now I think of it: this is not really xarray-specific, but I would have expected there to be a canonical 'generic' array API namespace that dispatches based on array_namespace of the arguments, e.g.

import array_api.generic_ufuncs as xnp

xnp.exp(jax_array) -> jax_array
xnp.exp(numpy_array) -> numpy_array

Does anything like this exist, and if not, should it exist? and could xarray piggy-pack off it without implementing the array API for xarray datatypes directly?

The problem is that Xarray objects shouldn't implement __array_namespace__, because they are not fully array API compliant. We talked about this pretty extensively in data-apis/array-api#698

@mjwillson
Copy link

If anyone wants to take a stab at adding this xarray.ufuncs module it would just have to be defined at the top-level and then examine the wrapped array type .data to decide which namespace to use on it for the particular operation.

I would have a go at this, although it looks like it would be easier for someone with more context about the old ufuncs.py. The old version contains some multi-argument dispatch logic which I'm not sure if you'd want to preserve or not, for example. And it goes via _unary_op and _binary_op, so not clear to me if the dispatch based on underlying array type should live there, or in ufuncs.py. If it lives in _unary_op/_binary_op, could that be disentangled from the dispatch logic in ufuncs.py, etc.

@mjwillson
Copy link

Alternatively would ufuncs.py now be better implemented just as thin wrappers around apply_ufunc?

@shoyer
Copy link
Member

shoyer commented Sep 17, 2024

Alternatively would ufuncs.py now be better implemented just as thin wrappers around apply_ufunc?

Yes, I would lean towards this. If I recall correctly, xarray.ufuncs pre-dated apply_ufunc.

agriyakhetarpal pushed a commit to HIPS/autograd that referenced this issue Oct 15, 2024
Small change that adds the `__array_namespace__` method from the [Array
API
standard](https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.array.__array_namespace__.html)
to autograd's `ArrayBox`.

This makes autograd (partially) compatible with libraries that check for
array API compatibility, such as
[xarray](https://github.com/pydata/xarray), because xarray will coerce
types that do not implement the array API into regular numpy arrays, see
relevant discussion
[here](pydata/xarray#7848).

The following code snippet:
```python
import xarray as xr

import autograd.numpy as np
from autograd import grad


def f(x):
    data = xr.DataArray(x, dims=range(x.ndim)).data
    print(data)
    return np.sum(data)


x = np.random.uniform(-1, 1, 10)
grad(f)(x)
```
prints
```
[<autograd.numpy.numpy_boxes.ArrayBox object at 0x7f9465347640>
 <autograd.numpy.numpy_boxes.ArrayBox object at 0x7f94652f7000>
 <autograd.numpy.numpy_boxes.ArrayBox object at 0x7f94652f7c40>]
```
currently, because the traced `ArrayBox` is being coerced into a regular
numpy array of `dtype=object` with `ArrayBox` elements. Differentiating
through this works (with some limitations), but is very inefficient.

With the addition suggested in this PR, the snippet instead prints:
```
Autograd ArrayBox with value [1. 1. 1.]
```
i.e. `data` is now a "proper" `ArrayBox`.

I might also be interested in working on a full xarray wrapper in the
future, but for now just this change would already help a lot.

---------

Co-authored-by: Fabian Joswig <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@slevang
Copy link
Contributor

slevang commented Nov 13, 2024

I ran into the ufuncs issue trying to do basic ops on objects containing jax arrays. Opened #9776 for discussion.

@lucascolley
Copy link

might be good to reopen this for further discussion

@slevang
Copy link
Contributor

slevang commented Nov 18, 2024

Opened #9798 as another follow up, curious to get feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API design array API standard Support for the Python array API standard topic-arrays related to flexible array support topic-chunked-arrays Managing different chunked backends, e.g. dask upstream issue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants