Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix input for background.traces, raise error in FlatTrace for negative trace #211

Merged
merged 1 commit into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ Bug Fixes
were set to the number of rows in the image, biasing the final fit to all bin
peaks. Previously for Gaussian, the entire fit failed. [#205, #206]

- Fixed input of `traces` in `Background`. Added a condition to 'FlatTrace' that
trace position must be a positive number. [#211]

Other changes
^^^^^^^^^^^^^

Expand Down
54 changes: 38 additions & 16 deletions specreduce/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ class Background(_ImageParser):
----------
image : `~astropy.nddata.NDData`-like or array-like
image with 2-D spectral image data
traces : List
list of trace objects (or integers to define FlatTraces) to
extract the background
traces : trace, int, float (single or list)
Individual or list of trace object(s) (or integers/floats to define
FlatTraces) to extract the background. If None, a FlatTrace at the
center of the image (according to `disp_axis`) will be used.
width : float
width of extraction aperture in pixels
statistic: string
Expand Down Expand Up @@ -82,16 +83,6 @@ def __post_init__(self):
crossdisp_axis : int
cross-dispersion axis
"""
def _to_trace(trace):
if not isinstance(trace, Trace):
trace = FlatTrace(self.image, trace)

# TODO: this check can be removed if/when implemented as a check in FlatTrace
if isinstance(trace, FlatTrace):
if trace.trace_pos < 1:
raise ValueError('trace_object.trace_pos must be >= 1')
return trace

self.image = self._parse_image(self.image)

if self.width < 0:
Expand All @@ -100,12 +91,10 @@ def _to_trace(trace):
self._bkg_array = np.zeros(self.image.shape[self.disp_axis])
return

if isinstance(self.traces, Trace):
self.traces = [self.traces]
self._set_traces()

bkg_wimage = np.zeros_like(self.image.data, dtype=np.float64)
for trace in self.traces:
trace = _to_trace(trace)
windows_max = trace.trace.data.max() + self.width/2
windows_min = trace.trace.data.min() - self.width/2
if windows_max >= self.image.shape[self.crossdisp_axis]:
Expand Down Expand Up @@ -150,6 +139,39 @@ def _to_trace(trace):
else:
raise ValueError("statistic must be 'average' or 'median'")

def _set_traces(self):
"""Determine `traces` from input. If an integer/float or list if int/float
is passed in, use these to construct FlatTrace objects. These values
must be positive. If None (which is initialized to an empty list),
construct a FlatTrace using the center of image (according to disp.
axis). Otherwise, any Trace object or list of Trace objects can be
passed in."""

if self.traces == []:
# assume a flat trace at the image center if nothing is passed in.
trace_pos = self.image.shape[self.disp_axis] / 2.
self.traces = [FlatTrace(self.image, trace_pos)]

if isinstance(self.traces, Trace):
# if just one trace, turn it into iterable.
self.traces = [self.traces]
return

# finally, if float/int is passed in convert to FlatTrace(s)
if isinstance(self.traces, (float, int)): # for a single number
self.traces = [self.traces]

if np.all([isinstance(x, (float, int)) for x in self.traces]):
self.traces = [FlatTrace(self.image, trace_pos) for trace_pos in self.traces]
return

else:
if not np.all([isinstance(x, Trace) for x in self.traces]):
raise ValueError('`traces` must be a `Trace` object or list of '
'`Trace` objects, a number or list of numbers to '
'define FlatTraces, or None to use a FlatTrace in '
'the middle of the image.')

@classmethod
def two_sided(cls, image, trace_object, separation, **kwargs):
"""
Expand Down
26 changes: 26 additions & 0 deletions specreduce/tests/test_background.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,29 @@ def test_warnings_errors(mk_test_spec_no_spectral_axis):

with pytest.raises(ValueError, match="width must be positive"):
Background.two_sided(image, 25, 2, width=-1)


def test_trace_inputs(mk_test_img_raw):

image = mk_test_img_raw

# When `Background` object is created with no Trace object passed in it should
# create a FlatTrace in the middle of the image (according to disp. axis)
background = Background(image, width=5)
assert np.all(background.traces[0].trace.data == image.shape[1] / 2.)

# FlatTrace(s) should be created if number or list of numbers is passed in for `traces`
background = Background(image, 10., width=5)
assert isinstance(background.traces[0], FlatTrace)
assert background.traces[0].trace_pos == 10.

traces = [10., 15]
background = Background(image, traces, width=5)
for i, trace_pos in enumerate(traces):
assert background.traces[i].trace_pos == trace_pos

# make sure error is raised if input for `traces` is invalid
match_str = 'objects, a number or list of numbers to define FlatTraces, ' +\
'or None to use a FlatTrace in the middle of the image.'
with pytest.raises(ValueError, match=match_str):
Background(image, 'non_valid_trace_pos')
8 changes: 6 additions & 2 deletions specreduce/tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ def test_flat_trace():
t.set_position(400.)
assert t[0] == 400.

t.set_position(-100)
assert np.ma.is_masked(t[0])

def test_negative_flat_trace_err():
# make sure correct error is raised when trying to create FlatTrace with
# negative trace_pos
with pytest.raises(ValueError, match='must be positive.'):
FlatTrace(IM, trace_pos=-1)


# test array traces
Expand Down
2 changes: 2 additions & 0 deletions specreduce/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def set_position(self, trace_pos):
trace_pos : float
Position of the trace
"""
if trace_pos < 1:
raise ValueError('`trace_pos` must be positive.')
self.trace_pos = trace_pos
self.trace = np.ones_like(self.image.data[0]) * self.trace_pos
self._bound_trace()
Expand Down