From 99d927e3fb52f7faf42b70e986de41e05e8e628d Mon Sep 17 00:00:00 2001 From: Nadia Dencheva Date: Thu, 31 Oct 2024 08:12:37 -0400 Subject: [PATCH] fix resample's reproject --- jwst/resample/resample_utils.py | 40 ++++++++--------------- jwst/resample/tests/test_resample_step.py | 6 ++-- 2 files changed, 17 insertions(+), 29 deletions(-) diff --git a/jwst/resample/resample_utils.py b/jwst/resample/resample_utils.py index 26bc8e85de..8678b7335a 100644 --- a/jwst/resample/resample_utils.py +++ b/jwst/resample/resample_utils.py @@ -164,40 +164,28 @@ def reproject(wcs1, wcs2): positions in ``wcs1`` and returns x, y positions in ``wcs2``. """ - # try: - # # forward_transform = wcs1.pixel_to_world_values - # # backward_transform = wcs2.world_to_pixel_values - # forward_transform = wcs1.forward_transform - # backward_transform = wcs2.backward_transform - # except AttributeError as err: - # raise TypeError("Input should be a WCS") from err - def _get_forward(wcs): - if isinstance(wcs, gwcs.WCS): - return wcs.forward_transform - elif isinstance(wcs, fitswcs.WCS): - return wcs.pixel_to_world_values - elif isinstance(wcs, Model): - return wcs - - def _get_backward(wcs): - if isinstance(wcs, gwcs.WCS): - return wcs.backward_transform - elif isinstance(wcs, fitswcs.WCS): - return wcs.world_to_pixel_values - elif isinstance(wcs, Model): - return wcs + try: + # Here we want to use the WCS API functions so that a Sliced WCS + # will work as well. However, the API functions do not accept + # keyword arguments and `with_bounding_box=False` cannot be passsed. + # We delete the bounding box on a copy of the WCS - yes, inefficient. + forward_transform = wcs1.pixel_to_world_values + wcs_no_bbox = deepcopy(wcs2) + wcs_no_bbox.bounding_box = None + backward_transform = wcs_no_bbox.world_to_pixel_values + except AttributeError as err: + raise TypeError("Input should be a WCS") from err + def _reproject(x, y): - #sky = forward_transform(x, y) - sky = _get_forward(wcs1)(x, y) + sky = forward_transform(x, y) flat_sky = [] for axis in sky: flat_sky.append(axis.flatten()) # Filter out RuntimeWarnings due to computed NaNs in the WCS with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) - #det = backward_transform(*tuple(flat_sky)) - det = _get_backward(wcs2)(*tuple(flat_sky)) + det = backward_transform(*tuple(flat_sky)) det_reshaped = [] for axis in det: det_reshaped.append(axis.reshape(x.shape)) diff --git a/jwst/resample/tests/test_resample_step.py b/jwst/resample/tests/test_resample_step.py index 75eddb8530..178340ccb4 100644 --- a/jwst/resample/tests/test_resample_step.py +++ b/jwst/resample/tests/test_resample_step.py @@ -831,9 +831,9 @@ def test_resample_undefined_variance(nircam_rate, shape): @pytest.mark.parametrize('ratio', [0.7, 1.2]) @pytest.mark.parametrize('rotation', [0, 15, 135]) -@pytest.mark.parametrize('crpix', [(256, 488), (700, 124)]) -@pytest.mark.parametrize('crval', [(50, 77), (20, -30)]) -@pytest.mark.parametrize('shape', [(1205, 1100)]) +@pytest.mark.parametrize('crpix', [(100, 101), (101, 101)]) +@pytest.mark.parametrize('crval', [(22.01, 12), (22.15, 12.01)]) +@pytest.mark.parametrize('shape', [(10205, 10100)]) def test_custom_wcs_resample_imaging(nircam_rate, ratio, rotation, crpix, crval, shape): im = AssignWcsStep.call(nircam_rate, sip_approx=False) im.data += 5