Skip to content

Commit

Permalink
Merge pull request #45 from kthyng/main
Browse files Browse the repository at this point in the history
Added `filter` and improved `sub_*` functions
  • Loading branch information
kthyng authored Apr 7, 2022
2 parents ea98204 + 242395c commit 3040061
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 28 deletions.
6 changes: 3 additions & 3 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dependencies:
- pytest
# Examples (remove and add as needed)
- cartopy
# - cf_xarray
- cf_xarray
- cmocean
- dask
- jupyter
Expand All @@ -19,5 +19,5 @@ dependencies:
- xarray
- xcmocean
- xesmf
- pip: # install from github to get recent PRs I contributed
- [email protected]:xarray-contrib/cf-xarray.git
# - pip: # install from github to get recent PRs I contributed
# - [email protected]:xarray-contrib/cf-xarray.git
2 changes: 1 addition & 1 deletion extract_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import extract_model.accessor

from .extract_model import argsel2d, sel2d, select # noqa: F401
from .utils import order, preprocess, sub_bbox, sub_grid
from .utils import filter, order, preprocess, sub_bbox, sub_grid


try:
Expand Down
32 changes: 26 additions & 6 deletions extract_model/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,38 @@ def sub_bbox(self, bbox, drop=True, dask_array_chunks=True):
See full docs at `em.sub_bbox()`.
"""

attrs = self.ds.attrs

dss = []
Vars = [
Var for Var in self.ds.data_vars if "longitude" in self.ds[Var].cf.coords
]
for Var in Vars:
dss.append(
em.sub_bbox(
self.ds[Var], bbox, drop=drop, dask_array_chunks=dask_array_chunks
for Var in self.ds.data_vars:
if Var in Vars:
dss.append(
em.sub_bbox(
self.ds[Var],
bbox,
drop=drop,
dask_array_chunks=dask_array_chunks,
)
)
)
else:
dss.append(self.ds[Var])

ds_out = xr.merge(dss)

ds_out.attrs = attrs

return ds_out

def filter(self, standard_names):
"""Filter Dataset to standard_names, keep imp variables too.
See full docs at `em.utils.filter()`.
"""

return xr.merge(dss)
return em.filter(self.ds, standard_names)


@xr.register_dataarray_accessor("em")
Expand Down
106 changes: 95 additions & 11 deletions extract_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,60 @@
import xarray as xr


def filter(ds, standard_names):
"""Filter Dataset by variables
... but retain all necessary for decoding vertical coords.
Parameters
----------
ds: Dataset
xarray Dataset to select model output from.
standard_names: list
Standard names of variables to keep in Dataset.
Returns
-------
Dataset with variables from standard_names included as well as variables corresponding to formula_terms needed to decode vertical coordinates using `cf-xarray`.
"""

# Deal with vertical coord decoding

# standard_names associated with vertical coordinates
s_standard_names_list = [
"ocean_s_coordinate_g1",
"ocean_s_coordinate_g2",
"ocean_sigma_coordinate",
]

# want to find the vertical coord standard_names variables AND those with formula_terms
# which should be identical but seems like a good check
formula_terms = lambda v: v is not None
s_standard_names = lambda v: v in s_standard_names_list

# get a Dataset with these coordinates
ds1 = ds.filter_by_attrs(
formula_terms=formula_terms, standard_name=s_standard_names
)

# For the vertical related coords (e.g. for ROMS these will be `s_rho` and `s_w`
# gather all formula term variable names to bring along
formula_vars = []
for coord in ds1.coords:
formula_vars.extend(list(ds1[coord].cf.formula_terms.values()))
Vars = list(set(formula_vars))
ds2 = ds[Vars]

# Also get a Dataset for all the requested variables
f_standard_names = lambda v: v in standard_names
ds3 = ds.filter_by_attrs(standard_name=f_standard_names)

# Combine
ds = xr.merge([ds1, ds2, ds3])

return ds


def sub_grid(ds, bbox, dask_array_chunks=True):
"""Subset Dataset grids.
Expand Down Expand Up @@ -50,25 +104,53 @@ def sub_grid(ds, bbox, dask_array_chunks=True):
if "lon_rho" in lon_names:
# variables with 'lon_rho', just use first one
Var = [Var for Var in ds.data_vars if "lon_rho" in ds[Var].coords][0]
subsetted = sub_bbox(ds[Var], bbox, drop=True)
# import pdb; pdb.set_trace()

# get xi_rho and eta_rho slice values
xi_rho, eta_rho = subsetted.xi_rho.values, subsetted.eta_rho.values
# unfortunately the indices are reset when the array changes size
# IF the dimensions are dims only and not coords
if "xi_rho" not in ds.coords:
subs = sub_bbox(ds[Var], bbox, other=-500, drop=False)

