Skip to content

Commit

Permalink
Merge pull request #86 from kthyng/mask_xoak
Browse files Browse the repository at this point in the history
Mask xoak
  • Loading branch information
kthyng authored Jan 20, 2023
2 parents 5d5f504 + c89824a commit 4486d31
Show file tree
Hide file tree
Showing 13 changed files with 180 additions and 69 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
fail-fast: false
matrix:
os: ["macos-latest", "ubuntu-latest", "windows-latest"]
python-version: ["3.7", "3.8", "3.9"]
python-version: ["3.8", "3.9", "3.10"]
steps:
- name: Checkout source
uses: actions/checkout@v3
Expand Down Expand Up @@ -40,7 +40,8 @@ jobs:
if: ${{ runner.os != 'Windows' }}
uses: conda-incubator/setup-miniconda@v2
with:
mamba-version: "*" # activate this to build with mamba.
miniforge-variant: Mambaforge
python-version: ${{ matrix.python-version }}
channels: conda-forge, defaults # These need to be specified to use mamba
channel-priority: true
environment-file: ci/environment-py${{ matrix.python-version }}.yml
Expand All @@ -52,7 +53,8 @@ jobs:
if: ${{ runner.os == 'Windows' }}
uses: conda-incubator/setup-miniconda@v2
with:
mamba-version: "*" # activate this to build with mamba.
miniforge-variant: Mambaforge
python-version: ${{ matrix.python-version }}
channels: conda-forge, defaults # These need to be specified to use mamba
channel-priority: true
environment-file: ci/environment-py${{ matrix.python-version }}-win.yml
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: test-env-win
channels:
- conda-forge
dependencies:
- python=3.7
- python=3.10
- cf_xarray>=0.6
- dask
- netcdf4
Expand All @@ -11,10 +11,11 @@ dependencies:
- requests
- scikit-learn # used by xoak for tree
- xarray
- xoak
# - xoak
- pytest
- pytest-benchmark
- pip:
- codecov
- pytest-cov
- coverage[toml]
- git+https://github.com/kthyng/xoak@include_distances
5 changes: 3 additions & 2 deletions ci/environment-py3.7.yml → ci/environment-py3.10.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: test-env-mac-unix
channels:
- conda-forge
dependencies:
- python=3.7
- python=3.10
- cf_xarray>=0.6
- dask <=2022.05.0 # for xESMF, https://github.com/axiom-data-science/extract_model/issues/49
- netcdf4
Expand All @@ -12,10 +12,11 @@ dependencies:
- scikit-learn # used by xoak for tree
- xarray
- xesmf
- xoak
# - xoak
- pytest
- pytest-benchmark
- pip:
- codecov
- pytest-cov
- coverage[toml]
- git+https://github.com/kthyng/xoak@include_distances
3 changes: 2 additions & 1 deletion ci/environment-py3.8-win.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ dependencies:
- requests
- scikit-learn # used by xoak for tree
- xarray
- xoak
# - xoak
- pytest
- pytest-benchmark
- pip:
- codecov
- pytest-cov
- coverage[toml]
- git+https://github.com/kthyng/xoak@include_distances
3 changes: 2 additions & 1 deletion ci/environment-py3.8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ dependencies:
- scikit-learn # used by xoak for tree
- xarray
- xesmf
- xoak
# - xoak
- pytest
- pytest-benchmark
- pip:
- codecov
- pytest-cov
- coverage[toml]
- git+https://github.com/kthyng/xoak@include_distances
3 changes: 2 additions & 1 deletion ci/environment-py3.9-win.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ dependencies:
- requests
- scikit-learn # used by xoak for tree
- xarray
- xoak
# - xoak
- pytest
- pytest-benchmark
- pip:
- codecov
- pytest-cov
- coverage[toml]
- git+https://github.com/kthyng/xoak@include_distances
3 changes: 2 additions & 1 deletion ci/environment-py3.9.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ dependencies:
- scikit-learn # used by xoak for tree
- xarray
- xesmf
- xoak
# - xoak
- pytest
- pytest-benchmark
- pip:
- codecov
- pytest-cov
- coverage[toml]
- git+https://github.com/kthyng/xoak@include_distances
8 changes: 8 additions & 0 deletions docs/whats_new.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
:mod:`What's New`
----------------------------

v1.1.0 (January 19, 2023)
=========================

Two main changes in `sel2d` / `sel2dcf`:

* a mask can be input to limit the lons/lats from the DataArray/Dataset that are used in searching for the nearest point, in case the nearest model point is on land but we still want a valid model point returned.
* incorporating changes from xoak that optional return distance of the model point(s) from the requested point(s).

v1.0.0 (December 9, 2022)
=========================
* Simplified dependencies
Expand Down
51 changes: 45 additions & 6 deletions extract_model/extract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
import warnings

