Skip to content

Commit

Permalink
Merge pull request #626 from mraspaud/feature-remove-stacking-gradient
Browse files Browse the repository at this point in the history
Replace stacking gradient search with resample_blocks variant
  • Loading branch information
mraspaud authored Oct 24, 2024
2 parents 8da073e + 4dd2948 commit 69c3a65
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 713 deletions.
311 changes: 35 additions & 276 deletions pyresample/gradient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,256 +53,26 @@ def GradientSearchResampler(source_geo_def, target_geo_def):

def create_gradient_search_resampler(source_geo_def, target_geo_def):
"""Create a gradient search resampler."""
if isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, AreaDefinition):
if (is_area_to_area(source_geo_def, target_geo_def) or
is_swath_to_area(source_geo_def, target_geo_def) or
is_area_to_swath(source_geo_def, target_geo_def)):
return ResampleBlocksGradientSearchResampler(source_geo_def, target_geo_def)
elif isinstance(source_geo_def, SwathDefinition) and isinstance(target_geo_def, AreaDefinition):
return StackingGradientSearchResampler(source_geo_def, target_geo_def)
raise NotImplementedError


@da.as_gufunc(signature='(),()->(),()')
def transform(x_coords, y_coords, src_prj=None, dst_prj=None):
"""Calculate projection coordinates."""
transformer = pyproj.Transformer.from_crs(src_prj, dst_prj)
return transformer.transform(x_coords, y_coords)
def is_area_to_area(source_geo_def, target_geo_def):
"""Check if source is area and target is area."""
return isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, AreaDefinition)


class StackingGradientSearchResampler(BaseResampler):
"""Resample using gradient search based bilinear interpolation, using stacking for dask processing."""

def __init__(self, source_geo_def, target_geo_def):
"""Init GradientResampler."""
super().__init__(source_geo_def, target_geo_def)
import warnings
warnings.warn("You are using the Gradient Search Resampler, which is still EXPERIMENTAL.", stacklevel=2)
self.use_input_coords = None
self._src_dst_filtered = False
self.prj = None
self.src_x = None
self.src_y = None
self.src_slices = None
self.dst_x = None
self.dst_y = None
self.dst_slices = None
self.src_gradient_xl = None
self.src_gradient_xp = None
self.src_gradient_yl = None
self.src_gradient_yp = None
self.dst_polys = {}
self.dst_mosaic_locations = None
self.coverage_status = None

def _get_projection_coordinates(self, datachunks):
"""Get projection coordinates."""
if self.use_input_coords is None:
try:
self.src_x, self.src_y = self.source_geo_def.get_proj_coords(
chunks=datachunks)
src_crs = self.source_geo_def.crs
self.use_input_coords = True
except AttributeError:
self.src_x, self.src_y = self.source_geo_def.get_lonlats(
chunks=datachunks)
src_crs = pyproj.CRS.from_string("+proj=longlat")
self.use_input_coords = False
try:
self.dst_x, self.dst_y = self.target_geo_def.get_proj_coords(
chunks=CHUNK_SIZE)
dst_crs = self.target_geo_def.crs
except AttributeError as err:
if self.use_input_coords is False:
raise NotImplementedError('Cannot resample lon/lat to lon/lat with gradient search.') from err
self.dst_x, self.dst_y = self.target_geo_def.get_lonlats(
chunks=CHUNK_SIZE)
dst_crs = pyproj.CRS.from_string("+proj=longlat")
if self.use_input_coords:
self.dst_x, self.dst_y = transform(
self.dst_x, self.dst_y,
src_prj=dst_crs, dst_prj=src_crs)
self.prj = pyproj.Proj(self.source_geo_def.crs)
else:
self.src_x, self.src_y = transform(
self.src_x, self.src_y,
src_prj=src_crs, dst_prj=dst_crs)
self.prj = pyproj.Proj(self.target_geo_def.crs)

def _get_prj_poly(self, geo_def):
# - None if out of Earth Disk
# - False is SwathDefinition
if isinstance(geo_def, SwathDefinition):
return False
try:
poly = get_polygon(self.prj, geo_def)
except (NotImplementedError, ValueError): # out-of-earth disk area or any valid projected boundary coordinates
poly = None
return poly

