Skip to content

Commit

Permalink
Add area to swath resampling using gradient search
Browse files Browse the repository at this point in the history
  • Loading branch information
mraspaud committed Oct 22, 2024
1 parent 7c971b1 commit c174f88
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 35 deletions.
13 changes: 8 additions & 5 deletions pyresample/gradient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,10 @@ def _concatenate_chunks(chunks):


def _fill_in_coords(target_geo_def, data_coords, data_dims):
x_coord, y_coord = target_geo_def.get_proj_vectors()
try:
x_coord, y_coord = target_geo_def.get_proj_vectors()
except AttributeError:
return None
coords = []
for key in data_dims:
if key == 'x':
Expand Down Expand Up @@ -219,8 +222,6 @@ class ResampleBlocksGradientSearchResampler(BaseResampler):

def __init__(self, source_geo_def, target_geo_def):
"""Init GradientResampler."""
if isinstance(target_geo_def, SwathDefinition):
raise NotImplementedError("Cannot resample to a SwathDefinition.")
if isinstance(source_geo_def, SwathDefinition):
source_geo_def.lons = source_geo_def.lons.persist()
source_geo_def.lats = source_geo_def.lats.persist()
Expand Down Expand Up @@ -325,11 +326,13 @@ def _get_coordinates_in_same_projection(source_area, target_area):
except AttributeError as err:
lons, lats = source_area.get_lonlats()
src_x, src_y = da.compute(lons, lats)
transformer = pyproj.Transformer.from_crs(target_area.crs, source_area.crs, always_xy=True)
try:
transformer = pyproj.Transformer.from_crs(target_area.crs, source_area.crs, always_xy=True)
dst_x, dst_y = transformer.transform(*target_area.get_proj_coords())
except AttributeError as err:
raise NotImplementedError("Cannot resample to Swath for now.") from err
# target is a swath definition
lons, lats = target_area.get_lonlats()
dst_x, dst_y = transformer.transform(*da.compute(lons, lats))
src_gradient_xl, src_gradient_xp = np.gradient(src_x, axis=[0, 1])
src_gradient_yl, src_gradient_yp = np.gradient(src_y, axis=[0, 1])
return (dst_x, dst_y), (src_gradient_xl, src_gradient_xp, src_gradient_yl, src_gradient_yp), (src_x, src_y)
Expand Down
5 changes: 4 additions & 1 deletion pyresample/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,10 @@ class AreaSlicer(Slicer):
def get_polygon_to_contain(self):
"""Get the shapely Polygon corresponding to *area_to_contain* in projection coordinates of *area_to_crop*."""
from shapely.geometry import Polygon
x, y = self.area_to_contain.get_edge_bbox_in_projection_coordinates(frequency=10)
try:
x, y = self.area_to_contain.get_edge_bbox_in_projection_coordinates(frequency=10)
except AttributeError:
x, y = self.area_to_contain.get_edge_lonlats(vertices_per_side=10)
if self.area_to_crop.is_geostationary:
x_geos, y_geos = get_geostationary_bounding_box_in_proj_coords(self.area_to_crop, 360)
x_geos, y_geos = self._transformer.transform(x_geos, y_geos, direction=TransformDirection.INVERSE)
Expand Down
78 changes: 49 additions & 29 deletions pyresample/test/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,12 @@ class TestRBGradientSearchResamplerArea2Swath:

def setup_method(self):
"""Set up the test case."""
chunks = 20
lons, lats = np.meshgrid(np.linspace(0, 20, 100), np.linspace(45, 66, 100))
self.dst_swath = SwathDefinition(lons, lats, crs="WGS84")
lons, lats = self.dst_swath.get_lonlats(chunks=10)
lons = xr.DataArray(lons, dims=["y", "x"])
lats = xr.DataArray(lats, dims=["y", "x"])
self.dst_swath_dask = SwathDefinition(lons, lats)

self.src_area = AreaDefinition('euro40', 'euro40', None,
{'proj': 'stere', 'lon_0': 14.0,
Expand All @@ -335,34 +340,49 @@ def setup_method(self):
(-2717181.7304994687, -5571048.14031214,
1378818.2695005313, -1475048.1403121399))

self.dst_area = AreaDefinition(
'omerc_otf',
'On-the-fly omerc area',
None,
{'alpha': '8.99811271718795',
'ellps': 'sphere',
'gamma': '0',
'k': '1',
'lat_0': '0',
'lonc': '13.8096029486222',
'proj': 'omerc',
'units': 'm'},
50, 100,
(-1461111.3603, 3440088.0459, 1534864.0322, 9598335.0457)
)

self.lons, self.lats = self.dst_area.get_lonlats(chunks=chunks)
xrlons = xr.DataArray(self.lons.persist())
xrlats = xr.DataArray(self.lats.persist())
self.dst_swath = SwathDefinition(xrlons, xrlats)

def test_resampling_to_swath_is_not_implemented(self):
"""Test that resampling to swath is not working yet."""
from pyresample.gradient import ResampleBlocksGradientSearchResampler

with pytest.raises(NotImplementedError):
ResampleBlocksGradientSearchResampler(self.src_area,
self.dst_swath)
@pytest.mark.parametrize("input_dtype", (np.float32, np.float64))
def test_resample_area_to_swath_2d(self, input_dtype):
"""Resample swath to area, 2d."""
swath_resampler = ResampleBlocksGradientSearchResampler(self.src_area, self.dst_swath_dask)

data = xr.DataArray(da.ones(self.src_area.shape, dtype=input_dtype),
dims=['y', 'x'])
with np.errstate(invalid="ignore"): # 'inf' space pixels cause runtime warnings
swath_resampler.precompute()
res_xr = swath_resampler.compute(data, method='bilinear')
res_np = res_xr.compute(scheduler='single-threaded')

assert res_xr.dtype == data.dtype
assert res_np.dtype == data.dtype
assert res_xr.shape == self.dst_swath.shape
assert res_np.shape == self.dst_swath.shape
assert type(res_xr) is type(data)
assert type(res_xr.data) is type(data.data)
assert not np.all(np.isnan(res_np))

@pytest.mark.parametrize("input_dtype", (np.float32, np.float64))
def test_resample_area_to_swath_3d(self, input_dtype):
"""Resample area to area, 3d."""
swath_resampler = ResampleBlocksGradientSearchResampler(self.src_area, self.dst_swath_dask)

data = xr.DataArray(da.ones((3, ) + self.src_area.shape,
dtype=input_dtype) *
np.array([1, 2, 3])[:, np.newaxis, np.newaxis],
dims=['bands', 'y', 'x'])
with np.errstate(invalid="ignore"): # 'inf' space pixels cause runtime warnings
swath_resampler.precompute()
res_xr = swath_resampler.compute(data, method='bilinear')
res_np = res_xr.compute(scheduler='single-threaded')

assert res_xr.dtype == data.dtype
assert res_np.dtype == data.dtype
assert res_xr.shape == (3, ) + self.dst_swath.shape
assert res_np.shape == (3, ) + self.dst_swath.shape
assert type(res_xr) is type(data)
assert type(res_xr.data) is type(data.data)
for i in range(res_np.shape[0]):
arr = np.ravel(res_np[i, :, :])
assert np.allclose(arr[np.isfinite(arr)], float(i + 1))


class TestEnsureDataArray(unittest.TestCase):
Expand Down

0 comments on commit c174f88

Please sign in to comment.