Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for lazy-loading chunked xarray.Datasets with dask #247

Draft
wants to merge 5 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions eomaps/_data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,15 +243,15 @@
"EOmaps: provided dataset has more than 2 dimensions..."
f"({data[parameter].dims})."
)
z_data = data[parameter].values
z_data = data[parameter] # .values

Check warning on line 246 in eomaps/_data_manager.py

View check run for this annotation

Codecov / codecov/patch

eomaps/_data_manager.py#L246

Added line #L246 was not covered by tests
data_dims = data[parameter].dims
else:
assert len(data.dims) <= 2, (
"EOmaps: provided dataset has more than 2 dimensions..."
f"({data.dims})."
)

z_data = data.values
z_data = data # .values

Check warning on line 254 in eomaps/_data_manager.py

View check run for this annotation

Codecov / codecov/patch

eomaps/_data_manager.py#L254

Added line #L254 was not covered by tests
data_dims = data.dims
parameter = data.name

Expand Down Expand Up @@ -430,7 +430,9 @@
z_data, xorig, yorig, ids, parameter = self._identify_data()

# check if Fill-value is provided, and mask the data accordingly
if self.m.data_specs.encoding:
# TODO currently this is only applied to numpy datasets
# (e.g. not to lazy xarray.Datasets handled with dask)
if self.m.data_specs.encoding and isinstance(z_data, np.ndarray):
fill_value = self.m.data_specs.encoding.get("_FillValue", None)
if fill_value:
z_data = np.ma.MaskedArray(
Expand Down Expand Up @@ -549,7 +551,7 @@
props["xorig"] = np.asanyarray(xorig)
props["yorig"] = np.asanyarray(yorig)
props["ids"] = ids
props["z_data"] = np.asanyarray(z_data)
props["z_data"] = z_data
props["x0"] = np.asanyarray(x0)
props["y0"] = np.asanyarray(y0)

Expand Down Expand Up @@ -1450,7 +1452,13 @@
# (to pick the correct value, we need to pick the transposed one!)

if self.m.shape.name == "shade_raster" and self.x0_1D is not None:
val = self.z_data.T.flat[ind]
if isinstance(self.z_data, np.ndarray):
val = self.z_data.T.flat[ind]

Check warning on line 1456 in eomaps/_data_manager.py

View check run for this annotation

Codecov / codecov/patch

eomaps/_data_manager.py#L1455-L1456

Added lines #L1455 - L1456 were not covered by tests
else:
# TODO
# for xarray.Datasets we use 2d indexing to support lazy dask arrays
val = self.z_data.T[np.unravel_index(ind, self.z_data.shape)].values

Check warning on line 1460 in eomaps/_data_manager.py

View check run for this annotation

Codecov / codecov/patch

eomaps/_data_manager.py#L1460

Added line #L1460 was not covered by tests

else:
val = self.z_data.flat[ind]

Expand Down
2 changes: 1 addition & 1 deletion eomaps/colorbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def _get_data(self):
if self._dynamic_shade_indicator is True:
data = self._m.coll.get_ds_data().values
else:
data = self._m._data_manager.z_data
data = np.asanyarray(self._m._data_manager.z_data)

return data

Expand Down
287 changes: 164 additions & 123 deletions eomaps/eomaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2906,10 +2906,6 @@
kwargs passed to `datashader.mpl_ext.dsshow`

"""
_log.info(
"EOmaps: Plotting "
f"{self._data_manager.z_data.size} datapoints ({self.shape.name})"
)

ds, mpl_ext, pd, xar = register_modules(
"datashader", "datashader.mpl_ext", "pandas", "xarray"
Expand All @@ -2936,137 +2932,183 @@
return

plot_width, plot_height = self._get_shade_axis_size()
_log.info(
"EOmaps: Plotting "
f"{self._data_manager.z_data.size} datapoints ({self.shape.name})"
)

# get rid of unnecessary dimensions in the numpy arrays
zdata = zdata.squeeze()
x0 = self._data_manager.x0.squeeze()
y0 = self._data_manager.y0.squeeze()

# the shape is always set after _prepare data!
if self.shape.name == "shade_points" and self._data_manager.x0_1D is None:
# fill masked-values with None to avoid issues with numba not being
# able to deal with numpy-arrays
# TODO report this to datashader to get it fixed properly?
if isinstance(zdata, np.ma.masked_array):
zdata = zdata.filled(None)

df = pd.DataFrame(
dict(
x=x0.ravel(),
y=y0.ravel(),
val=zdata.ravel(),
),
copy=False,
)

else:
if len(zdata.shape) == 2:
if (zdata.shape == x0.shape) and (zdata.shape == y0.shape):
# 2D coordinates and 2D raster

# use a curvilinear QuadMesh
if self.shape.name == "shade_raster":
self.shape.glyph = ds.glyphs.QuadMeshCurvilinear(
"x", "y", "val"
)
dataset_lazy = False
if (
isinstance(self.data, xar.Dataset)
and self.data_specs.parameter is not None
and self.data.chunks
and self.get_crs("in") == self.get_crs("out")
):
# pass chunked 2D datasets directly to allow lazy-loading
_log.info("EOmaps: Chunked xarray dataset is handled lazily with dask!")

Check warning on line 2948 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L2948

Added line #L2948 was not covered by tests

df = xar.Dataset(
data_vars=dict(val=(["xx", "yy"], zdata)),
# dims=["x", "y"],
coords=dict(
x=(["xx", "yy"], x0),
y=(["xx", "yy"], y0),
),
)
df = self.data

Check warning on line 2950 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L2950

Added line #L2950 was not covered by tests

elif (
((zdata.shape[1],) == x0.shape)
and ((zdata.shape[0],) == y0.shape)
and (x0.shape != y0.shape)
if self.shape.name == "shade_raster":
if (

Check warning on line 2953 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L2952-L2953

Added lines #L2952 - L2953 were not covered by tests
len(zdata.shape) == 2
and len(self._data_manager.x0.shape) == 1
and len(self._data_manager.y0.shape) == 1
):
raise AssertionError(
"EOmaps: it seems like you need to transpose your data! \n"
+ f"the dataset has a shape of {zdata.shape}, but the "
+ f"coordinates suggest ({x0.shape}, {y0.shape})"

self.shape.glyph = ds.glyphs.QuadMeshRectilinear(

Check warning on line 2959 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L2959

Added line #L2959 was not covered by tests
self.data_specs.x, self.data_specs.y, self.data_specs.parameter
)
elif (zdata.T.shape == x0.shape) and (zdata.T.shape == y0.shape):
raise AssertionError(
"EOmaps: it seems like you need to transpose your data! \n"
+ f"the dataset has a shape of {zdata.shape}, but the "
+ f"coordinates suggest {x0.shape}"
else:
self.shape.glyph = ds.glyphs.QuadMeshCurvilinear(

Check warning on line 2963 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L2963

Added line #L2963 was not covered by tests
self.data_specs.x, self.data_specs.y, self.data_specs.parameter
)

elif ((zdata.shape[0],) == x0.shape) and (
(zdata.shape[1],) == y0.shape
):
# 1D coordinates and 2D data
elif self.shape.name == "shade_points":
df = self.data.to_dask_dataframe()

Check warning on line 2968 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L2967-L2968

Added lines #L2967 - L2968 were not covered by tests

# use a rectangular QuadMesh
if self.shape.name == "shade_raster":
self.shape.glyph = ds.glyphs.QuadMeshRectilinear(
"x", "y", "val"
)
self.shape.aggregator.column = self.data_specs.parameter

Check warning on line 2970 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L2970

Added line #L2970 was not covered by tests

dataset_lazy = True

Check warning on line 2972 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L2972

Added line #L2972 was not covered by tests
else:
# get rid of unnecessary dimensions in the numpy arrays
zdata = zdata.squeeze()
x0 = self._data_manager.x0.squeeze()
y0 = self._data_manager.y0.squeeze()

# required to avoid ambiguities
if isinstance(zdata, xar.DataArray):
zdata = zdata.data

Check warning on line 2981 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L2981

Added line #L2981 was not covered by tests

# the shape is always set after _prepare data!
if self.shape.name == "shade_points" and self._data_manager.x0_1D is None:
# fill masked-values with None to avoid issues with numba not being
# able to deal with numpy-arrays
# TODO report this to datashader to get it fixed properly?
if isinstance(zdata, np.ma.masked_array):
zdata = zdata.filled(None)

Check warning on line 2989 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L2989

Added line #L2989 was not covered by tests

df = pd.DataFrame(
dict(
x=x0.ravel(),
y=y0.ravel(),
val=zdata.ravel(),
),
copy=False,
)

df = xar.DataArray(
data=zdata,
dims=["x", "y"],
coords=dict(x=x0, y=y0),
)
df = xar.Dataset(dict(val=df))
else:
try:
# try if reprojected coordinates can be used as 2d grid and if yes,
# directly use a curvilinear QuadMesh based on the reprojected
# coordinates to display the data
idx = pd.MultiIndex.from_arrays(
[x0.ravel(), y0.ravel()], names=["x", "y"]
)
if len(zdata.shape) == 2:
if (zdata.shape == x0.shape) and (zdata.shape == y0.shape):
# 2D coordinates and 2D raster

# use a curvilinear QuadMesh
if self.shape.name == "shade_raster":
self.shape.glyph = ds.glyphs.QuadMeshCurvilinear(

Check warning on line 3007 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3007

Added line #L3007 was not covered by tests
"x", "y", "val"
)

df = pd.DataFrame(
data=dict(val=zdata.ravel()), index=idx, copy=False
)
df = df.to_xarray()
xg, yg = np.meshgrid(df.x, df.y)
except Exception:
# first convert original coordinates of the 1D inputs to 2D,
# then reproject the grid and use a curvilinear QuadMesh to display
# the data
_log.warning(
"EOmaps: 1D data is converted to 2D prior to reprojection... "
"Consider using 'shade_points' as plot-shape instead!"
)
xorig = self._data_manager.xorig.ravel()
yorig = self._data_manager.yorig.ravel()
df = xar.Dataset(
data_vars=dict(val=(["xx", "yy"], zdata)),
# dims=["x", "y"],
coords=dict(
x=(["xx", "yy"], x0),
y=(["xx", "yy"], y0),
),
)

idx = pd.MultiIndex.from_arrays([xorig, yorig], names=["x", "y"])
elif (

Check warning on line 3020 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3020

Added line #L3020 was not covered by tests
((zdata.shape[1],) == x0.shape)
and ((zdata.shape[0],) == y0.shape)
and (x0.shape != y0.shape)
):
raise AssertionError(

Check warning on line 3025 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3025

Added line #L3025 was not covered by tests
"EOmaps: it seems like you need to transpose your data! \n"
+ f"the dataset has a shape of {zdata.shape}, but the "
+ f"coordinates suggest ({x0.shape}, {y0.shape})"
)
elif (zdata.T.shape == x0.shape) and (zdata.T.shape == y0.shape):
raise AssertionError(

Check warning on line 3031 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3030-L3031

Added lines #L3030 - L3031 were not covered by tests
"EOmaps: it seems like you need to transpose your data! \n"
+ f"the dataset has a shape of {zdata.shape}, but the "
+ f"coordinates suggest {x0.shape}"
)

df = pd.DataFrame(
data=dict(val=zdata.ravel()), index=idx, copy=False
)
df = df.to_xarray()
xg, yg = np.meshgrid(df.x, df.y)

# transform the grid from input-coordinates to the plot-coordinates
crs1 = CRS.from_user_input(self.data_specs.crs)
crs2 = CRS.from_user_input(self._crs_plot)
if crs1 != crs2:
transformer = self._get_transformer(
crs1,
crs2,
elif ((zdata.shape[0],) == x0.shape) and (

Check warning on line 3037 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3037

Added line #L3037 was not covered by tests
(zdata.shape[1],) == y0.shape
):
# 1D coordinates and 2D data

# use a rectangular QuadMesh
if self.shape.name == "shade_raster":
self.shape.glyph = ds.glyphs.QuadMeshRectilinear(

Check warning on line 3044 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3043-L3044

Added lines #L3043 - L3044 were not covered by tests
"x", "y", "val"
)

df = xar.DataArray(

Check warning on line 3048 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3048

Added line #L3048 was not covered by tests
data=zdata,
dims=["x", "y"],
coords=dict(x=x0, y=y0),
)
df = xar.Dataset(dict(val=df))

Check warning on line 3053 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3053

Added line #L3053 was not covered by tests
else:
try:

Check warning on line 3055 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3055

Added line #L3055 was not covered by tests
# try if reprojected coordinates can be used as 2d grid and if yes,
# directly use a curvilinear QuadMesh based on the reprojected
# coordinates to display the data
idx = pd.MultiIndex.from_arrays(

Check warning on line 3059 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3059

Added line #L3059 was not covered by tests
[x0.ravel(), y0.ravel()], names=["x", "y"]
)

df = pd.DataFrame(

Check warning on line 3063 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3063

Added line #L3063 was not covered by tests
data=dict(val=zdata.ravel()), index=idx, copy=False
)
xg, yg = transformer.transform(xg, yg)
df = df.to_xarray()
xg, yg = np.meshgrid(df.x, df.y)
except Exception:

Check warning on line 3068 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3066-L3068

Added lines #L3066 - L3068 were not covered by tests
# first convert original coordinates of the 1D inputs to 2D,
# then reproject the grid and use a curvilinear QuadMesh to display
# the data
_log.warning(

Check warning on line 3072 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3072

Added line #L3072 was not covered by tests
"EOmaps: 1D data is converted to 2D prior to reprojection... "
"Consider using 'shade_points' as plot-shape instead!"
)
xorig = self._data_manager.xorig.ravel()
yorig = self._data_manager.yorig.ravel()

Check warning on line 3077 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3076-L3077

Added lines #L3076 - L3077 were not covered by tests

# use a curvilinear QuadMesh
if self.shape.name == "shade_raster":
self.shape.glyph = ds.glyphs.QuadMeshCurvilinear("x", "y", "val")
idx = pd.MultiIndex.from_arrays(

Check warning on line 3079 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3079

Added line #L3079 was not covered by tests
[xorig, yorig], names=["x", "y"]
)

df = xar.Dataset(
data_vars=dict(val=(["xx", "yy"], df.val.values.T)),
coords=dict(x=(["xx", "yy"], xg), y=(["xx", "yy"], yg)),
)
df = pd.DataFrame(

Check warning on line 3083 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3083

Added line #L3083 was not covered by tests
data=dict(val=zdata.ravel()), index=idx, copy=False
)
df = df.to_xarray()
xg, yg = np.meshgrid(df.x, df.y)

Check warning on line 3087 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3086-L3087

Added lines #L3086 - L3087 were not covered by tests

# transform the grid from input-coordinates to the plot-coordinates
crs1 = CRS.from_user_input(self.data_specs.crs)
crs2 = CRS.from_user_input(self._crs_plot)
if crs1 != crs2:
transformer = self._get_transformer(

Check warning on line 3093 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3090-L3093

Added lines #L3090 - L3093 were not covered by tests
crs1,
crs2,
)
xg, yg = transformer.transform(xg, yg)

Check warning on line 3097 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3097

Added line #L3097 was not covered by tests

# use a curvilinear QuadMesh
if self.shape.name == "shade_raster":
self.shape.glyph = ds.glyphs.QuadMeshCurvilinear(

Check warning on line 3101 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3100-L3101

Added lines #L3100 - L3101 were not covered by tests
"x", "y", "val"
)

if self.shape.name == "shade_points":
df = df.to_dataframe().reset_index()
df = xar.Dataset(

Check warning on line 3105 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3105

Added line #L3105 was not covered by tests
data_vars=dict(val=(["xx", "yy"], df.val.values.T)),
coords=dict(x=(["xx", "yy"], xg), y=(["xx", "yy"], yg)),
)

if self.shape.name == "shade_points":
df = df.to_dataframe().reset_index()

if set_extent is True and self._set_extent_on_plot is True:
# convert to a numpy-array to support 2D indexing with boolean arrays
Expand All @@ -3093,18 +3135,17 @@
ax=self.ax,
plot_width=plot_width,
plot_height=plot_height,
# x_range=(x0, x1),
# y_range=(y0, y1),
# x_range=(df.x.min(), df.x.max()),
# y_range=(df.y.min(), df.y.max()),
x_range=x_range,
y_range=y_range,
vmin=self._vmin,
vmax=self._vmax,
**kwargs,
)

coll.set_label("Dataset " f"({self.shape.name} | {zdata.shape})")
if dataset_lazy:
coll.set_label("(lazy) Dataset " f"({self.shape.name} | {zdata.shape})")

Check warning on line 3146 in eomaps/eomaps.py

View check run for this annotation

Codecov / codecov/patch

eomaps/eomaps.py#L3146

Added line #L3146 was not covered by tests
else:
coll.set_label("Dataset " f"({self.shape.name} | {zdata.shape})")

self._coll = coll

Expand Down
Loading
Loading