diff --git a/pvxarray/accessor.py b/pvxarray/accessor.py index 833afdc..c836ed3 100644 --- a/pvxarray/accessor.py +++ b/pvxarray/accessor.py @@ -4,6 +4,7 @@ import xarray as xr from pvxarray import rectilinear, structured +from pvxarray.cf import get_cf_names class _LocIndexer: @@ -51,6 +52,15 @@ def mesh( order: Optional[str] = None, component: Optional[str] = None, ) -> pv.DataSet: + if (3 - (x, y, z).count(None)) < 1: + try: + x, y, z, _ = get_cf_names(self._obj) + except ImportError: # pragma: no cover + pass + if (3 - (x, y, z).count(None)) < 1: + raise ValueError( + "You must specify at least one dimension as X, Y, or Z or install `cf_xarray`." + ) ndim = 0 if x is not None: _x = self._get_array(x) diff --git a/pvxarray/cf.py b/pvxarray/cf.py new file mode 100644 index 0000000..9f829cf --- /dev/null +++ b/pvxarray/cf.py @@ -0,0 +1,13 @@ +def get_cf_names(da): + """Use `cf_xarray` to get the names of the X, Y, and Z arrays.""" + try: + import cf_xarray # noqa + + axes = da.cf.axes + except (AttributeError, ImportError): # pragma: no cover + raise ImportError("Please install `cf_xarray` to use CF conventions.") + x = axes.get("X", [None])[0] + y = axes.get("Y", [None])[0] + z = axes.get("Z", [None])[0] + t = axes.get("T", [None])[0] + return x, y, z, t diff --git a/requirements.txt b/requirements.txt index f07ff70..e56a8dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ netcdf4 pytest pytest-cov rioxarray +cf_xarray diff --git a/tests/test_base.py b/tests/test_base.py index 13a1bdd..c5a0952 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -36,6 +36,11 @@ def test_report(): assert pvxarray.Report() +def test_no_cf_and_no_names(sample): + with pytest.raises(ValueError): + sample[dict(t=0)].pyvista.mesh() + + def test_bad_key(sample): with pytest.raises(KeyError): sample[dict(t=0)].pyvista.mesh(x="foo") diff --git a/tests/test_cf.py b/tests/test_cf.py new file mode 100644 index 0000000..4b5be6d --- /dev/null +++ b/tests/test_cf.py @@ -0,0 +1,17 @@ +import numpy as np +import xarray as xr + + +def test_air_temperature(): + ds = xr.tutorial.load_dataset("air_temperature") + da = ds.air[dict(time=0)] + + mesh = da.pyvista.mesh() # No X,Y,Z specified, so should try cf_xarray + assert mesh + assert mesh.n_points == 1325 + assert "air" in mesh.point_data + + assert np.array_equal(mesh["air"], da.values.ravel()) + assert np.may_share_memory(mesh["air"], da.values.ravel()) + assert np.array_equal(mesh.x, da.lon) + assert np.array_equal(mesh.y, da.lat)