Skip to content

Commit

Permalink
Add missing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
robbibt committed Oct 9, 2024
1 parent da5202b commit 50ac8f7
Show file tree
Hide file tree
Showing 3 changed files with 405 additions and 32 deletions.
280 changes: 249 additions & 31 deletions docs/notebooks/Satellite_data.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion eo_tides/eo.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def pixel_tides(

# If only one tidal model exists, squeeze out "tide_model" dim
if len(tides_lowres.tide_model) == 1:
tides_lowres = tides_lowres.squeeze("tide_model", drop=True)
tides_lowres = tides_lowres.squeeze("tide_model")

# Ensure CRS is present before we apply any resampling
tides_lowres = tides_lowres.odc.assign_crs(ds.odc.geobox.crs)
Expand Down
155 changes: 155 additions & 0 deletions tests/test_eo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dask
import geopandas as gpd
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -191,3 +192,157 @@ def test_pixel_tides_times(satellite_ds, measured_tides_ds):
# Verify passing a dataset without time and custom times
measured_tides_ds = pixel_tides(satellite_ds_notime, times=custom_times)
assert len(measured_tides_ds.time) == len(custom_times)


def test_pixel_tides_quantile(satellite_ds):
# Model tides using `pixel_tides` and `calculate_quantiles`
quantiles = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
modelled_tides_ds = pixel_tides(satellite_ds, calculate_quantiles=quantiles)
modelled_tides_lowres = pixel_tides(
satellite_ds,
resample=False,
calculate_quantiles=quantiles,
)

# Verify that outputs contain quantile dim and values match inputs
assert modelled_tides_ds.dims == modelled_tides_lowres.dims
assert "quantile" in modelled_tides_ds.dims
assert "quantile" in modelled_tides_lowres.dims
assert modelled_tides_ds["quantile"].values.tolist() == quantiles
assert modelled_tides_lowres["quantile"].values.tolist() == quantiles

# Verify tides are monotonically increasing along quantile dim
# (in this case, axis=0)
assert np.all(np.diff(modelled_tides_ds, axis=0) > 0)

# Test results match expected results for a set of points across array

# Create test points, reproject to dataset CRS, and extract coords
# as xr.DataArrays so we can select data from our array
points = gpd.points_from_xy(
x=[122.14438, 122.30304, 122.12964, 122.29235],
y=[-17.91625, -17.92713, -18.07656, -18.08751],
crs="EPSG:4326",
).to_crs(satellite_ds.odc.geobox.crs)
x_coords = xr.DataArray(points.x, dims=["point"])
y_coords = xr.DataArray(points.y, dims=["point"])

# Extract modelled tides for each point
try:
extracted_tides = modelled_tides_ds.sel(x=x_coords, y=y_coords, method="nearest")
except KeyError:
extracted_tides = modelled_tides_ds.sel(longitude=x_coords, latitude=y_coords, method="nearest")

# Test if extracted tides match expected results (to within ~2 cm)
expected_tides = np.array([
[-1.89, -2.17, -2.1, -2.21],
[-1.20, -1.28, -1.26, -1.30],
[-0.71, -0.8, -0.77, -0.82],
[-0.33, -0.32, -0.34, -0.32],
[0.5, 0.42, 0.45, 0.41],
[1.59, 1.69, 1.66, 1.70],
])
assert np.allclose(extracted_tides.values, expected_tides, atol=0.02)


# Run test against multiple models
@pytest.mark.parametrize("quantiles", [None, [0.0, 0.5, 1.0]])
def test_pixel_tides_multiplemodels(satellite_ds, quantiles):
# Model tides using `pixel_tides` and multiple models
models = ["EOT20", "HAMTIDE11"]
modelled_tides_ds = pixel_tides(satellite_ds, model=models, calculate_quantiles=quantiles)
modelled_tides_lowres = pixel_tides(
satellite_ds,
model=models,
resample=False,
calculate_quantiles=quantiles,
)

# Verify that outputs contain quantile dim and values match inputs
assert modelled_tides_ds.dims == modelled_tides_lowres.dims
assert "tide_model" in modelled_tides_ds.dims
assert "tide_model" in modelled_tides_lowres.dims
assert modelled_tides_ds["tide_model"].values.tolist() == models
assert modelled_tides_lowres["tide_model"].values.tolist() == models

# Verify that both model outputs are correlated
assert (
xr.corr(
modelled_tides_ds.sel(tide_model="EOT20"),
modelled_tides_ds.sel(tide_model="HAMTIDE11"),
)
> 0.98
)


# Run test for different combinations of Dask chunking
@pytest.mark.parametrize(
"dask_chunks",
["auto", (300, 300), (200, 300)],
)
def test_pixel_tides_dask(satellite_ds, dask_chunks):
# Model tides with Dask compute turned off to return Dask arrays
modelled_tides_ds = pixel_tides(satellite_ds, dask_compute=False, dask_chunks=dask_chunks)

# Verify output is Dask-enabled
assert dask.is_dask_collection(modelled_tides_ds)

# If chunks set to "auto", check output matches `satellite_ds` chunks
if dask_chunks == "auto":
assert modelled_tides_ds.chunks == satellite_ds.nbart_red.chunks

# Otherwise, check output chunks match requested chunks
else:
output_chunks = tuple([i[0] for i in modelled_tides_ds.chunks[1:]])
assert output_chunks == dask_chunks


# Run test pixel tides and ensemble modelling
def test_pixel_tides_ensemble(satellite_ds):
# Model tides using `pixel_tides` and default ensemble model
modelled_tides_ds = pixel_tides(
satellite_ds,
model="ensemble",
ensemble_models=ENSEMBLE_MODELS,
)

assert modelled_tides_ds.tide_model == "ensemble"

# Model tides using `pixel_tides` and multiple models including
# ensemble and custom IDW params
models = ["EOT20", "HAMTIDE11", "ensemble"]
modelled_tides_ds = pixel_tides(
satellite_ds,
model=models,
ensemble_models=ENSEMBLE_MODELS,
k=10,
max_dist=20000,
)

assert "tide_model" in modelled_tides_ds.dims
assert set(modelled_tides_ds.tide_model.values) == set(models)

# Model tides using `pixel_tides` and custom ensemble funcs
ensemble_funcs = {
"ensemble-best": lambda x: x["rank"] == 1,
"ensemble-worst": lambda x: x["rank"] == 2,
"ensemble-mean-top2": lambda x: x["rank"].isin([1, 2]),
"ensemble-mean-weighted": lambda x: 3 - x["rank"],
"ensemble-mean": lambda x: x["rank"] <= 2,
}
modelled_tides_ds = pixel_tides(
satellite_ds,
model=models,
ensemble_func=ensemble_funcs,
ensemble_models=ENSEMBLE_MODELS,
)

assert set(modelled_tides_ds.tide_model.values) == set([
"EOT20",
"HAMTIDE11",
"ensemble-best",
"ensemble-worst",
"ensemble-mean-top2",
"ensemble-mean-weighted",
"ensemble-mean",
])

0 comments on commit 50ac8f7

Please sign in to comment.