def _get_src_poly(self, src_y_start, src_y_end, src_x_start, src_x_end):
"""Get bounding polygon for source chunk."""
geo_def = self.source_geo_def[src_y_start:src_y_end,
src_x_start:src_x_end]
return self._get_prj_poly(geo_def)

def _get_dst_poly(self, idx,
dst_x_start, dst_x_end,
dst_y_start, dst_y_end):
"""Get target chunk polygon."""
dst_poly = self.dst_polys.get(idx, None)
if dst_poly is None:
geo_def = self.target_geo_def[dst_y_start:dst_y_end,
dst_x_start:dst_x_end]
dst_poly = self._get_prj_poly(geo_def)
self.dst_polys[idx] = dst_poly
return dst_poly

def get_chunk_mappings(self):
"""Map source and target chunks together if they overlap."""
src_y_chunks, src_x_chunks = self.src_x.chunks
dst_y_chunks, dst_x_chunks = self.dst_x.chunks

coverage_status = []
src_slices, dst_slices = [], []
dst_mosaic_locations = []

src_x_start = 0
for src_x_step in src_x_chunks:
src_x_end = src_x_start + src_x_step
src_y_start = 0
for src_y_step in src_y_chunks:
src_y_end = src_y_start + src_y_step
# Get source chunk polygon
src_poly = self._get_src_poly(src_y_start, src_y_end,
src_x_start, src_x_end)

dst_x_start = 0
for x_step_number, dst_x_step in enumerate(dst_x_chunks):
dst_x_end = dst_x_start + dst_x_step
dst_y_start = 0
for y_step_number, dst_y_step in enumerate(dst_y_chunks):
dst_y_end = dst_y_start + dst_y_step
# Get destination chunk polygon
dst_poly = self._get_dst_poly((x_step_number, y_step_number),
dst_x_start, dst_x_end,
dst_y_start, dst_y_end)

covers = check_overlap(src_poly, dst_poly)

coverage_status.append(covers)
src_slices.append((src_y_start, src_y_end,
src_x_start, src_x_end))
dst_slices.append((dst_y_start, dst_y_end,
dst_x_start, dst_x_end))
dst_mosaic_locations.append((x_step_number, y_step_number))

dst_y_start = dst_y_end
dst_x_start = dst_x_end
src_y_start = src_y_end
src_x_start = src_x_end

self.src_slices = src_slices
self.dst_slices = dst_slices
self.dst_mosaic_locations = dst_mosaic_locations
self.coverage_status = coverage_status

def _filter_data(self, data, is_src=True, add_dim=False):
"""Filter unused chunks from the given array."""
if add_dim:
if data.ndim not in [2, 3]:
raise NotImplementedError('Gradient search resampling only '
'supports 2D or 3D arrays.')
if data.ndim == 2:
data = data[np.newaxis, :, :]

data_out = []
for i, covers in enumerate(self.coverage_status):
if covers:
if is_src:
y_start, y_end, x_start, x_end = self.src_slices[i]
else:
y_start, y_end, x_start, x_end = self.dst_slices[i]
try:
val = data[:, y_start:y_end, x_start:x_end]
except IndexError:
val = data[y_start:y_end, x_start:x_end]
else:
val = None
data_out.append(val)

return data_out

def _get_gradients(self):
"""Get gradients in X and Y directions."""
self.src_gradient_xl, self.src_gradient_xp = np.gradient(
self.src_x, axis=[0, 1])
self.src_gradient_yl, self.src_gradient_yp = np.gradient(
self.src_y, axis=[0, 1])

def _filter_src_dst(self):
"""Filter source and target chunks."""
self.src_x = self._filter_data(self.src_x)
self.src_y = self._filter_data(self.src_y)
self.src_gradient_yl = self._filter_data(self.src_gradient_yl)
self.src_gradient_yp = self._filter_data(self.src_gradient_yp)
self.src_gradient_xl = self._filter_data(self.src_gradient_xl)
self.src_gradient_xp = self._filter_data(self.src_gradient_xp)
self.dst_x = self._filter_data(self.dst_x, is_src=False)
self.dst_y = self._filter_data(self.dst_y, is_src=False)
self._src_dst_filtered = True

