Skip to content

Commit

Permalink
Merge pull request #102 from kthyng/small_updates
Browse files Browse the repository at this point in the history
Small updates plus a bunch of test fixing
  • Loading branch information
kthyng authored Oct 29, 2024
2 parents a2c14a1 + 033f56a commit c25ad82
Show file tree
Hide file tree
Showing 19 changed files with 186 additions and 162 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
fail-fast: false
matrix:
os: ["macos-latest", "ubuntu-latest", "windows-latest"]
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11", "3.12"]
steps:
- name: Checkout source
uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion ci/environment-py3.10-win.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dependencies:
- cf_xarray>=0.6
- dask
- netcdf4
- numpy <1.24 # https://github.com/numba/numba/issues/8615#issuecomment-1360792615
- numpy #<1.24 # https://github.com/numba/numba/issues/8615#issuecomment-1360792615
- pip
- requests
- scikit-learn # used by xoak for tree
Expand Down
2 changes: 1 addition & 1 deletion ci/environment-py3.10.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dependencies:
- cf_xarray>=0.6
- dask
- netcdf4
- numpy <1.24 # https://github.com/numba/numba/issues/8615#issuecomment-1360792615
- numpy #<1.24 # https://github.com/numba/numba/issues/8615#issuecomment-1360792615
- pip
- requests
- scikit-learn # used by xoak for tree
Expand Down
2 changes: 1 addition & 1 deletion ci/environment-py3.11-win.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dependencies:
- cf_xarray>=0.6
- dask
- netcdf4
- numpy <1.24 # https://github.com/numba/numba/issues/8615#issuecomment-1360792615
- numpy #<1.24 # https://github.com/numba/numba/issues/8615#issuecomment-1360792615
- pip
- requests
- scikit-learn # used by xoak for tree
Expand Down
2 changes: 1 addition & 1 deletion ci/environment-py3.11.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dependencies:
- cf_xarray>=0.6
- dask
- netcdf4
- numpy <1.24 # https://github.com/numba/numba/issues/8615#issuecomment-1360792615
- numpy #<1.24 # https://github.com/numba/numba/issues/8615#issuecomment-1360792615
- pip
- requests
- scikit-learn # used by xoak for tree
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ name: test-env-win
channels:
- conda-forge
dependencies:
- python=3.9
- python=3.12
- cf_xarray>=0.6
- dask
- netcdf4
- numpy <1.24 # https://github.com/numba/numba/issues/8615#issuecomment-1360792615
- numpy #<1.24 # https://github.com/numba/numba/issues/8615#issuecomment-1360792615
- pip
- requests
- scikit-learn # used by xoak for tree
Expand Down
4 changes: 2 additions & 2 deletions ci/environment-py3.9.yml → ci/environment-py3.12.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ name: test-env-mac-unix
channels:
- conda-forge
dependencies:
- python=3.9
- python=3.12
- cf_xarray>=0.6
- dask
- netcdf4
- numpy <1.24 # https://github.com/numba/numba/issues/8615#issuecomment-1360792615
- numpy #<1.24 # https://github.com/numba/numba/issues/8615#issuecomment-1360792615
- pip
- requests
- scikit-learn # used by xoak for tree
Expand Down
4 changes: 4 additions & 0 deletions docs/whats_new.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
:mod:`What's New`
-----------------

v1.4.1 (October 25, 2024)
=========================
* Small changes so horizontal coordinate search works more consistently with past.

v1.4.0 (November 6, 2023)
=========================
* small changes so that using xESMF as horizontal interpolator works
Expand Down
6 changes: 3 additions & 3 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@ name: extract_model
channels:
- conda-forge
dependencies:
- python>=3.8,<3.11
- python=3.10
# Required for full project functionality (dont remove)
- pytest
- pytest-benchmark
# Examples (remove and add as needed)
- cf_xarray
- cmocean
- dask <=2022.05.0 # for xESMF, https://github.com/axiom-data-science/extract_model/issues/49
- dask #<=2022.05.0 # for xESMF, https://github.com/axiom-data-science/extract_model/issues/49
- extract_model
- matplotlib
- netcdf4
- numpy <1.24 # https://github.com/numba/numba/issues/8615#issuecomment-1360792615
- numpy #<1.24 # https://github.com/numba/numba/issues/8615#issuecomment-1360792615
- numba # required by xesmf
- pip
- pooch
Expand Down
2 changes: 2 additions & 0 deletions extract_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import cf_xarray as cfxr # noqa: F401

