Skip to content

Commit

Permalink
Added a function to align a 2D spectrum along a trace to utils and te…
Browse files Browse the repository at this point in the history
…sts for it.
  • Loading branch information
hpparvi committed Dec 10, 2024
1 parent 140e817 commit 984fefe
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 2 deletions.
59 changes: 59 additions & 0 deletions specreduce/tests/test_align_along_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest
import numpy as np
import astropy.units as u

from specreduce.tracing import ArrayTrace
from specreduce.utils import align_spectrum_along_trace


def mk_test_image():
height, width = 9, 2
centers = np.array([5.5, 3.0])
image = np.zeros((height, width))
image[5, 0] = 1
image[2:4, 1] = 0.5
return image, ArrayTrace(image, centers)


def test_align_spectrum_along_trace_bad_input():
image, trace = mk_test_image()
with pytest.raises(ValueError, match='Unre'):
im = align_spectrum_along_trace(image, None)

with pytest.raises(ValueError, match='method must be'):
im = align_spectrum_along_trace(image, trace, method='int')

with pytest.raises(ValueError, match='Spectral axis length'):
im = align_spectrum_along_trace(image.T, trace, method='interpolate', disp_axis=0)

with pytest.raises(ValueError, match='Displacement axis must be'):
im = align_spectrum_along_trace(image, trace, disp_axis=2)

with pytest.raises(ValueError, match='The number of image dimensions must be'):
im = align_spectrum_along_trace(np.zeros((3,6,9)), trace)


@pytest.mark.parametrize("method, truth_data, truth_mask, truth_ucty",
[('interpolate',
np.array([[0, 0, 0, 0.00, 1.0, 0.00, 0, 0, 0],
[0, 0, 0, 0.25, 0.5, 0.25, 0, 0, 0]]).T,
np.array([[0, 0, 0, 0, 0, 0, 0, 1, 1],
[1, 1, 0, 0, 0, 0, 0, 0, 0]]).astype(bool).T,
np.array([[1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ],
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]]).T),
('shift',
np.array([[0. , 0. , 0. , 0. , 1. , 0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0.5, 0.5, 0. , 0. , 0. , 0. ]]).T,
np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1],
[1, 0, 0, 0, 0, 0, 0, 0, 0]]).astype(bool).T,
np.ones((9,2)))],
ids=('method=interpolate', 'method=shift'))
def test_align_spectrum_along_trace(method, truth_data, truth_mask, truth_ucty):
image, trace = mk_test_image()
im = align_spectrum_along_trace(image, trace, method=method)
assert im.shape == image.shape
assert im.unit == u.DN
assert im.uncertainty.uncertainty_type == 'var'
np.testing.assert_allclose(im.data, truth_data)
np.testing.assert_allclose(im.uncertainty.array, truth_ucty)
np.testing.assert_array_equal(im.mask, truth_mask)
121 changes: 119 additions & 2 deletions specreduce/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,128 @@
from numbers import Number
from typing import Literal

import numpy as np
from astropy.nddata import VarianceUncertainty, NDData
from specutils import Spectrum1D

from specreduce.core import _ImageParser
from specreduce.tracing import Trace, FlatTrace
from specreduce.extract import _ap_weight_image, _align_along_trace

__all__ = ['measure_cross_dispersion_profile', '_align_along_trace']
__all__ = ['measure_cross_dispersion_profile', '_align_along_trace', 'align_spectrum_along_trace']


def _get_image_ndim(image):
if isinstance(image, np.ndarray):
return image.ndim
elif isinstance(image, NDData):
return image.data.ndim
else:
raise ValueError('Unrecognized image data format.')


def align_spectrum_along_trace(image: NDData | np.ndarray,
trace: Trace | np.ndarray | Number,
method: Literal['interpolate', 'shift'] = 'interpolate',
disp_axis: int = 1) -> Spectrum1D:
"""
Align a 2D spectrum image along a trace either with an integer or sub-pixel precision.
Parameters
----------
image
The 2D image to align.
trace
A Trace object that defines the center of the cross-dispersion profile.
method
The method used to align the image columns: ``interpolate`` aligns the
image columns with a sub-pixel precision while ``shift`` does this using
integer shifts.
disp_axis
The index of the image's dispersion axis. [default: 1]
Returns
-------
Spectrum1D
A 1D spectral representation of the input image, aligned along the specified
trace and corrected for displacements. The output includes adjusted mask
and uncertainty information.
Raises
------
ValueError
If the number of dimensions of the image is not equal to 2, or
if the displacement axis is not 0 or 1.
"""
if _get_image_ndim(image) > 2:
raise ValueError('The number of image dimensions must be 2.')
if not (0 <= disp_axis <= 1):
raise ValueError('Displacement axis must be either 0 or 1.')

if isinstance(trace, Trace):
trace = trace.trace.data
elif isinstance(trace, (np.ndarray, Number)):
pass
else:
raise ValueError('Unrecognized trace format.')

image = _ImageParser()._parse_image(image, disp_axis=disp_axis)
data = image.data
mask = image.mask | ~np.isfinite(data)
ucty = image.uncertainty.represent_as(VarianceUncertainty).array

if disp_axis == 0:
data = data.T
mask = mask.T
ucty = ucty.T

n_rows = data.shape[0]
n_cols = data.shape[1]

rows = np.broadcast_to(np.arange(n_rows)[:, None], data.shape)
cols = np.broadcast_to(np.arange(n_cols), data.shape)

if method == 'interpolate':
# Calculate the real and integer shifts
# and the interpolation weights.
shifts = trace - n_rows / 2.0
k = np.floor(shifts).astype(int)
a = shifts - k

# Calculate the shifted indices and mask the
# edge pixels without information.
ix1 = rows + k
ix2 = ix1 + 1
shift_mask = (ix1 < 0) | (ix1 > n_rows - 2)
ix1 = np.clip(ix1, 0, n_rows - 1)
ix2 = np.clip(ix2, 0, n_rows - 1)

# Shift the data, uncertainties, and the mask using linear
# interpolation.
data_r = (1.0-a)*data[ix1, cols] + a*data[ix2, cols]
ucty_r = (1.0-a)**2*ucty[ix1, cols] + a**2*ucty[ix2, cols]
mask_r = mask[ix1, cols] | mask[ix2, cols] | shift_mask

elif method == 'shift':
shifts = trace.astype(int) - n_rows // 2
ix = rows + shifts
shift_mask = (ix < 0) | (ix > n_rows - 1)
ix = np.clip(ix, 0, n_rows - 1)

data_r = data[ix, cols]
ucty_r = ucty[ix, cols]
mask_r = mask[ix, cols] | shift_mask

else:
raise ValueError("method must be either 'interpolate' or 'shift'.")

if disp_axis == 0:
data_r = data_r.T
mask_r = mask_r.T
ucty_r = ucty_r.T

return Spectrum1D(data_r * image.unit, mask=mask_r, meta=image.meta,
uncertainty=VarianceUncertainty(ucty_r).represent_as(image.uncertainty))


def measure_cross_dispersion_profile(image, trace=None, crossdisp_axis=0,
Expand Down Expand Up @@ -157,7 +275,6 @@ def measure_cross_dispersion_profile(image, trace=None, crossdisp_axis=0,
else:
raise ValueError('`width` must be an integer, or None to use all '
'cross-dispersion pixels.')
width = int(width)

# rectify trace, if _align_along_trace is True and trace is not flat
aligned_trace = None
Expand Down

0 comments on commit 984fefe

Please sign in to comment.