Skip to content

Commit

Permalink
Add plot transect utility
Browse files Browse the repository at this point in the history
  • Loading branch information
ghiggi committed Aug 15, 2023
1 parent eb10855 commit 50d6756
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 18 deletions.
45 changes: 45 additions & 0 deletions gpm_api/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,24 @@ def _is_spatial_3d_datarray(da, strict):
return False


def _is_transect_datarray(da, strict):
"""Check if a DataArray is a spatial 3D array."""
spatial_dims = _get_available_spatial_dims(da)
if len(spatial_dims) != 1:
return False

vertical_dims = _get_available_vertical_dims(da)
if not vertical_dims:
return False
else:
if strict:
if len(da.dims) == 2:
return True
else:
return True
return False


def _is_spatial_2d_dataset(ds, strict):
"""Check if all DataArrays of a xr.Dataset are spatial 2D array."""
all_2d_spatial = np.all(
Expand All @@ -210,6 +228,17 @@ def _is_spatial_3d_dataset(ds, strict):
return False


def _is_transect_dataset(ds, strict):
"""Check if all DataArrays of a xr.Dataset are spatial profile array."""
all_profile_spatial = np.all(
[_is_transect_datarray(ds[var], strict=strict) for var in get_dataset_variables(ds)]
).item()
if all_profile_spatial:
return True
else:
return False


def is_spatial_2d(xr_obj, strict=True, squeeze=True):
"""Check if is spatial 2d xarray object.
Expand Down Expand Up @@ -237,6 +266,17 @@ def is_spatial_3d(xr_obj, strict=True, squeeze=True):
return _is_spatial_3d_datarray(xr_obj, strict=strict)


def is_transect(xr_obj, strict=True, squeeze=True):
"""Check if is spatial profile xarray object."""
check_is_xarray(xr_obj)
if squeeze:
xr_obj = xr_obj.squeeze() # remove dimensions of size 1
if isinstance(xr_obj, xr.Dataset):
return _is_transect_dataset(xr_obj, strict=strict)
else:
return _is_transect_datarray(xr_obj, strict=strict)


def check_is_spatial_2d(da, strict=True, squeeze=True):
if not is_spatial_2d(da, strict=strict, squeeze=squeeze):
raise ValueError("Expecting a 2D GPM field.")
Expand All @@ -247,6 +287,11 @@ def check_is_spatial_3d(da, strict=True, squeeze=True):
raise ValueError("Expecting a 3D GPM field.")


def check_is_transect(da, strict=True, squeeze=True):
if not is_transect(da, strict=strict, squeeze=squeeze):
raise ValueError("Expecting a transect of a 3D GPM field.")


def get_spatial_2d_variables(ds, strict=False, squeeze=True):
"""Get list of xr.Dataset 2D spatial variables."""
variables = [
Expand Down
2 changes: 1 addition & 1 deletion gpm_api/utils/utils_cmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@
# 'cmap_n': 10,
"cmap_type": "Colormap",
"vmin": 10,
"vmax": 40,
"vmax": 50,
"extend": "both",
"extendfrac": 0.05,
"label": "Reflectivity [$dBZ$]", # $Z_{e}$
Expand Down
27 changes: 27 additions & 0 deletions gpm_api/visualization/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,33 @@ def _plot_xr_imshow(
return p


def _plot_xr_pcolormesh(
ax,
da,
x,
y,
add_colorbar=True,
plot_kwargs={},
cbar_kwargs={},
):
"""Plot pcolormesh with xarray."""
ticklabels = cbar_kwargs.pop("ticklabels", None)
if not add_colorbar:
cbar_kwargs = {}
p = da.plot.pcolormesh(
x=x,
y=y,
ax=ax,
add_colorbar=add_colorbar,
cbar_kwargs=cbar_kwargs,
**plot_kwargs,
)
plt.title(da.name)
if add_colorbar and ticklabels is not None:
p.colorbar.ax.set_yticklabels(ticklabels)
return p


####--------------------------------------------------------------------------.
def plot_map(
da,
Expand Down
66 changes: 49 additions & 17 deletions gpm_api/visualization/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
@author: ghiggi
"""
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
import pyproj
import xarray as xr

from gpm_api.checks import check_is_transect
from gpm_api.utils.slices import ensure_is_slice, get_slice_size
from gpm_api.utils.utils_cmap import get_colormap_setting
from gpm_api.utils.utils_cmap import get_colorbar_settings
from gpm_api.visualization.plot import (
_plot_xr_pcolormesh,
_preprocess_figure_args,
)


def optimize_transect_slices(
Expand Down Expand Up @@ -174,22 +180,6 @@ def get_transect_slices(
return transect_slices


def plot_profile(da_profile, colorscale=None, ylim=None, ax=None):
x_direction = da_profile["lon"].dims[0]
# Retrieve title
title = da_profile.gpm_api.title(time_idx=0, prefix_product=False, add_timestep=False)
# Retrieve colormap configs
plot_kwargs, cbar_kwargs, ticklabels = get_colormap_setting(colorscale)
# Plot
p = da_profile.plot.pcolormesh(
x=x_direction, y="height", ax=ax, cbar_kwargs=cbar_kwargs, **plot_kwargs
)
p.axes.set_title(title)
if ylim is not None:
p.axes.set_ylim(ylim)
return p


def plot_transect_line(ds, ax, color="black"):
# Check is a profile (lon and lat are 1D coords)
if len(ds["lon"].shape) != 1:
Expand All @@ -212,3 +202,45 @@ def plot_transect_line(ds, ax, color="black"):
lon_l, lat_l, _ = g.fwd(*end_lonlat, az=fwd_az, dist=dist + 50000) # dist in m
ax.text(lon_r, lat_r, "R")
ax.text(lon_l, lat_l, "L")


def plot_transect(
da,
ax=None,
add_colorbar=True,
zoom=True,
fig_kwargs={},
cbar_kwargs={},
**plot_kwargs,
):
"""Plot GPM transect."""
# - Check inputs
check_is_transect(da)
_preprocess_figure_args(ax=ax, fig_kwargs=fig_kwargs)

# - Initialize figure
if ax is None:
fig, ax = plt.subplots(**fig_kwargs)

# - If not specified, retrieve/update plot_kwargs and cbar_kwargs as function of product name
plot_kwargs, cbar_kwargs = get_colorbar_settings(
name=da.name, plot_kwargs=plot_kwargs, cbar_kwargs=cbar_kwargs
)

# - If zoom on height regions with data
if zoom:
da = da.gpm_api.slice_range_with_valid_data()

# - Plot with xarray
x_direction = da["lon"].dims[0]
p = _plot_xr_pcolormesh(
ax=ax,
da=da,
x=x_direction,
y="height",
add_colorbar=add_colorbar,
plot_kwargs=plot_kwargs,
cbar_kwargs=cbar_kwargs,
)
# - Return mappable
return p

0 comments on commit 50d6756

Please sign in to comment.