Skip to content

Commit

Permalink
Refactor and improve visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
ghiggi committed Jun 2, 2024
1 parent 7f73aef commit 5c2d168
Show file tree
Hide file tree
Showing 14 changed files with 579 additions and 153 deletions.
2 changes: 2 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
"xarray": ("https://docs.xarray.dev/en/stable/", None),
"pyvista": ("https://docs.pyvista.org/version/stable/", None),
"pyresample": ("https://pyresample.readthedocs.io/en/stable/", None),
"dask": ("https://docs.dask.org/en/stable/", None),
# "polars": ("https://docs.pola.rs/", None),
}
always_document_param_types = True

Expand Down
74 changes: 60 additions & 14 deletions gpm/accessor/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,26 @@ def get_slices_regular_time(self, tolerance=None, min_size=1):
return get_slices_regular_time(self._obj, tolerance=tolerance, min_size=min_size)

@auto_wrap_docstring
def get_slices_contiguous_scans(self, min_size=2, min_n_scans=3):
def get_slices_contiguous_scans(
self,
min_size=2,
min_n_scans=3,
x="lon",
y="lat",
along_track_dim="along_track",
cross_track_dim="cross_track",
):
from gpm.utils.checks import get_slices_contiguous_scans

return get_slices_contiguous_scans(self._obj, min_size=min_size, min_n_scans=min_n_scans)
return get_slices_contiguous_scans(
self._obj,
min_size=min_size,
min_n_scans=min_n_scans,
x=x,
y=y,
cross_track_dim=cross_track_dim,
along_track_dim=along_track_dim,
)

Check notice on line 414 in gpm/accessor/methods.py

View check run for this annotation

CodeScene Delta Analysis / CodeScene Cloud Delta Analysis (main)

ℹ New issue: Excess Number of Function Arguments

GPM_Base_Accessor.get_slices_contiguous_scans has 6 arguments, threshold = 4. This function has too many arguments, indicating a lack of encapsulation. Avoid adding more arguments.

@auto_wrap_docstring
def get_slices_contiguous_granules(self, min_size=2):
Expand All @@ -404,16 +420,46 @@ def get_slices_contiguous_granules(self, min_size=2):
return get_slices_contiguous_granules(self._obj, min_size=min_size)

@auto_wrap_docstring
def get_slices_valid_geolocation(self, min_size=2):
def get_slices_valid_geolocation(
self,
min_size=2,
x="lon",
y="lat",
along_track_dim="along_track",
cross_track_dim="cross_track",
):
from gpm.utils.checks import get_slices_valid_geolocation

return get_slices_valid_geolocation(self._obj, min_size=min_size)
return get_slices_valid_geolocation(
self._obj,
min_size=min_size,
x=x,
y=y,
cross_track_dim=cross_track_dim,
along_track_dim=along_track_dim,
)

Check notice on line 440 in gpm/accessor/methods.py

View check run for this annotation

CodeScene Delta Analysis / CodeScene Cloud Delta Analysis (main)

ℹ New issue: Excess Number of Function Arguments

GPM_Base_Accessor.get_slices_valid_geolocation has 5 arguments, threshold = 4. This function has too many arguments, indicating a lack of encapsulation. Avoid adding more arguments.

@auto_wrap_docstring
def get_slices_regular(self, min_size=None, min_n_scans=3):
def get_slices_regular(
self,
min_size=None,
min_n_scans=3,
x="lon",
y="lat",
along_track_dim="along_track",
cross_track_dim="cross_track",
):
from gpm.utils.checks import get_slices_regular

return get_slices_regular(self._obj, min_size=min_size, min_n_scans=min_n_scans)
return get_slices_regular(
self._obj,
min_size=min_size,
min_n_scans=min_n_scans,
x=x,
y=y,
cross_track_dim=cross_track_dim,
along_track_dim=along_track_dim,
)

Check notice on line 462 in gpm/accessor/methods.py

View check run for this annotation

CodeScene Delta Analysis / CodeScene Cloud Delta Analysis (main)

ℹ New issue: Excess Number of Function Arguments

GPM_Base_Accessor.get_slices_regular has 6 arguments, threshold = 4. This function has too many arguments, indicating a lack of encapsulation. Avoid adding more arguments.