def compute(self, data, fill_value=None, **kwargs):
"""Resample the given data using gradient search algorithm."""
if 'bands' in data.dims:
datachunks = data.sel(bands=data.coords['bands'][0]).chunks
else:
datachunks = data.chunks
data_dims = data.dims
data_coords = data.coords

self._get_projection_coordinates(datachunks)

if self.src_gradient_xl is None:
self._get_gradients()
if self.coverage_status is None:
self.get_chunk_mappings()
if not self._src_dst_filtered:
self._filter_src_dst()

data = self._filter_data(data.data, add_dim=True)

res = parallel_gradient_search(data,
self.src_x, self.src_y,
self.dst_x, self.dst_y,
self.src_gradient_xl,
self.src_gradient_xp,
self.src_gradient_yl,
self.src_gradient_yp,
self.dst_mosaic_locations,
self.dst_slices,
**kwargs)

coords = _fill_in_coords(self.target_geo_def, data_coords, data_dims)

if fill_value is not None:
res = da.where(np.isnan(res), fill_value, res)
if res.ndim > len(data_dims):
res = res.squeeze()

res = xr.DataArray(res, dims=data_dims, coords=coords)
return res
def is_swath_to_area(source_geo_def, target_geo_def):
"""Check if source is swath and target is area."""
return isinstance(source_geo_def, SwathDefinition) and isinstance(target_geo_def, AreaDefinition)


def check_overlap(src_poly, dst_poly):
"""Check if the two polygons overlap."""
# swath definition case
if dst_poly is False or src_poly is False:
covers = True
# area / area case
elif dst_poly is not None and src_poly is not None:
covers = src_poly.intersects(dst_poly)
# out of earth disk case
else:
covers = False
return covers
def is_area_to_swath(source_geo_def, target_geo_def):
"""Check if source is area and targed is swath."""
return isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, SwathDefinition)


