From 85e1b20f3e8e51d1322916a00adf4caf68a49899 Mon Sep 17 00:00:00 2001 From: Clare Shanahan Date: Fri, 9 Feb 2024 22:32:14 -0500 Subject: [PATCH] clean up image parsing and tests (#210) --- specreduce/background.py | 6 +++--- specreduce/core.py | 28 +++++++++++++------------- specreduce/tests/test_image_parsing.py | 8 +++++--- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/specreduce/background.py b/specreduce/background.py index aa0c037..0797b07 100644 --- a/specreduce/background.py +++ b/specreduce/background.py @@ -9,7 +9,7 @@ from astropy.utils.decorators import deprecated_attribute from specutils import Spectrum1D -from specreduce.core import _ImageParser, _get_data_from_image +from specreduce.core import _ImageParser from specreduce.extract import _ap_weight_image from specreduce.tracing import Trace, FlatTrace @@ -183,7 +183,7 @@ def two_sided(cls, image, trace_object, separation, **kwargs): crossdisp_axis : int cross-dispersion axis """ - image = _get_data_from_image(image) if image is not None else cls.image + image = _ImageParser._get_data_from_image(image) if image is not None else cls.image kwargs['traces'] = [trace_object-separation, trace_object+separation] return cls(image=image, **kwargs) @@ -220,7 +220,7 @@ def one_sided(cls, image, trace_object, separation, **kwargs): crossdisp_axis : int cross-dispersion axis """ - image = _get_data_from_image(image) if image is not None else cls.image + image = _ImageParser._get_data_from_image(image) if image is not None else cls.image kwargs['traces'] = [trace_object+separation] return cls(image=image, **kwargs) diff --git a/specreduce/core.py b/specreduce/core.py index 41d0eb2..a997147 100644 --- a/specreduce/core.py +++ b/specreduce/core.py @@ -11,19 +11,6 @@ __all__ = ['SpecreduceOperation'] -def _get_data_from_image(image): - """Extract data array from various input types for `image`. - Retruns `np.ndarray` of image data.""" - - if isinstance(image, u.quantity.Quantity): - img = image.value - if isinstance(image, np.ndarray): - img = image - else: # NDData, including CCDData and Spectrum1D - img = image.data - return img - - class _ImageParser: """ Coerces images from accepted formats to Spectrum1D objects for @@ -64,7 +51,7 @@ def _parse_image(self, image, disp_axis=1): # useful for Background's instance methods return self.image - img = _get_data_from_image(image) + img = self._get_data_from_image(image) # mask and uncertainty are set as None when they aren't specified upon # creating a Spectrum1D object, so we must check whether these @@ -87,6 +74,19 @@ def _parse_image(self, image, disp_axis=1): return Spectrum1D(img * unit, spectral_axis=spectral_axis, uncertainty=uncertainty, mask=mask) + @staticmethod + def _get_data_from_image(image): + """Extract data array from various input types for `image`. + Retruns `np.ndarray` of image data.""" + + if isinstance(image, u.quantity.Quantity): + img = image.value + if isinstance(image, np.ndarray): + img = image + else: # NDData, including CCDData and Spectrum1D + img = image.data + return img + @dataclass class SpecreduceOperation(_ImageParser): diff --git a/specreduce/tests/test_image_parsing.py b/specreduce/tests/test_image_parsing.py index 39765f3..1e50be2 100644 --- a/specreduce/tests/test_image_parsing.py +++ b/specreduce/tests/test_image_parsing.py @@ -2,6 +2,7 @@ from astropy import units as u from specutils import Spectrum1D +from specreduce.core import _ImageParser from specreduce.extract import HorneExtract from specreduce.tracing import FlatTrace @@ -48,7 +49,8 @@ def compare_images(all_images, key, collection, compare='s1d'): # test consistency of general image parser results def test_parse_general(all_images): - all_images_parsed = {k: FlatTrace._parse_image(object, im) + + all_images_parsed = {k: _ImageParser()._parse_image(im) for k, im in all_images.items()} for key in all_images_parsed.keys(): compare_images(all_images, key, all_images_parsed) @@ -66,7 +68,7 @@ def test_parse_horne(all_images): for key, col in images_collection.items(): img = all_images[key] - col['general'] = FlatTrace._parse_image(object, img) + col['general'] = _ImageParser()._parse_image(img) if hasattr(all_images[key], 'uncertainty'): defaults = {} @@ -79,6 +81,6 @@ def test_parse_horne(all_images): 'mask': ~np.isfinite(img), 'unit': getattr(img, 'unit', u.DN)} - col[key] = HorneExtract._parse_image(object, img, **defaults) + col[key] = HorneExtract(img, FlatTrace(img, 2))._parse_image(img, **defaults) compare_images(all_images, key, col, compare='general')