Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Nadia committed Jun 12, 2024
1 parent f2127e7 commit aea7ab9
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 57 deletions.
48 changes: 24 additions & 24 deletions jwst/assign_wcs/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,30 +74,30 @@ def _domain_to_bounding_box(domain):
return bb


def reproject(wcs1, wcs2):
"""
Given two WCSs return a function which takes pixel coordinates in
the first WCS and computes their location in the second one.
It performs the forward transformation of ``wcs1`` followed by the
inverse of ``wcs2``.
Parameters
----------
wcs1, wcs2 : `~gwcs.wcs.WCS`
WCS objects.
Returns
-------
_reproject : func
Function to compute the transformations. It takes x, y
positions in ``wcs1`` and returns x, y positions in ``wcs2``.
"""

def _reproject(x, y):
sky = wcs1.forward_transform(x, y)
return wcs2.backward_transform(*sky)
return _reproject
# def reproject(wcs1, wcs2):
# """
# Given two WCSs return a function which takes pixel coordinates in
# the first WCS and computes their location in the second one.

# It performs the forward transformation of ``wcs1`` followed by the
# inverse of ``wcs2``.

# Parameters
# ----------
# wcs1, wcs2 : `~gwcs.wcs.WCS`
# WCS objects.

# Returns
# -------
# _reproject : func
# Function to compute the transformations. It takes x, y
# positions in ``wcs1`` and returns x, y positions in ``wcs2``.
# """

# def _reproject(x, y):
# sky = wcs1.forward_transform(x, y)
# return wcs2.backward_transform(*sky)
# return _reproject


def compute_scale(wcs: WCS, fiducial: Union[tuple, np.ndarray],
Expand Down
7 changes: 0 additions & 7 deletions jwst/regtest/test_niriss_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def test_niriss_tweakreg_no_sources(rtdata, fitsdiff_default_kwargs):
rtdata.input = "niriss/imaging/jw01537-o003_20240406t164421_image3_00004_asn.json"
rtdata.get_asn("niriss/imaging/jw01537-o003_20240406t164421_image3_00004_asn.json")

<<<<<<< HEAD
args = [
"jwst.tweakreg.TweakRegStep",
rtdata.input,
Expand All @@ -64,12 +63,6 @@ def test_niriss_tweakreg_no_sources(rtdata, fitsdiff_default_kwargs):

# run the test from the command line:
result = Step.from_cmdline(args)
=======
args = ["jwst.tweakreg.TweakRegStep", rtdata.input, "--abs_refcat='GAIADR3'"]
result = Step.from_cmdline(args)
# Check that the step is skipped
assert result.skip
>>>>>>> edae07e81 (fix assert statement; run the test with abs_refcat GAIADR3)

# Check the status of the step is set correctly in the files.
mc = datamodels.ModelContainer(rtdata.input)
Expand Down
34 changes: 8 additions & 26 deletions jwst/resample/resample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from astropy import units as u
from astropy import wcs as fitswcs
from astropy.wcs.wcsapi.wrappers import SlicedLowLevelWCS
from astropy.modeling import Model
import gwcs

Expand Down Expand Up @@ -145,40 +146,21 @@ def reproject(wcs1, wcs2):
positions in ``wcs1`` and returns x, y positions in ``wcs2``.
"""

# try:
# # forward_transform = wcs1.pixel_to_world_values
# # backward_transform = wcs2.world_to_pixel_values
# forward_transform = wcs1.forward_transform
# backward_transform = wcs2.backward_transform
# except AttributeError as err:
# raise TypeError("Input should be a WCS") from err
def _get_forward(wcs):
if isinstance(wcs, gwcs.WCS):
return wcs.forward_transform
elif isinstance(wcs, fitswcs.WCS):
return wcs.pixel_to_world
elif isinstance(wcs, Model):
return wcs

def _get_backward(wcs):
if isinstance(wcs, gwcs.WCS):
return wcs.backward_transform
elif isinstance(wcs, fitswcs.WCS):
return wcs.world_to_pixel
elif isinstance(wcs, Model):
return wcs
try:
forward_transform = wcs1.pixel_to_world_values
backward_transform = wcs2.world_to_pixel_values
except AttributeError as err:
raise TypeError("Input should be a WCS") from err

def _reproject(x, y):
#sky = forward_transform(x, y)
sky = _get_forward(wcs1)(x, y)
sky = forward_transform(x, y)
flat_sky = []
for axis in sky:
flat_sky.append(axis.flatten())
# Filter out RuntimeWarnings due to computed NaNs in the WCS
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
#det = backward_transform(*tuple(flat_sky))
det = _get_backward(wcs2)(*tuple(flat_sky))
det = backward_transform(*tuple(flat_sky))
det_reshaped = []
for axis in det:
det_reshaped.append(axis.reshape(x.shape))
Expand Down

0 comments on commit aea7ab9

Please sign in to comment.