Skip to content

Commit

Permalink
AL-837: make wcs_from_footprints accept footprint vertices instead of…
Browse files Browse the repository at this point in the history
… a list of WCS objects
  • Loading branch information
emolter committed Oct 15, 2024
1 parent dfe1d6d commit b873279
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 149 deletions.
152 changes: 48 additions & 104 deletions src/stcal/alignment/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import functools
import logging
import re
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand All @@ -15,7 +16,6 @@
from astropy import wcs as fitswcs
from astropy.coordinates import SkyCoord
from astropy.modeling import models as astmodels
from astropy.utils.misc import isiterable
from gwcs.wcstools import wcs_from_fiducial

log = logging.getLogger(__name__)
Expand All @@ -28,6 +28,7 @@
"compute_s_region_imaging",
"compute_s_region_keyword",
"wcs_from_footprints",
"wcs_bbox_from_shape",
"reproject",
]

Expand All @@ -41,15 +42,15 @@ def _calculate_fiducial_from_spatial_footprint(
Parameters
----------
spatial_footprint : np.ndarray
A 2xN array containing the world coordinates of the WCS footprint's
An Nx2 array containing the world coordinates of the WCS footprint's
bounding box, where N is the number of bounding box positions.
Returns
-------
lon_fiducial, lat_fiducial : np.ndarray, np.ndarray
The world coordinates of the fiducial point in the output coordinate frame.
"""
lon, lat = spatial_footprint
lon, lat = spatial_footprint.T
lon, lat = np.deg2rad(lon), np.deg2rad(lat)
x = np.cos(lat) * np.cos(lon)
y = np.cos(lat) * np.sin(lon)
Expand Down Expand Up @@ -139,15 +140,16 @@ def _generate_tranform(
return transform


def _get_axis_min_and_bounding_box(wcs_list: list[gwcs.wcs.WCS],
def _get_axis_min_and_bounding_box(footprints: list[np.ndarray],
ref_wcs: gwcs.wcs.WCS) -> tuple:
"""
Calculates axis minimum values and bounding box.
Parameters
----------
wcs_list : list
The list of WCS objects.
footprints : list
A list of numpy arrays each of shape (N, 2) containing the
(RA, Dec) vertices demarcating the footprint of the input WCSs.
ref_wcs : ~gwcs.wcs.WCS
The reference WCS object.
Expand All @@ -160,8 +162,7 @@ def _get_axis_min_and_bounding_box(wcs_list: list[gwcs.wcs.WCS],
2 - a tuple containing the bounding box region in the format
((x0_lower, x0_upper), (x1_lower, x1_upper)).
"""
footprints = [w.footprint().T for w in wcs_list]
domain_bounds = np.hstack([ref_wcs.backward_transform(*f) for f in footprints])
domain_bounds = np.hstack([ref_wcs.backward_transform(*f.T) for f in footprints])
axis_min_values = np.min(domain_bounds, axis=1)
domain_bounds = (domain_bounds.T - axis_min_values).T

Expand All @@ -177,24 +178,17 @@ def _get_axis_min_and_bounding_box(wcs_list: list[gwcs.wcs.WCS],
return (axis_min_values, output_bounding_box)


def _calculate_fiducial(wcs_list: list[gwcs.wcs.WCS],
bounding_box: Sequence | None,
def _calculate_fiducial(footprints: list[np.ndarray],
crval: Sequence | None = None) -> np.ndarray:
"""
Calculates the coordinates of the fiducial point and, if necessary, updates it with
the values in CRVAL (the update is applied to spatial axes only).
Parameters
----------
wcs_list : list
A list of WCS objects.
bounding_box : tuple, or list
The bounding box over which the WCS is valid. It can be a either tuple of tuples
or a list of lists of size 2 where each element represents a range of
(low, high) values. The bounding_box is in the order of the axes, axes_order.
For two inputs and axes_order(0, 1) the bounding box can be either
((xlow, xhigh), (ylow, yhigh)) or [[xlow, xhigh], [ylow, yhigh]].
footprints : list
A list of numpy arrays each of shape (N, 2) containing the
(RA, Dec) vertices demarcating the footprint of the input WCSs.
crval : list, optional
A reference world coordinate associated with the reference pixel. If not `None`,
Expand All @@ -206,15 +200,9 @@ def _calculate_fiducial(wcs_list: list[gwcs.wcs.WCS],
fiducial : np.ndarray
A two-elements array containing the world coordinate of the fiducial point.
"""
fiducial = compute_fiducial(wcs_list, bounding_box=bounding_box)
if crval is not None:
i = 0
for k, axt in enumerate(wcs_list[0].output_frame.axes_type):
if axt == "SPATIAL":
# overwrite only spatial axes with user-provided CRVAL
fiducial[k] = crval[i]
i += 1
return fiducial
return crval
return compute_fiducial(footprints)


def _calculate_offsets(fiducial: np.ndarray,
Expand Down Expand Up @@ -265,7 +253,7 @@ def _calculate_offsets(fiducial: np.ndarray,

def _calculate_new_wcs(wcs: gwcs.wcs.WCS,
shape: Sequence | None,
wcs_list: list[gwcs.wcs.WCS],
footprints: list[np.ndarray],
fiducial: np.ndarray,
crpix: Sequence | None = None,
transform: astmodels.Model | None = None,
Expand All @@ -282,8 +270,9 @@ def _calculate_new_wcs(wcs: gwcs.wcs.WCS,
The shape of the new WCS's pixel grid. If `None`, then the output bounding box
will be used to determine it.
wcs_list : list
A list containing WCS objects.
footprints : list
A list of numpy arrays each of shape (N, 2) containing the
(RA, Dec) vertices demarcating the footprint of the input WCSs.
fiducial : np.ndarray
A two-elements array containing the location on the sky in some standard
Expand All @@ -309,7 +298,7 @@ def _calculate_new_wcs(wcs: gwcs.wcs.WCS,
transform=transform,
input_frame=wcs.input_frame,
)
axis_min_values, output_bounding_box = _get_axis_min_and_bounding_box(wcs_list, wcs_new)
axis_min_values, output_bounding_box = _get_axis_min_and_bounding_box(footprints, wcs_new)
offsets = _calculate_offsets(
fiducial=fiducial,
wcs=wcs_new,
Expand All @@ -328,43 +317,6 @@ def _calculate_new_wcs(wcs: gwcs.wcs.WCS,
return wcs_new


def _validate_wcs_list(wcs_list: list[gwcs.wcs.WCS]) -> bool:
"""
Validates wcs_list.
Parameters
----------
wcs_list : list
A list of WCS objects.
Returns
-------
bool or Exception
If wcs_list is valid, returns True. Otherwise, it will raise an error.
Raises
------
ValueError
Raised whenever wcs_list is not an iterable.
TypeError
Raised whenever wcs_list is empty or any of its content is not an
instance of WCS.
"""
if not isiterable(wcs_list):
msg = "Expected 'wcs_list' to be an iterable of WCS objects."
raise ValueError(msg)

if len(wcs_list):
if not all(isinstance(w, gwcs.WCS) for w in wcs_list):
msg = "All items in 'wcs_list' are to be instances of gwcs.wcs.WCS."
raise TypeError(msg)
else:
msg = "'wcs_list' should not be empty."
raise TypeError(msg)

return True


def compute_scale(
wcs: gwcs.wcs.WCS,
fiducial: tuple | np.ndarray,
Expand Down Expand Up @@ -430,18 +382,17 @@ def compute_scale(
return float(np.sqrt(xscale * yscale))


def compute_fiducial(wcslist: list,
bounding_box: Sequence | None = None) -> np.ndarray:
def compute_fiducial(footprints: list[np.ndarray]) -> np.ndarray:
"""
Calculates the world coordinates of the fiducial point of a list of WCS objects.
For a celestial footprint this is the center. For a spectral footprint, it is the
beginning of its range.
Parameters
----------
wcslist : list
A list containing all the WCS objects for which the fiducial is to be
calculated.
footprints : list
A list of numpy arrays each of shape (N, 2) containing the
(RA, Dec) vertices demarcating the footprint of the input WCSs.
bounding_box : tuple, list, None
The bounding box over which the WCS is valid. It can be a either tuple of tuples
Expand All @@ -460,19 +411,8 @@ def compute_fiducial(wcslist: list,
-----
This function assumes all WCSs have the same output coordinate frame.
"""
axes_types = wcslist[0].output_frame.axes_type
spatial_axes = np.array(axes_types) == "SPATIAL"
spectral_axes = np.array(axes_types) == "SPECTRAL"
footprints = np.hstack([w.footprint(bounding_box=bounding_box).T for w in wcslist])
spatial_footprint = footprints[spatial_axes]
spectral_footprint = footprints[spectral_axes]

fiducial = np.empty(len(axes_types))
if spatial_footprint.any():
fiducial[spatial_axes] = _calculate_fiducial_from_spatial_footprint(spatial_footprint)
if spectral_footprint.any():
fiducial[spectral_axes] = spectral_footprint.min()
return fiducial
spatial_footprint = np.vstack(footprints)
return _calculate_fiducial_from_spatial_footprint(spatial_footprint)


def calc_rotation_matrix(roll_ref: float, v3i_yangle: float, vparity: int = 1) -> list[float]:
Expand Down Expand Up @@ -519,12 +459,27 @@ def calc_rotation_matrix(roll_ref: float, v3i_yangle: float, vparity: int = 1) -
return [pc1_1, pc1_2, pc2_1, pc2_2]


def sregion_to_footprint(s_region):
"""
Parameters
----------
s_region : str
The S_REGION header keyword
Returns
-------
footprint : np.array
A 2D array of the footprint of the region, shape (N, 2)
"""
no_prefix = re.sub(r"[a-zA-Z]", "", s_region)
return np.array(no_prefix.split(), dtype=float).reshape(-1, 2)


def wcs_from_footprints(
wcs_list: list[gwcs.wcs.WCS],
footprints: list[np.ndarray],
ref_wcs: gwcs.wcs.WCS,
ref_wcsinfo: dict,
transform: astropy.modeling.models.Model | None = None,
bounding_box: Sequence | None = None,
pscale_ratio: float | None = None,
pscale: float | None = None,
rotation: float | None = None,
Expand All @@ -549,8 +504,9 @@ def wcs_from_footprints(
Parameters
----------
wcs_list : list
A list of valid datamodels.
footprints : list
A list of numpy arrays each of shape (N, 2) containing the
(RA, Dec) vertices demarcating the footprint of the input WCSs.
ref_wcs :
A valid datamodel whose WCS is used as reference for the creation of the output
Expand All @@ -564,10 +520,6 @@ def wcs_from_footprints(
A transform, passed to :py:func:`gwcs.wcstools.wcs_from_fiducial`
If not supplied `Scaling | Rotation` is computed from ``refmodel``.
bounding_box : tuple
Bounding_box of the new WCS.
If not supplied it is computed from the bounding_box of all inputs.
pscale_ratio : float, None
Ratio of input to output pixel scale. Ignored when either
``transform`` or ``pscale`` are provided.
Expand Down Expand Up @@ -600,21 +552,13 @@ def wcs_from_footprints(
Right ascension and declination of the reference pixel. Automatically
computed if not provided.
wcs_list : list
A list of WCS objects. If not supplied, the WCS objects are extracted
from the input datamodels.
Returns
-------
wcs_new : ~gwcs.wcs.WCS
The WCS object corresponding to the combined input footprints.
"""
_validate_wcs_list(wcs_list)

fiducial = _calculate_fiducial(wcs_list=wcs_list, bounding_box=bounding_box, crval=crval)

ref_wcs = wcs_list[0] if ref_wcs is None else ref_wcs
fiducial = _calculate_fiducial(footprints, crval=crval)

transform = _generate_tranform(
ref_wcs,
Expand All @@ -630,7 +574,7 @@ def wcs_from_footprints(
wcs=ref_wcs,
shape=shape,
crpix=crpix,
wcs_list=wcs_list,
footprints=footprints,
fiducial=fiducial,
transform=transform,
)
Expand Down
2 changes: 1 addition & 1 deletion src/stcal/tweakreg/astrometric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def create_astrometric_catalog(

def compute_radius(wcs):
"""Compute the radius from the center to the furthest edge of the WCS."""
fiducial = compute_fiducial([wcs], wcs.bounding_box)
fiducial = compute_fiducial([wcs.footprint(),])
img_center = SkyCoord(
ra=fiducial[0] * u.degree,
dec=fiducial[1] * u.degree)
Expand Down
2 changes: 1 addition & 1 deletion src/stcal/tweakreg/tweakreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _parse_refcat(abs_refcat: str | Path,
# combine all aligned wcs to compute a new footprint to
# filter the absolute catalog sources
combined_wcs = wcs_from_footprints(
[corrector.wcs for corrector in correctors],
[corrector.wcs.footprint() for corrector in correctors],
ref_wcs=wcs,
ref_wcsinfo=wcsinfo,
)
Expand Down
Loading

0 comments on commit b873279

Please sign in to comment.