diff --git a/jwst/assign_wcs/util.py b/jwst/assign_wcs/util.py index d8e492587f..4a624543c3 100644 --- a/jwst/assign_wcs/util.py +++ b/jwst/assign_wcs/util.py @@ -127,7 +127,7 @@ def compute_scale(wcs: WCS, fiducial: Union[tuple, np.ndarray], if spectral and disp_axis is None: raise ValueError('If input WCS is spectral, a disp_axis must be given') - crpix = np.array(wcs.invert(*fiducial)) + crpix = np.array(wcs.invert(*fiducial, with_bounding_box=False)) delta = np.zeros_like(crpix) spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == 'SPATIAL')[0] diff --git a/jwst/resample/resample_utils.py b/jwst/resample/resample_utils.py index c527b450f1..39cdeba53e 100644 --- a/jwst/resample/resample_utils.py +++ b/jwst/resample/resample_utils.py @@ -4,6 +4,7 @@ import numpy as np from astropy import units as u +from astropy import wcs as fitswcs import gwcs from stdatamodels.dqflags import interpret_bit_flags @@ -157,21 +158,40 @@ def reproject(wcs1, wcs2): positions in ``wcs1`` and returns x, y positions in ``wcs2``. """ - try: - forward_transform = wcs1.pixel_to_world_values - backward_transform = wcs2.world_to_pixel_values - except AttributeError as err: - raise TypeError("Input should be a WCS") from err + # try: + # # forward_transform = wcs1.pixel_to_world_values + # # backward_transform = wcs2.world_to_pixel_values + # forward_transform = wcs1.forward_transform + # backward_transform = wcs2.backward_transform + # except AttributeError as err: + # raise TypeError("Input should be a WCS") from err + def _get_forward(wcs): + if isinstance(wcs, gwcs.WCS): + return wcs.forward_transform + elif isinstance(wcs, fitswcs.WCS): + return wcs.pixel_to_world + elif isinstance(wcs, Model): + return wcs + + def _get_backward(wcs): + if isinstance(wcs, gwcs.WCS): + return wcs.backward_transform + elif isinstance(wcs, fitswcs.WCS): + return wcs.world_to_pixel + elif isinstance(wcs, Model): + return wcs def _reproject(x, y): - sky = forward_transform(x, y) + #sky = forward_transform(x, y) + sky = _get_forward(wcs1)(x, y) flat_sky = [] for axis in sky: flat_sky.append(axis.flatten()) # Filter out RuntimeWarnings due to computed NaNs in the WCS with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) - det = backward_transform(*tuple(flat_sky)) + #det = backward_transform(*tuple(flat_sky)) + det = _get_backward(wcs2)(*tuple(flat_sky)) det_reshaped = [] for axis in det: det_reshaped.append(axis.reshape(x.shape)) diff --git a/jwst/resample/tests/test_utils.py b/jwst/resample/tests/test_utils.py index 7e7deb3070..518fc61998 100644 --- a/jwst/resample/tests/test_utils.py +++ b/jwst/resample/tests/test_utils.py @@ -38,7 +38,7 @@ def wcs_gwcs(): crpix = (500.0, 500.0) shape = (1000, 1000) pscale = 0.06 / 3600 - + prj = astmodels.Pix2Sky_TAN() fiducial = np.array(crval) @@ -193,7 +193,7 @@ def test_reproject(wcs1, wcs2, offset, request): wcs1 = request.getfixturevalue(wcs1) wcs2 = request.getfixturevalue(wcs2) x = np.arange(150, 200) - + f = reproject(wcs1, wcs2) res = f(x, x) assert_allclose(x, res[0] + offset) diff --git a/jwst/tweakreg/tests/test_multichip_jwst.py b/jwst/tweakreg/tests/test_multichip_jwst.py index 4402f7b897..b8e2f1a08a 100644 --- a/jwst/tweakreg/tests/test_multichip_jwst.py +++ b/jwst/tweakreg/tests/test_multichip_jwst.py @@ -81,7 +81,7 @@ def _make_gwcs_wcs(fits_hdr): Mapping((1, 2), name='xtyt')) c2tan.name = 'Cartesian 3D to TAN' - tan2c = (Mapping((0, 0, 1), n_inputs=2, name='xtyt2xyz') | + tan2c = (Mapping((0, 0, 1), name='xtyt2xyz') | (Const1D(1, name='one') & Identity(2, name='I(2D)'))) tan2c.name = 'TAN to cartesian 3D' @@ -376,7 +376,8 @@ def test_multichip_alignment_step_rel(monkeypatch): format='ascii.ecsv', delimiter=' ', names=['RA', 'DEC'] ) - x, y = wr.world_to_pixel(refcat['RA'], refcat['DEC']) + #x, y = wr.world_to_pixel(refcat['RA'].value, refcat['DEC'].value) + x, y = wr.invert(refcat['RA'].value, refcat['DEC'].value, with_bounding_box=False) refcat['x'] = x refcat['y'] = y mr.tweakreg_catalog = refcat @@ -459,7 +460,8 @@ def test_multichip_alignment_step_abs(monkeypatch): format='ascii.ecsv', delimiter=' ', names=['RA', 'DEC'] ) - x, y = wr.world_to_pixel(refcat['RA'], refcat['DEC']) + #x, y = wr.world_to_pixel(refcat['RA'], refcat['DEC']) + x, y = wr.invert(refcat['RA'].value, refcat['DEC'].value, with_bounding_box=False) refcat['x'] = x refcat['y'] = y mr.tweakreg_catalog = refcat