def _gradient_resample_data(src_data, src_x, src_y,
Expand Down Expand Up @@ -367,30 +137,6 @@ def _check_input_coordinates(dst_x, dst_y,
raise ValueError("Target arrays should all have the same shape")


def get_border_lonlats(geo_def: AreaDefinition):
"""Get the border x- and y-coordinates."""
if geo_def.is_geostationary:
lon_b, lat_b = get_geostationary_bounding_box_in_lonlats(geo_def, 3600)
else:
lons, lats = geo_def.get_boundary_lonlats()
lon_b = np.concatenate((lons.side1, lons.side2, lons.side3, lons.side4))
lat_b = np.concatenate((lats.side1, lats.side2, lats.side3, lats.side4))

return lon_b, lat_b


def get_polygon(prj, geo_def):
"""Get border polygon from area definition in projection *prj*."""
lon_b, lat_b = get_border_lonlats(geo_def)
x_borders, y_borders = prj(lon_b, lat_b)
boundary = [(x_borders[i], y_borders[i]) for i in range(len(x_borders))
if np.isfinite(x_borders[i]) and np.isfinite(y_borders[i])]
poly = Polygon(boundary)
if np.isfinite(poly.area) and poly.area > 0.0:
return poly
return None


def parallel_gradient_search(data, src_x, src_y, dst_x, dst_y,
src_gradient_xl, src_gradient_xp,
src_gradient_yl, src_gradient_yp,
Expand Down Expand Up @@ -456,7 +202,10 @@ def _concatenate_chunks(chunks):


def _fill_in_coords(target_geo_def, data_coords, data_dims):
x_coord, y_coord = target_geo_def.get_proj_vectors()
try:
x_coord, y_coord = target_geo_def.get_proj_vectors()
except AttributeError:
return None
coords = []
for key in data_dims:
if key == 'x':
Expand Down Expand Up @@ -489,10 +238,10 @@ class ResampleBlocksGradientSearchResampler(BaseResampler):

def __init__(self, source_geo_def, target_geo_def):
"""Init GradientResampler."""
if isinstance(target_geo_def, SwathDefinition):
raise NotImplementedError("Cannot resample to a SwathDefinition.")
if isinstance(source_geo_def, SwathDefinition):
source_geo_def.lons = source_geo_def.lons.persist()
source_geo_def.lats = source_geo_def.lats.persist()
super().__init__(source_geo_def, target_geo_def)
logger.debug("/!\\ Instantiating an experimental GradientSearch resampler /!\\")
self.indices_xy = None

def precompute(self, **kwargs):
Expand Down Expand Up @@ -590,14 +339,21 @@ def gradient_resampler_indices(source_area, target_area, block_info=None, **kwar
def _get_coordinates_in_same_projection(source_area, target_area):
try:
src_x, src_y = source_area.get_proj_coords()
transformer = pyproj.Transformer.from_crs(target_area.crs, source_area.crs, always_xy=True)
except AttributeError as err:
raise NotImplementedError("Cannot resample from Swath for now.") from err

work_crs = source_area.crs
except AttributeError:
# source is a swath definition, use target crs instead
lons, lats = source_area.get_lonlats()
src_x, src_y = da.compute(lons, lats)
trans = pyproj.Transformer.from_crs(source_area.crs, target_area.crs, always_xy=True)
src_x, src_y = trans.transform(src_x, src_y)
work_crs = target_area.crs
transformer = pyproj.Transformer.from_crs(target_area.crs, work_crs, always_xy=True)
try:
dst_x, dst_y = transformer.transform(*target_area.get_proj_coords())
except AttributeError as err:
raise NotImplementedError("Cannot resample to Swath for now.") from err
except AttributeError:
# target is a swath definition
lons, lats = target_area.get_lonlats()
dst_x, dst_y = transformer.transform(*da.compute(lons, lats))
src_gradient_xl, src_gradient_xp = np.gradient(src_x, axis=[0, 1])
src_gradient_yl, src_gradient_yp = np.gradient(src_y, axis=[0, 1])
return (dst_x, dst_y), (src_gradient_xl, src_gradient_xp, src_gradient_yl, src_gradient_yp), (src_x, src_y)
Expand All @@ -610,6 +366,9 @@ def block_bilinear_interpolator(data, indices_xy, fill_value=np.nan, block_info=
weight_l, l_start = np.modf(y_indices.clip(0, data.shape[-2] - 1))
weight_p, p_start = np.modf(x_indices.clip(0, data.shape[-1] - 1))

weight_l = weight_l.astype(data.dtype)
weight_p = weight_p.astype(data.dtype)

l_start = l_start.astype(int)
p_start = p_start.astype(int)
l_end = np.clip(l_start + 1, 1, data.shape[-2] - 1)
Expand Down
8 changes: 4 additions & 4 deletions pyresample/gradient/_gradient_search.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ cdef inline void bil(const data_type[:, :, :] data, int l0, int p0, float_index
p_b = min(p0 + 1, pmax)
w_p = dp
for i in range(z_size):
res[i] = ((1 - w_l) * (1 - w_p) * data[i, l_a, p_a] +
(1 - w_l) * w_p * data[i, l_a, p_b] +
w_l * (1 - w_p) * data[i, l_b, p_a] +
w_l * w_p * data[i, l_b, p_b])
res[i] = <data_type>((1 - w_l) * (1 - w_p) * data[i, l_a, p_a] +
(1 - w_l) * w_p * data[i, l_a, p_b] +
w_l * (1 - w_p) * data[i, l_b, p_a] +
w_l * w_p * data[i, l_b, p_b])


@cython.boundscheck(False)
Expand Down
7 changes: 3 additions & 4 deletions pyresample/resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def resample_blocks(func, src_area, src_arrays, dst_area,
fill_value: Desired value for any invalid values in the output array
kwargs: any other keyword arguments that will be passed on to func.
Returns:
A dask array, chunked as dst_area, containing the resampled data.
Principle of operations:
Resample_blocks works by iterating over chunks on the dst_area domain. For each chunk, the corresponding slice
Expand All @@ -235,10 +238,6 @@ def resample_blocks(func, src_area, src_arrays, dst_area,
"""
if dst_area == src_area:
raise ValueError("Source and destination areas are identical."
" Should you be running `map_blocks` instead of `resample_blocks`?")

name = _create_dask_name(name, func,
src_area, src_arrays,
dst_area, dst_arrays,
Expand Down
Loading

0 comments on commit 69c3a65

Please sign in to comment.