-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improved duck array wrapping #9798
base: main
Are you sure you want to change the base?
Conversation
xarray/core/array_api_compat.py
Outdated
if hasattr(x, "__array_namespace__"): | ||
return x.__array_namespace__() | ||
elif isinstance(x, array_type("cupy")): | ||
# special case cupy for now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cupy seems to have full compliance with the standard, but doesn't yet actually have __array_namespace__
on the core API. Others may be the same?
) | ||
|
||
def sliding_window_view(array, window_shape, axis=None, **kwargs): | ||
# TODO: some libraries (e.g. jax) don't have this, implement an alternative? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is one of the biggest outstanding bummers of wrapping jax arrays. There is apparently openness to adding this as an API (even though it would not offer any performance benefit in XLA). But given this is way outside the API standard, whether it makes sense to implement a general version within xarray that doesn't rely on stride tricks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could implement a version using "summed area tables" (basically run a single accumulator and then compute differences between the window edges); or convolutions I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have something that works pretty well with this style of gather operation. But only in a jit
context where XLA can work its magic. So I guess this is better left to the specific library to implement, or the user.
|
||
import xarray as xr | ||
|
||
# TODO: how to test these in CI? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: I just noticed xarray-array-testing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do install pint in CI if that helps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cupy requries a GPU, but we could add jax
into Xarray's CI suite in GitHub actions, at least.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the only thing we'll want to avoid is testing compat with too many array types in xarray
's CI. I guess I'll need to make progress on xarray-array-testing
to avoid that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax
would be pretty easy to add. The CPU jaxlib package is only like 90MB. And it's a nice test case since it very closely matches the numpy API, but only implements __array_namespace__
, and not __array_function__
or __array_ufunc__
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, we should add JAX CPU. See https://github.com/pydata/xarray/tree/main/ci/requirements for the config files. Feel free to do this in this PR or a follow-up.
see also data-apis/array-api#621 for the higher level discussion of |
xarray/core/array_api_compat.py
Outdated
def get_array_namespace(*values): | ||
def _get_single_namespace(x): | ||
if hasattr(x, "__array_namespace__"): | ||
return x.__array_namespace__() | ||
elif isinstance(x, array_type("cupy")): | ||
# special case cupy for now | ||
import cupy as cp | ||
|
||
return cp | ||
else: | ||
return np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you probably want to wrap array-api-compat's array_namespace
here, see https://github.com/scipy/scipy/blob/ec30b43e143ac0cb0e30129d4da0dbaa29e74c34/scipy/_lib/_array_api.py#L118-L152 for what we do in SciPy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to go that route, I did actually try that but array-api-compat doesn't handle a bunch of things we end up passing through this (scalars, index wrappers, etc) so it would require some careful prefiltering.
The only things this package effectively wraps that don't have __array_namespace__
are cupy, dask, and torch. I tried the torch wrapper but it doesn't pass our as_compatible_data check because the wrapper object itself doesn't have __array_namespace__
or __array_function__
😕
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Support for scalars was just merged half an hour ago! data-apis/array-api-compat#147
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this would handle the array-like check well, but would require adding this as a core xarray dependency to use it in as_compatible_data
. Not sure if there is any appetite for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In SciPy we vendor array-api-compat via a Git submodule. There's a little bit of build system bookkeeping needed, but otherwise it works well without introducing a dependency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I played around with this some more.
Getting an object that is compliant from the perspective of array_api_compat.is_array_api_obj
to work through the hierarchy of xarray objects requires basically swapping in a bunch of more restrictive hasattr(__array_namespace__)
checks for this function. Not too bad.
I was able to get things to the point that xarray can wrap torch.Tensor
, which is pretty cool. But the reality is that xarray relies on a lot of functionality beyond the Array API, so this just doesn't work in any practical sense. Similarly, swapping in the more restrictive cupy.array_api
for the main cupy
namespace causes all kinds of things to break.
It seems torch has had little movement on supporting the standard since 2021. From xarray's perspective, cupy
implements everything we need, except for the actual __array_namespace__
attribute declaring compatibility, which seems to be planned for v14.
So while this compat module is nice in theory, I don't think it's very useful for xarray. cupy, jax, and sparse are good to go without it, and we only need a single special case for cupy to fetch its __array_namespace__
.
I looked back at def as_array_type(self, asarray: callable, **kwargs) -> Self:
# e.g. ds.to_array_type(jnp.asarray, device="gpu")
def is_array_type(self, array: type) -> bool:
# e.g. ds.is_array_type(cp.ndarray) |
👍 I though I'd seen an array API method for converting between compliant array types, but I can't find it now |
Compliant namespaces should now implement def to_namespace(self, xp: ModuleType, **kwargs) -> Self:
xp.from_dlpack(self.data)
# e.g. ds.to_namespace(cp) But this actually doesn't work for cupy, since they're quite stringent about implicit device transfers:
Also sparse doesn't have this at all, and it wouldn't be clear whether you want a |
I think this is in pretty good shape now, except the question of whether to attempt any of this integration testing in CI. That could also be punted to xarray-array-testing (@keewis @TomNicholas) |
The new |
Sure I'll split that off. Agree the naming and API could use some discussion. |
interp will need scipy to support array-api objects which I think they're working on.
yeah funny that the linalg extension does not include More generally, this is great work @slevang. thanks! |
You aren't the first to hit it, reading that issue. Perhaps you could raise it again :) |
aha thanks @lucascolley . Indeed we should rewrite to use |
fwiw, I just didn't tackle this function because it is defined on Dataset, and has probably 15 calls to different numpy functions, followed by a loop over variables. So it seemed tricky to handle in case we have data vars with mixed namespaces or other edge cases. |
@@ -172,7 +166,8 @@ def isnull(data): | |||
) | |||
): | |||
# these types cannot represent missing values | |||
return full_like(data, dtype=bool, fill_value=False) | |||
dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe xp.bool_
is only needed for compatibility with NumPy 1.x releases. Can we make that clearer here, e.g., by using np.bool_
and adding a comment about NumPy versions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do, but also cupy
only has bool_
at the moment. But If I passed built-in bool
to this I think I got a cupy error. It was a weird range from like numpy 1.2x until 2 that they deprecated bool
and only had bool_
, then brought it back in 2.0 🤷
|
||
import xarray as xr | ||
|
||
# TODO: how to test these in CI? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cupy requries a GPU, but we could add jax
into Xarray's CI suite in GitHub actions, at least.
}, | ||
"xfails": {"quantile": "no nanquantile"}, | ||
}, | ||
"pint": { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added pint and sparse, and more consolidated xfail lists. Some of these are probably desired behavior, like all the methods that return bool or int not returning a pint.Quantity
, but it was easiest to just xfail them.
@@ -49,3 +49,5 @@ dependencies: | |||
- toolz | |||
- typing_extensions | |||
- zarr | |||
- pip: | |||
- jax # no way to get cpu-only jaxlib from conda if gpu is present |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to install via conda on my local machine with a GPU automatically gave all the CUDA libs, which seems like unnecessary bloat for dev environments. If anyone knows a way around this, happy to move it back to conda. I also got a failure on windows, I think for trying to pip install anything, so left it off of there.
whats-new.rst
api.rst
Companion to #9776.
My attempts to use xarray-wrapped jax arrays turned up a bunch of limitations of our duck array wrapping. Jax is probably worst case of the major duck array types out there, because it doesn't implement either
__array_function__
or__array_ufunc__
to intercept numpy calls. Cupy, sparse, and probably others do, so callingnp.func(cp.ndarray)
generally works fine, but this usually converts your data to numpy with jax.A lot of this was just grepping around for hard-coded
np.
cases that we can easily dispatch toduck_array_ops
versions. I image some of these changes could be controversial, because a number of them (notablynanmean/std/var
,pad
,quantile
,einsum
,cross
) aren't technically part of the standard. See #8834 (comment) for discussion about the nan-skipping aggregations.It feels like a much better user experience though to try our best to dispatch to the correct backend, and error if the function isn't implemented, rather than blindly calling numpy. And practically, all major array backends that are feasible to wrap today (
cupy
,sparse
,jax
,cubed
,arkouda
, ...?) implement all of these functions.To test, I just ran down the API list and ran most functions to see if we maintain proper wrapping. Prior to the changes here, I had 28 jax failures and 9 cupy failures, while all (non-xfailed) ones now pass.
Basically everything works except the interp/missing methods which have a lot of specialized code. Also a few odds and ends like polyfit and rank.