diff --git a/CHANGES.rst b/CHANGES.rst index 9c8d4ad..8bc659e 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -16,6 +16,13 @@ Other changes 1.4.2 ------ +New Features +^^^^^^^^^^^^ +- Added a ``specreduce.utils.align_2d_spectrum_along_trace`` utility function that aligns a + rectilinear 2D spectrum image along a spectrum trace. The rectification can be done using + either linear interpolation, giving a sub-pixel shift resolution, or using integer shifts. + The function also updates the image mask and propagates the uncertainties. + Bug Fixes ^^^^^^^^^ - Fixed Astropy v7.0 incompatibility bug [#229] in ``tracing.FitTrace``: changed to use diff --git a/specreduce/tests/test_align_along_trace.py b/specreduce/tests/test_align_along_trace.py new file mode 100644 index 0000000..c96ff97 --- /dev/null +++ b/specreduce/tests/test_align_along_trace.py @@ -0,0 +1,59 @@ +import pytest +import numpy as np +import astropy.units as u + +from specreduce.tracing import ArrayTrace +from specreduce.utils import align_2d_spectrum_along_trace + + +def mk_test_image(): + height, width = 9, 2 + centers = np.array([5.5, 3.0]) + image = np.zeros((height, width)) + image[5, 0] = 1 + image[2:4, 1] = 0.5 + return image, ArrayTrace(image, centers) + + +def test_align_spectrum_along_trace_bad_input(): + image, trace = mk_test_image() + with pytest.raises(ValueError, match='Unre'): + im = align_2d_spectrum_along_trace(image, None) # noqa + + with pytest.raises(ValueError, match='method must be'): + im = align_2d_spectrum_along_trace(image, trace, method='int') # noqa + + with pytest.raises(ValueError, match='Spectral axis length'): + im = align_2d_spectrum_along_trace(image.T, trace, method='interpolate', disp_axis=0) # noqa + + with pytest.raises(ValueError, match='Displacement axis must be'): + im = align_2d_spectrum_along_trace(image, trace, disp_axis=2) # noqa + + with pytest.raises(ValueError, match='The number of image dimensions must be'): + im = align_2d_spectrum_along_trace(np.zeros((3, 6, 9)), trace) # noqa + + +@pytest.mark.parametrize("method, truth_data, truth_mask, truth_ucty", + [('interpolate', + np.array([[0, 0, 0, 0.00, 1.0, 0.00, 0, 0, 0], + [0, 0, 0, 0.25, 0.5, 0.25, 0, 0, 0]]).T, + np.array([[0, 0, 0, 0, 0, 0, 0, 1, 1], + [1, 1, 0, 0, 0, 0, 0, 0, 0]]).astype(bool).T, + np.array([[1., 1., 1., 1., 1., 1., 1., 1., 1.], + [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]]).T), + ('shift', + np.array([[0., 0., 0., 0., 1., 0., 0., 0., 0.], + [0., 0., 0., 0.5, 0.5, 0., 0., 0., 0.]]).T, + np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1], + [1, 0, 0, 0, 0, 0, 0, 0, 0]]).astype(bool).T, + np.ones((9, 2)))], + ids=('method=interpolate', 'method=shift')) +def test_align_spectrum_along_trace(method, truth_data, truth_mask, truth_ucty): + image, trace = mk_test_image() + im = align_2d_spectrum_along_trace(image, trace, method=method) + assert im.shape == image.shape + assert im.unit == u.DN + assert im.uncertainty.uncertainty_type == 'var' + np.testing.assert_allclose(im.data, truth_data) + np.testing.assert_allclose(im.uncertainty.array, truth_ucty) + np.testing.assert_array_equal(im.mask, truth_mask) diff --git a/specreduce/utils/utils.py b/specreduce/utils/utils.py index 8cdee45..d18cff4 100644 --- a/specreduce/utils/utils.py +++ b/specreduce/utils/utils.py @@ -1,10 +1,136 @@ +from numbers import Number +from typing import Literal + import numpy as np +from astropy.nddata import VarianceUncertainty, NDData +from specutils import Spectrum1D from specreduce.core import _ImageParser from specreduce.tracing import Trace, FlatTrace from specreduce.extract import _ap_weight_image, _align_along_trace -__all__ = ['measure_cross_dispersion_profile', '_align_along_trace'] +__all__ = ['measure_cross_dispersion_profile', '_align_along_trace', + 'align_2d_spectrum_along_trace'] + + +def _get_image_ndim(image): + if isinstance(image, np.ndarray): + return image.ndim + elif isinstance(image, NDData): + return image.data.ndim + else: + raise ValueError('Unrecognized image data format.') + + +def align_2d_spectrum_along_trace(image: NDData | np.ndarray, + trace: Trace | np.ndarray | Number, + method: Literal['interpolate', 'shift'] = 'interpolate', + disp_axis: int = 1) -> Spectrum1D: + """ + Align a 2D spectrum image along a trace either with an integer or sub-pixel precision. + + This function rectifies a 2D spectrum by aligning its cross-dispersion profile along a given + trace. The function also updates the mask to reflect alignment operations and propagates + uncertainties when using the 'interpolate' method. The rectification process can use either + sub-pixel precision through interpolation or integer shifts for simplicity. The method assumes + the input spectrum is rectilinear, meaning the dispersion direction and spatial direction are + aligned with the pixel grid. + + Parameters + ---------- + image + The 2D image to align. + trace + Either a ``Trace`` object, a 1D ndarray, or a single value that defines the center + of the cross-dispersion profile. + method + The method used to align the image: ``interpolate`` aligns the image + with a sub-pixel precision using linear interpolation while ``shift`` + aligns the image using integer shifts. + disp_axis + The index of the image's dispersion axis. [default: 1] + + Returns + ------- + Spectrum1D + A rectified version of the image aligned along the specified trace. + + Notes + ----- + - This function is intended only for rectilinear spectra, where the dispersion + and spatial axes are already aligned with the image grid. Non-rectilinear spectra + require additional pre-processing (e.g., geometric rectification) before using + this function. + """ + if _get_image_ndim(image) > 2: + raise ValueError('The number of image dimensions must be 2.') + if not (0 <= disp_axis <= 1): + raise ValueError('Displacement axis must be either 0 or 1.') + + if isinstance(trace, Trace): + trace = trace.trace.data + elif isinstance(trace, (np.ndarray, Number)): + pass + else: + raise ValueError('Unrecognized trace format.') + + image = _ImageParser()._parse_image(image, disp_axis=disp_axis) + data = image.data + mask = image.mask | ~np.isfinite(data) + ucty = image.uncertainty.represent_as(VarianceUncertainty).array + + if disp_axis == 0: + data = data.T + mask = mask.T + ucty = ucty.T + + n_rows = data.shape[0] + n_cols = data.shape[1] + + rows = np.broadcast_to(np.arange(n_rows)[:, None], data.shape) + cols = np.broadcast_to(np.arange(n_cols), data.shape) + + if method == 'interpolate': + # Calculate the real and integer shifts + # and the interpolation weights. + shifts = trace - n_rows / 2.0 + k = np.floor(shifts).astype(int) + a = shifts - k + + # Calculate the shifted indices and mask the + # edge pixels without information. + ix1 = rows + k + ix2 = ix1 + 1 + shift_mask = (ix1 < 0) | (ix1 > n_rows - 2) + ix1 = np.clip(ix1, 0, n_rows - 1) + ix2 = np.clip(ix2, 0, n_rows - 1) + + # Shift the data, uncertainties, and the mask using linear + # interpolation. + data_r = (1.0-a)*data[ix1, cols] + a*data[ix2, cols] + ucty_r = (1.0-a)**2*ucty[ix1, cols] + a**2*ucty[ix2, cols] + mask_r = mask[ix1, cols] | mask[ix2, cols] | shift_mask + + elif method == 'shift': + shifts = trace.astype(int) - n_rows // 2 + ix = rows + shifts + shift_mask = (ix < 0) | (ix > n_rows - 1) + ix = np.clip(ix, 0, n_rows - 1) + + data_r = data[ix, cols] + ucty_r = ucty[ix, cols] + mask_r = mask[ix, cols] | shift_mask + + else: + raise ValueError("method must be either 'interpolate' or 'shift'.") + + if disp_axis == 0: + data_r = data_r.T + mask_r = mask_r.T + ucty_r = ucty_r.T + + return Spectrum1D(data_r * image.unit, mask=mask_r, meta=image.meta, + uncertainty=VarianceUncertainty(ucty_r).represent_as(image.uncertainty)) def measure_cross_dispersion_profile(image, trace=None, crossdisp_axis=0, @@ -157,7 +283,6 @@ def measure_cross_dispersion_profile(image, trace=None, crossdisp_axis=0, else: raise ValueError('`width` must be an integer, or None to use all ' 'cross-dispersion pixels.') - width = int(width) # rectify trace, if _align_along_trace is True and trace is not flat aligned_trace = None @@ -183,7 +308,7 @@ def measure_cross_dispersion_profile(image, trace=None, crossdisp_axis=0, # now that we have figured out the mask for the window in cross-disp. axis, # select only the pixel(s) we want to include in measuring the avg. profile - pixel_mask = np.ones((image.shape)) + pixel_mask = np.ones(image.shape) pixel_mask[:, pixels] = 0 # combine these masks to isolate the rows and cols used to measure profile