#### Plotting utility
@auto_wrap_docstring
Expand Down Expand Up @@ -499,8 +545,8 @@ def plot_swath_lines(
@auto_wrap_docstring
def plot_map_mesh(
self,
x="lon",
y="lat",
x=None,
y=None,
ax=None,
edgecolors="k",
linewidth=0.1,
Expand All @@ -527,8 +573,8 @@ def plot_map_mesh(
@auto_wrap_docstring
def plot_map_mesh_centroids(
self,
x="lon",
y="lat",
x=None,
y=None,
ax=None,
c="r",
s=1,
Expand Down Expand Up @@ -653,8 +699,8 @@ def plot_map(
self,
variable,
ax=None,
x="lon",
y="lat",
x=None,
y=None,
add_colorbar=True,
add_swath_lines=True,
add_background=True,
Expand Down Expand Up @@ -854,8 +900,8 @@ def title(
def plot_map(
self,
ax=None,
x="lon",
y="lat",
x=None,
y=None,
add_colorbar=True,
add_swath_lines=True,
add_background=True,
Expand Down
17 changes: 8 additions & 9 deletions gpm/bucket/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def quadmesh(self, origin="bottom"):
Return
--------
np.ndarray
numpy.ndarray
Quadmesh array of shape (M+1, N+1, 2)
"""
x_corners, y_corners = np.meshgrid(self.x_bounds, self.y_bounds)
Expand All @@ -440,7 +440,7 @@ def vertices(self, origin="bottom", ccw=True):
"""Return the partitions vertices in an array of shape (N, M, 4, 2).
The output vertices, once the first 2 dimension are flattened,
can be passed directly to a matplotlib.PolyCollection.
can be passed directly to a `matplotlib.PolyCollection`.
For plotting with cartopy, the polygon order must be "counterclockwise".
Parameters
Expand Down Expand Up @@ -563,7 +563,7 @@ def add_labels(self, df, x, y, remove_invalid_rows=True):
Parameters
----------
df : pandas.DataFrame, dask.DataFrame, polars.DataFrame or polars.LazyFrame
df : `pandas.DataFrame`, `dask.DataFrame`, `polars.DataFrame`, `pyarrow.Table` or `polars.LazyFrame`
Dataframe to which add partitions centroids.
x : str
Column name with the x coordinate.
Expand All @@ -575,7 +575,7 @@ def add_labels(self, df, x, y, remove_invalid_rows=True):
Returns
-------
df : pandas.DataFrame, dask.DataFrame, polars.DataFrame or polars.LazyFrame
df : `pandas.DataFrame`, `dask.DataFrame`, `polars.DataFrame`, `pyarrow.Table` or `polars.LazyFrame`
Dataframe with the partitions label(s) column(s).
"""
Expand Down Expand Up @@ -607,7 +607,7 @@ def add_centroids(self, df, x, y, x_coord=None, y_coord=None, remove_invalid_row
Parameters
----------
df : pandas.DataFrame, dask.DataFrame, polars.DataFrame or polars.LazyFrame
df : `pandas.DataFrame`, `dask.DataFrame`, `polars.DataFrame`, `pyarrow.Table` or `polars.LazyFrame`
Dataframe to which add partitions centroids.
x : str
Column name with the x coordinate.
Expand All @@ -625,7 +625,7 @@ def add_centroids(self, df, x, y, x_coord=None, y_coord=None, remove_invalid_row
Returns
-------
df : pandas.DataFrame, dask.DataFrame, polars.DataFrame or polars.LazyFrame
df : `pandas.DataFrame`, `dask.DataFrame`, `polars.DataFrame`, `pyarrow.Table` or `polars.LazyFrame`
Dataframe with the partitions centroids x and y coordinates columns.
"""
Expand Down Expand Up @@ -979,10 +979,9 @@ def to_dict(self):


class LonLatPartitioning(XYPartitioning):
"""
Handles geographic partitioning of data based on longitude and latitude bin sizes within a defined extent.
"""Handles geographic partitioning of data based on longitude and latitude bin sizes within a defined extent.
The last bin size (in lon and lat direction) might not be of size ``size` !
The last bin size (in lon and lat direction) might not be of size ``size`` !
Parameters
----------
Expand Down
27 changes: 15 additions & 12 deletions gpm/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,21 @@ def _is_expected_spatial_dims(spatial_dims):
def is_orbit(xr_obj):
"""Check whether the xarray object is a GPM ORBIT.
An orbit transect or nadir view is considered ORBIT.
An ORBIT transect or nadir view is considered ORBIT.
An ORBIT object must have the coordinates available !
"""
from gpm.dataset.crs import _get_proj_dim_coords
from gpm.dataset.crs import _get_swath_dim_coords

# Check dimension names
spatial_dims = get_spatial_dimensions(xr_obj)
if not _is_orbit_expected_spatial_dims(spatial_dims):
return False

# Check that no 1D coords exists
# - Swath objects are determined by 2D coordinates only
x_coord, y_coord = _get_proj_dim_coords(xr_obj)
if x_coord is None and y_coord is None:
# Check that swath coords exists
# - Swath objects are determined by 1D (nadir looking) and 2D coordinates
x_coord, y_coord = _get_swath_dim_coords(xr_obj)
if x_coord is not None and y_coord is not None:
return True
return False

Expand All @@ -262,6 +264,7 @@ def is_grid(xr_obj):
"""Check whether the xarray object is a GPM GRID.
A GRID slice is not considered a GRID object !
An GRID object must have the coordinates available !
"""
from gpm.dataset.crs import _get_proj_dim_coords

Expand Down Expand Up @@ -431,14 +434,14 @@ def check_is_gpm_object(xr_obj):
raise ValueError("Unrecognized GPM xarray object.")


def check_has_cross_track_dim(xr_obj):
if "cross_track" not in xr_obj.dims:
raise ValueError("The 'cross-track' dimension is not available.")
def check_has_cross_track_dim(xr_obj, dim="cross_track"):
if dim not in xr_obj.dims:
raise ValueError(f"The 'cross-track' dimension {dim} is not available.")


def check_has_along_track_dim(xr_obj):
if "along_track" not in xr_obj.dims:
raise ValueError("The 'along_track' dimension is not available.")
def check_has_along_track_dim(xr_obj, dim="along_track"):
if dim not in xr_obj.dims:
raise ValueError(f"The 'along_track' dimension {dim} is not available.")


def check_is_spatial_2d(xr_obj, strict=True, squeeze=True):
Expand Down
4 changes: 2 additions & 2 deletions gpm/retrievals/retrieval_1b_radar.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,8 @@ def open_dataset_1b_ka_fs(
"""Open 1B-Ka dataset in FS scan_mode format in either L1B or L2 format.
It expects start_time after the GPM DPR scan pattern change occurred the 8 May 2018.
(Over)Resample HS on MS (using LUT/range distance).
The L2 FS format has 176 bins.
Notes
-----
Expand All @@ -539,8 +541,6 @@ def open_dataset_1b_ka_fs(
- 1B-Ka MS have range resolution of 125 m (260 bins)
- 2A-Ka HS have range resolution of 125 m (88 bins)
- 2A-Ka MS have range resolution of 125 m (176 bins)
--> (Over)Resample HS on MS (using LUT/range distance)
--> L2 FS format has 176 bins
"""
from gpm.io.checks import check_time
Expand Down
26 changes: 26 additions & 0 deletions gpm/tests/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,23 @@ def test_is_orbit(
assert is_orbit(orbit_dataarray.isel(along_track=0))
assert is_orbit(orbit_dataarray.isel(cross_track=0)) # nadir-view

# Check with other dimensions names
assert is_orbit(orbit_dataarray.rename({"lon": "longitude", "lat": "latitude"}))
assert is_orbit(orbit_dataarray.rename({"cross_track": "y", "along_track": "x"}))
assert is_orbit(orbit_dataarray.isel(along_track=0).rename({"cross_track": "y"}))
assert is_orbit(orbit_dataarray.isel(cross_track=0).rename({"along_track": "x"}))

# Check grid is not confound with orbit
assert not is_orbit(grid_dataarray.isel(lon=0))
assert not is_orbit(grid_dataarray.isel(lat=0))
assert not is_orbit(xr.DataArray())

# Check also strange edge cases
assert not is_orbit(grid_dataarray.isel(lat=0).rename({"lon": "x"}))
assert not is_orbit(grid_dataarray.isel(lon=0).rename({"lat": "y"}))
assert not is_orbit(grid_dataarray.isel(lon=0).rename({"lat": "cross_track"}))
assert not is_orbit(grid_dataarray.isel(lon=0).rename({"lat": "along_track"}))

# With one dimensional longitude
n_x = 10
n_y = 20
Expand All @@ -382,6 +395,10 @@ def test_is_orbit(
invalid_da = xr.DataArray(data, coords={"x": x, "y": y})
assert not is_orbit(invalid_da)

# Assert without coordinates
assert not is_orbit(grid_dataarray.drop_vars(["lon", "lat"]))
assert not is_orbit(orbit_dataarray.drop_vars(["lon", "lat"]))


def test_is_grid(
orbit_dataarray: xr.DataArray,
Expand All @@ -392,11 +409,20 @@ def test_is_grid(
assert not is_grid(grid_dataarray.isel(lon=0))
assert not is_grid(grid_dataarray.isel(lat=0))

# Check with other dimensions names
assert is_grid(grid_dataarray.rename({"lon": "longitude", "lat": "latitude"}))
assert is_grid(grid_dataarray.rename({"lon": "x", "lat": "y"}))

# Check orbit is not confound with grid
assert not is_grid(orbit_dataarray)
assert not is_grid(orbit_dataarray.isel(along_track=0))
assert not is_grid(orbit_dataarray.isel(cross_track=0))
assert not is_grid(xr.DataArray())

# Assert without coordinates
assert not is_grid(grid_dataarray.drop_vars(["lon", "lat"]))
assert not is_grid(orbit_dataarray.drop_vars(["lon", "lat"]))


def test_check_is_orbit(
orbit_dataarray: xr.DataArray,
Expand Down
18 changes: 11 additions & 7 deletions gpm/tests/test_utils/test_utils_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,8 @@ def ds_contiguous(self) -> xr.Dataset:
ds["lat"] = (("cross_track", "along_track"), lat)
ds["lon"] = (("cross_track", "along_track"), lon)
ds["gpm_granule_id"] = np.ones(self.n_along_track)
ds["time"] = time

ds = ds.set_coords(["lon", "lat"])
ds = ds.assign_coords({"time": ("along_track", time)})
return ds

@pytest.fixture()
Expand Down Expand Up @@ -596,7 +596,8 @@ def ds_orbit_valid(self) -> xr.Dataset:
ds = xr.Dataset()
ds["lon"] = (("cross_track", "along_track"), lon)
ds["lat"] = (("cross_track", "along_track"), lat)
ds["time"] = (("along_track"), time)
ds = ds.set_coords(["lon", "lat"])
ds = ds.assign_coords({"time": ("along_track", time)})
return ds

@pytest.fixture()
Expand Down Expand Up @@ -631,11 +632,11 @@ def test_is_valid_geolocation(
) -> None:
"""Test _is_valid_geolocation."""
# Valid
valid = checks._is_valid_geolocation(ds_orbit_valid)
valid = checks._is_valid_geolocation(ds_orbit_valid, coord="lon")
assert np.all(valid)

# Invalid
valid = checks._is_valid_geolocation(ds_orbit_invalid)
valid = checks._is_valid_geolocation(ds_orbit_invalid, coord="lon")
assert np.sum(valid.all(dim="cross_track")) == self.n_along_track - 1

@pytest.mark.usefixtures("_set_is_orbit_to_true")
Expand Down Expand Up @@ -708,7 +709,8 @@ class TestWobblingSwath:
ds["lat"] = (("cross_track", "along_track"), lat)
ds["lon"] = (("cross_track", "along_track"), lon)
ds["gpm_granule_id"] = np.ones(n_along_track)
ds["time"] = time
ds = ds.set_coords(["lon", "lat"])
ds = ds.assign_coords({"time": ("along_track", time)})

# Threshold must be at least 3 to remove wobbling slices
threshold = 3
Expand Down Expand Up @@ -773,8 +775,10 @@ def ds_orbit(self) -> xr.Dataset:
ds = xr.Dataset()
ds["lon"] = (("cross_track", "along_track"), lon)
ds["lat"] = (("cross_track", "along_track"), lon)
ds = ds.set_coords(["lon", "lat"])
granule_ids = np.array([0, 0, 0, 1, 1, 1, 2, 2, 7, 8])
return ds.assign_coords({"gpm_granule_id": ("along_track", granule_ids)})
ds = ds.assign_coords({"gpm_granule_id": ("along_track", granule_ids)})
return ds

@pytest.fixture()
def ds_grid(self) -> xr.Dataset:
Expand Down
Loading

0 comments on commit 5c2d168

Please sign in to comment.