diff --git a/.gitignore b/.gitignore index 62371214..c4fc59e8 100644 --- a/.gitignore +++ b/.gitignore @@ -51,6 +51,8 @@ coverage.xml .hypothesis/ .pytest_cache/ cover/ +coverage/ +cov.xml # Translations *.mo @@ -86,7 +88,7 @@ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: -# .python-version +.python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. @@ -143,4 +145,7 @@ src/stcal/_version.py # auto-generated API docs docs/source/api -.DS_Store \ No newline at end of file +.DS_Store + +# VSCode stuff +.vscode \ No newline at end of file diff --git a/CHANGES.rst b/CHANGES.rst index 3064e2b5..66ba6c1f 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,6 +1,8 @@ 1.4.5 (unreleased) ================== +- Added ``alignment`` sub-package. [#179] + Changes to API -------------- @@ -88,6 +90,17 @@ Bug Fixes jump ~~~~ +- Added setting of number_extended_events for non-multiprocessing + mode. This is the value that is put into the header keyword EXTNCRS. [#178] + +1.4.1 (2023-06-29) + +Bug Fixes +--------- + +jump +~~~~ + - Added statement to prevent the number of cores used in multiprocessing from being larger than the number of rows. This was causing some CI tests to fail. [#176] diff --git a/docs/Makefile b/docs/Makefile index bcca5213..1235f237 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -6,6 +6,7 @@ SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build BUILDDIR = _build +APIDIR = api # Internal variables ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . @@ -25,6 +26,7 @@ help: clean: -rm -rf $(BUILDDIR)/* + -rm -rf $(APIDIR)/* html: $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html diff --git a/docs/api.rst b/docs/api.rst index 95659c47..a23b1894 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1,4 +1,4 @@ -stcall API +stcal API ========== .. automodapi:: stcal diff --git a/docs/conf.py b/docs/conf.py index fbadfd5e..fe87bb19 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -4,11 +4,12 @@ from pathlib import Path import stsci_rtd_theme + if sys.version_info < (3, 11): import tomli as tomllib else: import tomllib - + def setup(app): try: @@ -27,7 +28,7 @@ def setup(app): # values here: with open(REPO_ROOT / "pyproject.toml", "rb") as configuration_file: conf = tomllib.load(configuration_file) -setup_metadata = conf['project'] +setup_metadata = conf["project"] project = setup_metadata["name"] primary_author = setup_metadata["authors"][0] @@ -38,9 +39,32 @@ def setup(app): version = package.__version__.split("-", 1)[0] release = package.__version__ +# Configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = { + "python": ("https://docs.python.org/3/", None), + "numpy": ("https://numpy.org/devdocs", None), + "scipy": ("http://scipy.github.io/devdocs", None), + "matplotlib": ("https://matplotlib.org/stable", None), + "gwcs": ("https://gwcs.readthedocs.io/en/latest/", None), + "astropy": ("https://docs.astropy.org/en/stable/", None), +} + extensions = [ + "pytest_doctestplus.sphinx.doctestplus", + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.inheritance_diagram", + "sphinx.ext.viewcode", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", "sphinx_automodapi.automodapi", - "numpydoc", + "sphinx_automodapi.automodsumm", + "sphinx_automodapi.autodoc_enhancements", + "sphinx_automodapi.smart_resolver", + "sphinx_asdf", + "sphinx.ext.mathjax", ] autosummary_generate = True @@ -48,9 +72,7 @@ def setup(app): autoclass_content = "both" html_theme = "stsci_rtd_theme" -html_theme_options = { - "collapse_navigation": True -} +html_theme_options = {"collapse_navigation": True} html_theme_path = [stsci_rtd_theme.get_html_theme_path()] html_domain_indices = True html_sidebars = {"**": ["globaltoc.html", "relations.html", "searchbox.html"]} diff --git a/docs/stcal/alignment/description.rst b/docs/stcal/alignment/description.rst new file mode 100644 index 00000000..a537e476 --- /dev/null +++ b/docs/stcal/alignment/description.rst @@ -0,0 +1,4 @@ +Description +============ + +This sub-package contains all the modules common to all missions. \ No newline at end of file diff --git a/docs/stcal/alignment/index.rst b/docs/stcal/alignment/index.rst new file mode 100644 index 00000000..e8d65068 --- /dev/null +++ b/docs/stcal/alignment/index.rst @@ -0,0 +1,12 @@ +.. _alignment: + +=============== +Alignment Utils +=============== + +.. toctree:: + :maxdepth: 2 + + description.rst + +.. automodapi:: stcal.alignment diff --git a/docs/stcal/package_index.rst b/docs/stcal/package_index.rst index 44295cd1..b68f11b5 100644 --- a/docs/stcal/package_index.rst +++ b/docs/stcal/package_index.rst @@ -6,3 +6,4 @@ Package Index jump/index.rst ramp_fitting/index.rst + alignment/index.rst \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index cf6338ff..d07ed759 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,8 @@ dependencies = [ 'scipy >=1.6.0', 'numpy >=1.20', 'opencv-python-headless >=4.6.0.66', + 'asdf >=2.15.0', + 'gwcs >= 0.18.1', ] dynamic = ['version'] @@ -23,11 +25,12 @@ dynamic = ['version'] docs = [ 'numpydoc', 'packaging >=17', - 'sphinx', + 'sphinx<7.0.0', + 'sphinx-asdf', 'sphinx-astropy', 'sphinx-rtd-theme', 'stsci-rtd-theme', - 'tomli; python_version <"3.11"', + 'tomli; python_version <="3.11"', ] test = [ 'psutil', diff --git a/src/stcal/alignment/__init__.py b/src/stcal/alignment/__init__.py new file mode 100644 index 00000000..e870d4bd --- /dev/null +++ b/src/stcal/alignment/__init__.py @@ -0,0 +1 @@ +from .util import * # noqa: F403 diff --git a/src/stcal/alignment/resample_utils.py b/src/stcal/alignment/resample_utils.py new file mode 100644 index 00000000..7a0c04d3 --- /dev/null +++ b/src/stcal/alignment/resample_utils.py @@ -0,0 +1,40 @@ +import logging +import numpy as np +from stcal.alignment import util +from gwcs.wcstools import grid_from_bounding_box + +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + + +def calc_pixmap(in_wcs, out_wcs, shape=None): + """Return a pixel grid map from input frame to output frame + + Parameters + ---------- + in_wcs : `~astropy.wcs.WCS` + Input WCS objects or transforms. + out_wcs : `~astropy.wcs.WCS` or `~gwcs.wcs.WCS` + output WCS objects or transforms. + shape : tuple, optional + Shape of grid in pixels. The default is None. + + Returns + ------- + pixmap : ndarray of shape (xdim, ydim, 2) + Reprojected pixel grid map. `pixmap[xin, yin]` returns `xout, + yout` indices in the output image. + """ + if shape: + bb = util.wcs_bbox_from_shape(shape) + log.debug("Bounding box from data shape: {}".format(bb)) + else: + bb = util.wcs_bbox_from_shape(in_wcs.pixel_shape) + log.debug("Bounding box from WCS: {}".format(bb)) + + # creates 2 grids, one with rows of all x values * len(y) rows, + # and the reverse for all y columns + grid = grid_from_bounding_box(bb) + transform_function = util.reproject(in_wcs, out_wcs) + pixmap = np.dstack(transform_function(grid[0], grid[1])) + return pixmap diff --git a/src/stcal/alignment/util.py b/src/stcal/alignment/util.py new file mode 100644 index 00000000..762be702 --- /dev/null +++ b/src/stcal/alignment/util.py @@ -0,0 +1,882 @@ +""" +Common utility functions for datamodel alignment. + +""" +import logging +import functools +from typing import List, Protocol, Union + +import numpy as np + +from astropy.coordinates import SkyCoord +from astropy.utils.misc import isiterable +from astropy import units as u +from astropy.modeling import models as astmodels +from astropy import wcs as fitswcs + +from asdf import AsdfFile +import gwcs +from gwcs.wcstools import wcs_from_fiducial + + +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + + +__all__ = [ + "compute_scale", + "compute_fiducial", + "calc_rotation_matrix", + "wcs_from_footprints", + "reproject", +] + + +class SupportsDataWithWcs(Protocol): + _asdf: AsdfFile + + def to_flat_dict(): + ... + + +def _calculate_fiducial_from_spatial_footprint( + spatial_footprint: np.ndarray, +) -> np.ndarray: + """ + Calculates the fiducial coordinates from a given spatial footprint. + + Parameters + ---------- + spatial_footprint : numpy.ndarray + A 2xN array containing the world coordinates of the WCS footprint's + bounding box, where N is the number of bounding box positions. + + Returns + ------- + lon_fiducial, lat_fiducial : numpy.ndarray, numpy.ndarray + The world coordinates of the fiducial point in the output coordinate frame. + """ + lon, lat = spatial_footprint + lon, lat = np.deg2rad(lon), np.deg2rad(lat) + x = np.cos(lat) * np.cos(lon) + y = np.cos(lat) * np.sin(lon) + z = np.sin(lat) + + x_mid = (np.max(x) + np.min(x)) / 2.0 + y_mid = (np.max(y) + np.min(y)) / 2.0 + z_mid = (np.max(z) + np.min(z)) / 2.0 + lon_fiducial = np.rad2deg(np.arctan2(y_mid, x_mid)) % 360.0 + lat_fiducial = np.rad2deg( + np.arctan2(z_mid, np.sqrt(x_mid**2 + y_mid**2)) + ) + return lon_fiducial, lat_fiducial + + +def _generate_tranform( + refmodel: SupportsDataWithWcs, + ref_fiducial: np.array, + pscale_ratio: int = None, + pscale: float = None, + rotation: float = None, + transform=None, +): + """ + Creates a transform from pixel to world coordinates based on a + reference datamodel's WCS. + + Parameters + ---------- + refmodel : + The datamodel that should be used as reference for calculating the + transform parameters. + + pscale_ratio : int, None + Ratio of input to output pixel scale. This parameter is only used when + ``pscale=None`` and, in that case, it is passed on to ``compute_scale``. + + pscale : float, None + The plate scale. If `None`, the plate scale is calculated from the reference + datamodel. + + rotation : float, None + Position angle of output image's Y-axis relative to North. + A value of 0.0 would orient the final output image to be North up. + The default of `None` specifies that the images will not be rotated, + but will instead be resampled in the default orientation for the camera + with the x and y axes of the resampled image corresponding + approximately to the detector axes. Ignored when ``transform`` is + provided. If `None`, the rotation angle is extracted from the + reference model's ``meta.wcsinfo.roll_ref``. + + ref_fiducial : numpy.array + A two-elements array containing the world coordinates of the fiducial point. + + transform : ~astropy.modeling.Model + A transform between frames. + + Returns + ------- + transform : ~astropy.modeling.Model + An :py:mod:`~astropy` model containing the transform between frames. + """ + if transform is None: + sky_axes = refmodel.meta.wcs._get_axes_indices().tolist() + v3yangle = np.deg2rad(refmodel.meta.wcsinfo.v3yangle) + vparity = refmodel.meta.wcsinfo.vparity + if rotation is None: + roll_ref = np.deg2rad(refmodel.meta.wcsinfo.roll_ref) + else: + roll_ref = np.deg2rad(rotation) + (vparity * v3yangle) + + # reshape the rotation matrix returned from calc_rotation_matrix + # into the correct shape for constructing the transformation + pc = np.reshape( + calc_rotation_matrix(roll_ref, v3yangle, vparity=vparity), (2, 2) + ) + + rotation = astmodels.AffineTransformation2D( + pc, name="pc_rotation_matrix" + ) + transform = [rotation] + if sky_axes: + if not pscale: + pscale = compute_scale( + refmodel.meta.wcs, ref_fiducial, pscale_ratio=pscale_ratio + ) + transform.append( + astmodels.Scale(pscale, name="cdelt1") + & astmodels.Scale(pscale, name="cdelt2") + ) + + if transform: + transform = functools.reduce(lambda x, y: x | y, transform) + + return transform + + +def _get_axis_min_and_bounding_box(ref_model, wcs_list, ref_wcs): + """ + Calculates axis mininum values and bounding box. + + Parameters + ---------- + ref_model : + The reference datamodel for which to determine the minimum axis values and + bounding box. + + wcs_list : list + The list of WCS objects. + + ref_wcs : ~gwcs.wcs.WCS + The reference WCS object. + + Returns + ------- + tuple + A tuple containing two elements: + 1 - a :py:class:`numpy.ndarray` with the minimum value in each axis; + 2 - a tuple containing the bounding box region in the format + ((x0_lower, x0_upper), (x1_lower, x1_upper)). + """ + footprints = [w.footprint().T for w in wcs_list] + domain_bounds = np.hstack( + [ref_wcs.backward_transform(*f) for f in footprints] + ) + axis_min_values = np.min(domain_bounds, axis=1) + domain_bounds = (domain_bounds.T - axis_min_values).T + + output_bounding_box = [] + for axis in ref_model.meta.wcs.output_frame.axes_order: + axis_min, axis_max = ( + domain_bounds[axis].min(), + domain_bounds[axis].max(), + ) + # populate output_bounding_box + output_bounding_box.append((axis_min, axis_max)) + + output_bounding_box = tuple(output_bounding_box) + return (axis_min_values, output_bounding_box) + + +def _calculate_fiducial(wcs_list, bounding_box, crval=None): + """ + Calculates the coordinates of the fiducial point and, if necessary, updates it with + the values in CRVAL (the update is applied to spatial axes only). + + Parameters + ---------- + wcs_list : list + A list of WCS objects. + + bounding_box : tuple, or list, optional + The bounding box over which the WCS is valid. It can be a either tuple of tuples + or a list of lists of size 2 where each element represents a range of + (low, high) values. The bounding_box is in the order of the axes, axes_order. + For two inputs and axes_order(0, 1) the bounding box can be either + ((xlow, xhigh), (ylow, yhigh)) or [[xlow, xhigh], [ylow, yhigh]]. + + crval : list, optional + A reference world coordinate associated with the reference pixel. If not `None`, + then the fiducial coordinates of the spatial axes will be updated with the + values from ``crval``. + + Returns + ------- + fiducial : numpy.ndarray + A two-elements array containing the world coordinate of the fiducial point. + """ + fiducial = compute_fiducial(wcs_list, bounding_box=bounding_box) + if crval is not None: + i = 0 + for k, axt in enumerate(wcs_list[0].output_frame.axes_type): + if axt == "SPATIAL": + # overwrite only spatial axes with user-provided CRVAL + fiducial[k] = crval[i] + i += 1 + return fiducial + + +def _calculate_offsets(fiducial, wcs, axis_min_values, crpix): + """ + Calculates the offsets to the transform. + + Parameters + ---------- + fiducial : numpy.ndarray + A two-elements containing the world coordinates of the fiducial point. + + wcs : ~gwcs.wcs.WCS + A WCS object. It will be used to determine the + + axis_min_values : numpy.ndarray + A two-elements array containing the minimum pixel value for each axis. + + crpix : list or tuple + Pixel coordinates of the reference pixel. + + Returns + ------- + ~astropy.modeling.Model + A model with the offsets to be added to the WCS's transform. + + Notes + ----- + If ``crpix=None``, then ``fiducial``, ``wcs``, and ``axis_min_values`` must be + provided, in which case, the offsets will be calculated using the WCS object to + find the pixel coordinates of the fiducial point and then correct it by the minimum + pixel value for each axis. + """ + if ( + crpix is None + and fiducial is not None + and wcs is not None + and axis_min_values is not None + ): + offset1, offset2 = wcs.backward_transform(*fiducial) + offset1 -= axis_min_values[0] + offset2 -= axis_min_values[1] + else: + offset1, offset2 = crpix + + return astmodels.Shift(-offset1, name="crpix1") & astmodels.Shift( + -offset2, name="crpix2" + ) + + +def _calculate_new_wcs( + ref_model, shape, wcs_list, fiducial, crpix=None, transform=None +): + """ + Calculates a new WCS object based on the combined WCS objects provided. + + Parameters + ---------- + ref_model : + The reference model to be used when extracting metadata. + + shape : list + The shape of the new WCS's pixel grid. If `None`, then the output bounding box + will be used to determine it. + + wcs_list : list + A list containing WCS objects. + + fiducial : numpy.ndarray + A two-elements array containing the location on the sky in some standard + coordinate system. + + crpix : tuple, optional + The coordinates of the reference pixel. + + transform : ~astropy.modeling.Model + An optional tranform to be prepended to the transform constructed by the + fiducial point. The number of outputs of this transform must equal the number + of axes in the coordinate frame. + + Returns + ------- + wcs_new : ~gwcs.wcs.WCS + The new WCS object that corresponds to the combined WCS objects in `wcs_list`. + """ + wcs_new = wcs_from_fiducial( + fiducial, + coordinate_frame=ref_model.meta.wcs.output_frame, + projection=astmodels.Pix2Sky_TAN(), + transform=transform, + input_frame=ref_model.meta.wcs.input_frame, + ) + axis_min_values, output_bounding_box = _get_axis_min_and_bounding_box( + ref_model, wcs_list, wcs_new + ) + offsets = _calculate_offsets( + fiducial=fiducial, + wcs=wcs_new, + axis_min_values=axis_min_values, + crpix=crpix, + ) + + wcs_new.insert_transform("detector", offsets, after=True) + wcs_new.bounding_box = output_bounding_box + + if shape is None: + shape = [ + int(axs[1] - axs[0] + 0.5) for axs in output_bounding_box[::-1] + ] + + wcs_new.pixel_shape = shape[::-1] + wcs_new.array_shape = shape + return wcs_new + + +def _validate_wcs_list(wcs_list): + """ + Validates wcs_list. + + Parameters + ---------- + wcs_list : list + A list of WCS objects. + + Returns + ------- + bool or Exception + If wcs_list is valid, returns True. Otherwise, it will raise an error. + + Raises + ------ + ValueError + Raised whenever wcs_list is not an iterable. + TypeError + Raised whenever wcs_list is empty or any of its content is not an + instance of WCS. + """ + if not isiterable(wcs_list): + raise ValueError( + "Expected 'wcs_list' to be an iterable of WCS objects." + ) + elif len(wcs_list): + if not all(isinstance(w, gwcs.WCS) for w in wcs_list): + raise TypeError( + "All items in 'wcs_list' are to be instances of gwcs.wcs.WCS." + ) + else: + raise TypeError("'wcs_list' should not be empty.") + + return True + + +def wcsinfo_from_model(input_model: SupportsDataWithWcs): + """ + Creates a dict {wcs_keyword: array_of_values} pairs from a datamodel. + + Parameters + ---------- + input_model : + The input datamodel. + + Returns + ------- + wcsinfo : dict + A dict containing the WCS FITS keywords and corresponding values. + + """ + defaults = { + "CRPIX": 0, + "CRVAL": 0, + "CDELT": 1.0, + "CTYPE": "", + "CUNIT": u.Unit(""), + } + wcsaxes = input_model.meta.wcsinfo.wcsaxes + wcsinfo = {"WCSAXES": wcsaxes} + for key in ["CRPIX", "CRVAL", "CDELT", "CTYPE", "CUNIT"]: + val = [] + for ax in range(1, wcsaxes + 1): + k = (key + "{0}".format(ax)).lower() + v = getattr(input_model.meta.wcsinfo, k, defaults[key]) + val.append(v) + wcsinfo[key] = np.array(val) + + pc = np.zeros((wcsaxes, wcsaxes), dtype=np.float32) + for i in range(1, wcsaxes + 1): + for j in range(1, wcsaxes + 1): + pc[i - 1, j - 1] = getattr( + input_model.meta.wcsinfo, "pc{0}_{1}".format(i, j), 1 + ) + wcsinfo["PC"] = pc + wcsinfo["RADESYS"] = input_model.meta.coordinates.reference_frame + wcsinfo["has_cd"] = False + return wcsinfo + + +def compute_scale( + wcs: gwcs.WCS, + fiducial: Union[tuple, np.ndarray], + disp_axis: int = None, + pscale_ratio: float = None, +) -> float: + """Compute the scale at the fiducial point on the detector.. + + Parameters + ---------- + wcs : ~gwcs.wcs.WCS + Reference WCS object from which to compute a scaling factor. + + fiducial : tuple + Input fiducial of (RA, DEC) or (RA, DEC, Wavelength) used in calculating + reference points. + + disp_axis : int + Dispersion axis integer. Assumes the same convention as + ``wcsinfo.dispersion_direction`` + + pscale_ratio : int + Ratio of input to output pixel scale + + Returns + ------- + scale : float + Scaling factor for x and y or cross-dispersion direction. + + """ + spectral = "SPECTRAL" in wcs.output_frame.axes_type + + if spectral and disp_axis is None: + raise ValueError("If input WCS is spectral, a disp_axis must be given") + + crpix = np.array(wcs.invert(*fiducial)) + + delta = np.zeros_like(crpix) + spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == "SPATIAL")[0] + delta[spatial_idx[0]] = 1 + + crpix_with_offsets = np.vstack( + (crpix, crpix + delta, crpix + np.roll(delta, 1)) + ).T + crval_with_offsets = wcs(*crpix_with_offsets, with_bounding_box=False) + + coords = SkyCoord( + ra=crval_with_offsets[spatial_idx[0]], + dec=crval_with_offsets[spatial_idx[1]], + unit="deg", + ) + xscale = np.abs(coords[0].separation(coords[1]).value) + yscale = np.abs(coords[0].separation(coords[2]).value) + + if pscale_ratio is not None: + xscale *= pscale_ratio + yscale *= pscale_ratio + + if spectral: + # Assuming scale doesn't change with wavelength + # Assuming disp_axis is consistent with DataModel.meta.wcsinfo.dispersion.direction + return yscale if disp_axis == 1 else xscale + + return np.sqrt(xscale * yscale) + + +def compute_fiducial(wcslist: list, bounding_box=None) -> np.ndarray: + """ + Calculates the world coordinates of the fiducial point of a list of WCS objects. + For a celestial footprint this is the center. For a spectral footprint, it is the + beginning of its range. + + Parameters + ---------- + wcslist : list + A list containing all the WCS objects for which the fiducial is to be + calculated. + + bounding_box : tuple, list, None + The bounding box over which the WCS is valid. It can be a either tuple of tuples + or a list of lists of size 2 where each element represents a range of + (low, high) values. The bounding_box is in the order of the axes, axes_order. + For two inputs and axes_order(0, 1) the bounding box can be either + ((xlow, xhigh), (ylow, yhigh)) or [[xlow, xhigh], [ylow, yhigh]]. + + Returns + ------- + fiducial : numpy.ndarray + A two-elements array containing the world coordinates of the fiducial point + in the combined output coordinate frame. + + Notes + ----- + This function assumes all WCSs have the same output coordinate frame. + """ + + axes_types = wcslist[0].output_frame.axes_type + spatial_axes = np.array(axes_types) == "SPATIAL" + spectral_axes = np.array(axes_types) == "SPECTRAL" + footprints = np.hstack( + [w.footprint(bounding_box=bounding_box).T for w in wcslist] + ) + spatial_footprint = footprints[spatial_axes] + spectral_footprint = footprints[spectral_axes] + + fiducial = np.empty(len(axes_types)) + if spatial_footprint.any(): + fiducial[spatial_axes] = _calculate_fiducial_from_spatial_footprint( + spatial_footprint + ) + if spectral_footprint.any(): + fiducial[spectral_axes] = spectral_footprint.min() + return fiducial + + +def calc_rotation_matrix( + roll_ref: float, v3i_yangle: float, vparity: int = 1 +) -> List[float]: + """Calculate the rotation matrix. + + Parameters + ---------- + roll_ref : float + Telescope roll angle of V3 North over East at the ref. point in radians + + v3i_yangle : float + The angle between ideal Y-axis and V3 in radians. + + vparity : int + The x-axis parity, usually taken from the JWST SIAF parameter VIdlParity. + Value should be "1" or "-1". + + Returns + ------- + matrix: list + A list containing the rotation matrix elements in column order. + + Notes + ----- + The rotation matrix is + + .. math:: + PC = \\begin{bmatrix} + pc_{1,1} & pc_{2,1} \\\\ + pc_{1,2} & pc_{2,2} + \\end{bmatrix} + """ + if vparity not in (1, -1): + raise ValueError(f"vparity should be 1 or -1. Input was: {vparity}") + + rel_angle = roll_ref - (vparity * v3i_yangle) + + pc1_1 = vparity * np.cos(rel_angle) + pc1_2 = np.sin(rel_angle) + pc2_1 = vparity * -np.sin(rel_angle) + pc2_2 = np.cos(rel_angle) + + return [pc1_1, pc1_2, pc2_1, pc2_2] + + +def wcs_from_footprints( + dmodels, + refmodel=None, + transform=None, + bounding_box=None, + pscale_ratio=None, + pscale=None, + rotation=None, + shape=None, + crpix=None, + crval=None, +): + """ + Create a WCS from a list of input datamodels. + + A fiducial point in the output coordinate frame is created from the + footprints of all WCS objects. For a spatial frame this is the center + of the union of the footprints. For a spectral frame the fiducial is in + the beginning of the footprint range. + If ``refmodel`` is None, the first WCS object in the list is considered + a reference. The output coordinate frame and projection (for celestial frames) + is taken from ``refmodel``. + If ``transform`` is not supplied, a compound transform is created using + CDELTs and PC. + If ``bounding_box`` is not supplied, the `bounding_box` of the new WCS is computed + from `bounding_box` of all input WCSs. + + Parameters + ---------- + dmodels : list + A list of valid datamodels. + + refmodel : + A valid datamodel whose WCS is used as reference for the creation of the output + coordinate frame, projection, and scaling and rotation transforms. + If not supplied the first model in the list is used as ``refmodel``. + + transform : ~astropy.modeling.Model + A transform, passed to :py:func:`gwcs.wcstools.wcs_from_fiducial` + If not supplied `Scaling | Rotation` is computed from ``refmodel``. + + bounding_box : tuple + Bounding_box of the new WCS. + If not supplied it is computed from the bounding_box of all inputs. + + pscale_ratio : float, None + Ratio of input to output pixel scale. Ignored when either + ``transform`` or ``pscale`` are provided. + + pscale : float, None + Absolute pixel scale in degrees. When provided, overrides + ``pscale_ratio``. Ignored when ``transform`` is provided. + + rotation : float, None + Position angle of output image's Y-axis relative to North. + A value of 0.0 would orient the final output image to be North up. + The default of `None` specifies that the images will not be rotated, + but will instead be resampled in the default orientation for the camera + with the x and y axes of the resampled image corresponding + approximately to the detector axes. Ignored when ``transform`` is + provided. + + shape : tuple of int, None + Shape of the image (data array) using ``numpy.ndarray`` convention + (``ny`` first and ``nx`` second). This value will be assigned to + ``pixel_shape`` and ``array_shape`` properties of the returned + WCS object. + + crpix : tuple of float, None + Position of the reference pixel in the image array. If ``crpix`` is not + specified, it will be set to the center of the bounding box of the + returned WCS object. + + crval : tuple of float, None + Right ascension and declination of the reference pixel. Automatically + computed if not provided. + + Returns + ------- + wcs_new : ~gwcs.wcs.WCS + The WCS object corresponding to the combined input footprints. + + """ + + wcs_list = [im.meta.wcs for im in dmodels] + + _validate_wcs_list(wcs_list) + + fiducial = _calculate_fiducial( + wcs_list=wcs_list, bounding_box=bounding_box, crval=crval + ) + + refmodel = dmodels[0] if refmodel is None else refmodel + + transform = _generate_tranform( + refmodel=refmodel, + pscale_ratio=pscale_ratio, + pscale=pscale, + rotation=rotation, + ref_fiducial=np.array( + [refmodel.meta.wcsinfo.ra_ref, refmodel.meta.wcsinfo.dec_ref] + ), + transform=transform, + ) + + return _calculate_new_wcs( + ref_model=refmodel, + shape=shape, + crpix=crpix, + wcs_list=wcs_list, + fiducial=fiducial, + transform=transform, + ) + + +def update_s_region_imaging(model, center=True): + """ + Update the ``S_REGION`` keyword using ``WCS.footprint``. + + Parameters + ---------- + model : + The input datamodel. + center : bool, optional + Whether or not to use the center of the pixel as reference for the + coordinates, by default True + """ + + bbox = model.meta.wcs.bounding_box + + if bbox is None: + bbox = wcs_bbox_from_shape(model.data.shape) + + # footprint is an array of shape (2, 4) as we + # are interested only in the footprint on the sky + ### TODO: we shouldn't use center=True in the call below because we want to + ### calculate the coordinates of the footprint based on the *bounding box*, + ### which means we are interested in each pixel's vertice, not its center. + ### By using center=True, a difference of 0.5 pixel should be accounted for + ### when comparing the world coordinates of the bounding box and the footprint. + footprint = model.meta.wcs.footprint( + bbox, center=center, axis_type="spatial" + ).T + # take only imaging footprint + footprint = footprint[:2, :] + + # Make sure RA values are all positive + negative_ind = footprint[0] < 0 + if negative_ind.any(): + footprint[0][negative_ind] = 360 + footprint[0][negative_ind] + + footprint = footprint.T + update_s_region_keyword(model, footprint) + + +def wcs_bbox_from_shape(shape): + """Create a bounding box from the shape of the data. + + This is appropriate to attach to a wcs object + Parameters + ---------- + shape : tuple + The shape attribute from a `numpy.ndarray` array + + Returns + ------- + bbox : tuple + Bounding box in x, y order. + """ + return (-0.5, shape[-1] - 0.5), (-0.5, shape[-2] - 0.5) + + +def update_s_region_keyword(model, footprint): + """Update the S_REGION keyword. + + Parameters + ---------- + model : + The input model + footprint : numpy.array + A 4x2 numpy array containing the coordinates of the vertices of the footprint. + + Returns + ------- + s_region : str + String containing the S_REGION object. + """ + s_region = ( + "POLYGON ICRS " + " {0:.9f} {1:.9f}" + " {2:.9f} {3:.9f}" + " {4:.9f} {5:.9f}" + " {6:.9f} {7:.9f}".format(*footprint.flatten()) + ) + if "nan" in s_region: + # do not update s_region if there are NaNs. + log.info("There are NaNs in s_region, S_REGION not updated.") + else: + model.meta.wcsinfo.s_region = s_region + log.info(f"Update S_REGION to {model.meta.wcsinfo.s_region}") + + +def reproject(wcs1, wcs2): + """ + Given two WCSs or transforms return a function which takes pixel + coordinates in the first WCS or transform and computes them in pixel coordinates + in the second one. It performs the forward transformation of ``wcs1`` followed by the + inverse of ``wcs2``. + + Parameters + ---------- + wcs1 : astropy.wcs.WCS or gwcs.wcs.WCS + Input WCS objects or transforms. + wcs2 : astropy.wcs.WCS or gwcs.wcs.WCS + Output WCS objects or transforms. + + Returns + ------- + Function to compute the transformations. It takes x, y + positions in ``wcs1`` and returns x, y positions in ``wcs2``. + """ + + def _get_forward_transform_func(wcs1): + """Get the forward transform function from the input WCS. If the wcs is a + fitswcs.WCS object all_pix2world requres three inputs, the x (str, ndarrray), + y (str, ndarray), and origin (int). The origin should be between 0, and 1 + https://docs.astropy.org/en/latest/wcs/index.html#loading-wcs-information-from-a-fits-file + ) + """ # noqa : E501 + if isinstance(wcs1, fitswcs.WCS): + forward_transform = wcs1.all_pix2world + elif isinstance(wcs1, gwcs.WCS): + forward_transform = wcs1.forward_transform + else: + raise TypeError( + "Expected input to be astropy.wcs.WCS or gwcs.WCS " "object" + ) + return forward_transform + + def _get_backward_transform_func(wcs2): + if isinstance(wcs2, fitswcs.WCS): + backward_transform = wcs2.all_world2pix + elif isinstance(wcs2, gwcs.WCS): + backward_transform = wcs2.backward_transform + else: + raise TypeError( + "Expected input to be astropy.wcs.WCS or gwcs.WCS " "object" + ) + return backward_transform + + def _reproject( + x: Union[float, np.ndarray], y: Union[float, np.ndarray] + ) -> tuple: + """ + Reprojects the input coordinates from one WCS to another. + + Parameters: + ----------- + x : float or np.ndarray + x-coordinate(s) to be reprojected. + y : float or np.ndarray + y-coordinate(s) to be reprojected. + + Returns: + -------- + tuple + Tuple of np.ndarrays including reprojected x and y coordinates. + """ + # example inputs to resulting function (12, 13, 0) # third number is origin + # uses np.arrays for shape functionality + if not isinstance(x, (np.ndarray)): + x = np.array(x) + if not isinstance(y, (np.ndarray)): + y = np.array(y) + if x.shape != y.shape: + raise ValueError("x and y must be the same length") + sky = _get_forward_transform_func(wcs1)(x, y, 0) + + # rearrange into array including flattened x and y vaues + flat_sky = [] + for axis in sky: + flat_sky.append(axis.flatten()) + det = np.array( + _get_backward_transform_func(wcs2)(flat_sky[0], flat_sky[1], 0) + ) + det_reshaped = [] + for axis in det: + det_reshaped.append(axis.reshape(x.shape)) + return tuple(det_reshaped) + + return _reproject diff --git a/tests/test_alignment.py b/tests/test_alignment.py new file mode 100644 index 00000000..80af4d7e --- /dev/null +++ b/tests/test_alignment.py @@ -0,0 +1,381 @@ +import numpy as np + +from astropy.modeling import models +from astropy import coordinates as coord +from astropy import units as u +from astropy.io import fits + +from astropy import wcs as fitswcs +import gwcs +from gwcs import coordinate_frames as cf + +import pytest +from stcal.alignment import resample_utils +from stcal.alignment.util import ( + compute_fiducial, + compute_scale, + wcs_from_footprints, + _validate_wcs_list, + update_s_region_keyword, + wcs_bbox_from_shape, + update_s_region_imaging, + reproject, +) + + +def _create_wcs_object_without_distortion( + fiducial_world, + pscale, + shape, +): + # subtract 1 to account for pixel indexing starting at 0 + shift = models.Shift() & models.Shift() + + scale = models.Scale(pscale[0]) & models.Scale(pscale[1]) + + tan = models.Pix2Sky_TAN() + celestial_rotation = models.RotateNative2Celestial( + fiducial_world[0], + fiducial_world[1], + 180, + ) + + det2sky = shift | scale | tan | celestial_rotation + det2sky.name = "linear_transform" + + detector_frame = cf.Frame2D( + name="detector", axes_names=("x", "y"), unit=(u.pix, u.pix) + ) + sky_frame = cf.CelestialFrame( + reference_frame=coord.FK5(), name="fk5", unit=(u.deg, u.deg) + ) + + pipeline = [(detector_frame, det2sky), (sky_frame, None)] + + wcs_obj = gwcs.WCS(pipeline) + + wcs_obj.bounding_box = ( + (-0.5, shape[-1] - 0.5), + (-0.5, shape[-2] - 0.5), + ) + + return wcs_obj + + +def _create_wcs_and_datamodel(fiducial_world, shape, pscale): + wcs = _create_wcs_object_without_distortion( + fiducial_world=fiducial_world, shape=shape, pscale=pscale + ) + ra_ref, dec_ref = fiducial_world[0], fiducial_world[1] + return DataModel( + ra_ref=ra_ref, + dec_ref=dec_ref, + roll_ref=0, + v2_ref=0, + v3_ref=0, + v3yangle=0, + wcs=wcs, + ) + + +class WcsInfo: + def __init__(self, ra_ref, dec_ref, roll_ref, v2_ref, v3_ref, v3yangle): + self.ra_ref = ra_ref + self.dec_ref = dec_ref + self.ctype1 = "RA---TAN" + self.ctype2 = "DEC--TAN" + self.v2_ref = v2_ref + self.v3_ref = v3_ref + self.v3yangle = v3yangle + self.roll_ref = roll_ref + self.vparity = -1 + self.wcsaxes = 2 + self.s_region = "" + + +class Coordinates: + def __init__(self): + self.reference_frame = "ICRS" + + +class MetaData: + def __init__(self, ra_ref, dec_ref, roll_ref, v2_ref, v3_ref, v3yangle, wcs=None): + self.wcsinfo = WcsInfo(ra_ref, dec_ref, roll_ref, v2_ref, v3_ref, v3yangle) + self.wcs = wcs + self.coordinates = Coordinates() + + +class DataModel: + def __init__(self, ra_ref, dec_ref, roll_ref, v2_ref, v3_ref, v3yangle, wcs=None): + self.meta = MetaData( + ra_ref, dec_ref, roll_ref, v2_ref, v3_ref, v3yangle, wcs=wcs + ) + + +def test_compute_fiducial(): + """Test that util.compute_fiducial can properly determine the center of the + WCS's footprint. + """ + + shape = (3, 3) # in pixels + fiducial_world = (0, 0) # in deg + pscale = (0.000014, 0.000014) # in deg/pixel + + wcs = _create_wcs_object_without_distortion( + fiducial_world=fiducial_world, shape=shape, pscale=pscale + ) + + computed_fiducial = compute_fiducial([wcs]) + + assert all(np.isclose(wcs(1, 1), computed_fiducial)) + + +@pytest.mark.parametrize("pscales", [(0.000014, 0.000014), (0.000028, 0.000014)]) +def test_compute_scale(pscales): + """Test that util.compute_scale can properly determine the pixel scale of a + WCS object. + """ + shape = (3, 3) # in pixels + fiducial_world = (0, 0) # in deg + pscale = (pscales[0], pscales[1]) # in deg/pixel + + wcs = _create_wcs_object_without_distortion( + fiducial_world=fiducial_world, shape=shape, pscale=pscale + ) + expected_scale = np.sqrt(pscale[0] * pscale[1]) + + computed_scale = compute_scale(wcs=wcs, fiducial=fiducial_world) + + assert np.isclose(expected_scale, computed_scale) + + +def test_wcs_from_footprints(): + """ + Test that the WCS created from wcs_from_footprints has correct vertice coordinates. + + N.B.: this test will create two 3x3 arrays shifted by 0.000028 deg in + both directions, which means that the combined WCS generated by wcs_from_footprints + should be a 4x4 array with its fiducial point coordinates equal to the + first element of its footprint. + """ + shape = (3, 3) # in pixels + fiducial_world = (10, 0) # in deg + pscale = (0.000028, 0.000028) # in deg/pixel + dm_1 = _create_wcs_and_datamodel(fiducial_world, shape, pscale) + wcs_1 = dm_1.meta.wcs + + # shift fiducial by the size of a pixel projected onto the sky in both directions + # and create a new WCS + fiducial_world = ( + fiducial_world[0] - 0.000028, + fiducial_world[1] - 0.000028, + ) + dm_2 = _create_wcs_and_datamodel(fiducial_world, shape, pscale) + wcs_2 = dm_2.meta.wcs + + wcs = wcs_from_footprints([dm_1, dm_2]) + + # check that all elements of footprint match the *vertices* of the new combined WCS + assert all(np.isclose(wcs.footprint()[0], wcs(0, 0))) + assert all(np.isclose(wcs.footprint()[1], wcs(0, 4))) + assert all(np.isclose(wcs.footprint()[2], wcs(4, 4))) + assert all(np.isclose(wcs.footprint()[3], wcs(4, 0))) + + # check that fiducials match their expected coords in the new combined WCS + assert all(np.isclose(wcs_1(0, 0), wcs(2.5, 1.5))) + assert all(np.isclose(wcs_2(0, 0), wcs(3.5, 0.5))) + + +def test_validate_wcs_list(): + shape = (3, 3) # in pixels + fiducial_world = (10, 0) # in deg + pscale = (0.000028, 0.000028) # in deg/pixel + + dm_1 = _create_wcs_and_datamodel(fiducial_world, shape, pscale) + wcs_1 = dm_1.meta.wcs + + # shift fiducial by one pixel in both directions and create a new WCS + fiducial_world = ( + fiducial_world[0] - 0.000028, + fiducial_world[1] - 0.000028, + ) + dm_2 = _create_wcs_and_datamodel(fiducial_world, shape, pscale) + wcs_2 = dm_2.meta.wcs + + wcs_list = [wcs_1, wcs_2] + + assert _validate_wcs_list(wcs_list) + + +@pytest.mark.parametrize( + "wcs_list, expected_error", + [ + ([], TypeError), + ([1, 2, 3], TypeError), + (["1", "2", "3"], TypeError), + (["1", None, []], TypeError), + ("1", TypeError), + (1, ValueError), + (None, ValueError), + ], +) +def test_validate_wcs_list_invalid(wcs_list, expected_error): + with pytest.raises(Exception) as exec_info: + _validate_wcs_list(wcs_list) + + assert type(exec_info.value) == expected_error + + +def get_fake_wcs(): + fake_wcs1 = fitswcs.WCS( + fits.Header( + { + "NAXIS": 2, + "NAXIS1": 4, + "NAXIS2": 4, + "CTYPE1": "RA---TAN", + "CTYPE2": "DEC--TAN", + "CRVAL1": 0, + "CRVAL2": 0, + "CRPIX1": 1, + "CRPIX2": 1, + "CDELT1": -0.1, + "CDELT2": 0.1, + } + ) + ) + fake_wcs2 = fitswcs.WCS( + fits.Header( + { + "NAXIS": 2, + "NAXIS1": 5, + "NAXIS2": 5, + "CTYPE1": "RA---TAN", + "CTYPE2": "DEC--TAN", + "CRVAL1": 0, + "CRVAL2": 0, + "CRPIX1": 1, + "CRPIX2": 1, + "CDELT1": -0.05, + "CDELT2": 0.05, + } + ) + ) + return fake_wcs1, fake_wcs2 + + +@pytest.mark.parametrize( + "x_inp, y_inp, x_expected, y_expected", + [ + (1000, 2000, np.array(2000), np.array(4000)), # string input test + ([1000], [2000], np.array(2000), np.array(4000)), # array input test + pytest.param(1, 2, 3, 4, marks=pytest.mark.xfail), # expected failure test + ], +) +def test_reproject(x_inp, y_inp, x_expected, y_expected): + wcs1, wcs2 = get_fake_wcs() + f = reproject(wcs1, wcs2) + x_out, y_out = f(x_inp, y_inp) + assert np.allclose(x_out, x_expected, rtol=1e-05) + assert np.allclose(y_out, y_expected, rtol=1e-05) + + +def test_wcs_bbox_from_shape_2d(): + bb = wcs_bbox_from_shape((512, 2048)) + assert bb == ((-0.5, 2047.5), (-0.5, 511.5)) + + +@pytest.mark.parametrize( + "shape, pixmap_expected_shape", + [ + (None,(4, 4, 2)), + ((100, 200), (100, 200, 2)), + ], +) +def test_calc_pixmap_shape(shape, pixmap_expected_shape): + # TODO: add test for gwcs.WCS + wcs1, wcs2 = get_fake_wcs() + pixmap = resample_utils.calc_pixmap(wcs1, wcs2, shape=shape) + assert pixmap.shape==pixmap_expected_shape + + +@pytest.mark.parametrize( + "model, footprint, expected_s_region, expected_log_info", + [ + ( + _create_wcs_and_datamodel((10, 0), (3, 3), (0.000028, 0.000028)), + np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), + "POLYGON ICRS 1.000000000 2.000000000 3.000000000 4.000000000 5.000000000 6.000000000 7.000000000 8.000000000", # noqa: E501 + "Update S_REGION to POLYGON ICRS 1.000000000 2.000000000 3.000000000 4.000000000 5.000000000 6.000000000 7.000000000 8.000000000", # noqa: E501 + ), + ( + _create_wcs_and_datamodel((10, 0), (3, 3), (0.000028, 0.000028)), + np.array([[1.0, 2.0], [3.0, np.nan], [5.0, 6.0], [7.0, 8.0]]), + "", + "There are NaNs in s_region, S_REGION not updated.", + ), + ], +) +def test_update_s_region_keyword( + model, footprint, expected_s_region, expected_log_info, caplog +): + """ + Test that S_REGION keyword is being properly populated with the coordinate values. + """ + update_s_region_keyword(model, footprint) + assert model.meta.wcsinfo.s_region == expected_s_region + assert expected_log_info in caplog.text + + +@pytest.mark.parametrize( + "shape, expected_bbox", + [ + ((100, 200), ((-0.5, 199.5), (-0.5, 99.5))), + ((1, 1), ((-0.5, 0.5), (-0.5, 0.5))), + ((0, 0), ((-0.5, -0.5), (-0.5, -0.5))), + ], +) +def test_wcs_bbox_from_shape(shape, expected_bbox): + """ + Test that the bounding box generated by wcs_bbox_from_shape is correct. + """ + bbox = wcs_bbox_from_shape(shape) + assert bbox == expected_bbox + + +@pytest.mark.parametrize( + "model, bounding_box, data", + [ + ( + _create_wcs_and_datamodel((10, 0), (3, 3), (0.000028, 0.000028)), + ((-0.5, 2.5), (-0.5, 2.5)), + None, + ), + ( + _create_wcs_and_datamodel((10, 0), (3, 3), (0.000028, 0.000028)), + None, + np.zeros((3, 3)), + ), + ], +) +def test_update_s_region_imaging(model, bounding_box, data): + """ + Test that S_REGION keyword is being properly updated with the coordinates + corresponding to the footprint (same as WCS(bounding box)). + """ + model.data = data + model.meta.wcs.bounding_box = bounding_box + expected_s_region_coords = [ + *model.meta.wcs(-0.5, -0.5), + *model.meta.wcs(-0.5, 2.5), + *model.meta.wcs(2.5, 2.5), + *model.meta.wcs(2.5, -0.5), + ] + update_s_region_imaging(model, center=False) + updated_s_region_coords = [ + float(x) for x in model.meta.wcsinfo.s_region.split(" ")[3:] + ] + assert all( + np.isclose(x, y) + for x, y in zip(updated_s_region_coords, expected_s_region_coords) + )