diff --git a/specreduce/background.py b/specreduce/background.py index 88482695..3380975f 100644 --- a/specreduce/background.py +++ b/specreduce/background.py @@ -8,7 +8,7 @@ from astropy import units as u from specreduce.extract import _ap_weight_image, _to_spectrum1d_pixels -from specreduce.tracing import Trace, FlatTrace +from specreduce.tracing import FlatTrace, BaseTrace __all__ = ['Background'] @@ -77,7 +77,7 @@ def __post_init__(self): cross-dispersion axis """ def _to_trace(trace): - if not isinstance(trace, Trace): + if not isinstance(trace, BaseTrace): trace = FlatTrace(self.image, trace) # TODO: this check can be removed if/when implemented as a check in FlatTrace @@ -93,7 +93,7 @@ def _to_trace(trace): self.bkg_array = np.zeros(self.image.shape[self.disp_axis]) return - if isinstance(self.traces, Trace): + if isinstance(self.traces, BaseTrace): self.traces = [self.traces] bkg_wimage = np.zeros_like(self.image, dtype=np.float64) diff --git a/specreduce/extract.py b/specreduce/extract.py index 9965794f..91dea0d4 100644 --- a/specreduce/extract.py +++ b/specreduce/extract.py @@ -10,7 +10,7 @@ from astropy.nddata import NDData from specreduce.core import SpecreduceOperation -from specreduce.tracing import Trace, FlatTrace +from specreduce.tracing import FlatTrace, BaseTrace from specutils import Spectrum1D __all__ = ['BoxcarExtract', 'HorneExtract', 'OptimalExtract'] @@ -88,7 +88,7 @@ def _ap_weight_image(trace, width, disp_axis, crossdisp_axis, image_shape): Parameters ---------- - trace : `~specreduce.tracing.Trace`, required + trace : `~specreduce.tracing.BaseTrace`, required trace object width : float, required width of extraction aperture in pixels @@ -139,7 +139,7 @@ class BoxcarExtract(SpecreduceOperation): ---------- image : nddata-compatible image image with 2-D spectral image data - trace_object : Trace + trace_object : BaseTrace trace object width : float width of extraction aperture in pixels @@ -154,7 +154,7 @@ class BoxcarExtract(SpecreduceOperation): The extracted 1d spectrum expressed in DN and pixel units """ image: NDData - trace_object: Trace + trace_object: BaseTrace width: float = 5 disp_axis: int = 1 crossdisp_axis: int = 0 @@ -173,7 +173,7 @@ def __call__(self, image=None, trace_object=None, width=None, ---------- image : nddata-compatible image image with 2-D spectral image data - trace_object : Trace + trace_object : BaseTrace trace object width : float width of extraction aperture in pixels [default: 5] @@ -230,7 +230,7 @@ class HorneExtract(SpecreduceOperation): NDData object must specify uncertainty and a mask. An array requires use of the ``variance``, ``mask``, & ``unit`` arguments. - trace_object : `~specreduce.tracing.Trace`, required + trace_object : `~specreduce.tracing.BaseTrace`, required The associated 1D trace object created for the 2D image. disp_axis : int, optional @@ -264,7 +264,7 @@ class HorneExtract(SpecreduceOperation): """ image: NDData - trace_object: Trace + trace_object: BaseTrace bkgrd_prof: Model = field(default=models.Polynomial1D(2)) variance: np.ndarray = field(default=None) mask: np.ndarray = field(default=None) @@ -293,7 +293,7 @@ def __call__(self, image=None, trace_object=None, NDData object must specify uncertainty and a mask. An array requires use of the ``variance``, ``mask``, & ``unit`` arguments. - trace_object : `~specreduce.tracing.Trace`, required + trace_object : `~specreduce.tracing.BaseTrace`, required The associated 1D trace object created for the 2D image. disp_axis : int, optional diff --git a/specreduce/tracing.py b/specreduce/tracing.py index 6b2557ca..dda15dc1 100644 --- a/specreduce/tracing.py +++ b/specreduce/tracing.py @@ -1,7 +1,7 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst from copy import deepcopy -from dataclasses import dataclass +from dataclasses import dataclass, field import warnings from astropy.modeling import fitting, models @@ -10,56 +10,32 @@ from scipy.interpolate import UnivariateSpline import numpy as np -__all__ = ['Trace', 'FlatTrace', 'ArrayTrace', 'KosmosTrace'] +__all__ = ['BaseTrace', 'Trace', 'FlatTrace', 'ArrayTrace', 'KosmosTrace'] -@dataclass -class Trace: +@dataclass(frozen=True) +class BaseTrace: """ - Basic tracing class that by default traces the middle of the image. - - Parameters - ---------- - image : `~astropy.nddata.CCDData` - Image to be traced - - Properties - ---------- - shape : tuple - Shape of the array describing the trace + A dataclass common to all Trace objects. """ image: CCDData + _trace_pos: (float, np.ndarray) = field(repr=False) + _trace: np.ndarray = field(repr=False) def __post_init__(self): - self.trace_pos = self.image.shape[0] / 2 - self.trace = np.ones_like(self.image[0]) * self.trace_pos + # this class only exists to catch __post_init__ calls in its + # subclasses, so that super().__post_init__ calls work correctly. + pass def __getitem__(self, i): return self.trace[i] - @property - def shape(self): - return self.trace.shape - - def shift(self, delta): - """ - Shift the trace by delta pixels perpendicular to the axis being traced - - Parameters - ---------- - delta : float - Shift to be applied to the trace - """ - # act on self.trace.data to ignore the mask and then re-mask when calling _bound_trace - self.trace = np.asarray(self.trace.data) + delta - self._bound_trace() - def _bound_trace(self): """ Mask trace positions that are outside the upper/lower bounds of the image. """ ny = self.image.shape[0] - self.trace = np.ma.masked_outside(self.trace, 0, ny-1) + object.__setattr__(self, '_trace', np.ma.masked_outside(self._trace, 0, ny - 1)) def __add__(self, delta): """ @@ -77,9 +53,60 @@ def __sub__(self, delta): """ return self.__add__(-delta) + def shift(self, delta): + """ + Shift the trace by delta pixels perpendicular to the axis being traced + + Parameters + ---------- + delta : float + Shift to be applied to the trace + """ + # act on self._trace.data to ignore the mask and then re-mask when calling _bound_trace + object.__setattr__(self, '_trace', np.asarray(self._trace.data) + delta) + object.__setattr__(self, '_trace_pos', self._trace_pos + delta) + self._bound_trace() + + @property + def shape(self): + return self._trace.shape + + @property + def trace(self): + return self._trace -@dataclass -class FlatTrace(Trace): + @property + def trace_pos(self): + return self._trace_pos + + @staticmethod + def _default_trace_attrs(image): + """ + Compute a default trace position and trace array using only + the image dimensions. + """ + trace_pos = image.shape[0] / 2 + trace = np.ones_like(image[0]) * trace_pos + return trace_pos, trace + + +@dataclass(init=False, frozen=True) +class Trace(BaseTrace): + """ + Basic tracing class that by default traces the middle of the image. + + Parameters + ---------- + image : `~astropy.nddata.CCDData` + Image to be traced + """ + def __init__(self, image): + trace_pos, trace = self._default_trace_attrs(image) + super().__init__(image, trace_pos, trace) + + +@dataclass(init=False, frozen=True) +class FlatTrace(BaseTrace): """ Trace that is constant along the axis being traced @@ -92,10 +119,11 @@ class FlatTrace(Trace): trace_pos : float Position of the trace """ - trace_pos: float - def __post_init__(self): - self.set_position(self.trace_pos) + def __init__(self, image, trace_pos): + _, trace = self._default_trace_attrs(image) + super().__init__(image, trace_pos, trace) + self.set_position(trace_pos) def set_position(self, trace_pos): """ @@ -106,13 +134,13 @@ def set_position(self, trace_pos): trace_pos : float Position of the trace """ - self.trace_pos = trace_pos - self.trace = np.ones_like(self.image[0]) * self.trace_pos + object.__setattr__(self, '_trace_pos', trace_pos) + object.__setattr__(self, '_trace', np.ones_like(self.image[0]) * trace_pos) self._bound_trace() -@dataclass -class ArrayTrace(Trace): +@dataclass(init=False, frozen=True) +class ArrayTrace(BaseTrace): """ Define a trace given an array of trace positions @@ -121,25 +149,27 @@ class ArrayTrace(Trace): trace : `numpy.ndarray` Array containing trace positions """ - trace: np.ndarray + def __init__(self, image, trace): + trace_pos, _ = self._default_trace_attrs(image) + super().__init__(image, trace_pos, trace) - def __post_init__(self): nx = self.image.shape[1] - nt = len(self.trace) + nt = len(trace) if nt != nx: if nt > nx: # truncate trace to fit image - self.trace = self.trace[0:nx] + trace = trace[0:nx] else: # assume trace starts at beginning of image and pad out trace to fit. # padding will be the last value of the trace, but will be masked out. - padding = np.ma.MaskedArray(np.ones(nx - nt) * self.trace[-1], mask=True) - self.trace = np.ma.hstack([self.trace, padding]) + padding = np.ma.MaskedArray(np.ones(nx - nt) * trace[-1], mask=True) + trace = np.ma.hstack([trace, padding]) + object.__setattr__(self, '_trace', trace) self._bound_trace() -@dataclass -class KosmosTrace(Trace): +@dataclass(init=False, frozen=True) +class KosmosTrace(BaseTrace): """ Trace the spectrum aperture in an image. @@ -192,14 +222,25 @@ class KosmosTrace(Trace): 4) add other interpolation modes besides spline, maybe via specutils.manipulation methods? """ - bins: int = 20 - guess: float = None - window: int = None - peak_method: str = 'gaussian' + bins: int + guess: float + window: int + peak_method: str _crossdisp_axis = 0 _disp_axis = 1 - def __post_init__(self): + def _process_init_kwargs(self, **kwargs): + for attr, value in kwargs.items(): + object.__setattr__(self, attr, value) + + def __init__(self, image, bins=20, guess=None, window=None, peak_method='gaussian'): + # This method will assign the user supplied value (or default) to the attrs: + self._process_init_kwargs( + bins=bins, guess=guess, window=window, peak_method=peak_method + ) + trace_pos, trace = self._default_trace_attrs(image) + super().__init__(image, trace_pos, trace) + # handle multiple image types and mask uncaught invalid values if isinstance(self.image, NDData): img = np.ma.masked_invalid(np.ma.masked_array(self.image.data, @@ -223,7 +264,7 @@ def __post_init__(self): if not isinstance(self.bins, int): warnings.warn('TRACE: Converting bins to int') - self.bins = int(self.bins) + object.__setattr__(self, 'bins', int(self.bins)) if self.bins < 4: raise ValueError('bins must be >= 4') @@ -240,7 +281,7 @@ def __post_init__(self): "length of the image's spatial direction") elif self.window is not None and not isinstance(self.window, int): warnings.warn('TRACE: Converting window to int') - self.window = int(self.window) + object.__setattr__(self, 'window', int(self.window)) # set max peak location by user choice or wavelength with max avg flux ztot = img.sum(axis=self._disp_axis) / img.shape[self._disp_axis] @@ -343,4 +384,4 @@ def __post_init__(self): warnings.warn("TRACE ERROR: No valid points found in trace") trace_y = np.tile(np.nan, len(x_bins)) - self.trace = np.ma.masked_invalid(trace_y) + object.__setattr__(self, '_trace', np.ma.masked_invalid(trace_y))