diff --git a/docs/source/conf.py b/docs/source/conf.py index 65289860..d8c5ade7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -77,6 +77,8 @@ "xarray": ("https://docs.xarray.dev/en/stable/", None), "pyvista": ("https://docs.pyvista.org/version/stable/", None), "pyresample": ("https://pyresample.readthedocs.io/en/stable/", None), + "dask": ("https://docs.dask.org/en/stable/", None), + # "polars": ("https://docs.pola.rs/", None), } always_document_param_types = True diff --git a/gpm/accessor/methods.py b/gpm/accessor/methods.py index fe6ad29a..f576f512 100644 --- a/gpm/accessor/methods.py +++ b/gpm/accessor/methods.py @@ -392,10 +392,26 @@ def get_slices_regular_time(self, tolerance=None, min_size=1): return get_slices_regular_time(self._obj, tolerance=tolerance, min_size=min_size) @auto_wrap_docstring - def get_slices_contiguous_scans(self, min_size=2, min_n_scans=3): + def get_slices_contiguous_scans( + self, + min_size=2, + min_n_scans=3, + x="lon", + y="lat", + along_track_dim="along_track", + cross_track_dim="cross_track", + ): from gpm.utils.checks import get_slices_contiguous_scans - return get_slices_contiguous_scans(self._obj, min_size=min_size, min_n_scans=min_n_scans) + return get_slices_contiguous_scans( + self._obj, + min_size=min_size, + min_n_scans=min_n_scans, + x=x, + y=y, + cross_track_dim=cross_track_dim, + along_track_dim=along_track_dim, + ) @auto_wrap_docstring def get_slices_contiguous_granules(self, min_size=2): @@ -404,16 +420,46 @@ def get_slices_contiguous_granules(self, min_size=2): return get_slices_contiguous_granules(self._obj, min_size=min_size) @auto_wrap_docstring - def get_slices_valid_geolocation(self, min_size=2): + def get_slices_valid_geolocation( + self, + min_size=2, + x="lon", + y="lat", + along_track_dim="along_track", + cross_track_dim="cross_track", + ): from gpm.utils.checks import get_slices_valid_geolocation - return get_slices_valid_geolocation(self._obj, min_size=min_size) + return get_slices_valid_geolocation( + self._obj, + min_size=min_size, + x=x, + y=y, + cross_track_dim=cross_track_dim, + along_track_dim=along_track_dim, + ) @auto_wrap_docstring - def get_slices_regular(self, min_size=None, min_n_scans=3): + def get_slices_regular( + self, + min_size=None, + min_n_scans=3, + x="lon", + y="lat", + along_track_dim="along_track", + cross_track_dim="cross_track", + ): from gpm.utils.checks import get_slices_regular - return get_slices_regular(self._obj, min_size=min_size, min_n_scans=min_n_scans) + return get_slices_regular( + self._obj, + min_size=min_size, + min_n_scans=min_n_scans, + x=x, + y=y, + cross_track_dim=cross_track_dim, + along_track_dim=along_track_dim, + ) #### Plotting utility @auto_wrap_docstring @@ -499,8 +545,8 @@ def plot_swath_lines( @auto_wrap_docstring def plot_map_mesh( self, - x="lon", - y="lat", + x=None, + y=None, ax=None, edgecolors="k", linewidth=0.1, @@ -527,8 +573,8 @@ def plot_map_mesh( @auto_wrap_docstring def plot_map_mesh_centroids( self, - x="lon", - y="lat", + x=None, + y=None, ax=None, c="r", s=1, @@ -653,8 +699,8 @@ def plot_map( self, variable, ax=None, - x="lon", - y="lat", + x=None, + y=None, add_colorbar=True, add_swath_lines=True, add_background=True, @@ -854,8 +900,8 @@ def title( def plot_map( self, ax=None, - x="lon", - y="lat", + x=None, + y=None, add_colorbar=True, add_swath_lines=True, add_background=True, diff --git a/gpm/bucket/partitioning.py b/gpm/bucket/partitioning.py index 829b1a8f..0250cf79 100644 --- a/gpm/bucket/partitioning.py +++ b/gpm/bucket/partitioning.py @@ -428,7 +428,7 @@ def quadmesh(self, origin="bottom"): Return -------- - np.ndarray + numpy.ndarray Quadmesh array of shape (M+1, N+1, 2) """ x_corners, y_corners = np.meshgrid(self.x_bounds, self.y_bounds) @@ -440,7 +440,7 @@ def vertices(self, origin="bottom", ccw=True): """Return the partitions vertices in an array of shape (N, M, 4, 2). The output vertices, once the first 2 dimension are flattened, - can be passed directly to a matplotlib.PolyCollection. + can be passed directly to a `matplotlib.PolyCollection`. For plotting with cartopy, the polygon order must be "counterclockwise". Parameters @@ -563,7 +563,7 @@ def add_labels(self, df, x, y, remove_invalid_rows=True): Parameters ---------- - df : pandas.DataFrame, dask.DataFrame, polars.DataFrame or polars.LazyFrame + df : `pandas.DataFrame`, `dask.DataFrame`, `polars.DataFrame`, `pyarrow.Table` or `polars.LazyFrame` Dataframe to which add partitions centroids. x : str Column name with the x coordinate. @@ -575,7 +575,7 @@ def add_labels(self, df, x, y, remove_invalid_rows=True): Returns ------- - df : pandas.DataFrame, dask.DataFrame, polars.DataFrame or polars.LazyFrame + df : `pandas.DataFrame`, `dask.DataFrame`, `polars.DataFrame`, `pyarrow.Table` or `polars.LazyFrame` Dataframe with the partitions label(s) column(s). """ @@ -607,7 +607,7 @@ def add_centroids(self, df, x, y, x_coord=None, y_coord=None, remove_invalid_row Parameters ---------- - df : pandas.DataFrame, dask.DataFrame, polars.DataFrame or polars.LazyFrame + df : `pandas.DataFrame`, `dask.DataFrame`, `polars.DataFrame`, `pyarrow.Table` or `polars.LazyFrame` Dataframe to which add partitions centroids. x : str Column name with the x coordinate. @@ -625,7 +625,7 @@ def add_centroids(self, df, x, y, x_coord=None, y_coord=None, remove_invalid_row Returns ------- - df : pandas.DataFrame, dask.DataFrame, polars.DataFrame or polars.LazyFrame + df : `pandas.DataFrame`, `dask.DataFrame`, `polars.DataFrame`, `pyarrow.Table` or `polars.LazyFrame` Dataframe with the partitions centroids x and y coordinates columns. """ @@ -979,10 +979,9 @@ def to_dict(self): class LonLatPartitioning(XYPartitioning): - """ - Handles geographic partitioning of data based on longitude and latitude bin sizes within a defined extent. + """Handles geographic partitioning of data based on longitude and latitude bin sizes within a defined extent. - The last bin size (in lon and lat direction) might not be of size ``size` ! + The last bin size (in lon and lat direction) might not be of size ``size`` ! Parameters ---------- diff --git a/gpm/checks.py b/gpm/checks.py index d2b31f16..3b5bc8f7 100644 --- a/gpm/checks.py +++ b/gpm/checks.py @@ -241,19 +241,21 @@ def _is_expected_spatial_dims(spatial_dims): def is_orbit(xr_obj): """Check whether the xarray object is a GPM ORBIT. - An orbit transect or nadir view is considered ORBIT. + An ORBIT transect or nadir view is considered ORBIT. + An ORBIT object must have the coordinates available ! + """ - from gpm.dataset.crs import _get_proj_dim_coords + from gpm.dataset.crs import _get_swath_dim_coords # Check dimension names spatial_dims = get_spatial_dimensions(xr_obj) if not _is_orbit_expected_spatial_dims(spatial_dims): return False - # Check that no 1D coords exists - # - Swath objects are determined by 2D coordinates only - x_coord, y_coord = _get_proj_dim_coords(xr_obj) - if x_coord is None and y_coord is None: + # Check that swath coords exists + # - Swath objects are determined by 1D (nadir looking) and 2D coordinates + x_coord, y_coord = _get_swath_dim_coords(xr_obj) + if x_coord is not None and y_coord is not None: return True return False @@ -262,6 +264,7 @@ def is_grid(xr_obj): """Check whether the xarray object is a GPM GRID. A GRID slice is not considered a GRID object ! + An GRID object must have the coordinates available ! """ from gpm.dataset.crs import _get_proj_dim_coords @@ -431,14 +434,14 @@ def check_is_gpm_object(xr_obj): raise ValueError("Unrecognized GPM xarray object.") -def check_has_cross_track_dim(xr_obj): - if "cross_track" not in xr_obj.dims: - raise ValueError("The 'cross-track' dimension is not available.") +def check_has_cross_track_dim(xr_obj, dim="cross_track"): + if dim not in xr_obj.dims: + raise ValueError(f"The 'cross-track' dimension {dim} is not available.") -def check_has_along_track_dim(xr_obj): - if "along_track" not in xr_obj.dims: - raise ValueError("The 'along_track' dimension is not available.") +def check_has_along_track_dim(xr_obj, dim="along_track"): + if dim not in xr_obj.dims: + raise ValueError(f"The 'along_track' dimension {dim} is not available.") def check_is_spatial_2d(xr_obj, strict=True, squeeze=True): diff --git a/gpm/retrievals/retrieval_1b_radar.py b/gpm/retrievals/retrieval_1b_radar.py index 2f9707a3..39d70ca0 100644 --- a/gpm/retrievals/retrieval_1b_radar.py +++ b/gpm/retrievals/retrieval_1b_radar.py @@ -530,6 +530,8 @@ def open_dataset_1b_ka_fs( """Open 1B-Ka dataset in FS scan_mode format in either L1B or L2 format. It expects start_time after the GPM DPR scan pattern change occurred the 8 May 2018. + (Over)Resample HS on MS (using LUT/range distance). + The L2 FS format has 176 bins. Notes ----- @@ -539,8 +541,6 @@ def open_dataset_1b_ka_fs( - 1B-Ka MS have range resolution of 125 m (260 bins) - 2A-Ka HS have range resolution of 125 m (88 bins) - 2A-Ka MS have range resolution of 125 m (176 bins) - --> (Over)Resample HS on MS (using LUT/range distance) - --> L2 FS format has 176 bins """ from gpm.io.checks import check_time diff --git a/gpm/tests/test_checks.py b/gpm/tests/test_checks.py index 871afa16..9882db89 100644 --- a/gpm/tests/test_checks.py +++ b/gpm/tests/test_checks.py @@ -369,10 +369,23 @@ def test_is_orbit( assert is_orbit(orbit_dataarray.isel(along_track=0)) assert is_orbit(orbit_dataarray.isel(cross_track=0)) # nadir-view + # Check with other dimensions names + assert is_orbit(orbit_dataarray.rename({"lon": "longitude", "lat": "latitude"})) + assert is_orbit(orbit_dataarray.rename({"cross_track": "y", "along_track": "x"})) + assert is_orbit(orbit_dataarray.isel(along_track=0).rename({"cross_track": "y"})) + assert is_orbit(orbit_dataarray.isel(cross_track=0).rename({"along_track": "x"})) + + # Check grid is not confound with orbit assert not is_orbit(grid_dataarray.isel(lon=0)) assert not is_orbit(grid_dataarray.isel(lat=0)) assert not is_orbit(xr.DataArray()) + # Check also strange edge cases + assert not is_orbit(grid_dataarray.isel(lat=0).rename({"lon": "x"})) + assert not is_orbit(grid_dataarray.isel(lon=0).rename({"lat": "y"})) + assert not is_orbit(grid_dataarray.isel(lon=0).rename({"lat": "cross_track"})) + assert not is_orbit(grid_dataarray.isel(lon=0).rename({"lat": "along_track"})) + # With one dimensional longitude n_x = 10 n_y = 20 @@ -382,6 +395,10 @@ def test_is_orbit( invalid_da = xr.DataArray(data, coords={"x": x, "y": y}) assert not is_orbit(invalid_da) + # Assert without coordinates + assert not is_orbit(grid_dataarray.drop_vars(["lon", "lat"])) + assert not is_orbit(orbit_dataarray.drop_vars(["lon", "lat"])) + def test_is_grid( orbit_dataarray: xr.DataArray, @@ -392,11 +409,20 @@ def test_is_grid( assert not is_grid(grid_dataarray.isel(lon=0)) assert not is_grid(grid_dataarray.isel(lat=0)) + # Check with other dimensions names + assert is_grid(grid_dataarray.rename({"lon": "longitude", "lat": "latitude"})) + assert is_grid(grid_dataarray.rename({"lon": "x", "lat": "y"})) + + # Check orbit is not confound with grid assert not is_grid(orbit_dataarray) assert not is_grid(orbit_dataarray.isel(along_track=0)) assert not is_grid(orbit_dataarray.isel(cross_track=0)) assert not is_grid(xr.DataArray()) + # Assert without coordinates + assert not is_grid(grid_dataarray.drop_vars(["lon", "lat"])) + assert not is_grid(orbit_dataarray.drop_vars(["lon", "lat"])) + def test_check_is_orbit( orbit_dataarray: xr.DataArray, diff --git a/gpm/tests/test_utils/test_utils_checks.py b/gpm/tests/test_utils/test_utils_checks.py index dfee1c4f..b51912e3 100644 --- a/gpm/tests/test_utils/test_utils_checks.py +++ b/gpm/tests/test_utils/test_utils_checks.py @@ -417,8 +417,8 @@ def ds_contiguous(self) -> xr.Dataset: ds["lat"] = (("cross_track", "along_track"), lat) ds["lon"] = (("cross_track", "along_track"), lon) ds["gpm_granule_id"] = np.ones(self.n_along_track) - ds["time"] = time - + ds = ds.set_coords(["lon", "lat"]) + ds = ds.assign_coords({"time": ("along_track", time)}) return ds @pytest.fixture() @@ -596,7 +596,8 @@ def ds_orbit_valid(self) -> xr.Dataset: ds = xr.Dataset() ds["lon"] = (("cross_track", "along_track"), lon) ds["lat"] = (("cross_track", "along_track"), lat) - ds["time"] = (("along_track"), time) + ds = ds.set_coords(["lon", "lat"]) + ds = ds.assign_coords({"time": ("along_track", time)}) return ds @pytest.fixture() @@ -631,11 +632,11 @@ def test_is_valid_geolocation( ) -> None: """Test _is_valid_geolocation.""" # Valid - valid = checks._is_valid_geolocation(ds_orbit_valid) + valid = checks._is_valid_geolocation(ds_orbit_valid, coord="lon") assert np.all(valid) # Invalid - valid = checks._is_valid_geolocation(ds_orbit_invalid) + valid = checks._is_valid_geolocation(ds_orbit_invalid, coord="lon") assert np.sum(valid.all(dim="cross_track")) == self.n_along_track - 1 @pytest.mark.usefixtures("_set_is_orbit_to_true") @@ -708,7 +709,8 @@ class TestWobblingSwath: ds["lat"] = (("cross_track", "along_track"), lat) ds["lon"] = (("cross_track", "along_track"), lon) ds["gpm_granule_id"] = np.ones(n_along_track) - ds["time"] = time + ds = ds.set_coords(["lon", "lat"]) + ds = ds.assign_coords({"time": ("along_track", time)}) # Threshold must be at least 3 to remove wobbling slices threshold = 3 @@ -773,8 +775,10 @@ def ds_orbit(self) -> xr.Dataset: ds = xr.Dataset() ds["lon"] = (("cross_track", "along_track"), lon) ds["lat"] = (("cross_track", "along_track"), lon) + ds = ds.set_coords(["lon", "lat"]) granule_ids = np.array([0, 0, 0, 1, 1, 1, 2, 2, 7, 8]) - return ds.assign_coords({"gpm_granule_id": ("along_track", granule_ids)}) + ds = ds.assign_coords({"gpm_granule_id": ("along_track", granule_ids)}) + return ds @pytest.fixture() def ds_grid(self) -> xr.Dataset: diff --git a/gpm/tests/test_visualization/test_facetgrid.py b/gpm/tests/test_visualization/test_facetgrid.py index e2b37ae2..94c13202 100644 --- a/gpm/tests/test_visualization/test_facetgrid.py +++ b/gpm/tests/test_visualization/test_facetgrid.py @@ -268,7 +268,9 @@ def test_grid_col_rgb( ) -> None: """Test plotting orbit data using row, col and rgb arguments.""" grid_dataarray_4_frames_rgb = expand_dims(grid_dataarray_4_frames, 3, dim="rgb", axis=-1) - p = plot.plot_map(grid_dataarray_4_frames_rgb, col=EXTRA_DIM, col_wrap=2, rgb="rgb") + # BUG in xarray if y,x not provided + # p = plot.plot_map(grid_dataarray_4_frames_rgb, col=EXTRA_DIM, col_wrap=2, rgb="rgb") + p = plot.plot_map(grid_dataarray_4_frames_rgb, y="lat", x="lon", col=EXTRA_DIM, col_wrap=2, rgb="rgb") save_and_check_figure(figure=p.fig, name=get_test_name()) @@ -283,6 +285,14 @@ def test_orbit( p = plot.plot_image(orbit_dataarray_4_frames, col=EXTRA_DIM, col_wrap=2) save_and_check_figure(figure=p.fig, name=get_test_name()) + def test_orbit_without_coords( + self, + orbit_dataarray_4_frames: xr.DataArray, + ) -> None: + """Test plotting orbit data without coordinates.""" + p = plot.plot_image(orbit_dataarray_4_frames.drop_vars(["lon", "lat"]), col=EXTRA_DIM, col_wrap=2) + save_and_check_figure(figure=p.fig, name=get_test_name()) + def test_orbit_col_row_rgb( self, orbit_dataarray_2x2_frames: xr.DataArray, diff --git a/gpm/tests/test_visualization/test_plot.py b/gpm/tests/test_visualization/test_plot.py index fa74d414..48fd0e3c 100644 --- a/gpm/tests/test_visualization/test_plot.py +++ b/gpm/tests/test_visualization/test_plot.py @@ -183,6 +183,22 @@ def test_orbit_pole_projection( p = plot.plot_map(orbit_pole_dataarray, subplot_kwargs={"projection": crs_proj}) save_and_check_figure(figure=p.figure, name=get_test_name()) + def test_orbit_xy_dim( + self, + orbit_dataarray: xr.DataArray, + ) -> None: + """Test plotting orbit data with x and y dimensions.""" + p = plot.plot_map(orbit_dataarray.rename({"cross_track": "y", "along_track": "x"})) + save_and_check_figure(figure=p.figure, name=get_test_name()) + + def test_orbit_longitude_latitude_coords( + self, + orbit_dataarray: xr.DataArray, + ) -> None: + """Test plotting orbit data with longitude and latitude coordinates.""" + p = plot.plot_map(orbit_dataarray.rename({"lon": "longitude", "lat": "latitude"})) + save_and_check_figure(figure=p.figure, name=get_test_name()) + ####------------------------------------------------------------------------ #### - Test with NaN in the data @@ -501,6 +517,22 @@ def test_grid_time_dim( with pytest.raises(ValueError): # Expecting a 2D GPM field plot.plot_map(grid_dataarray) + def test_grid_xy_dim( + self, + grid_dataarray: xr.DataArray, + ) -> None: + """Test plotting grid data with x and y dimensions.""" + p = plot.plot_map(grid_dataarray.rename({"lat": "y", "lon": "x"})) + save_and_check_figure(figure=p.figure, name=get_test_name()) + + def test_grid_longitude_latitude_coords( + self, + grid_dataarray: xr.DataArray, + ) -> None: + """Test plotting grid data with longitude and latitude coordinates.""" + p = plot.plot_map(grid_dataarray.rename({"lon": "longitude", "lat": "latitude"})) + save_and_check_figure(figure=p.figure, name=get_test_name()) + def test_invalid( self, ) -> None: @@ -518,7 +550,23 @@ def test_orbit( orbit_dataarray: xr.DataArray, ) -> None: """Test plotting orbit data.""" - p = plot.plot_image(orbit_dataarray) + p = plot.plot_image(orbit_dataarray.drop_vars(["lon", "lat"])) + save_and_check_figure(figure=p.figure, name=get_test_name()) + + def test_orbit_without_coords( + self, + orbit_dataarray: xr.DataArray, + ) -> None: + """Test plotting orbit data without coordinates.""" + p = plot.plot_image(orbit_dataarray.drop_vars(["lon", "lat"])) + save_and_check_figure(figure=p.figure, name=get_test_name()) + + def test_orbit_with_xy_dims( + self, + orbit_dataarray: xr.DataArray, + ) -> None: + """Test plotting orbit data with x and y dimensions.""" + p = plot.plot_image(orbit_dataarray.rename({"cross_track": "y", "along_track": "x"})) save_and_check_figure(figure=p.figure, name=get_test_name()) def test_orbit_alpha_array( @@ -581,6 +629,30 @@ def test_grid( p = plot.plot_image(grid_dataarray) save_and_check_figure(figure=p.figure, name=get_test_name()) + def test_grid_without_coords( + self, + grid_dataarray: xr.DataArray, + ) -> None: + """Test plotting grid data without coordinates.""" + p = plot.plot_image(grid_dataarray.drop_vars(["lon", "lat"])) + save_and_check_figure(figure=p.figure, name=get_test_name()) + + def test_grid_with_xy_dims( + self, + grid_dataarray: xr.DataArray, + ) -> None: + """Test plotting grid data with x and y dimensions.""" + p = plot.plot_image(grid_dataarray.rename({"lat": "y", "lon": "x"})) + save_and_check_figure(figure=p.figure, name=get_test_name()) + + def test_grid_with_longitude_latitude_coords( + self, + grid_dataarray: xr.DataArray, + ) -> None: + """Test plotting grid data with x and y dimensions.""" + p = plot.plot_image(grid_dataarray.rename({"lat": "latitude", "lon": "longitude"})) + save_and_check_figure(figure=p.figure, name=get_test_name()) + def test_invalid( self, ) -> None: diff --git a/gpm/utils/checks.py b/gpm/utils/checks.py index eda8c8c1..affe2036 100644 --- a/gpm/utils/checks.py +++ b/gpm/utils/checks.py @@ -30,9 +30,8 @@ import numpy as np import pandas as pd -from gpm.checks import is_grid, is_orbit +from gpm.checks import check_has_along_track_dim, is_grid, is_orbit from gpm.utils.decorators import ( - check_has_along_track_dimension, check_is_gpm_object, check_is_orbit, ) @@ -382,26 +381,26 @@ def has_regular_time(xr_obj): ######################## -def _select_lons_lats_centroids(xr_obj): - if "cross_track" not in xr_obj.dims: - lons = xr_obj["lon"].to_numpy() - lats = xr_obj["lat"].to_numpy() +def _select_lons_lats_centroids(xr_obj, x="lon", y="lat", cross_track_dim="cross_track"): + if cross_track_dim not in xr_obj.dims: + lons = xr_obj[x].to_numpy() + lats = xr_obj[y].to_numpy() else: # Select centroids coordinates in the middle of the cross_track scan - middle_idx = int(xr_obj["cross_track"].shape[0] / 2) - lons = xr_obj["lon"].isel(cross_track=middle_idx).to_numpy() - lats = xr_obj["lat"].isel(cross_track=middle_idx).to_numpy() + middle_idx = int(xr_obj[cross_track_dim].shape[0] / 2) + lons = xr_obj[x].isel({cross_track_dim: middle_idx}).to_numpy() + lats = xr_obj[y].isel({cross_track_dim: middle_idx}).to_numpy() return lons, lats -def _get_along_track_scan_distance(xr_obj): +def _get_along_track_scan_distance(xr_obj, x="lon", y="lat", cross_track_dim="cross_track"): """Compute the distance between along_track centroids.""" from pyproj import Geod # Select centroids coordinates # - If no cross-track, take the available lat/lon 1D array # - If cross-track dimension is present, takes the coordinates in the swath middle. - lons, lats = _select_lons_lats_centroids(xr_obj) + lons, lats = _select_lons_lats_centroids(xr_obj, x=x, y=y, cross_track_dim=cross_track_dim) # Define between-centroids line coordinates start_lons = lons[:-1] @@ -415,7 +414,7 @@ def _get_along_track_scan_distance(xr_obj): return dist -def _is_contiguous_scans(xr_obj): +def _is_contiguous_scans(xr_obj, x="lon", y="lat", cross_track_dim="cross_track"): """Return a boolean array indicating if the next scan is contiguous. It assumes at least 3 scans are provided. @@ -423,7 +422,7 @@ def _is_contiguous_scans(xr_obj): The last element is set to True since it can not be verified. """ # Compute along track scan distance - dist = _get_along_track_scan_distance(xr_obj) + dist = _get_along_track_scan_distance(xr_obj, x=x, y=y, cross_track_dim=cross_track_dim) # Convert to km and round dist_km = np.round(dist / 1000, 0) @@ -451,8 +450,15 @@ def _is_contiguous_scans(xr_obj): @check_is_orbit -@check_has_along_track_dimension -def get_slices_contiguous_scans(xr_obj, min_size=2, min_n_scans=3): +def get_slices_contiguous_scans( + xr_obj, + min_size=2, + min_n_scans=3, + x="lon", + y="lat", + along_track_dim="along_track", + cross_track_dim="cross_track", +): """Return a list of slices ensuring contiguous scans (and granules). It checks for contiguous scans only in the middle of the cross-track ! @@ -477,8 +483,10 @@ def get_slices_contiguous_scans(xr_obj, min_size=2, min_n_scans=3): Output format: ``[slice(start,stop), slice(start,stop),...]`` """ + check_has_along_track_dim(xr_obj, dim=along_track_dim) + # Get number of scans - n_scans = xr_obj["along_track"].shape[0] + n_scans = xr_obj[along_track_dim].shape[0] # Define behaviour if less than 2/3 scan along track # --> Contiguity can't be verified without at least 3 slices ! @@ -488,7 +496,7 @@ def get_slices_contiguous_scans(xr_obj, min_size=2, min_n_scans=3): return [] # Get boolean array indicating if the next scan is contiguous - is_contiguous = _is_contiguous_scans(xr_obj) + is_contiguous = _is_contiguous_scans(xr_obj, x=x, y=y, cross_track_dim=cross_track_dim) # If non-contiguous scans are present, get the slices with contiguous scans # - It discard consecutive non-contiguous scans @@ -511,8 +519,13 @@ def get_slices_contiguous_scans(xr_obj, min_size=2, min_n_scans=3): @check_is_orbit -@check_has_along_track_dimension -def get_slices_non_contiguous_scans(xr_obj): +def get_slices_non_contiguous_scans( + xr_obj, + x="lon", + y="lat", + along_track_dim="along_track", + cross_track_dim="cross_track", +): """Return a list of slices where the scans discontinuity occurs. An input with less than ``2`` scans (along-track) returns an empty list. @@ -529,8 +542,10 @@ def get_slices_non_contiguous_scans(xr_obj): Output format: ``[slice(start,stop), slice(start,stop),...]`` """ + check_has_along_track_dim(xr_obj, dim=along_track_dim) + # Get number of scans - n_scans = xr_obj["along_track"].shape[0] + n_scans = xr_obj[along_track_dim].shape[0] # Define behaviour if less than 3 scan along track # --> Contiguity can't be verified without at least 3 slices ! @@ -550,19 +565,32 @@ def get_slices_non_contiguous_scans(xr_obj): # list_slices = [] # return list_slices - list_slices_valid = get_slices_contiguous_scans(xr_obj, min_size=2) - list_slices_full = [slice(0, len(xr_obj["along_track"]))] + list_slices_valid = get_slices_contiguous_scans( + xr_obj, + min_size=2, + x=x, + y=y, + along_track_dim=along_track_dim, + cross_track_dim=cross_track_dim, + ) + list_slices_full = [slice(0, len(xr_obj[along_track_dim]))] return list_slices_difference(list_slices_full, list_slices_valid) -# @check_has_cross_track_dimension -@check_has_along_track_dimension -def check_contiguous_scans(xr_obj, verbose=True): +def check_contiguous_scans( + xr_obj, + verbose=True, + x="lon", + y="lat", + along_track_dim="along_track", + cross_track_dim="cross_track", +): """Check no missing scans across the along_track direction. Note: - This sometimes occurs between orbit granules. - This sometimes occurs within a orbit granule. + - This function also works for nadir-looking only orbits (no cross-track). Parameters ---------- @@ -572,7 +600,14 @@ def check_contiguous_scans(xr_obj, verbose=True): If ``True``, it prints the time interval when the non contiguous scans occurs. """ - list_discontinuous_slices = get_slices_non_contiguous_scans(xr_obj) + check_has_along_track_dim(xr_obj, dim=along_track_dim) + list_discontinuous_slices = get_slices_non_contiguous_scans( + xr_obj, + x=x, + y=y, + along_track_dim=along_track_dim, + cross_track_dim=cross_track_dim, + ) n_discontinuous = len(list_discontinuous_slices) if n_discontinuous > 0: # Retrieve discontinuous timesteps interval @@ -590,11 +625,19 @@ def check_contiguous_scans(xr_obj, verbose=True): raise ValueError(msg) -# @check_has_cross_track_dimension -@check_has_along_track_dimension -def has_contiguous_scans(xr_obj): - """Return ``True`` if all scans are contiguous. ``False`` otherwise.""" - list_discontinuous_slices = get_slices_non_contiguous_scans(xr_obj) +def has_contiguous_scans(xr_obj, x="lon", y="lat", along_track_dim="along_track", cross_track_dim="cross_track"): + """Return ``True`` if all scans are contiguous. ``False`` otherwise. + + This functions also works with nadir-only looking orbit. + """ + check_has_along_track_dim(xr_obj, dim=along_track_dim) + list_discontinuous_slices = get_slices_non_contiguous_scans( + xr_obj, + x=x, + y=y, + along_track_dim=along_track_dim, + cross_track_dim=cross_track_dim, + ) n_discontinuous = len(list_discontinuous_slices) if n_discontinuous > 0: return False @@ -607,24 +650,31 @@ def has_contiguous_scans(xr_obj): ############################# -def _is_non_valid_geolocation(xr_obj, x="lon"): +def _is_non_valid_geolocation(xr_obj, coord): """Return a boolean array indicating if the geolocation is invalid. `True = Invalid`, `False = Valid`. """ - return np.isnan(xr_obj[x]) + return np.isnan(xr_obj[coord]) -def _is_valid_geolocation(xr_obj, x="lon"): +def _is_valid_geolocation(xr_obj, coord): """Return a boolean array indicating if the geolocation is valid. `True = Valid`, `False = Invalid`. """ - return ~np.isnan(xr_obj[x]) + return ~np.isnan(xr_obj[coord]) @check_is_orbit -def get_slices_valid_geolocation(xr_obj, min_size=2): +def get_slices_valid_geolocation( + xr_obj, + min_size=2, + x="lon", + y="lat", + along_track_dim="along_track", + cross_track_dim="cross_track", +): """Return a list of GPM ORBIT along-track slices with valid geolocation. The minimum size of the output slices is ``2``. @@ -648,19 +698,19 @@ def get_slices_valid_geolocation(xr_obj, min_size=2): """ # - Get invalid coordinates - invalid_lon_coords = _is_non_valid_geolocation(xr_obj, x="lon") - invalid_lat_coords = _is_non_valid_geolocation(xr_obj, x="lat") + invalid_lon_coords = _is_non_valid_geolocation(xr_obj, coord=x) + invalid_lat_coords = _is_non_valid_geolocation(xr_obj, coord=y) invalid_coords = np.logical_or(invalid_lon_coords, invalid_lat_coords) # - Identify cross-track index that along-track are always invalid - idx_cross_track_not_all_invalid = np.where(~invalid_coords.all("along_track"))[0] + idx_cross_track_not_all_invalid = np.where(~invalid_coords.all(along_track_dim))[0] # - If all invalid, return empty list if len(idx_cross_track_not_all_invalid) == 0: return [] # - Select only cross-track index that are not all invalid along-track - invalid_coords = invalid_coords.isel(cross_track=idx_cross_track_not_all_invalid) + invalid_coords = invalid_coords.isel({cross_track_dim: idx_cross_track_not_all_invalid}) # - Now identify scans across which there are still invalid coordinates - invalid_scans = invalid_coords.any(dim="cross_track") + invalid_scans = invalid_coords.any(dim=cross_track_dim) valid_scans = ~invalid_scans # - Now identify valid along-track slices list_slices = get_list_slices_from_bool_arr( @@ -672,7 +722,13 @@ def get_slices_valid_geolocation(xr_obj, min_size=2): return list_slices_filter(list_slices, min_size=min_size) -def get_slices_non_valid_geolocation(xr_obj): +def get_slices_non_valid_geolocation( + xr_obj, + x="lon", + y="lat", + along_track_dim="along_track", + cross_track_dim="cross_track", +): """Return a list of GPM ORBIT along-track slices with non-valid geolocation. The minimum size of the output slices is 2. @@ -695,12 +751,26 @@ def get_slices_non_valid_geolocation(xr_obj): Output format: ``[slice(start,stop), slice(start,stop),...]`` """ - list_slices_valid = get_slices_valid_geolocation(xr_obj, min_size=1) - list_slices_full = [slice(0, len(xr_obj["along_track"]))] + list_slices_valid = get_slices_valid_geolocation( + xr_obj, + min_size=1, + x=x, + y=y, + along_track_dim=along_track_dim, + cross_track_dim=cross_track_dim, + ) + list_slices_full = [slice(0, len(xr_obj[along_track_dim]))] return list_slices_difference(list_slices_full, list_slices_valid) -def check_valid_geolocation(xr_obj, verbose=True): +def check_valid_geolocation( + xr_obj, + verbose=True, + x="lon", + y="lat", + along_track_dim="along_track", + cross_track_dim="cross_track", +): """Check no geolocation errors in the GPM Dataset. Parameters @@ -709,7 +779,13 @@ def check_valid_geolocation(xr_obj, verbose=True): xarray object. """ - list_invalid_slices = get_slices_non_valid_geolocation(xr_obj) + list_invalid_slices = get_slices_non_valid_geolocation( + xr_obj, + x=x, + y=y, + along_track_dim=along_track_dim, + cross_track_dim=cross_track_dim, + ) n_invalid_scan_slices = len(list_invalid_slices) if n_invalid_scan_slices > 0: # Retrieve timesteps interval with non valid geolocation @@ -727,10 +803,16 @@ def check_valid_geolocation(xr_obj, verbose=True): @check_is_gpm_object -def has_valid_geolocation(xr_obj): +def has_valid_geolocation(xr_obj, x="lon", y="lat", along_track_dim="along_track", cross_track_dim="cross_track"): """Checks GPM object has valid geolocation.""" if is_orbit(xr_obj): - list_invalid_slices = get_slices_non_valid_geolocation(xr_obj) + list_invalid_slices = get_slices_non_valid_geolocation( + xr_obj, + x=x, + y=y, + along_track_dim=along_track_dim, + cross_track_dim=cross_track_dim, + ) n_invalid_scan_slices = len(list_invalid_slices) return n_invalid_scan_slices == 0 if is_grid(xr_obj): @@ -753,7 +835,7 @@ def wrapper(*args, **kwargs): # Retrieve slice offset start_offset = slc.start # Retrieve dataset subset - subset_xr_obj = xr_obj.isel(along_track=slc) + subset_xr_obj = xr_obj.isel({"along_track": slc}) # Update args new_args[0] = subset_xr_obj # Apply function @@ -811,7 +893,13 @@ def _get_non_wobbling_lats(lats, threshold=100): @apply_on_valid_geolocation -def get_slices_non_wobbling_swath(xr_obj, threshold=100): +def get_slices_non_wobbling_swath( + xr_obj, + threshold=100, + y="lat", + along_track_dim="along_track", + cross_track_dim="cross_track", +): """Return the GPM ORBIT along-track slices along which the swath is not wobbling. For wobbling, we define the occurrence of changes in latitude directions @@ -819,8 +907,8 @@ def get_slices_non_wobbling_swath(xr_obj, threshold=100): The function extract the along-track boundary on both swath sides and identify where the change in orbit direction occurs. """ - xr_obj = xr_obj.transpose("cross_track", "along_track", ...) - lats = xr_obj["lat"].to_numpy() + xr_obj = xr_obj.transpose(cross_track_dim, along_track_dim, ...) + lats = xr_obj[y].to_numpy() lats_side0 = lats[0, :] lats_side2 = lats[-1, :] # Get valid slices @@ -830,7 +918,13 @@ def get_slices_non_wobbling_swath(xr_obj, threshold=100): @apply_on_valid_geolocation -def get_slices_wobbling_swath(xr_obj, threshold=100): +def get_slices_wobbling_swath( + xr_obj, + threshold=100, + y="lat", + cross_track_dim="cross_track", + along_track_dim="along_track", +): """Return the GPM ORBIT along-track slices along which the swath is wobbling. For wobbling, we define the occurrence of changes in latitude directions @@ -838,8 +932,14 @@ def get_slices_wobbling_swath(xr_obj, threshold=100): The function extract the along-track boundary on both swath sides and identify where the change in orbit direction occurs. """ - list_slices1 = get_slices_non_wobbling_swath(xr_obj, threshold=threshold) - list_slices_full = [slice(0, len(xr_obj["along_track"]))] + list_slices1 = get_slices_non_wobbling_swath( + xr_obj, + threshold=threshold, + y=y, + along_track_dim=along_track_dim, + cross_track_dim=cross_track_dim, + ) + list_slices_full = [slice(0, len(xr_obj[along_track_dim]))] return list_slices_difference(list_slices_full, list_slices1) @@ -864,7 +964,15 @@ def is_regular(xr_obj): @check_is_gpm_object -def get_slices_regular(xr_obj, min_size=None, min_n_scans=3): +def get_slices_regular( + xr_obj, + min_size=None, + min_n_scans=3, + x="lon", + y="lat", + along_track_dim="along_track", + cross_track_dim="cross_track", +): """Return a list of slices to select regular GPM objects. For GPM ORBITS, it returns slices to select contiguous scans with valid geolocation. @@ -900,9 +1008,13 @@ def get_slices_regular(xr_obj, min_size=None, min_n_scans=3): xr_obj, min_size=min_size, min_n_scans=min_n_scans, + x=x, + y=y, + along_track_dim=along_track_dim, + cross_track_dim=cross_track_dim, ) # Get swath portions where there are valid geolocation - list_slices_geolocation = get_slices_valid_geolocation(xr_obj, min_size=min_size) + list_slices_geolocation = get_slices_valid_geolocation(xr_obj, min_size=min_size, x=x, y=y) # Find swath portions meeting all the requirements return list_slices_intersection(list_slices_geolocation, list_slices_contiguous) diff --git a/gpm/utils/manipulations.py b/gpm/utils/manipulations.py index 469eafe9..f5ac1c4e 100644 --- a/gpm/utils/manipulations.py +++ b/gpm/utils/manipulations.py @@ -124,11 +124,13 @@ def conversion_factors_degree_to_meter(latitude): Parameters ---------- - latitude : Latitude in degrees where the conversion is needed + latitude : numpy.ndarray + Latitude in degrees where the conversion is needed Returns ------- - (cx, cy) : Tuple containing conversion factors for longitude and latitude + (cx, cy) : tuple + Tuple containing conversion factors for longitude and latitude """ # Earth's radius at the equator (in meters) R = 6378137 diff --git a/gpm/visualization/grid.py b/gpm/visualization/grid.py index 6ac333d2..19f51ae0 100644 --- a/gpm/visualization/grid.py +++ b/gpm/visualization/grid.py @@ -36,6 +36,7 @@ add_optimize_layout_method, check_object_format, create_grid_mesh_data_array, + infer_xy_labels, initialize_cartopy_plot, plot_cartopy_imshow, # plot_mpl_imshow, @@ -49,8 +50,8 @@ def _plot_grid_map_cartopy( da, - x="lon", - y="lat", + x=None, + y=None, ax=None, add_colorbar=True, interpolation="nearest", @@ -112,8 +113,8 @@ def _plot_grid_map_cartopy( def _plot_grid_map_facetgrid( da, - x="lon", - y="lat", + x=None, + y=None, ax=None, add_colorbar=True, interpolation="nearest", @@ -194,8 +195,8 @@ def _plot_grid_map_facetgrid( def plot_grid_map( da, - x="lon", - y="lat", + x=None, + y=None, ax=None, add_colorbar=True, interpolation="nearest", @@ -244,8 +245,8 @@ def plot_grid_map( def plot_grid_mesh( xr_obj, - x="lon", - y="lat", + x=None, + y=None, ax=None, edgecolors="k", linewidth=0.1, @@ -265,6 +266,9 @@ def plot_grid_mesh( add_background=add_background, ) + # Infer x and y + x, y = infer_xy_labels(xr_obj, x=x, y=y, rgb=plot_kwargs.get("rgb", None)) + # Create 2D mesh `xarray.DataArray` da = create_grid_mesh_data_array(xr_obj, x=x, y=y) diff --git a/gpm/visualization/orbit.py b/gpm/visualization/orbit.py index 8aadfeed..4bcfb1d2 100644 --- a/gpm/visualization/orbit.py +++ b/gpm/visualization/orbit.py @@ -55,15 +55,106 @@ #### ORBIT utilities -def remove_invalid_outer_cross_track(xr_obj, coord="lon", alpha=None): - """Remove outer crosstrack scans if geolocation is always missing.""" - if "cross_track" not in xr_obj.dims: +def infer_orbit_xy_coords(da, x=None, y=None): + """ + Infer possible x and y coordinates for the given DataArray. + + Parameters + ---------- + da : xarray.DataArray + The input DataArray. + x : str, optional + The name of the x (longitude) coordinate. If None, it will be inferred. + y : str, optional + The name of the y (latitude) coordinate. If None, it will be inferred. + + Returns + ------- + tuple + The inferred (x, y) coordinates. + """ + possible_x_coords = ["x", "lon", "longitude"] + possible_y_coords = ["y", "lat", "latitude"] + + if x is None: + for coord in possible_x_coords: + if coord in da.coords: + x = coord + break + else: + raise ValueError("Cannot infer x coordinate. Please provide the x coordinate.") + + if y is None: + for coord in possible_y_coords: + if coord in da.coords: + y = coord + break + else: + raise ValueError("Cannot infer y coordinate. Please provide the y coordinate.") + + return x, y + + +def infer_orbit_xy_dim(da, x, y): + """ + Infer possible along-track and cross-track dimensions for the given DataArray. + + Parameters + ---------- + da : xarray.DataArray + The input DataArray. + x : str + The name of the x coordinate. + y : str + The name of the y coordinate. + + Returns + ------- + tuple + The inferred (along_track_dim, cross_track_dim) dimensions. + """ + possible_along_track_dim = ["x", "along_track"] + possible_cross_track_dim = ["y", "cross_track"] + coordinates_dims = np.unique(list(da[x].dims) + list(da[y].dims)).tolist() + + # Retrieve n_dims + n_dims = len(coordinates_dims) # nadir-only vs 2D + # Check for cross_track_dim + cross_track_dim = None + for dim in coordinates_dims: + if dim in possible_cross_track_dim: + cross_track_dim = dim + break + + # Check for along_track_dim + along_track_dim = None + for dim in coordinates_dims: + if dim in possible_along_track_dim: + along_track_dim = dim + break + if n_dims > 1: + if cross_track_dim is None: + raise ValueError(f"Cross-track dimension could not be identified across {coordinates_dims}.") + if along_track_dim is None: + raise ValueError(f"Along-track dimension could not be identified across {coordinates_dims}.") + return along_track_dim, cross_track_dim + + +def remove_invalid_outer_cross_track( + xr_obj, + coord="lon", + cross_track_dim="cross_track", + along_track_dim="along_track", + alpha=None, +): + """Remove outer cross-track scans if geolocation is always missing.""" + if cross_track_dim not in xr_obj.dims: return xr_obj, alpha - if "along_track" not in xr_obj.dims: + if along_track_dim not in xr_obj.dims: coord_arr = np.asanyarray(xr_obj[coord]) isna = np.isnan(coord_arr) else: - coord_arr = np.asanyarray(xr_obj[coord].transpose("cross_track", "along_track")) + coord_arr = np.asanyarray(xr_obj[coord].transpose(cross_track_dim, along_track_dim)) isna = np.all(np.isnan(coord_arr), axis=1) if isna[0] or isna[-1]: is_valid = ~isna @@ -74,18 +165,26 @@ def remove_invalid_outer_cross_track(xr_obj, coord="lon", alpha=None): # Define slice slc = slice(start_index, end_index) # Subset object - xr_obj = xr_obj.isel({"cross_track": slc}) + xr_obj = xr_obj.isel({cross_track_dim: slc}) if alpha is not None: alpha = alpha[slc, :] return xr_obj, alpha -def _get_contiguous_slices(da): +def _get_contiguous_slices(da, x="lon", y="lat", along_track_dim="along_track", cross_track_dim="cross_track"): # NOTE: Using get_slices_regular would split when there is any NaN coordinate - if "along_track" not in da.dims: # noqa: SIM108 + if along_track_dim not in da.dims: list_slices = [None] # case: cross-track transect else: - list_slices = get_slices_contiguous_scans(da, min_size=2, min_n_scans=2) + list_slices = get_slices_contiguous_scans( + da, + min_size=2, + min_n_scans=2, + x=x, + y=y, + along_track_dim=along_track_dim, + cross_track_dim=cross_track_dim, + ) # Check there are scans to plot if len(list_slices) == 0: @@ -106,22 +205,33 @@ def wrapper(*args, **kwargs): # Get axis ax = args[1] if len(args) > 1 else kwargs.get("ax") - # Get slices with contiguous scans if along_track dimension is available - list_slices = _get_contiguous_slices(da) + # Define dimensions and coordinates + x, y = infer_orbit_xy_coords(da, x=kwargs.get("x", None), y=kwargs.get("y", None)) + along_track_dim, cross_track_dim = infer_orbit_xy_dim(da, x=x, y=y) - # - Define kwargs + # Define kwargs user_kwargs = kwargs.copy() + user_kwargs["x"] = x + user_kwargs["y"] = y p = None - x = user_kwargs.get("x", "lon") - y = user_kwargs.get("y", "lat") is_facetgrid = user_kwargs.get("_is_facetgrid", False) alpha = user_kwargs.get("alpha", None) alpha_2darray_provided = isinstance(alpha, np.ndarray) + + # Get slices with contiguous scans if along_track dimension is available + list_slices = _get_contiguous_slices( + da, + x=x, + y=y, + along_track_dim=along_track_dim, + cross_track_dim=cross_track_dim, + ) + # - Call the function over each slice for i, slc in enumerate(list_slices): # Retrieve contiguous data array # - slc=None when cross-track transect - tmp_da = da.isel({"along_track": slc}) if slc is not None else da + tmp_da = da.isel({along_track_dim: slc}) if slc is not None else da # Adapt for alpha tmp_alpha = alpha[:, slc].copy() if alpha_2darray_provided else None @@ -129,8 +239,20 @@ def wrapper(*args, **kwargs): # Remove outer cross-track indices if all without coordinates # - Infill of coordinates is done separately with infill_invalid_coordins # - If along_track transect, return as it is - tmp_da, tmp_alpha = remove_invalid_outer_cross_track(tmp_da, coord=x, alpha=tmp_alpha) - tmp_da, _ = remove_invalid_outer_cross_track(tmp_da, coord=y, alpha=None) + tmp_da, tmp_alpha = remove_invalid_outer_cross_track( + tmp_da, + alpha=tmp_alpha, + coord=x, + cross_track_dim=cross_track_dim, + along_track_dim=along_track_dim, + ) + tmp_da, _ = remove_invalid_outer_cross_track( + tmp_da, + alpha=None, + coord=y, + cross_track_dim=cross_track_dim, + along_track_dim=along_track_dim, + ) # Define temporary kwargs tmp_kwargs = user_kwargs.copy() @@ -274,8 +396,8 @@ def plot_swath( def _plot_orbit_map_cartopy( da, ax=None, - x="lon", - y="lat", + x=None, + y=None, add_colorbar=True, add_swath_lines=True, add_background=True, @@ -329,8 +451,8 @@ def _plot_orbit_map_cartopy( def _plot_orbit_map_facetgrid( da, - x="lon", - y="lat", + x=None, + y=None, ax=None, add_colorbar=True, add_swath_lines=True, @@ -379,6 +501,7 @@ def _plot_orbit_map_facetgrid( ) # Plot the maps + x, y = infer_orbit_xy_coords(da, x=x, y=y) fc = fc.map_dataarray( _plot_orbit_map_cartopy, x=x, @@ -411,8 +534,8 @@ def _plot_orbit_map_facetgrid( def plot_orbit_map( da, ax=None, - x="lon", - y="lat", + x=None, + y=None, add_colorbar=True, add_swath_lines=True, add_background=True, @@ -426,6 +549,7 @@ def plot_orbit_map( da = check_object_format(da, plot_kwargs=plot_kwargs, check_function=check_has_spatial_dim, strict=True) # Plot FacetGrid if "col" in plot_kwargs or "row" in plot_kwargs: + x, y = infer_orbit_xy_coords(da, x=x, y=y) p = _plot_orbit_map_facetgrid( da=da, x=x, @@ -462,8 +586,8 @@ def plot_orbit_map( def plot_orbit_mesh( da, ax=None, - x="lon", - y="lat", + x=None, + y=None, edgecolors="k", linewidth=0.1, add_background=True, diff --git a/gpm/visualization/plot.py b/gpm/visualization/plot.py index e39bee07..d470f39d 100644 --- a/gpm/visualization/plot.py +++ b/gpm/visualization/plot.py @@ -496,6 +496,14 @@ def preprocess_subplot_kwargs(subplot_kwargs): return subplot_kwargs +def infer_xy_labels(da, x=None, y=None, rgb=None): + from xarray.plot.utils import _infer_xy_labels + + # Infer dimensions + x, y = _infer_xy_labels(da, x=x, y=y, imshow=True, rgb=rgb) # dummy flag for rgb + return x, y + + def initialize_cartopy_plot( ax, fig_kwargs, @@ -700,6 +708,8 @@ def plot_cartopy_imshow( """Plot imshow with cartopy.""" plot_kwargs = {} if plot_kwargs is None else plot_kwargs + # Infer x and y + x, y = infer_xy_labels(da, x=x, y=y, rgb=plot_kwargs.get("rgb", None)) # - Ensure image with correct dimensions orders da = da.transpose(y, x, ...) arr = np.asanyarray(da.data) @@ -717,7 +727,7 @@ def plot_cartopy_imshow( origin = "lower" if y_coords[1] > y_coords[0] else "upper" # - Add variable field with cartopy - _ = plot_kwargs.pop("rgb", None) + rgb = plot_kwargs.pop("rgb", False) p = ax.imshow( arr, transform=ccrs.PlateCarree(), @@ -730,7 +740,7 @@ def plot_cartopy_imshow( ax.set_extent(extent) # - Add colorbar - if add_colorbar: + if add_colorbar and not rgb: _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs) return p @@ -1140,8 +1150,8 @@ def plot_image( def plot_map( da, - x="lon", - y="lat", + x=None, + y=None, ax=None, add_colorbar=True, add_swath_lines=True, # used only for GPM orbit objects @@ -1159,9 +1169,13 @@ def plot_map( da : `xr.DataArray` xarray DataArray. x : str, optional - Longitude coordinate name. The default is ``"lon"``. + Longitude coordinate name. + If ``None``, takes the second dimension. + The default is ``None``. y : str, optional - Latitude coordinate name. The default is ``"lat"``. + Latitude coordinate name. + If ``None``, takes the first dimension. + The default is ``None``. ax : `cartopy.GeoAxes`, optional The cartopy GeoAxes where to plot the map. If ``None``, a figure is initialized using the @@ -1240,8 +1254,8 @@ def plot_map( def plot_map_mesh( xr_obj, - x="lon", - y="lat", + x=None, + y=None, ax=None, edgecolors="k", linewidth=0.1, @@ -1251,12 +1265,12 @@ def plot_map_mesh( **plot_kwargs, ): from gpm.checks import is_grid, is_orbit - - from .grid import plot_grid_mesh - from .orbit import plot_orbit_mesh + from gpm.visualization.grid import plot_grid_mesh + from gpm.visualization.orbit import infer_orbit_xy_coords, plot_orbit_mesh # Plot orbit if is_orbit(xr_obj): + x, y = infer_orbit_xy_coords(xr_obj, x=x, y=y) p = plot_orbit_mesh( da=xr_obj[y], ax=ax, @@ -1290,8 +1304,8 @@ def plot_map_mesh( def plot_map_mesh_centroids( xr_obj, - x="lon", - y="lat", + x=None, + y=None, ax=None, c="r", s=1, @@ -1301,7 +1315,8 @@ def plot_map_mesh_centroids( **plot_kwargs, ): """Plot GPM orbit granule mesh centroids in a cartographic map.""" - from gpm.checks import is_grid + from gpm.checks import is_grid, is_orbit + from gpm.visualization.orbit import infer_orbit_xy_coords # Initialize figure if necessary ax = initialize_cartopy_plot( @@ -1311,9 +1326,16 @@ def plot_map_mesh_centroids( add_background=add_background, ) - # Retrieve centroids + # Retrieve orbits lon, lat coordinates + if is_orbit(xr_obj): + x, y = infer_orbit_xy_coords(xr_obj, x=x, y=y) + + # Retrieve grid centroids mesh if is_grid(xr_obj): + x, y = infer_xy_labels(xr_obj, x=x, y=y) xr_obj = create_grid_mesh_data_array(xr_obj, x=x, y=y) + + # Extract numpy arrays lon = xr_obj[x].to_numpy() lat = xr_obj[y].to_numpy()