# index
i_xi_rho = int((subs != -500).sum(dim="xi_rho").argmax())
xi_rho_bool = subs.isel(eta_rho=i_xi_rho) != -500
# import pdb; pdb.set_trace()
if "T" in subs.cf.axes:
xi_rho_bool = xi_rho_bool.cf.isel(T=0)
if "Z" in subs.cf.axes:
xi_rho_bool = xi_rho_bool.cf.isel(Z=0)
xi_rho = np.arange(ds.lon_rho.shape[1])[xi_rho_bool]

i_eta_rho = int((subs != -500).sum(dim="eta_rho").argmax())
eta_rho_bool = subs.isel(xi_rho=i_eta_rho) != -500
if "T" in subs.cf.axes:
eta_rho_bool = eta_rho_bool.cf.isel(T=0)
if "Z" in subs.cf.axes:
eta_rho_bool = eta_rho_bool.cf.isel(Z=0)
eta_rho = np.arange(ds.lon_rho.shape[0])[eta_rho_bool]

else: # 'xi_rho' in ds.coords
# this works in this case because the dimensions as coords can
# "remember" their original values
subsetted = sub_bbox(ds[Var], bbox, drop=True)
# get xi_rho and eta_rho slice values
xi_rho, eta_rho = subsetted.xi_rho.values, subsetted.eta_rho.values

# This first part is to keep the dimensions consistent across
# the grids
# then know xi_u, eta_v
sel_dict = {"xi_rho": xi_rho, "eta_rho": eta_rho}
if "xi_u" in ds:
if "xi_u" in ds.dims:
sel_dict["xi_u"] = xi_rho[:-1]
if "eta_v" in ds:
if "eta_v" in ds.dims:
sel_dict["eta_v"] = eta_rho[:-1]
if "eta_u" in ds:
if "eta_u" in ds.dims:
sel_dict["eta_u"] = eta_rho
if "xi_v" in ds:
if "xi_v" in ds.dims:
sel_dict["xi_v"] = xi_rho
if "eta_psi" in ds:
if "eta_psi" in ds.dims:
sel_dict["eta_psi"] = eta_rho[:-1]
if "xi_psi" in ds:
if "xi_psi" in ds.dims:
sel_dict["xi_psi"] = xi_rho[:-1]
# adjust dimensions of full dataset
import dask
Expand All @@ -89,7 +171,7 @@ def sub_grid(ds, bbox, dask_array_chunks=True):
return ds_new


def sub_bbox(da, bbox, drop=True, dask_array_chunks=True):
def sub_bbox(da, bbox, other=xr.core.dtypes.NA, drop=True, dask_array_chunks=True):
"""Subset DataArray in space.
Can also be used on a Dataset if there is only one horizontal grid.
Expand All @@ -100,6 +182,8 @@ def sub_bbox(da, bbox, drop=True, dask_array_chunks=True):
Property to select model output from.
bbox: list
The bounding box for subsetting is defined as [min_lon, min_lat, max_lon, max_lat]
other: int, float, optional
Value to input in da where bbox is False. Either other or drop is used. By default is nan.
drop: bool, optional
This is passed onto xarray's `da.where()` function. If True, coordinates outside bbox
are dropped from the DataArray, otherwise they are kept but masked/nan'ed.
Expand All @@ -122,7 +206,7 @@ def sub_bbox(da, bbox, drop=True, dask_array_chunks=True):
import dask

with dask.config.set(**{"array.slicing.split_large_chunks": dask_array_chunks}):
da_smaller = da.where(box, drop=drop)
da_smaller = da.where(box, other=other, drop=drop)

return da_smaller

Expand Down
17 changes: 10 additions & 7 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,22 @@ def test_sub_bbox(self, model):
assert np.allclose(da_out, da_compare, equal_nan=True)
assert ds_out.equals(ds_compare)

def test_sub_grid_ds_roms(self, model):
def test_sub_grid_ds(self, model):
"""Test subset on Dataset."""

url, var_name, bbox = model["url"], model["var_name"], model["bbox"]

# Dataset
ds = xr.open_mfdataset([url], preprocess=em.preprocess)
# bbox = [-92, 27, -91, 29]
# if 'roms' in url.stem:
# import pdb; pdb.set_trace()
ds_out = ds.em.sub_grid(bbox=bbox)
da_compare = ds[var_name].em.sub_bbox(bbox=bbox)
if "roms" not in url.stem:

X, Y = da_compare.cf["X"].values, da_compare.cf["Y"].values
sel_dict = {"X": X, "Y": Y}
ds_new = ds.cf.sel(sel_dict)
da_compare = ds[var_name].em.sub_bbox(bbox=bbox)

assert ds_out.equals(ds_new)
X, Y = da_compare.cf["X"].values, da_compare.cf["Y"].values
sel_dict = {"X": X, "Y": Y}
ds_new = ds.cf.sel(sel_dict)

assert ds_out.equals(ds_new)

0 comments on commit 3040061

Please sign in to comment.