From 9e6101a52c4230e330a2708ebe2487780c1e59a7 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Wed, 31 May 2023 11:01:57 -0500 Subject: [PATCH] Fix performance regression in base resampler class when comparing geometries --- pyresample/resampler.py | 38 ++++++++++- .../test/test_resamplers/test_resampler.py | 67 +++++++++++++++++-- 2 files changed, 99 insertions(+), 6 deletions(-) diff --git a/pyresample/resampler.py b/pyresample/resampler.py index 781ee3919..84f860157 100644 --- a/pyresample/resampler.py +++ b/pyresample/resampler.py @@ -121,7 +121,7 @@ def resample(self, data, cache_dir=None, mask_area=None, **kwargs): Returns (xarray.DataArray): Data resampled to the target area """ - if self.source_geo_def == self.target_geo_def: + if self._geometries_are_the_same(): return data # default is to mask areas for SwathDefinitions if mask_area is None and isinstance( @@ -143,6 +143,42 @@ def resample(self, data, cache_dir=None, mask_area=None, **kwargs): cache_id = self.precompute(cache_dir=cache_dir, **kwargs) return self.compute(data, cache_id=cache_id, **kwargs) + def _geometries_are_the_same(self): + """Check if two geometries are the same object and resampling isn't needed. + + For area definitions this is a simple comparison using the ``==``. + When swaths are involved care is taken to not check coordinate equality + to avoid the expensive computation. A swath and an area are never + considered equal in this case even if they describe the same geographic + region. + + Two swaths are only considered equal if the underlying arrays are the + exact same objects. Otherwise, they are considered not equal and + coordinate values are never checked. This has + the downside that if two SwathDefinitions have equal coordinates but + are loaded or created separately they will be considered not equal. + + """ + if self.source_geo_def is self.target_geo_def: + return True + if type(self.source_geo_def) is not type(self.target_geo_def): # noqa + # these aren't the exact same class + return False + if isinstance(self.source_geo_def, AreaDefinition): + return self.source_geo_def == self.target_geo_def + # swath or coordinate definitions + src_lons, src_lats = self.source_geo_def.get_lonlats() + dst_lons, dst_lats = self.target_geo_def.get_lonlats() + if (src_lons is dst_lons) and (src_lats is dst_lats): + return True + + if not all(isinstance(arr, da.Array) for arr in (src_lons, src_lats, dst_lons, dst_lats)): + # they aren't the same object and they aren't dask arrays so not equal + return False + # if dask task names are the same then they are the same even if the + # dask Array instance itself is different + return src_lons.name == dst_lons.name and src_lats.name == dst_lats.name + def _create_cache_filename(self, cache_dir=None, prefix='', fmt='.zarr', **kwargs): """Create filename for the cached resampling parameters.""" diff --git a/pyresample/test/test_resamplers/test_resampler.py b/pyresample/test/test_resamplers/test_resampler.py index cbf9e03f9..b5e97511e 100644 --- a/pyresample/test/test_resamplers/test_resampler.py +++ b/pyresample/test/test_resamplers/test_resampler.py @@ -20,12 +20,14 @@ from unittest import mock +import dask.array as da import numpy as np import pytest +import xarray as xr from pytest_lazyfixture import lazy_fixture from pyresample.future.resamplers.resampler import Resampler -from pyresample.geometry import AreaDefinition +from pyresample.geometry import AreaDefinition, SwathDefinition from pyresample.resampler import BaseResampler @@ -76,13 +78,68 @@ def test_resampler(src, dst): assert resample_results.shape == dst.shape -def test_base_resampler_does_nothing_when_src_and_dst_areas_are_equal(): +@pytest.mark.parametrize( + ("use_swaths", "copy_dst_swath"), + [ + (False, None), + (True, None), # same objects are equal + (True, "dask"), # same dask tasks are equal + (True, "swath_def"), # same underlying arrays are equal + ]) +def test_base_resampler_does_nothing_when_src_and_dst_areas_are_equal(_geos_area, use_swaths, copy_dst_swath): """Test that the BaseResampler does nothing when the source and target areas are the same.""" + src_geom = _geos_area if not use_swaths else _xarray_swath_def_from_area(_geos_area) + dst_geom = src_geom + if copy_dst_swath == "dask": + dst_geom = _xarray_swath_def_from_area(_geos_area) + elif copy_dst_swath == "swath_def": + dst_geom = SwathDefinition(dst_geom.lons, dst_geom.lats) + + resampler = BaseResampler(src_geom, dst_geom) + some_data = xr.DataArray(da.zeros(src_geom.shape, dtype=np.float64), dims=('y', 'x')) + assert resampler.resample(some_data) is some_data + + +@pytest.mark.parametrize( + ("src_area", "numpy_swath"), + [ + (False, False), + (False, True), + (True, False), + ]) +@pytest.mark.parametrize("dst_area", [False, True]) +def test_base_resampler_unequal_geometries(_geos_area, _geos_area2, src_area, numpy_swath, dst_area): + """Test cases where BaseResampler geometries are not considered equal.""" + src_geom = _geos_area if src_area else _xarray_swath_def_from_area(_geos_area, numpy_swath) + dst_geom = _geos_area2 if dst_area else _xarray_swath_def_from_area(_geos_area2) + resampler = BaseResampler(src_geom, dst_geom) + some_data = xr.DataArray(da.zeros(src_geom.shape, dtype=np.float64), dims=('y', 'x')) + with pytest.raises(NotImplementedError): + resampler.resample(some_data) + + +def _xarray_swath_def_from_area(area_def, use_numpy=False): + chunks = None if use_numpy else -1 + lons_da, lats_da = area_def.get_lonlats(chunks=chunks) + lons = xr.DataArray(lons_da, dims=('y', 'x')) + lats = xr.DataArray(lats_da, dims=('y', 'x')) + swath_def = SwathDefinition(lons, lats) + return swath_def + + +@pytest.fixture +def _geos_area(): src_area = AreaDefinition('src', 'src area', None, {'ellps': 'WGS84', 'h': '35785831', 'proj': 'geos'}, 100, 100, (5550000.0, 5550000.0, -5550000.0, -5550000.0)) + return src_area - resampler = BaseResampler(src_area, src_area) - some_data = np.zeros(src_area.shape, dtype=np.float64) - assert resampler.resample(some_data) is some_data + +@pytest.fixture +def _geos_area2(): + src_area = AreaDefinition('src', 'src area', None, + {'ellps': 'WGS84', 'h': '35785831', 'proj': 'geos'}, + 200, 200, + (5550000.0, 5550000.0, -5550000.0, -5550000.0)) + return src_area