Skip to content

Commit

Permalink
fix input for background.traces, raise error in FlatTrace for negativ…
Browse files Browse the repository at this point in the history
…e trace
  • Loading branch information
cshanahan1 committed Feb 27, 2024
1 parent 85e1b20 commit ddc7aab
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 18 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ Bug Fixes
were set to the number of rows in the image, biasing the final fit to all bin
peaks. Previously for Gaussian, the entire fit failed. [#205, #206]

- Fixed input of `traces` in `Background`. Added a condition to 'FlatTrace' that
trace position must be a positive number. [#211]

Other changes
^^^^^^^^^^^^^

Expand Down
54 changes: 38 additions & 16 deletions specreduce/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ class Background(_ImageParser):
----------
image : `~astropy.nddata.NDData`-like or array-like
image with 2-D spectral image data
traces : List
list of trace objects (or integers to define FlatTraces) to
extract the background
traces : trace, int, float (single or list)
Individual or list of trace object(s) (or integers/floats to define
FlatTraces) to extract the background. If None, a FlatTrace at the
center of the image (according to `disp_axis`) will be used.
width : float
width of extraction aperture in pixels
statistic: string
Expand Down Expand Up @@ -82,16 +83,6 @@ def __post_init__(self):
crossdisp_axis : int
cross-dispersion axis
"""
def _to_trace(trace):
if not isinstance(trace, Trace):
trace = FlatTrace(self.image, trace)

# TODO: this check can be removed if/when implemented as a check in FlatTrace
if isinstance(trace, FlatTrace):
if trace.trace_pos < 1:
raise ValueError('trace_object.trace_pos must be >= 1')
return trace

self.image = self._parse_image(self.image)

if self.width < 0:
Expand All @@ -100,12 +91,10 @@ def _to_trace(trace):
self._bkg_array = np.zeros(self.image.shape[self.disp_axis])
return

if isinstance(self.traces, Trace):
self.traces = [self.traces]
self._set_traces()

bkg_wimage = np.zeros_like(self.image.data, dtype=np.float64)
for trace in self.traces:
trace = _to_trace(trace)
windows_max = trace.trace.data.max() + self.width/2
windows_min = trace.trace.data.min() - self.width/2
if windows_max >= self.image.shape[self.crossdisp_axis]:
Expand Down Expand Up @@ -150,6 +139,39 @@ def _to_trace(trace):
else:
raise ValueError("statistic must be 'average' or 'median'")

def _set_traces(self):
"""Determine `traces` from input. If an integer/float or list if int/float
is passed in, use these to construct FlatTrace objects. These values
must be positive. If None (which is initialized to an empty list),
construct a FlatTrace using the center of image (according to disp.
axis). Otherwise, any Trace object or list of Trace objects can be
passed in."""

if self.traces == []:
# assume a flat trace at the image center if nothing is passed in.
trace_pos = self.image.shape[self.disp_axis] / 2.
self.traces = [FlatTrace(self.image, trace_pos)]

if isinstance(self.traces, Trace):
# if just one trace, turn it into iterable.
self.traces = [self.traces]
return

# finally, if float/int is passed in convert to FlatTrace(s)
if isinstance(self.traces, (float, int)): # for a single number
self.traces = [self.traces]

if np.all([isinstance(x, (float, int)) for x in self.traces]):
self.traces = [FlatTrace(self.image, trace_pos) for trace_pos in self.traces]
return

else:
if not np.all([isinstance(x, Trace) for x in self.traces]):
raise ValueError('`traces` must be a `Trace` object or list of '
'`Trace` objects, a number or list of numbers to '
'define FlatTraces, or None to use a FlatTrace in '
'the middle of the image.')

@classmethod
def two_sided(cls, image, trace_object, separation, **kwargs):
"""
Expand Down
26 changes: 26 additions & 0 deletions specreduce/tests/test_background.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,29 @@ def test_warnings_errors(mk_test_spec_no_spectral_axis):

with pytest.raises(ValueError, match="width must be positive"):
Background.two_sided(image, 25, 2, width=-1)


def test_trace_inputs(mk_test_img_raw):

image = mk_test_img_raw

# When `Background` object is created with no Trace object passed in it should
# create a FlatTrace in the middle of the image (according to disp. axis)
background = Background(image, width=5)
assert np.all(background.traces[0].trace.data == image.shape[1] / 2.)

# FlatTrace(s) should be created if number or list of numbers is passed in for `traces`
background = Background(image, 10., width=5)
assert isinstance(background.traces[0], FlatTrace)
assert background.traces[0].trace_pos == 10.

traces = [10., 15]
background = Background(image, traces, width=5)
for i, trace_pos in enumerate(traces):
assert background.traces[i].trace_pos == trace_pos

# make sure error is raised if input for `traces` is invalid
match_str = 'objects, a number or list of numbers to define FlatTraces, ' +\
'or None to use a FlatTrace in the middle of the image.'
with pytest.raises(ValueError, match=match_str):
Background(image, 'non_valid_trace_pos')
8 changes: 6 additions & 2 deletions specreduce/tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ def test_flat_trace():
t.set_position(400.)
assert t[0] == 400.

t.set_position(-100)
assert np.ma.is_masked(t[0])

def test_negative_flat_trace_err():
# make sure correct error is raised when trying to create FlatTrace with
# negative trace_pos
with pytest.raises(ValueError, match='must be positive.'):
FlatTrace(IM, trace_pos=-1)


# test array traces
Expand Down
2 changes: 2 additions & 0 deletions specreduce/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def set_position(self, trace_pos):
trace_pos : float
Position of the trace
"""
if trace_pos < 1:
raise ValueError('`trace_pos` must be positive.')
self.trace_pos = trace_pos
self.trace = np.ones_like(self.image.data[0]) * self.trace_pos
self._bound_trace()
Expand Down

0 comments on commit ddc7aab

Please sign in to comment.