From b01d68d33fda9593d7d39736d98056c6c7b352b3 Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Thu, 24 Oct 2024 13:35:37 +0300 Subject: [PATCH 1/7] Add dtype checks to more tests --- .../enhancement_tests/test_enhancements.py | 7 ++-- satpy/tests/test_composites.py | 33 +++++++++++++------ 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/satpy/tests/enhancement_tests/test_enhancements.py b/satpy/tests/enhancement_tests/test_enhancements.py index 96176fda34..3d20677237 100644 --- a/satpy/tests/enhancement_tests/test_enhancements.py +++ b/satpy/tests/enhancement_tests/test_enhancements.py @@ -109,14 +109,15 @@ def _calc_func(data): exp_data = exp_data[np.newaxis, :, :] run_and_check_enhancement(_enh_func, in_data, exp_data) - def test_cira_stretch(self): + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) + def test_cira_stretch(self, dtype): """Test applying the cira_stretch.""" from satpy.enhancements import cira_stretch expected = np.array([[ [np.nan, -7.04045974, -7.04045974, 0.79630132, 0.95947296], - [1.05181359, 1.11651012, 1.16635571, 1.20691137, 1.24110186]]]) - run_and_check_enhancement(cira_stretch, self.ch1, expected) + [1.05181359, 1.11651012, 1.16635571, 1.20691137, 1.24110186]]], dtype=dtype) + run_and_check_enhancement(cira_stretch, self.ch1.astype(dtype), expected) def test_reinhard(self): """Test the reinhard algorithm.""" diff --git a/satpy/tests/test_composites.py b/satpy/tests/test_composites.py index 1b60161a52..50b9b88b74 100644 --- a/satpy/tests/test_composites.py +++ b/satpy/tests/test_composites.py @@ -250,27 +250,34 @@ def test_more_than_three_datasets(self): with pytest.raises(ValueError, match="Expected 3 datasets, got 4"): comp((self.ds1, self.ds2, self.ds3, self.ds1), optional_datasets=(self.ds4_big,)) - def test_self_sharpened_no_high_res(self): + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) + def test_self_sharpened_no_high_res(self, dtype): """Test for exception when no high_res band is specified.""" from satpy.composites import SelfSharpenedRGB comp = SelfSharpenedRGB(name="true_color", high_resolution_band=None) with pytest.raises(ValueError, match="SelfSharpenedRGB requires at least one high resolution band, not 'None'"): comp((self.ds1, self.ds2, self.ds3)) - def test_basic_no_high_res(self): + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) + def test_basic_no_high_res(self, dtype): """Test that three datasets can be passed without optional high res.""" from satpy.composites import RatioSharpenedRGB comp = RatioSharpenedRGB(name="true_color") - res = comp((self.ds1, self.ds2, self.ds3)) + res = comp((self.ds1.astype(dtype), self.ds2.astype(dtype), self.ds3.astype(dtype))) assert res.shape == (3, 2, 2) + assert res.dtype == dtype - def test_basic_no_sharpen(self): + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) + def test_basic_no_sharpen(self, dtype): """Test that color None does no sharpening.""" from satpy.composites import RatioSharpenedRGB comp = RatioSharpenedRGB(name="true_color", high_resolution_band=None) - res = comp((self.ds1, self.ds2, self.ds3), optional_datasets=(self.ds4,)) + res = comp((self.ds1.astype(dtype), self.ds2.astype(dtype), self.ds3.astype(dtype)), + optional_datasets=(self.ds4.astype(dtype),)) assert res.shape == (3, 2, 2) + assert res.dtype == dtype + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize( ("high_resolution_band", "neutral_resolution_band", "exp_r", "exp_g", "exp_b"), [ @@ -300,22 +307,26 @@ def test_basic_no_sharpen(self): np.array([[1.0, 1.0], [np.nan, 1.0]], dtype=np.float64)) ] ) - def test_ratio_sharpening(self, high_resolution_band, neutral_resolution_band, exp_r, exp_g, exp_b): + def test_ratio_sharpening(self, high_resolution_band, neutral_resolution_band, exp_r, exp_g, exp_b, dtype): """Test RatioSharpenedRGB by different groups of high_resolution_band and neutral_resolution_band.""" from satpy.composites import RatioSharpenedRGB comp = RatioSharpenedRGB(name="true_color", high_resolution_band=high_resolution_band, neutral_resolution_band=neutral_resolution_band) - res = comp((self.ds1, self.ds2, self.ds3), optional_datasets=(self.ds4,)) + res = comp((self.ds1.astype(dtype), self.ds2.astype(dtype), self.ds3.astype(dtype)), + optional_datasets=(self.ds4.astype(dtype),)) assert "units" not in res.attrs assert isinstance(res, xr.DataArray) assert isinstance(res.data, da.Array) + assert res.dtype == dtype data = res.values np.testing.assert_allclose(data[0], exp_r, rtol=1e-5) np.testing.assert_allclose(data[1], exp_g, rtol=1e-5) np.testing.assert_allclose(data[2], exp_b, rtol=1e-5) + assert res.dtype == dtype + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize( ("exp_shape", "exp_r", "exp_g", "exp_b"), [ @@ -325,17 +336,19 @@ def test_ratio_sharpening(self, high_resolution_band, neutral_resolution_band, e np.array([[16 / 3, 16 / 3], [16 / 3, 0]], dtype=np.float64)) ] ) - def test_self_sharpened_basic(self, exp_shape, exp_r, exp_g, exp_b): + def test_self_sharpened_basic(self, exp_shape, exp_r, exp_g, exp_b, dtype): """Test that three datasets can be passed without optional high res.""" from satpy.composites import SelfSharpenedRGB comp = SelfSharpenedRGB(name="true_color") - res = comp((self.ds1, self.ds2, self.ds3)) - data = res.values + res = comp((self.ds1.astype(dtype), self.ds2.astype(dtype), self.ds3.astype(dtype))) + assert res.dtype == dtype + data = res.values assert data.shape == exp_shape np.testing.assert_allclose(data[0], exp_r, rtol=1e-5) np.testing.assert_allclose(data[1], exp_g, rtol=1e-5) np.testing.assert_allclose(data[2], exp_b, rtol=1e-5) + assert data.dtype == dtype class TestDifferenceCompositor(unittest.TestCase): From 0e4a3f56772d46844a08e756bf099e87615eea5c Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Thu, 24 Oct 2024 13:36:44 +0300 Subject: [PATCH 2/7] Add dtype checks for modifiers --- satpy/tests/test_modifiers.py | 58 +++++++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/satpy/tests/test_modifiers.py b/satpy/tests/test_modifiers.py index 7e28a7456b..fa360f2a8b 100644 --- a/satpy/tests/test_modifiers.py +++ b/satpy/tests/test_modifiers.py @@ -135,29 +135,46 @@ def test_basic_default_not_provided(self, sunz_ds1, as_32bit): assert res.dtype == res_np.dtype assert "y" not in res.coords assert "x" not in res.coords + if as_32bit: + assert res.dtype == np.float32 - def test_basic_lims_not_provided(self, sunz_ds1): + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) + def test_basic_lims_not_provided(self, sunz_ds1, dtype): """Test custom limits when SZA isn't provided.""" from satpy.modifiers.geometry import SunZenithCorrector comp = SunZenithCorrector(name="sza_test", modifiers=tuple(), correction_limit=90) - res = comp((sunz_ds1,), test_attr="test") - np.testing.assert_allclose(res.values, np.array([[66.853262, 68.168939], [66.30742, 67.601493]])) - + res = comp((sunz_ds1.astype(dtype),), test_attr="test") + expected = np.array([[66.853262, 68.168939], [66.30742, 67.601493]], dtype=dtype) + values = res.values + np.testing.assert_allclose(values, expected, rtol=1e-5) + assert res.dtype == dtype + assert values.dtype == dtype + + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize("data_arr", [lazy_fixture("sunz_ds1"), lazy_fixture("sunz_ds1_stacked")]) - def test_basic_default_provided(self, data_arr, sunz_sza): + def test_basic_default_provided(self, data_arr, sunz_sza, dtype): """Test default limits when SZA is provided.""" from satpy.modifiers.geometry import SunZenithCorrector comp = SunZenithCorrector(name="sza_test", modifiers=tuple()) - res = comp((data_arr, sunz_sza), test_attr="test") - np.testing.assert_allclose(res.values, np.array([[22.401667, 22.31777], [22.437503, 22.353533]])) - + res = comp((data_arr.astype(dtype), sunz_sza.astype(dtype)), test_attr="test") + expected = np.array([[22.401667, 22.31777], [22.437503, 22.353533]], dtype=dtype) + values = res.values + np.testing.assert_allclose(values, expected) + assert res.dtype == dtype + assert values.dtype == dtype + + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize("data_arr", [lazy_fixture("sunz_ds1"), lazy_fixture("sunz_ds1_stacked")]) - def test_basic_lims_provided(self, data_arr, sunz_sza): + def test_basic_lims_provided(self, data_arr, sunz_sza, dtype): """Test custom limits when SZA is provided.""" from satpy.modifiers.geometry import SunZenithCorrector comp = SunZenithCorrector(name="sza_test", modifiers=tuple(), correction_limit=90) - res = comp((data_arr, sunz_sza), test_attr="test") - np.testing.assert_allclose(res.values, np.array([[66.853262, 68.168939], [66.30742, 67.601493]])) + res = comp((data_arr.astype(dtype), sunz_sza.astype(dtype)), test_attr="test") + expected = np.array([[66.853262, 68.168939], [66.30742, 67.601493]], dtype=dtype) + values = res.values + np.testing.assert_allclose(values, expected, rtol=1e-5) + assert res.dtype == dtype + assert values.dtype == dtype def test_imcompatible_areas(self, sunz_ds2, sunz_sza): """Test sunz correction on incompatible areas.""" @@ -502,6 +519,7 @@ def _create_test_data(self, name, wavelength, resolution): }) return input_band, red_band, angle1, angle1, angle1, angle1 + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize( ("name", "wavelength", "resolution", "aerosol_type", "reduce_lim_low", "reduce_lim_high", "reduce_strength", "exp_mean", "exp_unique"), @@ -521,7 +539,7 @@ def _create_test_data(self, name, wavelength, resolution): ] ) def test_rayleigh_corrector(self, name, wavelength, resolution, aerosol_type, reduce_lim_low, reduce_lim_high, - reduce_strength, exp_mean, exp_unique): + reduce_strength, exp_mean, exp_unique, dtype): """Test PSPRayleighReflectance with fake data.""" from satpy.modifiers.atmosphere import PSPRayleighReflectance ray_cor = PSPRayleighReflectance(name=name, atmosphere="us-standard", aerosol_types=aerosol_type, @@ -535,42 +553,48 @@ def test_rayleigh_corrector(self, name, wavelength, resolution, aerosol_type, re assert ray_cor.attrs["reduce_strength"] == reduce_strength input_band, red_band, *_ = self._create_test_data(name, wavelength, resolution) - res = ray_cor([input_band, red_band]) + res = ray_cor([input_band.astype(dtype), red_band.astype(dtype)]) assert isinstance(res, xr.DataArray) assert isinstance(res.data, da.Array) + assert res.dtype == dtype data = res.values unique = np.unique(data[~np.isnan(data)]) np.testing.assert_allclose(np.nanmean(data), exp_mean, rtol=1e-5) assert data.shape == (3, 5) np.testing.assert_allclose(unique, exp_unique, rtol=1e-5) + assert data.dtype == dtype + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize("as_optionals", [False, True]) - def test_rayleigh_with_angles(self, as_optionals): + def test_rayleigh_with_angles(self, as_optionals, dtype): """Test PSPRayleighReflectance with angles provided.""" from satpy.modifiers.atmosphere import PSPRayleighReflectance aerosol_type = "rayleigh_only" ray_cor = PSPRayleighReflectance(name="B01", atmosphere="us-standard", aerosol_types=aerosol_type) - prereqs, opt_prereqs = self._get_angles_prereqs_and_opts(as_optionals) + prereqs, opt_prereqs = self._get_angles_prereqs_and_opts(as_optionals, dtype) with mock.patch("satpy.modifiers.atmosphere.get_angles") as get_angles: res = ray_cor(prereqs, opt_prereqs) get_angles.assert_not_called() assert isinstance(res, xr.DataArray) assert isinstance(res.data, da.Array) + assert res.dtype == dtype data = res.values unique = np.unique(data[~np.isnan(data)]) np.testing.assert_allclose(unique, np.array([-75.0, -37.71298492, 31.14350754]), rtol=1e-5) assert data.shape == (3, 5) + assert data.dtype == dtype - def _get_angles_prereqs_and_opts(self, as_optionals): + def _get_angles_prereqs_and_opts(self, as_optionals, dtype): wavelength = (0.45, 0.47, 0.49) resolution = 1000 input_band, red_band, *angles = self._create_test_data("B01", wavelength, resolution) - prereqs = [input_band, red_band] + prereqs = [input_band.astype(dtype), red_band.astype(dtype)] opt_prereqs = [] + angles = [a.astype(dtype) for a in angles] if as_optionals: opt_prereqs = angles else: From 965c1b448666bfd4d931f1505c79b185e2e2b6b9 Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Thu, 24 Oct 2024 13:57:41 +0300 Subject: [PATCH 3/7] Compute Rayleigh correction with original data precision --- satpy/modifiers/atmosphere.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/satpy/modifiers/atmosphere.py b/satpy/modifiers/atmosphere.py index 1c6225f42a..a190510a83 100644 --- a/satpy/modifiers/atmosphere.py +++ b/satpy/modifiers/atmosphere.py @@ -77,7 +77,9 @@ def __call__(self, projectables, optional_datasets=None, **info): projectables = projectables + (optional_datasets or []) if len(projectables) != 6: vis, red = self.match_data_arrays(projectables) - sata, satz, suna, sunz = get_angles(vis) + # Adjust the angle data precision to match the data + # This does not affect the accuracy visibly + sata, satz, suna, sunz = [d.astype(vis.dtype) for d in get_angles(vis)] else: vis, red, sata, satz, suna, sunz = self.match_data_arrays(projectables) # First make sure the two azimuth angles are in the range 0-360: @@ -116,14 +118,14 @@ def __call__(self, projectables, optional_datasets=None, **info): refl_cor_band = corrector.get_reflectance(sunz, satz, ssadiff, vis.attrs["wavelength"][1], red.data) - if reduce_strength > 0: if reduce_lim_low > reduce_lim_high: reduce_lim_low = reduce_lim_high refl_cor_band = corrector.reduce_rayleigh_highzenith(sunz, refl_cor_band, reduce_lim_low, reduce_lim_high, reduce_strength) - proj = vis - refl_cor_band + # Need to convert again to data precision, Rayleigh calculations always promote datatype to float64 + proj = vis - refl_cor_band.astype(vis.dtype) proj.attrs = vis.attrs self.apply_modifier_info(vis, proj) return proj From 392fb773a617e40fc2324290516f80ec0db0da0f Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Fri, 25 Oct 2024 13:03:30 +0300 Subject: [PATCH 4/7] Fix clipping of Rayleigh strength reduction --- satpy/modifiers/atmosphere.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/satpy/modifiers/atmosphere.py b/satpy/modifiers/atmosphere.py index a190510a83..8d869b2d53 100644 --- a/satpy/modifiers/atmosphere.py +++ b/satpy/modifiers/atmosphere.py @@ -99,7 +99,7 @@ def __call__(self, projectables, optional_datasets=None, **info): aerosol_type = self.attrs.get("aerosol_type", "marine_clean_aerosol") reduce_lim_low = abs(self.attrs.get("reduce_lim_low", 70)) reduce_lim_high = abs(self.attrs.get("reduce_lim_high", 105)) - reduce_strength = np.clip(self.attrs.get("reduce_strength", 0), 0, 1) + reduce_strength = np.clip(self.attrs.get("reduce_strength", 0), 0, 1).astype(vis.dtype) logger.info("Removing Rayleigh scattering with atmosphere '%s' and " "aerosol type '%s' for '%s'", @@ -118,6 +118,7 @@ def __call__(self, projectables, optional_datasets=None, **info): refl_cor_band = corrector.get_reflectance(sunz, satz, ssadiff, vis.attrs["wavelength"][1], red.data) + if reduce_strength > 0: if reduce_lim_low > reduce_lim_high: reduce_lim_low = reduce_lim_high @@ -125,7 +126,7 @@ def __call__(self, projectables, optional_datasets=None, **info): reduce_lim_low, reduce_lim_high, reduce_strength) # Need to convert again to data precision, Rayleigh calculations always promote datatype to float64 - proj = vis - refl_cor_band.astype(vis.dtype) + proj = vis - refl_cor_band proj.attrs = vis.attrs self.apply_modifier_info(vis, proj) return proj From 118fc93b407ceb4a8688057d1c43233cddfcebf6 Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Fri, 25 Oct 2024 13:14:15 +0300 Subject: [PATCH 5/7] Remove obsolete comment --- satpy/modifiers/atmosphere.py | 1 - 1 file changed, 1 deletion(-) diff --git a/satpy/modifiers/atmosphere.py b/satpy/modifiers/atmosphere.py index 8d869b2d53..c7144c27ca 100644 --- a/satpy/modifiers/atmosphere.py +++ b/satpy/modifiers/atmosphere.py @@ -125,7 +125,6 @@ def __call__(self, projectables, optional_datasets=None, **info): refl_cor_band = corrector.reduce_rayleigh_highzenith(sunz, refl_cor_band, reduce_lim_low, reduce_lim_high, reduce_strength) - # Need to convert again to data precision, Rayleigh calculations always promote datatype to float64 proj = vis - refl_cor_band proj.attrs = vis.attrs self.apply_modifier_info(vis, proj) From de80552b37282dcc5cd5be7fe410d777a5ad341d Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Fri, 25 Oct 2024 14:50:10 +0300 Subject: [PATCH 6/7] Test also computed dtypes --- satpy/tests/test_composites.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/satpy/tests/test_composites.py b/satpy/tests/test_composites.py index 50b9b88b74..91c383d1b3 100644 --- a/satpy/tests/test_composites.py +++ b/satpy/tests/test_composites.py @@ -266,6 +266,7 @@ def test_basic_no_high_res(self, dtype): res = comp((self.ds1.astype(dtype), self.ds2.astype(dtype), self.ds3.astype(dtype))) assert res.shape == (3, 2, 2) assert res.dtype == dtype + assert res.values.dtype == dtype @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_basic_no_sharpen(self, dtype): @@ -276,6 +277,7 @@ def test_basic_no_sharpen(self, dtype): optional_datasets=(self.ds4.astype(dtype),)) assert res.shape == (3, 2, 2) assert res.dtype == dtype + assert res.values.dtype == dtype @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize( From 3f1076ae94af61c375be206d260b3b96d86520b9 Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Fri, 25 Oct 2024 14:50:50 +0300 Subject: [PATCH 7/7] Remove unnecessary dtype parametrization --- satpy/tests/test_composites.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/satpy/tests/test_composites.py b/satpy/tests/test_composites.py index 91c383d1b3..eb3f90b715 100644 --- a/satpy/tests/test_composites.py +++ b/satpy/tests/test_composites.py @@ -250,8 +250,7 @@ def test_more_than_three_datasets(self): with pytest.raises(ValueError, match="Expected 3 datasets, got 4"): comp((self.ds1, self.ds2, self.ds3, self.ds1), optional_datasets=(self.ds4_big,)) - @pytest.mark.parametrize("dtype", [np.float32, np.float64]) - def test_self_sharpened_no_high_res(self, dtype): + def test_self_sharpened_no_high_res(self): """Test for exception when no high_res band is specified.""" from satpy.composites import SelfSharpenedRGB comp = SelfSharpenedRGB(name="true_color", high_resolution_band=None)