Skip to content

Commit

Permalink
make it backwards compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
nden committed Oct 29, 2024
1 parent 44173f3 commit 76af105
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 13 deletions.
2 changes: 1 addition & 1 deletion jwst/assign_wcs/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def compute_scale(wcs: WCS, fiducial: Union[tuple, np.ndarray],
if spectral and disp_axis is None:
raise ValueError('If input WCS is spectral, a disp_axis must be given')

crpix = np.array(wcs.invert(*fiducial))
crpix = np.array(wcs.invert(*fiducial, with_bounding_box=False))

delta = np.zeros_like(crpix)
spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == 'SPATIAL')[0]
Expand Down
34 changes: 27 additions & 7 deletions jwst/resample/resample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from astropy import units as u
from astropy import wcs as fitswcs
import gwcs

from stdatamodels.dqflags import interpret_bit_flags
Expand Down Expand Up @@ -157,21 +158,40 @@ def reproject(wcs1, wcs2):
positions in ``wcs1`` and returns x, y positions in ``wcs2``.
"""

try:
forward_transform = wcs1.pixel_to_world_values
backward_transform = wcs2.world_to_pixel_values
except AttributeError as err:
raise TypeError("Input should be a WCS") from err
# try:
# # forward_transform = wcs1.pixel_to_world_values
# # backward_transform = wcs2.world_to_pixel_values
# forward_transform = wcs1.forward_transform
# backward_transform = wcs2.backward_transform
# except AttributeError as err:
# raise TypeError("Input should be a WCS") from err
def _get_forward(wcs):
if isinstance(wcs, gwcs.WCS):
return wcs.forward_transform
elif isinstance(wcs, fitswcs.WCS):
return wcs.pixel_to_world
elif isinstance(wcs, Model):
return wcs

def _get_backward(wcs):
if isinstance(wcs, gwcs.WCS):
return wcs.backward_transform
elif isinstance(wcs, fitswcs.WCS):
return wcs.world_to_pixel
elif isinstance(wcs, Model):
return wcs

def _reproject(x, y):
sky = forward_transform(x, y)
#sky = forward_transform(x, y)
sky = _get_forward(wcs1)(x, y)
flat_sky = []
for axis in sky:
flat_sky.append(axis.flatten())
# Filter out RuntimeWarnings due to computed NaNs in the WCS
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
det = backward_transform(*tuple(flat_sky))
#det = backward_transform(*tuple(flat_sky))
det = _get_backward(wcs2)(*tuple(flat_sky))
det_reshaped = []
for axis in det:
det_reshaped.append(axis.reshape(x.shape))
Expand Down
4 changes: 2 additions & 2 deletions jwst/resample/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def wcs_gwcs():
crpix = (500.0, 500.0)
shape = (1000, 1000)
pscale = 0.06 / 3600

prj = astmodels.Pix2Sky_TAN()
fiducial = np.array(crval)

Expand Down Expand Up @@ -193,7 +193,7 @@ def test_reproject(wcs1, wcs2, offset, request):
wcs1 = request.getfixturevalue(wcs1)
wcs2 = request.getfixturevalue(wcs2)
x = np.arange(150, 200)

f = reproject(wcs1, wcs2)
res = f(x, x)
assert_allclose(x, res[0] + offset)
Expand Down
8 changes: 5 additions & 3 deletions jwst/tweakreg/tests/test_multichip_jwst.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _make_gwcs_wcs(fits_hdr):
Mapping((1, 2), name='xtyt'))
c2tan.name = 'Cartesian 3D to TAN'

tan2c = (Mapping((0, 0, 1), n_inputs=2, name='xtyt2xyz') |
tan2c = (Mapping((0, 0, 1), name='xtyt2xyz') |
(Const1D(1, name='one') & Identity(2, name='I(2D)')))
tan2c.name = 'TAN to cartesian 3D'

Expand Down Expand Up @@ -376,7 +376,8 @@ def test_multichip_alignment_step_rel(monkeypatch):
format='ascii.ecsv', delimiter=' ',
names=['RA', 'DEC']
)
x, y = wr.world_to_pixel(refcat['RA'], refcat['DEC'])
#x, y = wr.world_to_pixel(refcat['RA'].value, refcat['DEC'].value)
x, y = wr.invert(refcat['RA'].value, refcat['DEC'].value, with_bounding_box=False)
refcat['x'] = x
refcat['y'] = y
mr.tweakreg_catalog = refcat
Expand Down Expand Up @@ -459,7 +460,8 @@ def test_multichip_alignment_step_abs(monkeypatch):
format='ascii.ecsv', delimiter=' ',
names=['RA', 'DEC']
)
x, y = wr.world_to_pixel(refcat['RA'], refcat['DEC'])
#x, y = wr.world_to_pixel(refcat['RA'], refcat['DEC'])
x, y = wr.invert(refcat['RA'].value, refcat['DEC'].value, with_bounding_box=False)
refcat['x'] = x
refcat['y'] = y
mr.tweakreg_catalog = refcat
Expand Down

0 comments on commit 76af105

Please sign in to comment.