diff --git a/CHANGES.rst b/CHANGES.rst index f9ef142..dd944aa 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -34,6 +34,9 @@ Bug Fixes were set to the number of rows in the image, biasing the final fit to all bin peaks. Previously for Gaussian, the entire fit failed. [#205, #206] +- Fixed input of `traces` in `Background`. Added a condition to 'FlatTrace' that + trace position must be a positive number. [#211] + Other changes ^^^^^^^^^^^^^ diff --git a/specreduce/background.py b/specreduce/background.py index 0797b07..59cf59c 100644 --- a/specreduce/background.py +++ b/specreduce/background.py @@ -82,16 +82,6 @@ def __post_init__(self): crossdisp_axis : int cross-dispersion axis """ - def _to_trace(trace): - if not isinstance(trace, Trace): - trace = FlatTrace(self.image, trace) - - # TODO: this check can be removed if/when implemented as a check in FlatTrace - if isinstance(trace, FlatTrace): - if trace.trace_pos < 1: - raise ValueError('trace_object.trace_pos must be >= 1') - return trace - self.image = self._parse_image(self.image) if self.width < 0: @@ -100,12 +90,10 @@ def _to_trace(trace): self._bkg_array = np.zeros(self.image.shape[self.disp_axis]) return - if isinstance(self.traces, Trace): - self.traces = [self.traces] + self._set_traces() bkg_wimage = np.zeros_like(self.image.data, dtype=np.float64) for trace in self.traces: - trace = _to_trace(trace) windows_max = trace.trace.data.max() + self.width/2 windows_min = trace.trace.data.min() - self.width/2 if windows_max >= self.image.shape[self.crossdisp_axis]: @@ -150,6 +138,39 @@ def _to_trace(trace): else: raise ValueError("statistic must be 'average' or 'median'") + def _set_traces(self): + """Determine `traces` from input. If an integer/float or list if int/float + is passed in, use these to construct FlatTrace objects. These values + must be positive. If None (which is initialized to an empty list), + construct a FlatTrace using the center of image (according to disp. + axis). Otherwise, any Trace object or list of Trace objects can be + passed in.""" + + if self.traces == []: + # assume a flat trace at the image center if nothing is passed in. + trace_pos = self.image.shape[self.disp_axis] / 2. + self.traces = [FlatTrace(self.image, trace_pos)] + + if isinstance(self.traces, Trace): + # if just one trace, turn it into iterable. + self.traces = [self.traces] + return + + # finally, if float/int is passed in convert to FlatTrace(s) + if isinstance(self.traces, (float, int)): # for a single number + self.traces = [self.traces] + + if np.all([isinstance(x, (float, int)) for x in self.traces]): + self.traces = [FlatTrace(self.image, trace_pos) for trace_pos in self.traces] + return + + else: + if not np.all([isinstance(x, Trace) for x in self.traces]): + raise ValueError('`traces` must be a `Trace` object or list of ' + '`Trace` objects, a number or list of numbers to ' + 'define FlatTraces, or None to use a FlatTrace in ' + 'the middle of the image.') + @classmethod def two_sided(cls, image, trace_object, separation, **kwargs): """ diff --git a/specreduce/tests/test_background.py b/specreduce/tests/test_background.py index 11095d5..09364c3 100644 --- a/specreduce/tests/test_background.py +++ b/specreduce/tests/test_background.py @@ -109,3 +109,29 @@ def test_warnings_errors(mk_test_spec_no_spectral_axis): with pytest.raises(ValueError, match="width must be positive"): Background.two_sided(image, 25, 2, width=-1) + + +def test_trace_inputs(mk_test_img_raw): + + image = mk_test_img_raw + + # When `Background` object is created with no Trace object passed in it should + # create a FlatTrace in the middle of the image (according to disp. axis) + background = Background(image, width=5) + assert np.all(background.traces[0].trace.data == image.shape[1] / 2.) + + # FlatTrace(s) should be created if number or list of numbers is passed in for `traces` + background = Background(image, 10., width=5) + assert isinstance(background.traces[0], FlatTrace) + assert background.traces[0].trace_pos == 10. + + traces = [10., 15] + background = Background(image, traces, width=5) + for i, trace_pos in enumerate(traces): + assert background.traces[i].trace_pos == trace_pos + + # make sure error is raised if input for `traces` is invalid + match_str = 'objects, a number or list of numbers to define FlatTraces, ' +\ + 'or None to use a FlatTrace in the middle of the image.' + with pytest.raises(ValueError, match=match_str): + Background(image, 'non_valid_trace_pos') diff --git a/specreduce/tests/test_tracing.py b/specreduce/tests/test_tracing.py index a9aa115..867ef2c 100644 --- a/specreduce/tests/test_tracing.py +++ b/specreduce/tests/test_tracing.py @@ -36,8 +36,12 @@ def test_flat_trace(): t.set_position(400.) assert t[0] == 400. - t.set_position(-100) - assert np.ma.is_masked(t[0]) + +def test_negative_flat_trace_err(): + # make sure correct error is raised when trying to create FlatTrace with + # negative trace_pos + with pytest.raises(ValueError, match='must be positive.'): + FlatTrace(IM, trace_pos=-1) # test array traces diff --git a/specreduce/tracing.py b/specreduce/tracing.py index e57231a..7964480 100644 --- a/specreduce/tracing.py +++ b/specreduce/tracing.py @@ -110,6 +110,8 @@ def set_position(self, trace_pos): trace_pos : float Position of the trace """ + if trace_pos < 1: + raise ValueError('`trace_pos` must be positive.') self.trace_pos = trace_pos self.trace = np.ones_like(self.image.data[0]) * self.trace_pos self._bound_trace()