Skip to content

Commit

Permalink
fix resample's reproject
Browse files Browse the repository at this point in the history
  • Loading branch information
nden committed Dec 10, 2024
1 parent cdd0e65 commit 99d927e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 29 deletions.
40 changes: 14 additions & 26 deletions jwst/resample/resample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,40 +164,28 @@ def reproject(wcs1, wcs2):
positions in ``wcs1`` and returns x, y positions in ``wcs2``.
"""

# try:
# # forward_transform = wcs1.pixel_to_world_values
# # backward_transform = wcs2.world_to_pixel_values
# forward_transform = wcs1.forward_transform
# backward_transform = wcs2.backward_transform
# except AttributeError as err:
# raise TypeError("Input should be a WCS") from err
def _get_forward(wcs):
if isinstance(wcs, gwcs.WCS):
return wcs.forward_transform
elif isinstance(wcs, fitswcs.WCS):
return wcs.pixel_to_world_values
elif isinstance(wcs, Model):
return wcs

def _get_backward(wcs):
if isinstance(wcs, gwcs.WCS):
return wcs.backward_transform
elif isinstance(wcs, fitswcs.WCS):
return wcs.world_to_pixel_values
elif isinstance(wcs, Model):
return wcs
try:
# Here we want to use the WCS API functions so that a Sliced WCS
# will work as well. However, the API functions do not accept
# keyword arguments and `with_bounding_box=False` cannot be passsed.
# We delete the bounding box on a copy of the WCS - yes, inefficient.
forward_transform = wcs1.pixel_to_world_values
wcs_no_bbox = deepcopy(wcs2)
wcs_no_bbox.bounding_box = None
backward_transform = wcs_no_bbox.world_to_pixel_values
except AttributeError as err:
raise TypeError("Input should be a WCS") from err


def _reproject(x, y):
#sky = forward_transform(x, y)
sky = _get_forward(wcs1)(x, y)
sky = forward_transform(x, y)
flat_sky = []
for axis in sky:
flat_sky.append(axis.flatten())
# Filter out RuntimeWarnings due to computed NaNs in the WCS
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
#det = backward_transform(*tuple(flat_sky))
det = _get_backward(wcs2)(*tuple(flat_sky))
det = backward_transform(*tuple(flat_sky))
det_reshaped = []
for axis in det:
det_reshaped.append(axis.reshape(x.shape))
Expand Down
6 changes: 3 additions & 3 deletions jwst/resample/tests/test_resample_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,9 +831,9 @@ def test_resample_undefined_variance(nircam_rate, shape):

@pytest.mark.parametrize('ratio', [0.7, 1.2])
@pytest.mark.parametrize('rotation', [0, 15, 135])
@pytest.mark.parametrize('crpix', [(256, 488), (700, 124)])
@pytest.mark.parametrize('crval', [(50, 77), (20, -30)])
@pytest.mark.parametrize('shape', [(1205, 1100)])
@pytest.mark.parametrize('crpix', [(100, 101), (101, 101)])
@pytest.mark.parametrize('crval', [(22.01, 12), (22.15, 12.01)])
@pytest.mark.parametrize('shape', [(10205, 10100)])
def test_custom_wcs_resample_imaging(nircam_rate, ratio, rotation, crpix, crval, shape):
im = AssignWcsStep.call(nircam_rate, sip_approx=False)
im.data += 5
Expand Down

0 comments on commit 99d927e

Please sign in to comment.