From b87327977e60c0344caf0ec14dbf25a6aba78b7b Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Tue, 15 Oct 2024 16:52:57 -0400 Subject: [PATCH] AL-837: make wcs_from_footprints accept footprint vertices instead of a list of WCS objects --- src/stcal/alignment/util.py | 152 ++++++++---------------- src/stcal/tweakreg/astrometric_utils.py | 2 +- src/stcal/tweakreg/tweakreg.py | 2 +- tests/test_alignment.py | 47 +------- 4 files changed, 54 insertions(+), 149 deletions(-) diff --git a/src/stcal/alignment/util.py b/src/stcal/alignment/util.py index 49ef0c200..bc328be5e 100644 --- a/src/stcal/alignment/util.py +++ b/src/stcal/alignment/util.py @@ -3,6 +3,7 @@ import functools import logging +import re from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -15,7 +16,6 @@ from astropy import wcs as fitswcs from astropy.coordinates import SkyCoord from astropy.modeling import models as astmodels -from astropy.utils.misc import isiterable from gwcs.wcstools import wcs_from_fiducial log = logging.getLogger(__name__) @@ -28,6 +28,7 @@ "compute_s_region_imaging", "compute_s_region_keyword", "wcs_from_footprints", + "wcs_bbox_from_shape", "reproject", ] @@ -41,7 +42,7 @@ def _calculate_fiducial_from_spatial_footprint( Parameters ---------- spatial_footprint : np.ndarray - A 2xN array containing the world coordinates of the WCS footprint's + An Nx2 array containing the world coordinates of the WCS footprint's bounding box, where N is the number of bounding box positions. Returns @@ -49,7 +50,7 @@ def _calculate_fiducial_from_spatial_footprint( lon_fiducial, lat_fiducial : np.ndarray, np.ndarray The world coordinates of the fiducial point in the output coordinate frame. """ - lon, lat = spatial_footprint + lon, lat = spatial_footprint.T lon, lat = np.deg2rad(lon), np.deg2rad(lat) x = np.cos(lat) * np.cos(lon) y = np.cos(lat) * np.sin(lon) @@ -139,15 +140,16 @@ def _generate_tranform( return transform -def _get_axis_min_and_bounding_box(wcs_list: list[gwcs.wcs.WCS], +def _get_axis_min_and_bounding_box(footprints: list[np.ndarray], ref_wcs: gwcs.wcs.WCS) -> tuple: """ Calculates axis minimum values and bounding box. Parameters ---------- - wcs_list : list - The list of WCS objects. + footprints : list + A list of numpy arrays each of shape (N, 2) containing the + (RA, Dec) vertices demarcating the footprint of the input WCSs. ref_wcs : ~gwcs.wcs.WCS The reference WCS object. @@ -160,8 +162,7 @@ def _get_axis_min_and_bounding_box(wcs_list: list[gwcs.wcs.WCS], 2 - a tuple containing the bounding box region in the format ((x0_lower, x0_upper), (x1_lower, x1_upper)). """ - footprints = [w.footprint().T for w in wcs_list] - domain_bounds = np.hstack([ref_wcs.backward_transform(*f) for f in footprints]) + domain_bounds = np.hstack([ref_wcs.backward_transform(*f.T) for f in footprints]) axis_min_values = np.min(domain_bounds, axis=1) domain_bounds = (domain_bounds.T - axis_min_values).T @@ -177,8 +178,7 @@ def _get_axis_min_and_bounding_box(wcs_list: list[gwcs.wcs.WCS], return (axis_min_values, output_bounding_box) -def _calculate_fiducial(wcs_list: list[gwcs.wcs.WCS], - bounding_box: Sequence | None, +def _calculate_fiducial(footprints: list[np.ndarray], crval: Sequence | None = None) -> np.ndarray: """ Calculates the coordinates of the fiducial point and, if necessary, updates it with @@ -186,15 +186,9 @@ def _calculate_fiducial(wcs_list: list[gwcs.wcs.WCS], Parameters ---------- - wcs_list : list - A list of WCS objects. - - bounding_box : tuple, or list - The bounding box over which the WCS is valid. It can be a either tuple of tuples - or a list of lists of size 2 where each element represents a range of - (low, high) values. The bounding_box is in the order of the axes, axes_order. - For two inputs and axes_order(0, 1) the bounding box can be either - ((xlow, xhigh), (ylow, yhigh)) or [[xlow, xhigh], [ylow, yhigh]]. + footprints : list + A list of numpy arrays each of shape (N, 2) containing the + (RA, Dec) vertices demarcating the footprint of the input WCSs. crval : list, optional A reference world coordinate associated with the reference pixel. If not `None`, @@ -206,15 +200,9 @@ def _calculate_fiducial(wcs_list: list[gwcs.wcs.WCS], fiducial : np.ndarray A two-elements array containing the world coordinate of the fiducial point. """ - fiducial = compute_fiducial(wcs_list, bounding_box=bounding_box) if crval is not None: - i = 0 - for k, axt in enumerate(wcs_list[0].output_frame.axes_type): - if axt == "SPATIAL": - # overwrite only spatial axes with user-provided CRVAL - fiducial[k] = crval[i] - i += 1 - return fiducial + return crval + return compute_fiducial(footprints) def _calculate_offsets(fiducial: np.ndarray, @@ -265,7 +253,7 @@ def _calculate_offsets(fiducial: np.ndarray, def _calculate_new_wcs(wcs: gwcs.wcs.WCS, shape: Sequence | None, - wcs_list: list[gwcs.wcs.WCS], + footprints: list[np.ndarray], fiducial: np.ndarray, crpix: Sequence | None = None, transform: astmodels.Model | None = None, @@ -282,8 +270,9 @@ def _calculate_new_wcs(wcs: gwcs.wcs.WCS, The shape of the new WCS's pixel grid. If `None`, then the output bounding box will be used to determine it. - wcs_list : list - A list containing WCS objects. + footprints : list + A list of numpy arrays each of shape (N, 2) containing the + (RA, Dec) vertices demarcating the footprint of the input WCSs. fiducial : np.ndarray A two-elements array containing the location on the sky in some standard @@ -309,7 +298,7 @@ def _calculate_new_wcs(wcs: gwcs.wcs.WCS, transform=transform, input_frame=wcs.input_frame, ) - axis_min_values, output_bounding_box = _get_axis_min_and_bounding_box(wcs_list, wcs_new) + axis_min_values, output_bounding_box = _get_axis_min_and_bounding_box(footprints, wcs_new) offsets = _calculate_offsets( fiducial=fiducial, wcs=wcs_new, @@ -328,43 +317,6 @@ def _calculate_new_wcs(wcs: gwcs.wcs.WCS, return wcs_new -def _validate_wcs_list(wcs_list: list[gwcs.wcs.WCS]) -> bool: - """ - Validates wcs_list. - - Parameters - ---------- - wcs_list : list - A list of WCS objects. - - Returns - ------- - bool or Exception - If wcs_list is valid, returns True. Otherwise, it will raise an error. - - Raises - ------ - ValueError - Raised whenever wcs_list is not an iterable. - TypeError - Raised whenever wcs_list is empty or any of its content is not an - instance of WCS. - """ - if not isiterable(wcs_list): - msg = "Expected 'wcs_list' to be an iterable of WCS objects." - raise ValueError(msg) - - if len(wcs_list): - if not all(isinstance(w, gwcs.WCS) for w in wcs_list): - msg = "All items in 'wcs_list' are to be instances of gwcs.wcs.WCS." - raise TypeError(msg) - else: - msg = "'wcs_list' should not be empty." - raise TypeError(msg) - - return True - - def compute_scale( wcs: gwcs.wcs.WCS, fiducial: tuple | np.ndarray, @@ -430,8 +382,7 @@ def compute_scale( return float(np.sqrt(xscale * yscale)) -def compute_fiducial(wcslist: list, - bounding_box: Sequence | None = None) -> np.ndarray: +def compute_fiducial(footprints: list[np.ndarray]) -> np.ndarray: """ Calculates the world coordinates of the fiducial point of a list of WCS objects. For a celestial footprint this is the center. For a spectral footprint, it is the @@ -439,9 +390,9 @@ def compute_fiducial(wcslist: list, Parameters ---------- - wcslist : list - A list containing all the WCS objects for which the fiducial is to be - calculated. + footprints : list + A list of numpy arrays each of shape (N, 2) containing the + (RA, Dec) vertices demarcating the footprint of the input WCSs. bounding_box : tuple, list, None The bounding box over which the WCS is valid. It can be a either tuple of tuples @@ -460,19 +411,8 @@ def compute_fiducial(wcslist: list, ----- This function assumes all WCSs have the same output coordinate frame. """ - axes_types = wcslist[0].output_frame.axes_type - spatial_axes = np.array(axes_types) == "SPATIAL" - spectral_axes = np.array(axes_types) == "SPECTRAL" - footprints = np.hstack([w.footprint(bounding_box=bounding_box).T for w in wcslist]) - spatial_footprint = footprints[spatial_axes] - spectral_footprint = footprints[spectral_axes] - - fiducial = np.empty(len(axes_types)) - if spatial_footprint.any(): - fiducial[spatial_axes] = _calculate_fiducial_from_spatial_footprint(spatial_footprint) - if spectral_footprint.any(): - fiducial[spectral_axes] = spectral_footprint.min() - return fiducial + spatial_footprint = np.vstack(footprints) + return _calculate_fiducial_from_spatial_footprint(spatial_footprint) def calc_rotation_matrix(roll_ref: float, v3i_yangle: float, vparity: int = 1) -> list[float]: @@ -519,12 +459,27 @@ def calc_rotation_matrix(roll_ref: float, v3i_yangle: float, vparity: int = 1) - return [pc1_1, pc1_2, pc2_1, pc2_2] +def sregion_to_footprint(s_region): + """ + Parameters + ---------- + s_region : str + The S_REGION header keyword + + Returns + ------- + footprint : np.array + A 2D array of the footprint of the region, shape (N, 2) + """ + no_prefix = re.sub(r"[a-zA-Z]", "", s_region) + return np.array(no_prefix.split(), dtype=float).reshape(-1, 2) + + def wcs_from_footprints( - wcs_list: list[gwcs.wcs.WCS], + footprints: list[np.ndarray], ref_wcs: gwcs.wcs.WCS, ref_wcsinfo: dict, transform: astropy.modeling.models.Model | None = None, - bounding_box: Sequence | None = None, pscale_ratio: float | None = None, pscale: float | None = None, rotation: float | None = None, @@ -549,8 +504,9 @@ def wcs_from_footprints( Parameters ---------- - wcs_list : list - A list of valid datamodels. + footprints : list + A list of numpy arrays each of shape (N, 2) containing the + (RA, Dec) vertices demarcating the footprint of the input WCSs. ref_wcs : A valid datamodel whose WCS is used as reference for the creation of the output @@ -564,10 +520,6 @@ def wcs_from_footprints( A transform, passed to :py:func:`gwcs.wcstools.wcs_from_fiducial` If not supplied `Scaling | Rotation` is computed from ``refmodel``. - bounding_box : tuple - Bounding_box of the new WCS. - If not supplied it is computed from the bounding_box of all inputs. - pscale_ratio : float, None Ratio of input to output pixel scale. Ignored when either ``transform`` or ``pscale`` are provided. @@ -600,21 +552,13 @@ def wcs_from_footprints( Right ascension and declination of the reference pixel. Automatically computed if not provided. - wcs_list : list - A list of WCS objects. If not supplied, the WCS objects are extracted - from the input datamodels. - Returns ------- wcs_new : ~gwcs.wcs.WCS The WCS object corresponding to the combined input footprints. """ - _validate_wcs_list(wcs_list) - - fiducial = _calculate_fiducial(wcs_list=wcs_list, bounding_box=bounding_box, crval=crval) - - ref_wcs = wcs_list[0] if ref_wcs is None else ref_wcs + fiducial = _calculate_fiducial(footprints, crval=crval) transform = _generate_tranform( ref_wcs, @@ -630,7 +574,7 @@ def wcs_from_footprints( wcs=ref_wcs, shape=shape, crpix=crpix, - wcs_list=wcs_list, + footprints=footprints, fiducial=fiducial, transform=transform, ) diff --git a/src/stcal/tweakreg/astrometric_utils.py b/src/stcal/tweakreg/astrometric_utils.py index 09864cc18..a7c54faac 100644 --- a/src/stcal/tweakreg/astrometric_utils.py +++ b/src/stcal/tweakreg/astrometric_utils.py @@ -139,7 +139,7 @@ def create_astrometric_catalog( def compute_radius(wcs): """Compute the radius from the center to the furthest edge of the WCS.""" - fiducial = compute_fiducial([wcs], wcs.bounding_box) + fiducial = compute_fiducial([wcs.footprint(),]) img_center = SkyCoord( ra=fiducial[0] * u.degree, dec=fiducial[1] * u.degree) diff --git a/src/stcal/tweakreg/tweakreg.py b/src/stcal/tweakreg/tweakreg.py index 0a11dc54d..bb9c1957a 100644 --- a/src/stcal/tweakreg/tweakreg.py +++ b/src/stcal/tweakreg/tweakreg.py @@ -241,7 +241,7 @@ def _parse_refcat(abs_refcat: str | Path, # combine all aligned wcs to compute a new footprint to # filter the absolute catalog sources combined_wcs = wcs_from_footprints( - [corrector.wcs for corrector in correctors], + [corrector.wcs.footprint() for corrector in correctors], ref_wcs=wcs, ref_wcsinfo=wcsinfo, ) diff --git a/tests/test_alignment.py b/tests/test_alignment.py index ecf223037..9a4aef403 100644 --- a/tests/test_alignment.py +++ b/tests/test_alignment.py @@ -10,7 +10,6 @@ from stcal.alignment import resample_utils from stcal.alignment.util import ( - _validate_wcs_list, compute_fiducial, compute_s_region_imaging, compute_s_region_keyword, @@ -132,8 +131,8 @@ def test_compute_fiducial(): pscale = (0.000014, 0.000014) # in deg/pixel wcs = _create_wcs_object_without_distortion(fiducial_world=fiducial_world, shape=shape, pscale=pscale) - - computed_fiducial = compute_fiducial([wcs]) + footprint = wcs.footprint() + computed_fiducial = compute_fiducial([footprint]) assert all(np.isclose(wcs(1, 1), computed_fiducial)) @@ -178,8 +177,8 @@ def test_wcs_from_footprints(): ) dm_2 = _create_wcs_and_datamodel(fiducial_world, shape, pscale) wcs_2 = dm_2.meta.wcs - - wcs = wcs_from_footprints([wcs_1, wcs_2], wcs_1, dm_1.meta.wcsinfo.instance) + footprints = [wcs_1.footprint(), wcs_2.footprint()] + wcs = wcs_from_footprints(footprints, wcs_1, dm_1.meta.wcsinfo.instance) # check that all elements of footprint match the *vertices* of the new combined WCS assert all(np.isclose(wcs.footprint()[0], wcs(0, 0))) @@ -192,44 +191,6 @@ def test_wcs_from_footprints(): assert all(np.isclose(wcs_2(0, 0), wcs(3.5, 0.5))) -def test_validate_wcs_list(): - shape = (3, 3) # in pixels - fiducial_world = (10, 0) # in deg - pscale = (0.000028, 0.000028) # in deg/pixel - - dm_1 = _create_wcs_and_datamodel(fiducial_world, shape, pscale) - wcs_1 = dm_1.meta.wcs - - # shift fiducial by one pixel in both directions and create a new WCS - fiducial_world = ( - fiducial_world[0] - 0.000028, - fiducial_world[1] - 0.000028, - ) - dm_2 = _create_wcs_and_datamodel(fiducial_world, shape, pscale) - wcs_2 = dm_2.meta.wcs - - wcs_list = [wcs_1, wcs_2] - - assert _validate_wcs_list(wcs_list) - - -@pytest.mark.parametrize( - ("wcs_list", "expected_error"), - [ - ([], TypeError), - ([1, 2, 3], TypeError), - (["1", "2", "3"], TypeError), - (["1", None, []], TypeError), - ("1", TypeError), - (1, ValueError), - (None, ValueError), - ], -) -def test_validate_wcs_list_invalid(wcs_list, expected_error): - with pytest.raises(expected_error, match=r".*"): - _validate_wcs_list(wcs_list) - - def get_fake_wcs(): fake_wcs1 = fitswcs.WCS( fits.Header(