diff --git a/gpm/dataset/conventions.py b/gpm/dataset/conventions.py index 6c315db7..4ce5bac2 100644 --- a/gpm/dataset/conventions.py +++ b/gpm/dataset/conventions.py @@ -82,6 +82,29 @@ def _check_time_period_coverage(ds, start_time=None, end_time=None, raise_error= warnings.warn(msg, GPM_Warning, stacklevel=1) +def _get_chunks_encodings(ds): + dict_chunksizes = {} + dict_preferred_chunks = {} + for name in list(ds.data_vars) + list(ds.coords): + preferred_chunks = ds[name].encoding.get("preferred_chunks", None) + chunksizes = ds[name].encoding.get("chunksizes", None) + if preferred_chunks: + # Use values() to remove phony_dim_* keys + dict_preferred_chunks[name] = dict(zip(ds[name].dims, preferred_chunks.values())) + if preferred_chunks: + dict_chunksizes[name] = dict(zip(ds[name].dims, chunksizes)) + return dict_chunksizes, dict_preferred_chunks + + +def _update_chunks_encodings(ds, dict_chunksizes, dict_preferred_chunks): + for name in list(ds.data_vars) + list(ds.coords): + if name in dict_chunksizes: + ds[name].encoding["preferred_chunks"] = {dim: dict_preferred_chunks[name][dim] for dim in ds[name].dims} + if name in dict_chunksizes: + ds[name].encoding["chunksizes"] = tuple([dict_chunksizes[name][dim] for dim in ds[name].dims]) + return ds + + def reshape_dataset(ds): """Define the dataset dimension order. @@ -90,12 +113,14 @@ def reshape_dataset(ds): For GPM GRID objects: (..., time, lat, lon) For GPM ORBIT objects: (cross_track, along_track, ...) """ + dict_chunksizes, dict_preferred_chunks = _get_chunks_encodings(ds) if "lat" in ds.dims: ds = ds.transpose(..., "lat", "lon") elif "cross_track" in ds.dims: ds = ds.transpose("cross_track", "along_track", ...) else: ds = ds.transpose("along_track", ...) + ds = _update_chunks_encodings(ds, dict_chunksizes=dict_chunksizes, dict_preferred_chunks=dict_preferred_chunks) return ds