Skip to content

Commit

Permalink
fix input for background.traces, raise error in FlatTrace for negativ…
Browse files Browse the repository at this point in the history
…e trace
  • Loading branch information
cshanahan1 committed Feb 20, 2024
1 parent cb5071f commit ca99d3e
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 15 deletions.
47 changes: 34 additions & 13 deletions specreduce/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,6 @@ def __post_init__(self):
crossdisp_axis : int
cross-dispersion axis
"""
def _to_trace(trace):
if not isinstance(trace, Trace):
trace = FlatTrace(self.image, trace)

# TODO: this check can be removed if/when implemented as a check in FlatTrace
if isinstance(trace, FlatTrace):
if trace.trace_pos < 1:
raise ValueError('trace_object.trace_pos must be >= 1')
return trace

self.image = self._parse_image(self.image)

if self.width < 0:
Expand All @@ -100,12 +90,10 @@ def _to_trace(trace):
self._bkg_array = np.zeros(self.image.shape[self.disp_axis])
return

if isinstance(self.traces, Trace):
self.traces = [self.traces]
self._set_traces()

bkg_wimage = np.zeros_like(self.image.data, dtype=np.float64)
for trace in self.traces:
trace = _to_trace(trace)
windows_max = trace.trace.data.max() + self.width/2
windows_min = trace.trace.data.min() - self.width/2
if windows_max >= self.image.shape[self.crossdisp_axis]:
Expand Down Expand Up @@ -150,6 +138,39 @@ def _to_trace(trace):
else:
raise ValueError("statistic must be 'average' or 'median'")

def _set_traces(self):
"""Determine `traces` from input. If an integer/float or list if int/float
is passed in, use these to construct FlatTrace objects. These values
must be positive. If None (which is initialized to an empty list),
construct a FlatTrace using the center of image (according to disp.
axis). Otherwise, any Trace object or list of Trace objects can be
passed in."""

if self.traces == []:
# assume a flat trace at the image center if nothing is passed in.
trace_pos = self.image.shape[self.disp_axis] / 2.
self.traces = [FlatTrace(self.image, trace_pos)]

if isinstance(self.traces, Trace):
# if just one trace, turn it into iterable.
self.traces = [self.traces]
return

# finally, if float/int is passed in convert to FlatTrace(s)
if isinstance(self.traces, (float, int)): # for a single number
self.traces = [self.traces]

if np.all([isinstance(x, (float, int)) for x in self.traces]):
self.traces = [FlatTrace(self.image, trace_pos) for trace_pos in self.traces]
return

else:
if not np.all([isinstance(x, Trace) for x in self.traces]):
raise ValueError('`traces` must be a `Trace` object or list of '

Check warning on line 169 in specreduce/background.py

View check run for this annotation

Codecov / codecov/patch

specreduce/background.py#L169

Added line #L169 was not covered by tests
'`Trace` objects, a number or list of numbers to '
'define FlatTraces, or None to use a FlatTrace in '
'the middle of the image.')

@classmethod
def two_sided(cls, image, trace_object, separation, **kwargs):
"""
Expand Down
20 changes: 20 additions & 0 deletions specreduce/tests/test_background.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,23 @@ def test_warnings_errors(mk_test_spec_no_spectral_axis):

with pytest.raises(ValueError, match="width must be positive"):
Background.two_sided(image, 25, 2, width=-1)


def test_trace_inputs(mk_test_img_raw):

image = mk_test_img_raw

# When `Background` object is created with no Trace object passed in it should
# create a FlatTrace in the middle of the image (according to disp. axis)
background = Background(image, width=5)
assert np.all(background.traces[0].trace.data == image.shape[1] / 2.)

# FlatTrace(s) should be created if number or list of numbers is passed in for `traces`
background = Background(image, 10., width=5)
assert isinstance(background.traces[0], FlatTrace)
assert background.traces[0].trace_pos == 10.

traces = [10., 12]
background = Background(image, traces, width=5)
for i, trace_pos in enumerate(traces):
assert background.traces[i].trace_pos == trace_pos
8 changes: 6 additions & 2 deletions specreduce/tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ def test_flat_trace():
t.set_position(400.)
assert t[0] == 400.

t.set_position(-100)
assert np.ma.is_masked(t[0])

def test_negative_flat_trace_err():
# make sure correct error is raised when trying to create FlatTrace with
# negative trace_pos
with pytest.raises(ValueError, match='must be positive.'):
FlatTrace(IM, trace_pos=-1)


# test array traces
Expand Down
2 changes: 2 additions & 0 deletions specreduce/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def set_position(self, trace_pos):
trace_pos : float
Position of the trace
"""
if trace_pos < 1:
raise ValueError('`trace_pos` must be positive.')
self.trace_pos = trace_pos
self.trace = np.ones_like(self.image.data[0]) * self.trace_pos
self._bound_trace()
Expand Down

0 comments on commit ca99d3e

Please sign in to comment.