diff --git a/specreduce/tests/test_align_along_trace.py b/specreduce/tests/test_align_along_trace.py new file mode 100644 index 0000000..f2a79a3 --- /dev/null +++ b/specreduce/tests/test_align_along_trace.py @@ -0,0 +1,59 @@ +import pytest +import numpy as np +import astropy.units as u + +from specreduce.tracing import ArrayTrace +from specreduce.utils import align_spectrum_along_trace + + +def mk_test_image(): + height, width = 9, 2 + centers = np.array([5.5, 3.0]) + image = np.zeros((height, width)) + image[5, 0] = 1 + image[2:4, 1] = 0.5 + return image, ArrayTrace(image, centers) + + +def test_align_spectrum_along_trace_bad_input(): + image, trace = mk_test_image() + with pytest.raises(ValueError, match='Unre'): + im = align_spectrum_along_trace(image, None) + + with pytest.raises(ValueError, match='method must be'): + im = align_spectrum_along_trace(image, trace, method='int') + + with pytest.raises(ValueError, match='Spectral axis length'): + im = align_spectrum_along_trace(image.T, trace, method='interpolate', disp_axis=0) + + with pytest.raises(ValueError, match='Displacement axis must be'): + im = align_spectrum_along_trace(image, trace, disp_axis=2) + + with pytest.raises(ValueError, match='The number of image dimensions must be'): + im = align_spectrum_along_trace(np.zeros((3,6,9)), trace) + + +@pytest.mark.parametrize("method, truth_data, truth_mask, truth_ucty", + [('interpolate', + np.array([[0, 0, 0, 0.00, 1.0, 0.00, 0, 0, 0], + [0, 0, 0, 0.25, 0.5, 0.25, 0, 0, 0]]).T, + np.array([[0, 0, 0, 0, 0, 0, 0, 1, 1], + [1, 1, 0, 0, 0, 0, 0, 0, 0]]).astype(bool).T, + np.array([[1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ], + [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]]).T), + ('shift', + np.array([[0. , 0. , 0. , 0. , 1. , 0. , 0. , 0. , 0. ], + [0. , 0. , 0. , 0.5, 0.5, 0. , 0. , 0. , 0. ]]).T, + np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1], + [1, 0, 0, 0, 0, 0, 0, 0, 0]]).astype(bool).T, + np.ones((9,2)))], + ids=('method=interpolate', 'method=shift')) +def test_align_spectrum_along_trace(method, truth_data, truth_mask, truth_ucty): + image, trace = mk_test_image() + im = align_spectrum_along_trace(image, trace, method=method) + assert im.shape == image.shape + assert im.unit == u.DN + assert im.uncertainty.uncertainty_type == 'var' + np.testing.assert_allclose(im.data, truth_data) + np.testing.assert_allclose(im.uncertainty.array, truth_ucty) + np.testing.assert_array_equal(im.mask, truth_mask) diff --git a/specreduce/utils/utils.py b/specreduce/utils/utils.py index 8cdee45..54e2f41 100644 --- a/specreduce/utils/utils.py +++ b/specreduce/utils/utils.py @@ -1,10 +1,128 @@ +from numbers import Number +from typing import Literal + import numpy as np +from astropy.nddata import VarianceUncertainty, NDData +from specutils import Spectrum1D from specreduce.core import _ImageParser from specreduce.tracing import Trace, FlatTrace from specreduce.extract import _ap_weight_image, _align_along_trace -__all__ = ['measure_cross_dispersion_profile', '_align_along_trace'] +__all__ = ['measure_cross_dispersion_profile', '_align_along_trace', 'align_spectrum_along_trace'] + + +def _get_image_ndim(image): + if isinstance(image, np.ndarray): + return image.ndim + elif isinstance(image, NDData): + return image.data.ndim + else: + raise ValueError('Unrecognized image data format.') + + +def align_spectrum_along_trace(image: NDData | np.ndarray, + trace: Trace | np.ndarray | Number, + method: Literal['interpolate', 'shift'] = 'interpolate', + disp_axis: int = 1) -> Spectrum1D: + """ + Align a 2D spectrum image along a trace either with an integer or sub-pixel precision. + + Parameters + ---------- + image + The 2D image to align. + trace + A Trace object that defines the center of the cross-dispersion profile. + method + The method used to align the image columns: ``interpolate`` aligns the + image columns with a sub-pixel precision while ``shift`` does this using + integer shifts. + disp_axis + The index of the image's dispersion axis. [default: 1] + + Returns + ------- + Spectrum1D + A 1D spectral representation of the input image, aligned along the specified + trace and corrected for displacements. The output includes adjusted mask + and uncertainty information. + + Raises + ------ + ValueError + If the number of dimensions of the image is not equal to 2, or + if the displacement axis is not 0 or 1. + """ + if _get_image_ndim(image) > 2: + raise ValueError('The number of image dimensions must be 2.') + if not (0 <= disp_axis <= 1): + raise ValueError('Displacement axis must be either 0 or 1.') + + if isinstance(trace, Trace): + trace = trace.trace.data + elif isinstance(trace, (np.ndarray, Number)): + pass + else: + raise ValueError('Unrecognized trace format.') + + image = _ImageParser()._parse_image(image, disp_axis=disp_axis) + data = image.data + mask = image.mask | ~np.isfinite(data) + ucty = image.uncertainty.represent_as(VarianceUncertainty).array + + if disp_axis == 0: + data = data.T + mask = mask.T + ucty = ucty.T + + n_rows = data.shape[0] + n_cols = data.shape[1] + + rows = np.broadcast_to(np.arange(n_rows)[:, None], data.shape) + cols = np.broadcast_to(np.arange(n_cols), data.shape) + + if method == 'interpolate': + # Calculate the real and integer shifts + # and the interpolation weights. + shifts = trace - n_rows / 2.0 + k = np.floor(shifts).astype(int) + a = shifts - k + + # Calculate the shifted indices and mask the + # edge pixels without information. + ix1 = rows + k + ix2 = ix1 + 1 + shift_mask = (ix1 < 0) | (ix1 > n_rows - 2) + ix1 = np.clip(ix1, 0, n_rows - 1) + ix2 = np.clip(ix2, 0, n_rows - 1) + + # Shift the data, uncertainties, and the mask using linear + # interpolation. + data_r = (1.0-a)*data[ix1, cols] + a*data[ix2, cols] + ucty_r = (1.0-a)**2*ucty[ix1, cols] + a**2*ucty[ix2, cols] + mask_r = mask[ix1, cols] | mask[ix2, cols] | shift_mask + + elif method == 'shift': + shifts = trace.astype(int) - n_rows // 2 + ix = rows + shifts + shift_mask = (ix < 0) | (ix > n_rows - 1) + ix = np.clip(ix, 0, n_rows - 1) + + data_r = data[ix, cols] + ucty_r = ucty[ix, cols] + mask_r = mask[ix, cols] | shift_mask + + else: + raise ValueError("method must be either 'interpolate' or 'shift'.") + + if disp_axis == 0: + data_r = data_r.T + mask_r = mask_r.T + ucty_r = ucty_r.T + + return Spectrum1D(data_r * image.unit, mask=mask_r, meta=image.meta, + uncertainty=VarianceUncertainty(ucty_r).represent_as(image.uncertainty)) def measure_cross_dispersion_profile(image, trace=None, crossdisp_axis=0, @@ -157,7 +275,6 @@ def measure_cross_dispersion_profile(image, trace=None, crossdisp_axis=0, else: raise ValueError('`width` must be an integer, or None to use all ' 'cross-dispersion pixels.') - width = int(width) # rectify trace, if _align_along_trace is True and trace is not flat aligned_trace = None