Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved duck array wrapping #9798

Open
wants to merge 27 commits into
base: main
Choose a base branch
from

Conversation

slevang
Copy link
Contributor

@slevang slevang commented Nov 18, 2024

Companion to #9776.

My attempts to use xarray-wrapped jax arrays turned up a bunch of limitations of our duck array wrapping. Jax is probably worst case of the major duck array types out there, because it doesn't implement either __array_function__ or __array_ufunc__ to intercept numpy calls. Cupy, sparse, and probably others do, so calling np.func(cp.ndarray) generally works fine, but this usually converts your data to numpy with jax.

A lot of this was just grepping around for hard-coded np. cases that we can easily dispatch to duck_array_ops versions. I image some of these changes could be controversial, because a number of them (notably nanmean/std/var, pad, quantile, einsum, cross) aren't technically part of the standard. See #8834 (comment) for discussion about the nan-skipping aggregations.

It feels like a much better user experience though to try our best to dispatch to the correct backend, and error if the function isn't implemented, rather than blindly calling numpy. And practically, all major array backends that are feasible to wrap today (cupy, sparse, jax, cubed, arkouda, ...?) implement all of these functions.

To test, I just ran down the API list and ran most functions to see if we maintain proper wrapping. Prior to the changes here, I had 28 jax failures and 9 cupy failures, while all (non-xfailed) ones now pass.

Basically everything works except the interp/missing methods which have a lot of specialized code. Also a few odds and ends like polyfit and rank.

if hasattr(x, "__array_namespace__"):
return x.__array_namespace__()
elif isinstance(x, array_type("cupy")):
# special case cupy for now
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cupy seems to have full compliance with the standard, but doesn't yet actually have __array_namespace__ on the core API. Others may be the same?

xarray/core/computation.py Outdated Show resolved Hide resolved
)

def sliding_window_view(array, window_shape, axis=None, **kwargs):
# TODO: some libraries (e.g. jax) don't have this, implement an alternative?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one of the biggest outstanding bummers of wrapping jax arrays. There is apparently openness to adding this as an API (even though it would not offer any performance benefit in XLA). But given this is way outside the API standard, whether it makes sense to implement a general version within xarray that doesn't rely on stride tricks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could implement a version using "summed area tables" (basically run a single accumulator and then compute differences between the window edges); or convolutions I guess.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have something that works pretty well with this style of gather operation. But only in a jit context where XLA can work its magic. So I guess this is better left to the specific library to implement, or the user.


import xarray as xr

# TODO: how to test these in CI?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: I just noticed xarray-array-testing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do install pint in CI if that helps.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cupy requries a GPU, but we could add jax into Xarray's CI suite in GitHub actions, at least.

Copy link
Collaborator

@keewis keewis Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the only thing we'll want to avoid is testing compat with too many array types in xarray's CI. I guess I'll need to make progress on xarray-array-testing to avoid that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax would be pretty easy to add. The CPU jaxlib package is only like 90MB. And it's a nice test case since it very closely matches the numpy API, but only implements __array_namespace__, and not __array_function__ or __array_ufunc__.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, we should add JAX CPU. See https://github.com/pydata/xarray/tree/main/ci/requirements for the config files. Feel free to do this in this PR or a follow-up.

@lucascolley
Copy link

see also data-apis/array-api#621 for the higher level discussion of nan reductions

Comment on lines 49 to 59
def get_array_namespace(*values):
def _get_single_namespace(x):
if hasattr(x, "__array_namespace__"):
return x.__array_namespace__()
elif isinstance(x, array_type("cupy")):
# special case cupy for now
import cupy as cp

return cp
else:
return np

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably want to wrap array-api-compat's array_namespace here, see https://github.com/scipy/scipy/blob/ec30b43e143ac0cb0e30129d4da0dbaa29e74c34/scipy/_lib/_array_api.py#L118-L152 for what we do in SciPy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to go that route, I did actually try that but array-api-compat doesn't handle a bunch of things we end up passing through this (scalars, index wrappers, etc) so it would require some careful prefiltering.

The only things this package effectively wraps that don't have __array_namespace__ are cupy, dask, and torch. I tried the torch wrapper but it doesn't pass our as_compatible_data check because the wrapper object itself doesn't have __array_namespace__ or __array_function__ 😕

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Support for scalars was just merged half an hour ago! data-apis/array-api-compat#147

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this would handle the array-like check well, but would require adding this as a core xarray dependency to use it in as_compatible_data. Not sure if there is any appetite for that.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In SciPy we vendor array-api-compat via a Git submodule. There's a little bit of build system bookkeeping needed, but otherwise it works well without introducing a dependency.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I played around with this some more.

Getting an object that is compliant from the perspective of array_api_compat.is_array_api_obj to work through the hierarchy of xarray objects requires basically swapping in a bunch of more restrictive hasattr(__array_namespace__) checks for this function. Not too bad.

I was able to get things to the point that xarray can wrap torch.Tensor, which is pretty cool. But the reality is that xarray relies on a lot of functionality beyond the Array API, so this just doesn't work in any practical sense. Similarly, swapping in the more restrictive cupy.array_api for the main cupy namespace causes all kinds of things to break.

It seems torch has had little movement on supporting the standard since 2021. From xarray's perspective, cupy implements everything we need, except for the actual __array_namespace__ attribute declaring compatibility, which seems to be planned for v14.

