From 0de7840bd58d49097ba6e54a3b633aa7974accb7 Mon Sep 17 00:00:00 2001 From: Nadia Dencheva Date: Thu, 30 May 2024 17:17:12 -0400 Subject: [PATCH] make it backwards compatible --- jwst/assign_wcs/util.py | 2 +- jwst/resample/resample_utils.py | 34 +++++++++++++++++----- jwst/resample/tests/test_utils.py | 4 +-- jwst/tweakreg/tests/test_multichip_jwst.py | 8 +++-- 4 files changed, 35 insertions(+), 13 deletions(-) diff --git a/jwst/assign_wcs/util.py b/jwst/assign_wcs/util.py index 2767182fe0..c8e5c77465 100644 --- a/jwst/assign_wcs/util.py +++ b/jwst/assign_wcs/util.py @@ -129,7 +129,7 @@ def compute_scale(wcs: WCS, fiducial: Union[tuple, np.ndarray], if spectral and disp_axis is None: raise ValueError('If input WCS is spectral, a disp_axis must be given') - crpix = np.array(wcs.invert(*fiducial)) + crpix = np.array(wcs.invert(*fiducial, with_bounding_box=False)) delta = np.zeros_like(crpix) spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == 'SPATIAL')[0] diff --git a/jwst/resample/resample_utils.py b/jwst/resample/resample_utils.py index ed4447d8b5..3b7ce13f89 100644 --- a/jwst/resample/resample_utils.py +++ b/jwst/resample/resample_utils.py @@ -4,6 +4,7 @@ import numpy as np from astropy import units as u +from astropy import wcs as fitswcs import gwcs from stdatamodels.dqflags import interpret_bit_flags @@ -143,21 +144,40 @@ def reproject(wcs1, wcs2): positions in ``wcs1`` and returns x, y positions in ``wcs2``. """ - try: - forward_transform = wcs1.pixel_to_world_values - backward_transform = wcs2.world_to_pixel_values - except AttributeError as err: - raise TypeError("Input should be a WCS") from err + # try: + # # forward_transform = wcs1.pixel_to_world_values + # # backward_transform = wcs2.world_to_pixel_values + # forward_transform = wcs1.forward_transform + # backward_transform = wcs2.backward_transform + # except AttributeError as err: + # raise TypeError("Input should be a WCS") from err + def _get_forward(wcs): + if isinstance(wcs, gwcs.WCS): + return wcs.forward_transform + elif isinstance(wcs, fitswcs.WCS): + return wcs.pixel_to_world + elif isinstance(wcs, Model): + return wcs + + def _get_backward(wcs): + if isinstance(wcs, gwcs.WCS): + return wcs.backward_transform + elif isinstance(wcs, fitswcs.WCS): + return wcs.world_to_pixel + elif isinstance(wcs, Model): + return wcs def _reproject(x, y): - sky = forward_transform(x, y) + #sky = forward_transform(x, y) + sky = _get_forward(wcs1)(x, y) flat_sky = [] for axis in sky: flat_sky.append(axis.flatten()) # Filter out RuntimeWarnings due to computed NaNs in the WCS with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) - det = backward_transform(*tuple(flat_sky)) + #det = backward_transform(*tuple(flat_sky)) + det = _get_backward(wcs2)(*tuple(flat_sky)) det_reshaped = [] for axis in det: det_reshaped.append(axis.reshape(x.shape)) diff --git a/jwst/resample/tests/test_utils.py b/jwst/resample/tests/test_utils.py index af41f9730b..f48dce95cf 100644 --- a/jwst/resample/tests/test_utils.py +++ b/jwst/resample/tests/test_utils.py @@ -37,7 +37,7 @@ def wcs_gwcs(): crpix = (500.0, 500.0) shape = (1000, 1000) pscale = 0.06 / 3600 - + prj = astmodels.Pix2Sky_TAN() fiducial = np.array(crval) @@ -192,7 +192,7 @@ def test_reproject(wcs1, wcs2, offset, request): wcs1 = request.getfixturevalue(wcs1) wcs2 = request.getfixturevalue(wcs2) x = np.arange(150, 200) - + f = reproject(wcs1, wcs2) res = f(x, x) assert_allclose(x, res[0] + offset) diff --git a/jwst/tweakreg/tests/test_multichip_jwst.py b/jwst/tweakreg/tests/test_multichip_jwst.py index 6daa383c33..9766aea85d 100644 --- a/jwst/tweakreg/tests/test_multichip_jwst.py +++ b/jwst/tweakreg/tests/test_multichip_jwst.py @@ -83,7 +83,7 @@ def _make_gwcs_wcs(fits_hdr): Mapping((1, 2), name='xtyt')) c2tan.name = 'Cartesian 3D to TAN' - tan2c = (Mapping((0, 0, 1), n_inputs=2, name='xtyt2xyz') | + tan2c = (Mapping((0, 0, 1), name='xtyt2xyz') | (Const1D(1, name='one') & Identity(2, name='I(2D)'))) tan2c.name = 'TAN to cartesian 3D' @@ -377,7 +377,8 @@ def test_multichip_alignment_step(monkeypatch): format='ascii.ecsv', delimiter=' ', names=['RA', 'DEC'] ) - x, y = wr.world_to_pixel(refcat['RA'], refcat['DEC']) + #x, y = wr.world_to_pixel(refcat['RA'].value, refcat['DEC'].value) + x, y = wr.invert(refcat['RA'].value, refcat['DEC'].value, with_bounding_box=False) refcat['x'] = x refcat['y'] = y mr.tweakreg_catalog = refcat @@ -450,7 +451,8 @@ def test_multichip_alignment_step_abs(monkeypatch): format='ascii.ecsv', delimiter=' ', names=['RA', 'DEC'] ) - x, y = wr.world_to_pixel(refcat['RA'], refcat['DEC']) + #x, y = wr.world_to_pixel(refcat['RA'], refcat['DEC']) + x, y = wr.invert(refcat['RA'].value, refcat['DEC'].value, with_bounding_box=False) refcat['x'] = x refcat['y'] = y mr.tweakreg_catalog = refcat