import extract_model.accessor # noqa: F401
import extract_model.preprocessing # noqa: F401
import extract_model.utils # noqa: F401

from .extract_model import sel2d, sel2dcf, select, selZ # noqa: F401
from .preprocessing import preprocess
Expand Down
45 changes: 28 additions & 17 deletions extract_model/extract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,9 @@ def select(
* False: 2D array of points with 1 dimension the lons and the other dimension the lats.
* True: lons/lats as unstructured coordinate pairs (in xESMF language, LocStream).
locstreamT: boolean, optional
If False, interpolate in time dimension independently of horizontal points. If True, use advanced indexing/interpolation in xarray to interpolate times to each horizontal locstream point. If this is True, locstream must be True.
If False, interpolate in time dimension independently of horizontal points. If True, use advanced indexing/interpolation in xarray to interpolate times to each horizontal locstream point.
locstreamZ: boolean, optional
If False, interpolate in depth dimension independently of horizontal points. If True, use advanced indexing after depth interpolation select depths to match each horizontal locstream point. If this is True, locstream must be True and locstreamT must be True.
If False, interpolate in depth dimension independently of horizontal points. If True, use advanced indexing after depth interpolation select depths to match each horizontal locstream point.
new_dim : str
This is the name of the new dimension created if we are interpolating to a new set of points that are not a grid.
weights: xESMF netCDF file path, DataArray, optional
Expand Down Expand Up @@ -360,14 +360,15 @@ def select(
"Use extrap=True to extrapolate."
)

if locstreamT:
if not locstream:
raise ValueError("if `locstreamT` is True, `locstream` must also be True.")
if locstreamZ:
if not locstream or not locstreamT:
raise ValueError(
"if `locstreamZ` is True, `locstream` and `locstreamT` must also be True."
)
# these are only true if interpolating in those directions too — need to fix them
# if locstreamT:
# if not locstream:
# raise ValueError("if `locstreamT` is True, `locstream` must also be True.")
# if locstreamZ:
# if not locstream or not locstreamT:
# raise ValueError(
# "if `locstreamZ` is True, `locstream` and `locstreamT` must also be True."
# )

# Perform interpolation
if horizontal_interp:
Expand Down Expand Up @@ -443,13 +444,12 @@ def select(
xs, ys = proj(xs, ys)
x, y = proj(longitude, latitude)

# import pdb; pdb.set_trace()
# lam = calc_barycentric(x, y, xs.reshape((10,9,3)), ys.reshape((10,9,3)))
lam = calc_barycentric(x.flatten(), y.flatten(), xs, ys)
# lam = calc_barycentric(x, y, xs, ys)
# interp_coords are the coords and indices that went into the interpolation
da, interp_coords = interp_with_barycentric(da, ixs, iys, lam)
# import pdb; pdb.set_trace()

# if not locstream:
# FIGURE OUT HOW TO RECONSTITUTE INTO GRID HERE
kwargs_out["interp_coords"] = interp_coords
Expand Down Expand Up @@ -665,6 +665,7 @@ def pt_in_itriangle_proj(ix, iy):

# advanced indexing to select all assuming coherent time series
# make sure len of each dimension matches

if locstreamZ:

dims_to_index = [da.cf["T"].name]
Expand Down Expand Up @@ -697,10 +698,13 @@ def sel2d(
mask: Optional[DataArray] = None,
use_xoak: bool = True,
return_info: bool = False,
k: Optional[int] = None,
**kwargs,
):
"""Find the value of the var at closest location to inputs, optionally respecting mask.
Note: I don't think this function selects for time or depth, only for horizontal coordinates. If you need to select for time or depth, use `select` instead.
This is meant to mimic `xarray` `.sel()` in API and idea, except that the horizontal selection is done for 2D coordinates instead of 1D coordinates, since `xarray` cannot yet handle 2D coordinates. This wraps `xoak`.
Order of inputs is important:
Expand Down Expand Up @@ -728,10 +732,12 @@ def sel2d(
If True, use xoak to find nearest 1 point. If False, use BallTree directly to find distances and nearest 4 points.
return_info: bool
If True, return a dict of extra information that depends on what processes were run.
k: int, optional
For not xoak — number of nearest neighbors to find. Default is either 1 or 50 depending on if a mask is input, but can be overridden by user with this input.
Returns
-------
An xarray object of the same type as input as var which is selected in horizontal coordinates to input locations and, in input, to time and vertical selections. If not selected, other dimensions are brought along. Other items returned in kwargs include:
An xarray object of the same type as input as var which is selected in horizontal coordinates to input locations and, if input, to time and vertical selections. If not selected, other dimensions are brought along. Other items returned in kwargs include:
* distances: the distances from the requested points to the returned nearest points
Expand Down Expand Up @@ -809,7 +815,6 @@ def sel2d(
mask = mask.load()

# Assume mask is 2D — but not true for wetting/drying
# import pdb; pdb.set_trace()
# find indices representing mask
eta, xi = np.where(mask.values)

Expand Down Expand Up @@ -898,16 +903,22 @@ def sel2d(

else:

# make sure the mask matches
if mask is not None:
# import pdb; pdb.set_trace()
msg = f"Mask {mask.name} dimensions do not match horizontal var {var.name} dimensions. mask dims: {mask.dims}, var dims: {var.dims}"
assert len(set(mask.dims) - set(var.dims)) == 0, msg

# currently lons, lats 1D only

# if no mask, assume user just wants 1 nearest point to each input lons/lats pair
# probably should expand this later to be more generic
if mask is None:
k = 1
k = k or 1
# if user inputs mask, use it to only return the nearest point that is active
# so, find nearest 30 points to have options
else:
k = 30
k = k or 50

distances, (iys, ixs) = tree_query(var[lonname], var[latname], lons, lats, k=k)

Expand All @@ -916,7 +927,7 @@ def sel2d(
raise ValueError("all found values are masked!")

if mask is not None:
isorted_mask = np.argsort(-mask.values[iys, ixs], axis=-1)
isorted_mask = np.argsort(-mask.values[iys, ixs], axis=-1, kind="mergesort")
# sort the ixs and iys according to this sorting so that if there are unmasked indices,
# they are leftmost also, and we will use the leftmost values.
ixs_brought_along = np.take_along_axis(ixs, isorted_mask, axis=1)
Expand Down
4 changes: 2 additions & 2 deletions extract_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,15 +586,15 @@ def interp_with_barycentric(da, ixs, iys, lam):
Y=xr.DataArray(iys, dims=("npts", "triangle")),
)
with xr.set_options(keep_attrs=True):
da = xr.dot(vector, lam, dims=("triangle"))
da = xr.dot(vector, lam, dim=("triangle"))

# get z coordinates to go with interpolated output if not available
if "vertical" in vector.cf.coords:
zkey = vector.cf["vertical"].name

# only need to interpolate z coordinates if they are not 1D
if vector[zkey].ndim > 1:
da_vert = xr.dot(vector[zkey], lam, dims=("triangle"))
da_vert = xr.dot(vector[zkey], lam, dim=("triangle"))

# add vertical coords into da
da = da.assign_coords({zkey: da_vert})
Expand Down
40 changes: 20 additions & 20 deletions tests/grids/test_triangular_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def test_fvcom_subset(real_fvcom, preload):
subsetter = UnstructuredGridSubset()
ds = subsetter.subset(real_fvcom, bbox, "fvcom", preload=preload)
assert ds is not None
assert ds.dims["node"] == 1833
assert ds.dims["nele"] == 3392
assert ds.sizes["node"] == 1833
assert ds.sizes["nele"] == 3392
# Check a node variable
np.testing.assert_allclose(
ds["x"][:10],
Expand Down Expand Up @@ -95,8 +95,8 @@ def test_fvcom_subset_accessor(real_fvcom):
bbox = (276.4, 41.5, 277.4, 42.1)
ds = real_fvcom.em.sub_bbox(bbox)
assert ds is not None
assert ds.dims["node"] == 1833
assert ds.dims["nele"] == 3392
assert ds.sizes["node"] == 1833
assert ds.sizes["nele"] == 3392
# Check a node variable
np.testing.assert_allclose(
ds["x"][:10],
Expand All @@ -121,17 +121,17 @@ def test_fvcom_subset_accessor(real_fvcom):

ds = real_fvcom.em.sub_bbox(bbox, model_type="FVCOM")
assert ds is not None
assert ds.dims["node"] == 1833
assert ds.dims["nele"] == 3392
assert ds.sizes["node"] == 1833
assert ds.sizes["nele"] == 3392


@pytest.mark.parametrize("preload", [False, True], ids=lambda x: f"preload={x}")
def test_fvcom_sub_grid_accessor(real_fvcom, preload):
bbox = (276.4, 41.5, 277.4, 42.1)
ds = real_fvcom.em.sub_grid(bbox=bbox, preload=preload)
assert ds is not None
assert ds.dims["node"] == 1833
assert ds.dims["nele"] == 3392
assert ds.sizes["node"] == 1833
assert ds.sizes["nele"] == 3392
# Check a node variable
np.testing.assert_allclose(
ds["x"][:10],
Expand All @@ -156,8 +156,8 @@ def test_fvcom_sub_grid_accessor(real_fvcom, preload):

ds = real_fvcom.em.sub_grid(bbox=bbox, model_type="FVCOM", preload=preload)
assert ds is not None
assert ds.dims["node"] == 1833
assert ds.dims["nele"] == 3392
assert ds.sizes["node"] == 1833
assert ds.sizes["nele"] == 3392


def test_fvcom_filter(real_fvcom):
Expand Down Expand Up @@ -196,10 +196,10 @@ def test_fvcom_subset_scalars(real_fvcom, preload):
ds = real_fvcom.assign(variables={"example": xvar})
ds_ss = ds.em.sub_grid(bbox=bbox, preload=preload)
assert ds_ss is not None
assert ds_ss.dims["node"] == 1833
assert ds_ss.dims["nele"] == 3392
assert ds_ss.sizes["node"] == 1833
assert ds_ss.sizes["nele"] == 3392
assert "example" in ds_ss.variables
assert len(ds_ss["example"].dims) < 1
assert len(ds_ss["example"].sizes) < 1


@pytest.mark.parametrize("preload", [False, True], ids=lambda x: f"preload={x}")
Expand Down Expand Up @@ -230,8 +230,8 @@ def test_selfe_sub_bbox_accessor(selfe_data):
bbox = (-123.8, 46.2, -123.6, 46.3)
ds_ss = selfe_data.em.sub_bbox(bbox=bbox)
assert ds_ss is not None
assert ds_ss.dims["node"] == 4273
assert ds_ss.dims["nele"] == 8178
assert ds_ss.sizes["node"] == 4273
assert ds_ss.sizes["nele"] == 8178
np.testing.assert_allclose(
ds_ss["x"][:10],
np.array(
Expand All @@ -257,8 +257,8 @@ def test_selfe_sub_grid_accessor(selfe_data, preload):
bbox = (-123.8, 46.2, -123.6, 46.3)
ds_ss = selfe_data.em.sub_grid(bbox=bbox, preload=preload)
assert ds_ss is not None
assert ds_ss.dims["node"] == 4273
assert ds_ss.dims["nele"] == 8178
assert ds_ss.sizes["node"] == 4273
assert ds_ss.sizes["nele"] == 8178
np.testing.assert_allclose(
ds_ss["x"][:10],
np.array(
Expand Down Expand Up @@ -286,10 +286,10 @@ def test_selfe_subset_scalars(selfe_data, preload):
bbox = (-123.8, 46.2, -123.6, 46.3)
ds_ss = ds.em.sub_grid(bbox=bbox, preload=preload)
assert ds_ss is not None
assert ds_ss.dims["node"] == 4273
assert ds_ss.dims["nele"] == 8178
assert ds_ss.sizes["node"] == 4273
assert ds_ss.sizes["nele"] == 8178
assert "example" in ds_ss.variables
assert len(ds_ss["example"].dims) < 1
assert len(ds_ss["example"].sizes) < 1


def test_selfe_preload(selfe_data: xr.Dataset):
Expand Down
Loading

0 comments on commit c25ad82

Please sign in to comment.