So while this compat module is nice in theory, I don't think it's very useful for xarray. cupy, jax, and sparse are good to go without it, and we only need a single special case for cupy to fetch its __array_namespace__.

@slevang
Copy link
Contributor Author

slevang commented Nov 20, 2024

I looked back at cupy-xarray. Now that this duck array stuff all works pretty well, I'm wondering how people feel about adding official DataArray/Dataset methods analogous to this, but in a generalized way. Something like:

def as_array_type(self, asarray: callable, **kwargs) -> Self:
# e.g. ds.to_array_type(jnp.asarray, device="gpu")

def is_array_type(self, array: type) -> bool:
# e.g. ds.is_array_type(cp.ndarray)

@slevang slevang marked this pull request as ready for review November 20, 2024 15:47
@dcherian
Copy link
Contributor

methods analogous to this, but in a generalized way.

👍 I though I'd seen an array API method for converting between compliant array types, but I can't find it now

@slevang
Copy link
Contributor Author

slevang commented Nov 20, 2024

Compliant namespaces should now implement from_dlpack which is generally the recommended conversion protocol. So I suppose we could instead pass the namespace and hard code it to use that:

def to_namespace(self, xp: ModuleType, **kwargs) -> Self:
    xp.from_dlpack(self.data)
       
# e.g. ds.to_namespace(cp)

But this actually doesn't work for cupy, since they're quite stringent about implicit device transfers:

TypeError: CPU arrays cannot be directly imported to CuPy. Use `cupy.array(numpy.from_dlpack(input))` instead.

Also sparse doesn't have this at all, and it wouldn't be clear whether you want a COO, DOK, etc.

xarray/core/computation.py Outdated Show resolved Hide resolved
xarray/core/dataset.py Outdated Show resolved Hide resolved
xarray/core/variable.py Outdated Show resolved Hide resolved
@slevang
Copy link
Contributor Author

slevang commented Nov 21, 2024

I think this is in pretty good shape now, except the question of whether to attempt any of this integration testing in CI. That could also be punted to xarray-array-testing (@keewis @TomNicholas)

@shoyer
Copy link
Member

shoyer commented Nov 22, 2024

The new is_array_type and as_array_types methods seem worthy of broader discussion. I would leave them out of this PR for now, so they don't hold up merging it.

@slevang
Copy link
Contributor Author

slevang commented Nov 22, 2024

Sure I'll split that off. Agree the naming and API could use some discussion.

@dcherian
Copy link
Contributor

Basically everything works except the interp/missing methods which have a lot of specialized code.

interp will need scipy to support array-api objects which I think they're working on.

Also a few odds and ends like polyfit and rank.

yeah funny that the linalg extension does not include lstsq.


More generally, this is great work @slevang. thanks!

@lucascolley
Copy link

yeah funny that the linalg extension does not include lstsq.

From 2021:

Initial discussion on linalg.lstsq API design can be found in data-apis/array-api#227. There, it was held that the current API design posed undue burden on specification implementers due to API complexity (returning solutions, ranks, residuals, and singular values) and violated various linear algebra API design principles (particularly orthogonality).

Based on ensuing discussion in consortium meetings, consensus was found in simply removing the linalg.lstsq API altogether. This API finds low usage among downstream libraries (see data) and has several variants due to specialized use cases (see SciPy).

Should this or similar functionality be found to be desirable in the future, specification can happen in a future revision of the standard.

You aren't the first to hit it, reading that issue. Perhaps you could raise it again :)

@dcherian
Copy link
Contributor

aha thanks @lucascolley . Indeed we should rewrite to use solve probably.

@slevang
Copy link
Contributor Author

slevang commented Nov 22, 2024

fwiw, jax and cupy both have linalg.lstsq, although it sounds like solve would be the right way to go.

I just didn't tackle this function because it is defined on Dataset, and has probably 15 calls to different numpy functions, followed by a loop over variables. So it seemed tricky to handle in case we have data vars with mixed namespaces or other edge cases.

@@ -172,7 +166,8 @@ def isnull(data):
)
):
# these types cannot represent missing values
return full_like(data, dtype=bool, fill_value=False)
dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe xp.bool_ is only needed for compatibility with NumPy 1.x releases. Can we make that clearer here, e.g., by using np.bool_ and adding a comment about NumPy versions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do, but also cupy only has bool_ at the moment. But If I passed built-in bool to this I think I got a cupy error. It was a weird range from like numpy 1.2x until 2 that they deprecated bool and only had bool_, then brought it back in 2.0 🤷


import xarray as xr

# TODO: how to test these in CI?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cupy requries a GPU, but we could add jax into Xarray's CI suite in GitHub actions, at least.

},
"xfails": {"quantile": "no nanquantile"},
},
"pint": {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added pint and sparse, and more consolidated xfail lists. Some of these are probably desired behavior, like all the methods that return bool or int not returning a pint.Quantity, but it was easiest to just xfail them.

@@ -49,3 +49,5 @@ dependencies:
- toolz
- typing_extensions
- zarr
- pip:
- jax # no way to get cpu-only jaxlib from conda if gpu is present
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to install via conda on my local machine with a GPU automatically gave all the CUDA libs, which seems like unnecessary bloat for dev environments. If anyone knows a way around this, happy to move it back to conda. I also got a failure on windows, I think for trying to pip install anything, so left it off of there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants