diff --git a/cordex/__init__.py b/cordex/__init__.py index 532d739..903ba10 100644 --- a/cordex/__init__.py +++ b/cordex/__init__.py @@ -2,7 +2,14 @@ from . import regions, tables, tutorial from .accessor import CordexDataArrayAccessor, CordexDatasetAccessor # noqa -from .domain import cordex_domain, create_dataset, domain, domain_info, vertices +from .domain import ( + cordex_domain, + create_dataset, + domain, + domain_info, + vertices, + rewrite_coords, +) from .tables import domains, ecmwf from .transform import ( map_crs, @@ -41,4 +48,5 @@ "domains", "ecmwf", "cell_area", + "rewrite_coords", ] diff --git a/cordex/cf.py b/cordex/cf.py index 0ee3ea0..512fe67 100644 --- a/cordex/cf.py +++ b/cordex/cf.py @@ -186,20 +186,20 @@ "long_name": "latitude in rotated pole grid", "units": "degrees", }, - "longitude": { + "lon": { "standard_name": "longitude", "long_name": "longitude", "units": "degrees_east", }, - "latitude": { + "lat": { "standard_name": "latitude", "long_name": "latitude", "units": "degrees_north", }, - "lon_vertices": { + "vertices_lon": { "units": "degrees_east", }, - "lat_vertices": { + "vertices_lat": { "units": "degrees_north", }, }, diff --git a/cordex/domain.py b/cordex/domain.py index f990fd8..3edb906 100644 --- a/cordex/domain.py +++ b/cordex/domain.py @@ -9,7 +9,7 @@ from . import cf from .config import nround from .tables import domains -from .transform import grid_mapping, transform, transform_bounds +from .transform import grid_mapping, transform, transform_coords, transform_bounds from .utils import cell_area, get_tempfile @@ -566,6 +566,90 @@ def vertices(rlon, rlat, src_crs, trg_crs=None): return xr.merge([lat_vertices, lon_vertices]) +def rewrite_coords(ds, coords="xy", domain_id=None, mip_era="CMIP5", method="nearest"): + """ + Rewrite coordinates in a dataset to correct rounding errors. + + This function is useful for ensuring that the coordinates in a dataset are consistent and + can be compared to other datasets. It can reindex the dataset based on specified coordinates + or domain information by trying to keep the original coordinate attributes. + + Parameters + ---------- + ds : xr.Dataset + The dataset containing the grid to be rewritten. + coords : str, optional + Specifies which coordinates to rewrite. Options are: + - "xy": Rewrite only the X and Y coordinates. + - "lonlat": Rewrite only the longitude and latitude coordinates. + - "all": Rewrite both X, Y, longitude, and latitude coordinates. + Default is "xy". + domain_id : str, optional + The domain identifier used to obtain grid information. If not provided, the function will attempt to use the grid mapping information from the dataset. + mip_era : str, optional + The MIP era (e.g., "CMIP5", "CMIP6") used to determine coordinate attributes. Default is "CMIP5". + method : str, optional + The method used for reindexing. Options include "nearest", "linear", etc. Default is "nearest". + + Returns + ------- + ds : xr.Dataset + The dataset with rewritten coordinates. + """ + if ( + domain_id is None + and ds.cf["grid_mapping"].grid_mapping_name == "rotated_latitude_longitude" + ): + domain_id = ds.cx.domain_id + if domain_id: + # we use "official" grid information + grid_info = domain_info(domain_id) + dx = grid_info["dlon"] + dy = grid_info["dlat"] + x0 = grid_info["ll_lon"] + y0 = grid_info["ll_lat"] + nx = grid_info["nlon"] + ny = grid_info["nlat"] + else: + # we use the grid information from the dataset + x = ds.cf["X"].data + y = ds.cf["Y"].data + nx = x.size + ny = y.size + dx = (x[1] - x[0]).round(5) + dy = (y[1] - y[0]).round(5) + x0 = x[0].round(nround) + y0 = y[0].round(nround) + + xn = _lin_coord(nx, dx, x0) + yn = _lin_coord(ny, dy, y0) + + if coords == "xy" or coords == "all": + ds = ds.cf.reindex(X=xn, Y=yn, method=method) + + if coords == "lonlat" or coords == "all": + # check if the dataset already has longitude and latitude coordinates + # if so, overwrite them (take care to keep attributes though) + try: + trg_dims = (ds.cf["longitude"].name, ds.cf["latitude"].name) + overwrite = True + except KeyError: + trg_dims = ("lon", "lat") + overwrite = False + dst = transform_coords(ds, trg_dims=trg_dims) + if overwrite is False: + ds = ds.assign_coords( + {trg_dims[0]: dst[trg_dims[0]], trg_dims[1]: dst[trg_dims[1]]} + ) + ds[trg_dims[0]].attrs = cf.vocabulary[mip_era]["coords"][trg_dims[0]] + ds[trg_dims[1]].attrs = cf.vocabulary[mip_era]["coords"][trg_dims[1]] + else: + ds[trg_dims[0]][:] = dst[trg_dims[0]] + ds[trg_dims[1]][:] = dst[trg_dims[1]] + + return ds + + def _crop_to_domain(ds, domain_id, drop=True): domain = cordex_domain(domain_id) x_mask = ds.cf["X"].round(8).isin(domain.cf["X"]) diff --git a/cordex/preprocessing/preprocessing.py b/cordex/preprocessing/preprocessing.py index 125878f..f1835a6 100644 --- a/cordex/preprocessing/preprocessing.py +++ b/cordex/preprocessing/preprocessing.py @@ -271,7 +271,7 @@ def replace_rlon_rlat(ds, domain=None): """ ds = ds.copy() if domain is None: - domain = ds.attrs.get("CORDEX_domain", None) + domain = ds.cx.domain_id dm = cordex_domain(domain) for coord in ["rlon", "rlat"]: if coord in ds.coords: @@ -327,7 +327,7 @@ def replace_lon_lat(ds, domain=None): """ ds = ds.copy() if domain is None: - domain = ds.attrs.get("CORDEX_domain", None) + domain = ds.cx.domain_id dm = cordex_domain(domain) for coord in ["lon", "lat"]: if coord in ds.coords: diff --git a/docs/api.rst b/docs/api.rst index 6516a4e..7d41ca0 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -19,6 +19,7 @@ Top-level functions transform transform_coords transform_bounds + rewrite_coords derotate_vector cell_area diff --git a/docs/whats_new.rst b/docs/whats_new.rst index b5e8480..f092ed5 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -3,6 +3,22 @@ What's new ========== +v0.10.0 (Unreleased) +-------------------- + +New function :py:meth:`cordex.rewrite_coords` that rewrites coordinates (X and Y axes and transformed coordinates) in a dataset to correct +rounding errors. This version drops python3.8 support. + +New Features +~~~~~~~~~~~~ + +- New function :py:meth:`cordex.rewrite_coords` (:pull:`306`). + +Breaking Changes +~~~~~~~~~~~~~~~~ + +- Drop python3.8 support (:pull:`306`). + v0.9.0 (18 November 2024) ------------------------- diff --git a/tests/test_domain.py b/tests/test_domain.py index 63786f6..6fb310a 100644 --- a/tests/test_domain.py +++ b/tests/test_domain.py @@ -1,6 +1,7 @@ import numpy as np import pytest import xarray as xr +import pandas as pd import cordex as cx @@ -78,7 +79,6 @@ def test_constructor(): @pytest.mark.parametrize("bounds", [False, True]) # @pytest.mark.parametrize("", [2, 3]) def test_domain_info(bounds): - import pandas as pd info = { "short_name": "EUR-11", @@ -113,3 +113,60 @@ def test_vertices(): eur11.rotated_latitude_longitude.grid_north_pole_latitude, ) cx.vertices(eur11.rlon, eur11.rlat, src_crs=ccrs.RotatedPole(*pole)) + + +@pytest.mark.parametrize("domain_id", ["EUR-11", "EUR-44", "SAM-44", "AFR-22"]) +def test_rewrite_coords(domain_id): + """ + Test the rewrite_coords function for different domains. + + This function tests the rewrite_coords function by creating a sample dataset + with typical coordinate precision issues (random noise) and verifying that + the coordinates are correctly rewritten. + + Parameters + ---------- + domain_id : str + The domain identifier used to obtain grid information for testing. + """ + # Create a sample dataset + grid = cx.domain(domain_id) + + # Create typical coordinate precision issue by adding random noise + rlon_noise = np.random.randn(*grid.rlon.shape) * np.finfo("float32").eps + rlat_noise = np.random.randn(*grid.rlat.shape) * np.finfo("float32").eps + lon_noise = np.random.randn(*grid.lon.shape) * np.finfo("float32").eps + lat_noise = np.random.randn(*grid.lat.shape) * np.finfo("float32").eps + + # Call the rewrite_coords function for "xy" coordinates + grid_noise = grid.assign_coords( + rlon=grid.rlon + rlon_noise, rlat=grid.rlat + rlat_noise + ) + rewritten_data = cx.rewrite_coords(grid_noise, coords="xy") + + np.testing.assert_array_equal(rewritten_data.rlon, grid.rlon) + np.testing.assert_array_equal(rewritten_data.rlat, grid.rlat) + xr.testing.assert_identical(rewritten_data, grid) + + # Call the rewrite_coords function for "lonlat" coordinates + grid_noise = grid.assign_coords(lon=grid.lon + lon_noise, lat=grid.lat + lat_noise) + rewritten_data = cx.rewrite_coords(grid_noise, coords="lonlat") + + np.testing.assert_array_equal(rewritten_data.lon, grid.lon) + np.testing.assert_array_equal(rewritten_data.lat, grid.lat) + xr.testing.assert_identical(rewritten_data, grid) + + # Call the rewrite_coords function for "all" coordinates + grid_noise = grid.assign_coords( + rlon=grid.rlon + rlon_noise, + rlat=grid.rlat + rlat_noise, + lon=grid.lon + lon_noise, + lat=grid.lat + lat_noise, + ) + rewritten_data = cx.rewrite_coords(grid_noise, coords="all") + + np.testing.assert_array_equal(rewritten_data.rlon, grid.rlon) + np.testing.assert_array_equal(rewritten_data.rlat, grid.rlat) + np.testing.assert_array_equal(rewritten_data.lon, grid.lon) + np.testing.assert_array_equal(rewritten_data.lat, grid.lat) + xr.testing.assert_identical(rewritten_data, grid)