from numbers import Number
from typing import Optional

import cf_xarray # noqa: F401
import numpy as np
import xarray as xr
import xoak # noqa: F401

from xarray import DataArray


try:
import xesmf as xe
Expand Down Expand Up @@ -367,8 +370,10 @@ def select(
return da.squeeze(), weights


def sel2d(var, **kwargs):
"""Find the value of the var at closest location to inputs.
def sel2d(
var, mask: Optional[DataArray] = None, distances_name: str = "distance", **kwargs
):
"""Find the value of the var at closest location to inputs, optionally respecting mask.
This is meant to mimic `xarray` `.sel()` in API and idea, except that the horizontal selection is done for 2D coordinates instead of 1D coordinates, since `xarray` cannot yet handle 2D coordinates. This wraps `xoak`.
Expand All @@ -381,6 +386,8 @@ def sel2d(var, **kwargs):
Can also pass through `xarray.sel()` information for other dimension selections.
Optionally input mask so that if requested lon/lat is on land, the nearest valid model point will be returned. Otherwise nan's will be returned. If requested lon/lat is outside domain but not on land, the nearest model output will be returned regardless.
Parameters
----------
var: DataArray, Dataset
Expand All @@ -389,10 +396,14 @@ def sel2d(var, **kwargs):
>>> em.sel2d(da, ...)
instead of `ds.variable` directly. Then subsequent calls will be faster. See `xoak` for more information.
A Dataset will "remember" the index calculated for whichever grid coordinates were first requested and subsequently run faster for requests on that grid (and not run for other grids).
mask : DataArray, optional
If input, mask is applied to lon/lat so that if requested lon/lat is on land, the nearest valid model point will be returned. Otherwise nan's will be returned. If requested lon/lat is outside domain but not on land, the nearest model output will be returned regardless.
distances_name : str, optional
Provide a name in which to save the distances from xoak; there will be one per lon/lat location found. If None, distances won't be returned in object.
Returns
-------
An xarray object of the same type as input as var which is selected in horizontal coordinates to input locations and, in input, to time and vertical selections. If not selected, other dimensions are brought along.
An xarray object of the same type as input as var which is selected in horizontal coordinates to input locations and, in input, to time and vertical selections. If not selected, other dimensions are brought along. If distances_name is not None, Dataset is returned.
Notes
-----
Expand Down Expand Up @@ -445,23 +456,51 @@ def sel2d(var, **kwargs):
# create Dataset
ds_to_find = xr.Dataset({"lat_to_find": (dims, lats), "lon_to_find": (dims, lons)})

if mask is not None:

# Assume mask is 2D — but not true for wetting/drying

# find indices representing mask
eta, xi = np.where(mask.values)

# make advanced indexer to flatten arrays
var_flat = var.cf.isel(
X=xr.DataArray(xi, dims="loc"), Y=xr.DataArray(eta, dims="loc")
)

var = var_flat.copy()

if var.xoak.index is None:
var.xoak.set_index([latname, lonname], "sklearn_geo_balltree")
elif (latname, lonname) != var.xoak._index_coords:
raise ValueError(
f"Index has been built for grid with coords {var.xoak._index_coords} but coord names input are ({latname}, {lonname})."
)
elif var.xoak.index is not None:
pass
else:
warnings.warn(
"Maybe a mask is not present or being properly identified in var. You could use `use_mask=False`.",
RuntimeWarning,
)

# perform selection
output = var.xoak.sel(
{latname: ds_to_find.lat_to_find, lonname: ds_to_find.lon_to_find}
{latname: ds_to_find.lat_to_find, lonname: ds_to_find.lon_to_find},
distances_name=distances_name,
)

# distances between input points and nearest points
# distances = var.xoak._index.query(np.array([*zip(lats,lons)]))['distances'][:,0]
# import pdb; pdb.set_trace()

with xr.set_options(keep_attrs=True):
return output.sel(**kwargs)


def sel2dcf(var, **kwargs):
def sel2dcf(
var, mask: Optional[DataArray] = None, distances_name: str = "distance", **kwargs
):
"""Find nearest value(s) on 2D horizontal grid using cf-xarray names.
Use "longitude" and "latitude" for those coordinate names.
Expand Down Expand Up @@ -511,7 +550,7 @@ def sel2dcf(var, **kwargs):

new_kwargs.update(kwargs)

return sel2d(var, **new_kwargs)
return sel2d(var, mask=mask, distances_name=distances_name, **new_kwargs)


def selZ(var, depths):
Expand Down
79 changes: 43 additions & 36 deletions extract_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def order(da):
)


def preprocess_roms(ds):
def preprocess_roms(ds, interp_vertical: bool = True):
"""Preprocess ROMS model output for use with cf-xarray.
Also fixes any other known issues with model output.
Expand All @@ -491,6 +491,8 @@ def preprocess_roms(ds):
----------
ds: xarray Dataset
interp_vertical=True
Returns
-------
Same Dataset but with some metadata added and/or altered.
Expand Down Expand Up @@ -546,39 +548,42 @@ def preprocess_roms(ds):
ds["s_w"].attrs["standard_name"] = "ocean_s_coordinate_g2"

# calculate vertical coord
name_dict = {}
if "s_rho" in ds.dims:
name_dict["s_rho"] = "z_rho"
if "positive" in ds.s_rho.attrs:
ds.s_rho.attrs.pop("positive")
if "s_w" in ds.dims:
name_dict["s_w"] = "z_w"
if "positive" in ds.s_w.attrs:
ds.s_w.attrs.pop("positive")
ds.cf.decode_vertical_coords(outnames=name_dict)

# fix attrs
for zname in ["z_rho", "z_w"]: # name_dict.values():
if zname in ds:
ds[
zname
].attrs = {} # coord inherits from one of the vars going into calculation
ds[zname].attrs["positive"] = "up"
ds[zname].attrs["units"] = "m"
ds[zname] = order(ds[zname])

# replace s_rho with z_rho, etc, to make z_rho the vertical coord
for sname, zname in name_dict.items():
for var in ds.data_vars:
if ds[var].ndim == 4:
if "coordinates" in ds[var].encoding:
coords = ds[var].encoding["coordinates"]
if sname in coords: # replace if present
coords = coords.replace(sname, zname)
else: # still add z_rho or z_w
if zname in ds.coords and ds[zname].shape == ds[var].shape:
coords += f" {zname}"
ds[var].encoding["coordinates"] = coords
if interp_vertical:
name_dict = {}
if "s_rho" in ds.dims:
name_dict["s_rho"] = "z_rho"
if "positive" in ds.s_rho.attrs:
ds.s_rho.attrs.pop("positive")
if "s_w" in ds.dims:
name_dict["s_w"] = "z_w"
if "positive" in ds.s_w.attrs:
ds.s_w.attrs.pop("positive")
ds.cf.decode_vertical_coords(outnames=name_dict)

# fix attrs
for zname in ["z_rho", "z_w"]: # name_dict.values():
if zname in ds:
ds[
zname
].attrs = (
{}
) # coord inherits from one of the vars going into calculation
ds[zname].attrs["positive"] = "up"
ds[zname].attrs["units"] = "m"
ds[zname] = order(ds[zname])

# replace s_rho with z_rho, etc, to make z_rho the vertical coord
for sname, zname in name_dict.items():
for var in ds.data_vars:
if ds[var].ndim == 4:
if "coordinates" in ds[var].encoding:
coords = ds[var].encoding["coordinates"]
if sname in coords: # replace if present
coords = coords.replace(sname, zname)
else: # still add z_rho or z_w
if zname in ds.coords and ds[zname].shape == ds[var].shape:
coords += f" {zname}"
ds[var].encoding["coordinates"] = coords

# # easier to remove "coordinates" attribute from any variables than add it to all
# for var in ds.data_vars:
Expand Down Expand Up @@ -706,13 +711,15 @@ def preprocess_rtofs(ds):
raise NotImplementedError


def preprocess(ds, model_type=None):
def preprocess(ds, model_type=None, kwargs=None):
"""A preprocess function for reading in with xarray.
This tries to address known model shortcomings in a generic way so that
`cf-xarray` will work generally, including decoding vertical coordinates.
"""

kwargs = kwargs or {}

# This is an internal attribute used by netCDF which xarray doesn't know or care about, but can
# be returned from THREDDS.
if "_NCProperties" in ds.attrs:
Expand All @@ -739,7 +746,7 @@ def preprocess(ds, model_type=None):
model_type = guess_model_type(ds)

if model_type in preprocess_map:
return preprocess_map[model_type](ds)
return preprocess_map[model_type](ds, **kwargs)

return ds

Expand Down
6 changes: 4 additions & 2 deletions tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_2dsel():
model = models[3]
da = model["da"]
i, j = model["i"], model["j"]
varname = da.name

if da.cf["longitude"].ndim == 1:
longitude = float(da.cf["X"][i])
Expand All @@ -37,7 +38,8 @@ def test_2dsel():
da_check = da.cf.isel(X=i, Y=j)

# checks that the resultant model output is the same
assert np.allclose(da_sel2d.squeeze(), da_check)
assert np.allclose(da_sel2d[varname].squeeze(), da_check)

da_test = da.em.sel2dcf(longitude=lon_comp, latitude=lat_comp)
assert np.allclose(da_sel2d, da_test)
assert np.allclose(da_sel2d[varname], da_test[varname])
assert np.allclose(da_sel2d["distance"], da_test["distance"])
Loading

0 comments on commit 4486d31

Please sign in to comment.