Skip to content

Commit

Permalink
clean up image parsing and tests (#210)
Browse files Browse the repository at this point in the history
  • Loading branch information
cshanahan1 authored Feb 10, 2024
1 parent cb5071f commit 85e1b20
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 20 deletions.
6 changes: 3 additions & 3 deletions specreduce/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from astropy.utils.decorators import deprecated_attribute
from specutils import Spectrum1D

from specreduce.core import _ImageParser, _get_data_from_image
from specreduce.core import _ImageParser
from specreduce.extract import _ap_weight_image
from specreduce.tracing import Trace, FlatTrace

Expand Down Expand Up @@ -183,7 +183,7 @@ def two_sided(cls, image, trace_object, separation, **kwargs):
crossdisp_axis : int
cross-dispersion axis
"""
image = _get_data_from_image(image) if image is not None else cls.image
image = _ImageParser._get_data_from_image(image) if image is not None else cls.image
kwargs['traces'] = [trace_object-separation, trace_object+separation]
return cls(image=image, **kwargs)

Expand Down Expand Up @@ -220,7 +220,7 @@ def one_sided(cls, image, trace_object, separation, **kwargs):
crossdisp_axis : int
cross-dispersion axis
"""
image = _get_data_from_image(image) if image is not None else cls.image
image = _ImageParser._get_data_from_image(image) if image is not None else cls.image
kwargs['traces'] = [trace_object+separation]
return cls(image=image, **kwargs)

Expand Down
28 changes: 14 additions & 14 deletions specreduce/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,6 @@
__all__ = ['SpecreduceOperation']


def _get_data_from_image(image):
"""Extract data array from various input types for `image`.
Retruns `np.ndarray` of image data."""

if isinstance(image, u.quantity.Quantity):
img = image.value
if isinstance(image, np.ndarray):
img = image
else: # NDData, including CCDData and Spectrum1D
img = image.data
return img


class _ImageParser:
"""
Coerces images from accepted formats to Spectrum1D objects for
Expand Down Expand Up @@ -64,7 +51,7 @@ def _parse_image(self, image, disp_axis=1):
# useful for Background's instance methods
return self.image

img = _get_data_from_image(image)
img = self._get_data_from_image(image)

# mask and uncertainty are set as None when they aren't specified upon
# creating a Spectrum1D object, so we must check whether these
Expand All @@ -87,6 +74,19 @@ def _parse_image(self, image, disp_axis=1):
return Spectrum1D(img * unit, spectral_axis=spectral_axis,
uncertainty=uncertainty, mask=mask)

@staticmethod
def _get_data_from_image(image):
"""Extract data array from various input types for `image`.
Retruns `np.ndarray` of image data."""

if isinstance(image, u.quantity.Quantity):
img = image.value
if isinstance(image, np.ndarray):
img = image
else: # NDData, including CCDData and Spectrum1D
img = image.data
return img


@dataclass
class SpecreduceOperation(_ImageParser):
Expand Down
8 changes: 5 additions & 3 deletions specreduce/tests/test_image_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from astropy import units as u
from specutils import Spectrum1D

from specreduce.core import _ImageParser
from specreduce.extract import HorneExtract
from specreduce.tracing import FlatTrace

Expand Down Expand Up @@ -48,7 +49,8 @@ def compare_images(all_images, key, collection, compare='s1d'):

# test consistency of general image parser results
def test_parse_general(all_images):
all_images_parsed = {k: FlatTrace._parse_image(object, im)

all_images_parsed = {k: _ImageParser()._parse_image(im)
for k, im in all_images.items()}
for key in all_images_parsed.keys():
compare_images(all_images, key, all_images_parsed)
Expand All @@ -66,7 +68,7 @@ def test_parse_horne(all_images):

for key, col in images_collection.items():
img = all_images[key]
col['general'] = FlatTrace._parse_image(object, img)
col['general'] = _ImageParser()._parse_image(img)

if hasattr(all_images[key], 'uncertainty'):
defaults = {}
Expand All @@ -79,6 +81,6 @@ def test_parse_horne(all_images):
'mask': ~np.isfinite(img),
'unit': getattr(img, 'unit', u.DN)}

col[key] = HorneExtract._parse_image(object, img, **defaults)
col[key] = HorneExtract(img, FlatTrace(img, 2))._parse_image(img, **defaults)

compare_images(all_images, key, col, compare='general')

0 comments on commit 85e1b20

Please sign in to comment.