From 5da615d3dfdce4765d18225600c878b163e27f59 Mon Sep 17 00:00:00 2001 From: Mihai Cara Date: Tue, 1 Oct 2024 03:30:05 -0400 Subject: [PATCH] Refactor previous code to work with arrays only --- src/stcal/resample/__init__.py | 7 +- src/stcal/resample/resample.py | 1731 ++++++++++++++------------------ src/stcal/resample/utils.py | 64 +- 3 files changed, 832 insertions(+), 970 deletions(-) diff --git a/src/stcal/resample/__init__.py b/src/stcal/resample/__init__.py index 1ae898af0..2baa9834a 100644 --- a/src/stcal/resample/__init__.py +++ b/src/stcal/resample/__init__.py @@ -1,9 +1,8 @@ from .resample import * __all__ = [ + "LibModelAccess", "OutputTooLargeError", - "ResampleModelIO", - "ResampleBase", - "ResampleCoAdd", - "ResampleSingle" + "Resample", + "resampled_wcs_from_models", ] diff --git a/src/stcal/resample/resample.py b/src/stcal/resample/resample.py index d5e10eb91..bffb13020 100644 --- a/src/stcal/resample/resample.py +++ b/src/stcal/resample/resample.py @@ -1,10 +1,9 @@ +import abc +from copy import deepcopy import logging import os -import warnings -from copy import deepcopy import sys -import abc -from pathlib import Path, PurePath +import warnings import numpy as np from scipy.ndimage import median_filter @@ -14,24 +13,29 @@ import psutil from spherical_geometry.polygon import SphericalPolygon +from astropy.nddata.bitmask import ( + bitfield_to_boolean_mask, + interpret_bit_flags, +) -from astropy.nddata.bitmask import interpret_bit_flags - -from .utils import get_tmeasure, build_mask +from .utils import bytes2human, get_tmeasure +from ..alignment.util import ( + compute_scale, + wcs_bbox_from_shape, + wcs_from_footprints, +) log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) __all__ = [ + "LibModelAccess", "OutputTooLargeError", - "ResampleModelIO", - "ResampleBase", - "ResampleCoAdd", - "ResampleSingle" + "Resample", + "resampled_wcs_from_models", ] - _SUPPORTED_CUSTOM_WCS_PARS = [ 'pixel_scale_ratio', 'pixel_scale', @@ -42,43 +46,6 @@ ] -# FIXME: temporarily copied here to avoid this import: -# from stdatamodels.jwst.library.basic_utils import bytes2human -def bytes2human(n): - """Convert bytes to human-readable format - - Taken from the `psutil` library which references - http://code.activestate.com/recipes/578019 - - Parameters - ---------- - n : int - Number to convert - - Returns - ------- - readable : str - A string with units attached. - - Examples - -------- - >>> bytes2human(10000) - '9.8K' - - >>> bytes2human(100001221) - '95.4M' - """ - symbols = ('K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y') - prefix = {} - for i, s in enumerate(symbols): - prefix[s] = 1 << (i + 1) * 10 - for s in reversed(symbols): - if n >= prefix[s]: - value = float(n) / prefix[s] - return '%.1f%s' % (value, s) - return "%sB" % n - - def _resample_range(data_shape, bbox=None): # Find range of input pixels to resample: if bbox is None: @@ -95,58 +62,186 @@ def _resample_range(data_shape, bbox=None): return xmin, xmax, ymin, ymax -class ResampleModelIO(abc.ABC): - @abc.abstractmethod - def open_model(self, file_name): - ... +class LibModelAccess(abc.ABC): + # list of model attributes needed by this module. While this is not + # required, it is helpful for subclasses to check they know how to + # access these attributes. + min_supported_attributes = [ + # arrays: + "data", + "dq", + "var_rnoise", + "var_poisson", + "var_flat", + + # meta: + "filename", + "group_id", + "s_region", + "wcsinfo", + "wcs", + + "exposure_time", + "start_time", + "end_time", + "duration", + "measurement_time", + "effective_exposure_time", + "elapsed_exposure_time", + + "pixelarea_steradians", +# "pixelarea_arcsecsq", + + "level", # sky level + "subtracted", + + "weight_type", + "pointings", + "n_coadds", + ] @abc.abstractmethod - def get_model_attr_value(self, model, attribute_name): + def iter_model(self, attributes=None): ... + @property @abc.abstractmethod - def set_model_attr_value(self, model, attribute_name, value): + def n_models(self): ... + @property @abc.abstractmethod - def get_model_meta(self, model, attributes): + def n_groups(self): ... - @abc.abstractmethod - def set_model_meta(self, model, attributes): - ... - @abc.abstractmethod - def get_model_array(self, model, array_name): - ... +def resampled_wcs_from_models( + input_models, + pixel_scale_ratio=1.0, + pixel_scale=None, + output_shape=None, + rotation=None, + crpix=None, + crval=None, +): + """ + Computes the WCS of the resampled image from input models and + specified WCS parameters. - @abc.abstractmethod - def set_model_array(self, model, array_name, data): - ... + Parameters + ---------- - @abc.abstractmethod - def close_model(self, model): - ... + input_models : LibModelAccess + An object of `LibModelAccess`-derived type. + + pixel_scale_ratio : float, optional + Desired pixel scale ratio defined as the ratio of the first model's + pixel scale computed from this model's WCS at the fiducial point + (taken as the ``ref_ra`` and ``ref_dec`` from the ``wcsinfo`` meta + attribute of the first input image) to the desired output pixel + scale. Ignored when ``pixel_scale`` is specified. + + pixel_scale : float, None, optional + Desired pixel scale (in degrees) of the output WCS. When provided, + overrides ``pixel_scale_ratio``. + + output_shape : tuple of two integers (int, int), None, optional + Shape of the image (data array) using ``np.ndarray`` convention + (``ny`` first and ``nx`` second). This value will be assigned to + ``pixel_shape`` and ``array_shape`` properties of the returned + WCS object. + + rotation : float, None, optional + Position angle of output image's Y-axis relative to North. + A value of 0.0 would orient the final output image to be North up. + The default of `None` specifies that the images will not be rotated, + but will instead be resampled in the default orientation for the + camera with the x and y axes of the resampled image corresponding + approximately to the detector axes. Ignored when ``transform`` is + provided. + + crpix : tuple of float, None, optional + Position of the reference pixel in the resampled image array. + If ``crpix`` is not specified, it will be set to the center of the + bounding box of the returned WCS object. + + crval : tuple of float, None, optional + Right ascension and declination of the reference pixel. + Automatically computed if not provided. - @abc.abstractmethod - def save_model(self, model): - ... + Returns + ------- + wcs : ~gwcs.wcs.WCS + The WCS object corresponding to the combined input footprints. - @abc.abstractmethod - def write_model(self, model, file_name, **kwargs): - ... + pscale_in : float + Computed pixel scale (in degrees) of the first input image. - @abc.abstractmethod - def new_model(self, image_shape=None, file_name=None, copy_meta_from=None): - """ Return a new model for the resampled output """ - ... + pscale_out : float + Computed pixel scale (in degrees) of the output image. + + """ + # build a list of WCS of all input models: + wcs_list = [] + ref_wcsinfo = None + for model_info, _ in input_models.iter_model( + attributes=["data", "wcs", "wcsinfo"] + ): + # TODO: is deepcopy necessary? Is ModelLibrary read-only by default? + w = deepcopy(model_info["wcs"]) + if ref_wcsinfo is None: + ref_wcsinfo = model_info["wcsinfo"] + # make sure all WCS objects have the bounding_box defined: + if w.bounding_box is None: + bbox = wcs_bbox_from_shape(model_info["data"].shape) + w.bounding_box = bbox + wcs_list.append(w) + + if output_shape is None: + bounding_box = None + else: + bounding_box = wcs_bbox_from_shape(output_shape) + + pscale_in0 = compute_scale( + wcs_list[0], + fiducial=np.array([ref_wcsinfo["ra_ref"], ref_wcsinfo["dec_ref"]]) + ) + + if pixel_scale is None: + pixel_scale = pscale_in0 / pixel_scale_ratio + log.info( + f"Pixel scale ratio (pscale_in / pscale_out): {pixel_scale_ratio}" + ) + log.info(f"Computed output pixel scale: {3600 * pixel_scale} arcsec.") + else: + pixel_scale_ratio = pscale_in0 / pixel_scale + log.info(f"Output pixel scale: {3600 * pixel_scale} arcsec.") + log.info( + "Computed pixel scale ratio (pscale_in / pscale_out): " + f"{pixel_scale_ratio}." + ) + + wcs = wcs_from_footprints( + wcs_list=wcs_list, + ref_wcs=wcs_list[0], + ref_wcsinfo=ref_wcsinfo, + pscale_ratio=pixel_scale_ratio, + pscale=pixel_scale, + rotation=rotation, + bounding_box=bounding_box, + shape=output_shape, + crpix=crpix, + crval=crval, + ) + + return wcs, pscale_in0, pixel_scale, pixel_scale_ratio class OutputTooLargeError(RuntimeError): """Raised when the output is too large for in-memory instantiation""" -class ResampleBase(abc.ABC): +class Resample: """ This is the controlling routine for the resampling process. @@ -165,21 +260,31 @@ class ResampleBase(abc.ABC): resample_suffix = 'i2d' resample_file_ext = '.fits' n_arrays_per_output = 2 # #flt-point arrays in the output (data, weight, var, err, etc.) + + # supported output arrays (subclasses can add more): + output_array_types = { + "data": np.float32, + "wht": np.float32, + "con": np.int32, + "var_rnoise": np.float32, + "var_flat": np.float32, + "var_poisson": np.float32, + "err": np.float32, + } + dq_flag_name_map = {} - def __init__(self, input_models, - pixfrac=1.0, kernel="square", fillval=0.0, wht_type="ivm", - good_bits=0, output_wcs=None, wcs_pars=None, - enable_ctx=True, - in_memory=True, allowed_memory=None, **kwargs): + def __init__(self, input_models, pixfrac=1.0, kernel="square", + fillval=0.0, wht_type="ivm", good_bits=0, + output_wcs=None, wcs_pars=None, output_model=None, + accumulate=False, enable_ctx=True, enable_var=True, + allowed_memory=None): """ Parameters ---------- - input_models : list of objects - list of data models, one for each input image - - output : str - filename for output + input_models : LibModelAccess + A `LibModelAccess` object allowing iterating over all contained + models of interest. kwargs : dict Other parameters. @@ -187,28 +292,21 @@ def __init__(self, input_models, .. note:: ``output_shape`` is in the ``x, y`` order. - .. note:: - ``in_memory`` controls whether or not the resampled - array from ``resample_many_to_many()`` - should be kept in memory or written out to disk and - deleted from memory. Default value is `True` to keep - all products in memory. """ - self._enable_ctx = enable_ctx + # input models + self._input_models = input_models self._output_model = None - self._output_filename = None self._output_wcs = None - self._output_array_shape = None - self._close_output = False - self._output_pixel_scale = None - self._template_output_model = None + self._enable_ctx = enable_ctx + self._enable_var = enable_var + self._accumulate = accumulate - # input models - self._input_models = input_models - # a lightweight data model with meta from first input model but no data. - # it will be updated by 'prload_input_meta()' below - self._first_model_meta = None + # these are attributes that are used only for information purpose + # and are added to created the output_model only if they are not already + # present there: + self._pixel_scale_ratio = None + self._output_pixel_scale = None # resample parameters self.pixfrac = pixfrac @@ -216,9 +314,10 @@ def __init__(self, input_models, self.fillval = fillval self.weight_type = wht_type self.good_bits = good_bits - self.in_memory = in_memory - self._user_output_wcs = output_wcs + self._output_wcs = output_wcs + + self.input_file_names = [] # check wcs_pars has supported keywords: if wcs_pars is None: @@ -231,161 +330,308 @@ def __init__(self, input_models, "Unsupported custom WCS parameters: " f"{','.join(map(repr, unsup))}." ) - # WCS parameters (should be deleted once not needed; - # once an output WCS was created) - self._wcs_pars = wcs_pars - - # process additional kwags specific to subclasses and store - # unprocessed/unrecognized kwargs in ukwargs and warn about these - # unrecognized kwargs - ukwargs = self.process_kwargs(kwargs) - self._warn_extra_kwargs(ukwargs) - - # load meta necessary for output WCS (and other) computations: - self.preload_input_meta( - wcs1=True, - filename=self._output_model is None, - s_region=output_wcs is None, - ) - # computed average pixel scale of the first input image: - input_pscale0 = np.rad2deg( - np.sqrt(_compute_image_pixel_area(self._input_img1_wcs)) - ) + # determine output WCS and set-up output model if needed: + if output_model is None: + if output_wcs is None: + output_wcs, _, ps, ps_ratio = resampled_wcs_from_models( + input_models, + pixel_scale_ratio=wcs_pars.get("pixel_scale_ratio", 1.0), + pixel_scale=wcs_pars.get("pixel_scale"), + output_shape=wcs_pars.get("output_shape"), + rotation=wcs_pars.get("rotation"), + crpix=wcs_pars.get("crpix"), + crval=wcs_pars.get("crval"), + ) + self._output_pixel_scale = ps # degrees + self._pixel_scale_ratio = ps_ratio + else: + self.check_output_wcs(output_wcs, wcs_pars) + self._output_pixel_scale = np.rad2deg( + np.sqrt(_compute_image_pixel_area(output_wcs)) + ) + log.info( + "Computed output pixel scale: " + f"{3600 * self._output_pixel_scale} arcsec." + ) + + self._output_wcs = output_wcs - # compute output pixel scale, WCS, set-up output model - if self._output_model: - self._output_wcs = deepcopy( - self.get_model_attr_value(self._output_model, "wcs") + else: + self.validate_output_model( + output_model=output_model, + output_wcs=output_wcs, + accumulate=accumulate, + enable_ctx=enable_ctx, + enable_var=enable_var, ) - self._output_array_shape = self.get_model_array( - self._output_model, - "data" - ).shape - # TODO: extract any useful info from the output image before we close it: - # if meta has pixel scale, populate it from there, if not: + self._output_model = output_model + self._output_wcs = output_model["wcs"] + if output_wcs: + log.warning( + "'output_wcs' will be ignored. Using the 'wcs' supplied " + "by the 'output_model' instead." + ) self._output_pixel_scale = np.rad2deg( - np.sqrt(_compute_image_pixel_area(self._output_wcs)) + np.sqrt(_compute_image_pixel_area(output_wcs)) ) - self._pixel_scale_ratio = self._output_pixel_scale / input_pscale0 - log.info(f'Computed output pixel scale: {self._output_pixel_scale} arcsec.') - - self._create_output_template_model() # create template before possibly closing output - if self._close_output and not self.in_memory: - self.close_model(self._output_model) - self._output_model = None - - elif output_wcs: - naxes = output_wcs.output_frame.naxes - if naxes != 2: - raise RuntimeError( - "Output WCS needs 2 spatial axes but the " - f"supplied WCS has {naxes} axes." - ) - self._output_wcs = deepcopy(output_wcs) + self._pixel_scale_ratio = output_model.get("wcs", None) + log.info( + "Computed output pixel scale: " + f"{3600 * self._output_pixel_scale} arcsec." + ) + + self._output_array_shape = self._output_wcs.array_shape + + # Check that the output data shape has no zero length dimensions + npix = np.prod(self._output_array_shape) + if not npix: + raise ValueError( + f"Invalid output frame shape: {tuple(self._output_array_shape)}" + ) + + # set up output model (arrays, etc.) + if self._output_model is None: + self._output_model = self.create_output_model( + allowed_memory=allowed_memory + ) + + self._group_ids = [] + + log.info(f"Driz parameter kernel: {self.kernel}") + log.info(f"Driz parameter pixfrac: {self.pixfrac}") + log.info(f"Driz parameter fillval: {self.fillval}") + log.info(f"Driz parameter weight_type: {self.weight_type}") + + log.debug(f"Output mosaic size: {self._output_wcs.pixel_shape}") + + def check_output_wcs(self, output_wcs, wcs_pars, + estimate_output_shape=True): + """ + Check that provided WCS has expected properties and that its + ``array_shape`` property is defined. + + """ + naxes = output_wcs.output_frame.naxes + if naxes != 2: + raise RuntimeError( + "Output WCS needs 2 spatial axes but the " + f"supplied WCS has {naxes} axes." + ) + + # make sure array_shape and pixel_shape are set: + if output_wcs.array_shape is None and estimate_output_shape: if wcs_pars and "output_shape" in wcs_pars: - self._output_array_shape = wcs_pars["output_shape"] + output_wcs.array_shape = wcs_pars["output_shape"] else: - self._output_array_shape = self._output_wcs.array_shape - if not self._output_array_shape and output_wcs.bounding_box: + if output_wcs.bounding_box: halfpix = 0.5 + sys.float_info.epsilon - self._output_array_shape = ( + output_wcs.array_shape = ( int(output_wcs.bounding_box[1][1] + halfpix), int(output_wcs.bounding_box[0][1] + halfpix), ) else: + # TODO: In principle, we could compute footprints of all + # input models, convert them to image coordinates using + # `output_wcs`, and then take max(x_i), max(y_i) as + # output image size. raise ValueError( - "Unable to infer output image size from provided inputs." + "Unable to infer output image size from provided " + "inputs." ) - self._output_wcs.array_shape = self._output_array_shape - self._output_pixel_scale = np.rad2deg( - np.sqrt(_compute_image_pixel_area(self._output_wcs)) + @classmethod + def output_model_attributes(cls, accumulate, enable_ctx, enable_var): + """ + Returns a set of string keywords that must be present in an + 'output_model' that is provided as input at the class initialization. + + """ + # always required: + attributes = { + "data", + "wcs", + "wht", + } + + if enable_ctx: + attributes.add("con") + if enable_var: + attributes.update( + ["var_rnoise", "var_poisson", "var_flat", "err"] + ) + if accumulate: + if enable_ctx: + attributes.add("n_coadds") + + # additional attributes required for input parameter 'output_model' + # when data and weight arrays are not None: + attributes.update( + { + "pixfrac", + "kernel", + "fillval", + "weight_type", + "pointings", + "exposure_time", + "measurement_time", + "start_time", + "end_time", + "duration", + } ) - self._pixel_scale_ratio = self._output_pixel_scale / input_pscale0 - log.info(f'Computed output pixel scale: {self._output_pixel_scale} arcsec.') - self._create_output_template_model() - else: - # build output WCS and calculate ouput image shape - if "pixel_scale" in wcs_pars and wcs_pars['pixel_scale'] is not None: - self._pixel_scale_ratio = wcs_pars["pixel_scale"] / input_pscale0 - log.info(f'Output pixel scale: {wcs_pars["pixel_scale"]} arcsec.') - log.info(f'Computed output pixel scale ratio: {self._pixel_scale_ratio}.') - else: - self._pixel_scale_ratio = wcs_pars.get("pixel_scale_ratio", 1.0) - log.info(f'Output pixel scale ratio: {self._pixel_scale_ratio}') - self._output_pixel_scale = input_pscale0 * self._pixel_scale_ratio - wcs_pars = wcs_pars.copy() - wcs_pars["pixel_scale"] = self._output_pixel_scale - log.info(f'Computed output pixel scale: {self._output_pixel_scale} arcsec.') - - w, ps = self.compute_output_wcs(**wcs_pars) - self._output_wcs = w - self._output_pixel_scale = ps - self._output_array_shape = self._output_wcs.array_shape - self._create_output_template_model() + return attributes - # Check that the output data shape has no zero length dimensions - npix = np.prod(self._output_array_shape) - if not npix: + def validate_output_model(self, output_model, accumulate, + enable_ctx, enable_var): + if output_model is None: + if accumulate: + raise ValueError( + "'output_model' must be defined when 'accumulate' is True." + ) + return + + required_attributes = self.output_model_attributes( + accumulate=accumulate, + enable_ctx=enable_ctx, + enable_var=enable_var, + ) + + for attr in required_attributes: + if attr not in output_model: + raise ValueError( + f"'output_model' dictionary must have '{attr}' set." + ) + + model_wcs = output_model["wcs"] + self.check_output_wcs(model_wcs, estimate_output_shape=False) + wcs_shape = model_wcs.array_shape + ref_shape = output_model["data"].shape + if accumulate and wcs_shape is None: raise ValueError( - f"Invalid output frame shape: {tuple(self._output_array_shape)}" + "Output model's 'wcs' must have 'array_shape' attribute " + "set when 'accumulate' parameter is True." ) - assert self._pixel_scale_ratio - log.info(f"Driz parameter kernel: {self.kernel}") - log.info(f"Driz parameter pixfrac: {self.pixfrac}") - log.info(f"Driz parameter fillval: {self.fillval}") - log.info(f"Driz parameter weight_type: {self.weight_type}") + if not np.array_equiv(wcs_shape, ref_shape): + raise ValueError( + "Output model's 'wcs.array_shape' value is not consistent " + "with the shape of the data array." + ) + + for attr in required_attributes.difference(["data", "wcs"]): + if (isinstance(output_model[attr], np.ndarray) and + not np.array_equiv(output_model[attr].shape, ref_shape)): + raise ValueError( + "'output_wcs.array_shape' value is not consistent " + f"with the shape of the '{attr}' array." + ) + + # TODO: also check "pixfrac", "kernel", "fillval", "weight_type" + # with initializer parameters. log a warning if different. - self.check_memory_requirements(allowed_memory) + def create_output_model(self, allowed_memory): + """ Create a new "output model": a dictionary of data and meta fields. + Check that there is enough memory to hold all arrays. + """ + assert self._output_wcs is not None + assert np.array_equiv( + self._output_wcs.array_shape, + self._output_array_shape + ) + assert self._output_pixel_scale - log.debug('Output mosaic size: {}'.format(self._output_wcs.pixel_shape)) + pix_area = self._output_pixel_scale**2 + + output_model = { + # WCS: + "wcs": deepcopy(self._output_wcs), + + # main arrays: + "data": None, + "wht": None, + "con": None, + + # resample parameters: + "pixfrac": self.pixfrac, + "kernel": self.kernel, + "fillval": self.fillval, + "weight_type": self.weight_type, + + # accumulate-specific: + "n_coadds": 0, + + # pixel scale: + "pixelarea_steradians": pix_area, + "pixelarea_arcsecsq": pix_area * np.rad2deg(3600)**2, + "pixel_scale_ratio": self._pixel_scale_ratio, + + # drizzle info: + "pointings": 0, + + # exposure time: + "exposure_time": 0.0, + "measurement_time": None, + "start_time": None, + "end_time": None, + "duration": 0.0, + } + + if self._enable_var: + output_model.update( + { + "var_rnoise": None, + "var_flat": None, + "var_poisson": None, + "err": None, + } + ) + + if allowed_memory: + self.check_memory_requirements(list(output_model), allowed_memory) + + return output_model @property def output_model(self): return self._output_model - def process_kwargs(self, kwargs): - """ A method called by ``__init__`` to process input kwargs before - output WCS is created and before output model template is created. - - Returns - ------- - kwargs : dict - Unrecognized/not processed ``kwargs``. + @property + def output_array_shape(self): + return self._output_array_shape - """ - return {k : v for k, v in kwargs.items()} + @property + def group_ids(self): + return self._group_ids - def _warn_extra_kwargs(self, kwargs): - for k in kwargs: - log.warning(f"Unrecognized argument '{k}' will be ignored.") + def check_memory_requirements(self, arrays, allowed_memory): + """ Called just before `create_output_model` returns to verify + that there is enough memory to hold the output. - def check_memory_requirements(self, allowed_memory): - """ Called just before '_pre_run_callback()' is called to verify - that there is enough memory to hold the output. """ + """ if allowed_memory is None and "DMODEL_ALLOWED_MEMORY" not in os.environ: return allowed_memory = float(allowed_memory) # get the available memory - available_memory = psutil.virtual_memory().available + psutil.swap_memory().total - - # determine data type of the output model: - out_model = self.new_model((2, 2)) - data = self.get_model_array(out_model) - data_type = data.dtype - del data, out_model + available_memory = ( + psutil.virtual_memory().available + psutil.swap_memory().total + ) # compute the output array size npix = npix = np.prod(self._output_array_shape) nmodels = len(self._input_models) - nconpl = nmodels // 32 + (1 if nmodels % 32 else 0) - n_arr = self.n_arrays_per_output + 2 # 2 comes from pixmap - required_memory = npix * (n_arr * data_type.itemsize + nconpl * 4) + nconpl = nmodels // 32 + (1 if nmodels % 32 else 0) # #context planes + required_memory = 0 + for arr in arrays: + if arr in self.output_array_types: + f = nconpl if arr == "con" else 1 + required_memory += f * self.output_array_types[arr].itemsize + # add pixmap itemsize: + required_memory += 2 * np.dtype(float).itemsize + required_memory *= npix # compare used to available used_fraction = required_memory / available_memory @@ -396,62 +642,18 @@ def check_memory_requirements(self, allowed_memory): f'Model cannot be instantiated.' ) - def compute_output_wcs(self, **wcs_pars): - """ returns a tuple of distortion-free WCS object and its pixel scale """ - ... - - def preload_input_meta(self, wcs1, filename, s_region): - # set-up lists for WCS and file names - self._input_img1_wcs = None - self._input_s_region_list = [] - self._input_filename_list = [] - - # loop over only science exposures in the ModelLibrary - # sci_indices = self._input_models.ind_asn_type("science") - with self._input_models: - for model in self._input_models: - # model = self._input_models.borrow(idx) - - try: - if self.get_model_attr_value(model, "exptype").upper() != "SCIENCE": - self._input_models.shelve(model, modify=False) - continue - except AttributeError: - pass - - if self._input_img1_wcs is None and wcs1: - # extract all info needed from *this* model: - self._input_img1_wcs = deepcopy( - self.get_model_attr_value(model, "wcs") - ) - self._input_img1_wcs.array_shape = self.get_model_array( - model, - "data" - ).shape - - if filename: - self._input_filename_list.append( - self.get_model_attr_value(model, "filename") - ) - - if s_region: - self._input_s_region_list.append( - self.get_model_attr_value(model, "s_region") - ) - - self._input_models.shelve(model, modify=False) - - # store first model's entire meta (except for WCS and data): - if self._first_model_meta is None: - self._first_model_meta = self.new_model(copy_meta_from=model) - - def build_driz_weight(self, model, weight_type=None, good_bits=None): - """Create a weight map for use by drizzle - """ - data = self.get_model_array(model, "data") - dq = self.get_model_array(model, "dq") - - dqmask = build_mask(dq, good_bits, flag_name_map=self.dq_flag_name_map) + def build_driz_weight(self, model_info, weight_type=None, good_bits=None): + """Create a weight map for use by drizzle. """ + data = model_info["data"] + dq = model_info["dq"] + + dqmask = bitfield_to_boolean_mask( + dq, + good_bits, + good_mask_value=1, + dtype=np.uint8, + flag_name_map=self.dq_flag_name_map, + ) if weight_type and weight_type.startswith('ivm'): weight_type = weight_type.strip() @@ -466,14 +668,14 @@ def build_driz_weight(self, model, weight_type=None, good_bits=None): # disable selective median if SATURATED flag is included # in good_bits: try: - saturation = self.dq_flag_name_map['SATURATED'] + saturation = self.dq_flag_name_map["SATURATED"] if selective_median and not (bitvalue & saturation): selective_median = False weight_type = 'ivm' except AttributeError: pass - var_rnoise = self.get_model_array(model, "var_rnoise", default=None) + var_rnoise = model_info["var_rnoise"] if (var_rnoise is not None and var_rnoise.shape == data.shape): with np.errstate(divide="ignore", invalid="ignore"): inv_variance = var_rnoise**-1 @@ -535,8 +737,8 @@ def build_driz_weight(self, model, weight_type=None, good_bits=None): result = inv_variance * dqmask - elif weight_type == 'exptime': - exptime = self.get_model_attr_value(model, "exposure_time") + elif weight_type == "exptime": + exptime = model_info["exposure_time"] result = exptime * dqmask else: @@ -544,291 +746,276 @@ def build_driz_weight(self, model, weight_type=None, good_bits=None): return result.astype(np.float32) - @abc.abstractmethod - def run(self): - ... + def init_time_info(self): + """ Initialize variables/arrays needed to process exposure time. """ + self._t_used_group_id = [] - def _create_output_template_model(self): - pass - - def update_exposure_times(self): - """Modify exposure time metadata in-place""" - total_exposure_time = 0. - exptime_start = [] - exptime_end = [] - duration = 0.0 - total_exptime = 0.0 - measurement_time_success = [] - - for exposure in self._input_models.group_indices.values(): - with self._input_models: - model0 = self._input_models.borrow(exposure[0]) - attrs = self.get_model_meta( - model0, - ["exposure_time", "start_time", "end_time", "duration"] - ) + self._total_exposure_time = self.output_model["exposure_time"] + self._duration = self.output_model["duration"] + self._total_measurement_time = self.output_model["measurement_time"] + if self._total_measurement_time is None: + self._total_measurement_time = 0.0 - t, success = get_tmeasure(model0) - self._input_models.shelve(model0, exposure[0]) + if (start_time := self.output_model.get("start_time", None)) is None: + self._exptime_start = [] + else: + self._exptime_start[start_time] - total_exposure_time += attrs["exposure_time"] - measurement_time_success.append(success) - total_exptime += t - exptime_start.append(attrs["start_time"]) - exptime_end.append(attrs["end_time"]) - duration += attrs["duration"] + if (end_time := self.output_model.get("end_time", None)) is None: + self._exptime_end = [] + else: + self._exptime_end[end_time] - attrs = { - # basic exposure time attributes: - "exposure_time": total_exposure_time, - "start_time": min(exptime_start), - "end_time": max(exptime_end), - # Update other exposure time keywords: - # XPOSURE (identical to the total effective exposure time, EFFEXPTM) - "effective_exposure_time": total_exptime, - # DURATION (identical to TELAPSE, elapsed time) - "duration": duration, - "elapsed_exposure_time": duration, - } + self._measurement_time_success = [] - if not all(measurement_time_success): - attrs["measurement_time"] = total_exptime + def update_total_time(self, model_info): + """ A method called by the `~ResampleBase.run` method to process each + image's time attributes. - self.set_model_meta(self._output_model, attrs) + """ + if (group_id := model_info["group_id"]) in self._t_used_group_id: + return + self._t_used_group_id.append(group_id) + self._exptime_start.append(model_info["start_time"]) + self._exptime_end.append(model_info["end_time"]) -class ResampleCoAdd(ResampleBase): - """ - This is the controlling routine for the resampling process. + t, success = get_tmeasure(model_info) + self._total_exposure_time += model_info["exposure_time"] + self._measurement_time_success.append(success) + self._total_measurement_time += t - Notes - ----- - This routine performs the following operations:: + self._duration += model_info["duration"] - 1. Extracts parameter settings from input model, such as pixfrac, - weight type, exposure time (if relevant), and kernel, and merges - them with any user-provided values. - 2. Creates output WCS based on input images and define mapping function - between all input arrays and the output array. - 3. Updates output data model with output arrays from drizzle, including - a record of metadata from all input models. - """ + def finalize_time_info(self): + """ Perform final computations for the total time and update relevant + fileds of the output model. - def __init__(self, input_models, output, accum=False, - pixfrac=1.0, kernel="square", fillval=0.0, wht_type="ivm", - good_bits=0, output_wcs=None, wcs_pars=None, - enable_ctx=True, enable_err=True, - in_memory=True, allowed_memory=None): """ - Parameters - ---------- - input_models : list of objects - list of data models, one for each input image + attrs = { + # basic exposure time attributes: + "exposure_time": self._total_exposure_time, + "start_time": min(self._exptime_start), + "end_time": max(self._exptime_end), + # Update other exposure time keywords: + # XPOSURE (identical to the total effective exposure time, EFFEXPTM) + "effective_exposure_time": self._total_exposure_time, + # DURATION (identical to TELAPSE, elapsed time) + "duration": self._duration, + "elapsed_exposure_time": self._duration, + } - output : DataModel, str - filename for output + if all(self._measurement_time_success): + attrs["measurement_time"] = self._total_measurement_time - kwargs : dict - Other parameters. + self._output_model.update(attrs) - .. note:: - ``output_shape`` is in the ``x, y`` order. + def init_resample_data(self): + """ Create a `Drizzle` object to process image data. """ + om = self._output_model - .. note:: - ``in_memory`` controls whether or not the resampled - array from ``resample_many_to_many()`` - should be kept in memory or written out to disk and - deleted from memory. Default value is `True` to keep - all products in memory. - """ - self._accum = accum - - super().__init__( - input_models=input_models, - pixfrac=pixfrac, - kernel=kernel, - fillval=fillval, - wht_type=wht_type, - good_bits=good_bits, - output_wcs=output_wcs, - wcs_pars=wcs_pars, - enable_ctx=enable_ctx, - in_memory=in_memory, - allowed_memory=allowed_memory, - output=output, + self.driz_data = Drizzle( + kernel=self.kernel, + fillval=self.fillval, + out_shape=self._output_array_shape, + out_img=om["data"], + out_wht=om["wht"], + out_ctx=om["con"], + exptime=om["exposure_time"], + begin_ctx_id=om["n_coadds"], + max_ctx_id=om["n_coadds"] + self._input_models.n_models, ) - self._enable_err = enable_err - def process_kwargs(self, kwargs): - """ A method called by ``__init__`` to process input kwargs before - output WCS is created and before output model template is created. + def init_resample_variance(self): + """ Create a `Drizzle` objects to process image variance arrays. """ + self._var_rnoise_sum = np.full(self._output_array_shape, np.nan) + self._var_poisson_sum = np.full(self._output_array_shape, np.nan) + self._var_flat_sum = np.full(self._output_array_shape, np.nan) + # self._total_weight_var_rnoise = np.zeros(self._output_array_shape) + self._total_weight_var_poisson = np.zeros(self._output_array_shape) + self._total_weight_var_flat = np.zeros(self._output_array_shape) + + # create resample objects for the three variance arrays: + driz_init_kwargs = { + 'kernel': self.kernel, + 'fillval': np.nan, + 'out_shape': self._output_array_shape, + # 'exptime': 1.0, + 'no_ctx': True, + } + self.driz_rnoise = Drizzle(**driz_init_kwargs) + self.driz_poisson = Drizzle(**driz_init_kwargs) + self.driz_flat = Drizzle(**driz_init_kwargs) + + def _check_var_array(self, model_info, array_name): + """ Check that a variance array has the same shape as the model's + data array. + """ - kwargs = super().process_kwargs(kwargs) - output = kwargs.pop("output", None) - accum = kwargs.pop("accum", False) - - # Load the model if accum is True - if isinstance(output, str): - self._output_filename = output - if accum: - try: - self._output_model = self.open_model(output) - self._close_output = True - log.info( - "Output model has been loaded and it will be used to " - "accumulate new data." - ) - if self._user_output_wcs: - log.info( - "'output_wcs' will be ignored when 'output' is " - "provided and accum=True" - ) - if self._wcs_pars: - log.info( - "'wcs_pars' will be ignored when 'output' is " - "provided and accum=True" - ) - except FileNotFoundError: - pass + array_data = model_info.get(array_name, None) + sci_data = model_info["data"] + model_name = _get_model_name(model_info) - elif output is not None: - self._output_filename = self.get_model_attr_value( - output, - "filename" + if array_data is None or array_data.size == 0: + log.debug( + f"No data for '{array_name}' for model " + f"{repr(model_name)}. Skipping ..." ) - self._output_model = output - self._close_output = False + return False - return kwargs + elif array_data.shape != sci_data.shape: + log.warning( + f"Data shape mismatch for '{array_name}' for model " + f"{repr(model_name)}. Skipping ..." + ) + return False - def _create_new_output_model(self): - # this probably needs to be an abstract class. - # also this is mostly needed for "single" drizzle. - output_model = self.new_model( - None, - copy_meta_from=self._first_model_meta - ) + return True - # update meta data and wcs - pix_area = self._output_pixel_scale**2 - self.set_model_meta( - output_model, - { - "wcs": deepcopy(self._output_wcs), - "pixelarea_steradians": pix_area, - "pixelarea_arcsecsq": pix_area * np.rad2deg(3600)**2, - } - ) + def add_model(self, model_info, image_model): + """ Resample and add data (variance, etc.) arrays to the output arrays. - return output_model + Parameters + ---------- - def build_output_model_name(self): - fnames = {f for f in self._input_filename_list if f is not None} + model_info : dict + A dictionary with data extracted from an image model needed for + `Resample` to successfully process this model. - if not fnames: - return "resampled_data_{resample_suffix}{resample_file_ext}" + image_model : object + The original data model from which ``model`` data was extracted. + It is not used by this method in this class but can be used + by pipeline-specific subclasses to perform additional processing + such as blend headers. - # TODO: maybe remove ending suffix for single file names? - prefix = os.path.commonprefix( - [PurePath(f).stem.strip('_- ') for f in fnames] - ) + """ + in_data = model_info["data"] - return prefix + "{resample_suffix}{resample_file_ext}" + if (group_id := model_info["group_id"]) not in self.group_ids: + self.group_ids.append(group_id) + self.output_model["pointings"] += 1 - def create_output_model(self, resample_results): - # this probably needs to be an abstract class (different telescopes - # may want to save different arrays and ignore others). + self.input_file_names.append(model_info["filename"]) - if not self._output_model and self._output_filename: - if self._accum and Path(self._output_filename).is_file(): - self._output_model = self.open_model(self._output_filename) - else: - self._output_model = self._create_new_output_model() - self._close_output = not self.in_memory + # Check that input models are 2D images + if in_data.ndim != 2: + raise RuntimeError( + f"Input model {_get_model_name(model_info)} " + "is not a 2D image." + ) - if self._output_filename is None: - self._output_filename = self.build_output_model_name() + input_pixflux_area = model_info["pixelarea_steradians"] + imwcs = model_info["wcs"] + if (input_pixflux_area and + 'SPECTRAL' not in imwcs.output_frame.axes_type): + if not np.array_equiv(imwcs.array_shape, in_data.shape): + imwcs.array_shape = in_data.shape + input_pixel_area = _compute_image_pixel_area(imwcs) + if input_pixel_area is None: + model_name = model_info["filename"] + if not model_name: + model_name = "Unknown" + raise ValueError( + "Unable to compute input pixel area from WCS of input " + f"image {repr(model_name)}." + ) + iscale = np.sqrt(input_pixflux_area / input_pixel_area) + else: + iscale = 1.0 - self.set_model_array(self._output_model, "data", resample_results.out_img) + # TODO: should weight_type=None here? + in_wht = self.build_driz_weight( + model_info, + weight_type=self.weight_type, + good_bits=self.good_bits + ) - self.update_exposure_times() - if self._enable_err: - self._finish_variance_processing() + # apply sky subtraction + blevel = model_info["level"] + if not model_info["subtracted"] and blevel is not None: + in_data = in_data - blevel - self.set_model_meta( - self._output_model, - { - "weight_type": self.weight_type, - "pointings": len(self._input_models.group_names), - } + xmin, xmax, ymin, ymax = _resample_range( + in_data.shape, + imwcs.bounding_box ) - # TODO: also store the number of images added in total: ncoadds? - self.final_post_processing() + pixmap = calc_pixmap(wcs_from=imwcs, wcs_to=self._output_wcs) + + add_image_kwargs = { + 'exptime': model_info["exposure_time"], + 'pixmap': pixmap, + 'scale': iscale, + 'weight_map': in_wht, + 'wht_scale': 1.0, + 'pixfrac': self.pixfrac, + 'in_units': 'cps', # TODO: get units from data model + 'xmin': xmin, + 'xmax': xmax, + 'ymin': ymin, + 'ymax': ymax, + } - self.write_model( - self._output_model, - self._output_filename, - overwrite=True - ) + self.driz_data.add_image(in_data, **add_image_kwargs) - if self._close_output and not self.in_memory: - self.close_model(self._output_model) - self._output_model = None - return self._output_filename + if self._enable_var: + self.resample_variance_data(model_info, add_image_kwargs) - return self._output_model + def run(self): + """ Resample and coadd many inputs to a single output. - def _setup_variance_data(self): - self._var_rnoise_sum = np.full(self._output_array_shape, np.nan) - self._var_poisson_sum = np.full(self._output_array_shape, np.nan) - self._var_flat_sum = np.full(self._output_array_shape, np.nan) - # self._total_weight_var_rnoise = np.zeros(self._output_array_shape) - self._total_weight_var_poisson = np.zeros(self._output_array_shape) - self._total_weight_var_flat = np.zeros(self._output_array_shape) + 1. Call methods that initialize data, variance, and time computations. + 2. Add input images (data, variances, etc) to output arrays. + 3. Perform final computations to compute variance and error + arrays and total expose time information for the resampled image. - def _check_var_array(self, data_model, array_name): - array_data = self.get_model_array(data_model, array_name, default=None) - sci_data = self.get_model_array(data_model, "data", default=None) - filename = self.get_model_attr_value(data_model, "filename") + """ + self.init_time_info() + self.init_resample_data() + if self._enable_var: + self.init_resample_variance() + + for model_info, image_model in self._input_models.iter_model(): + self.add_model(model_info, image_model) + self.update_total_time(model_info) + + # assign resampled arrays to the output model dictionary: + self._output_model["data"] = self.driz_data.out_img.astype( + dtype=self.output_array_types["data"] + ) + self._output_model["wht"] = self.driz_data.out_wht.astype( + dtype=self.output_array_types["wht"] + ) - if array_data is None or array_data.size == 0: - log.debug( - f"No data for '{array_name}' for model " - f"{repr(filename)}. Skipping ..." + if self._enable_ctx: + self._output_model["con"] = self.driz_data.out_ctx.astype( + dtype=self.output_array_types["con"] ) - return False - elif array_data.shape != sci_data.shape: - log.warning( - f"Data shape mismatch for '{array_name}' for model " - f"{repr(filename)}. Skipping ..." - ) - return False + if self._enable_var: + self.finalize_variance_processing() + self.compute_errors() - return True + self.finalize_time_info() - def _resample_variance_data(self, data_model, driz_init_kwargs, add_image_kwargs): - log.info("Resampling variance components") + def resample_variance_data(self, data_model, add_image_kwargs): + """ Resample and add input model's variance arrays to the output + vararrays. - # create resample objects for the three variance arrays: - driz_init_kwargs = { - 'kernel': self.kernel, - 'fillval': np.nan, - 'out_shape': self._output_array_shape, - # 'exptime': 1.0, - 'no_ctx': True, - } - driz_rnoise = Drizzle(**driz_init_kwargs) - driz_poisson = Drizzle(**driz_init_kwargs) - driz_flat = Drizzle(**driz_init_kwargs) + """ + log.info("Resampling variance components") # Resample read-out noise and compute weight map for variance arrays if self._check_var_array(data_model, 'var_rnoise'): - data = self.get_model_array(data_model, "var_rnoise") + data = data_model["var_rnoise"] data = np.sqrt(data) - driz_rnoise.add_image(data, **add_image_kwargs) - var = driz_rnoise.out_img + + # reset driz output arrays: + self.driz_rnoise.out_img[:, :] = self.driz_rnoise.fillval + self.driz_rnoise.out_wht[:, :] = 0.0 + + self.driz_rnoise.add_image(data, **add_image_kwargs) + var = self.driz_rnoise.out_img np.square(var, out=var) weight_mask = var > 0 @@ -869,50 +1056,38 @@ def _resample_variance_data(self, data_model, driz_init_kwargs, add_image_kwargs weight = np.ones(self._output_array_shape) weight_mask = np.ones(self._output_array_shape, dtype=bool) - if self._check_var_array(data_model, 'var_poisson'): - data = self.get_model_array(data_model, "var_poisson") + for var_name in ["var_flat", "var_poisson"]: + if not self._check_var_array(data_model, var_name): + continue + data = data_model[var_name] data = np.sqrt(data) - driz_poisson.add_image(data, **add_image_kwargs) - var = driz_poisson.out_img - np.square(var, out=var) - mask = (var > 0) & weight_mask + driz = getattr(self, var_name.replace("var", "driz")) + var_sum = getattr(self, f"_{var_name}_sum") + t_var_weight = getattr(self, f"_total_weight_{var_name}") - # Add the inverse of the resampled variance to a running sum. - # Update only pixels (in the running sum) with valid new values: - self._var_poisson_sum[mask] = np.nansum( - [ - self._var_poisson_sum[mask], - var[mask] * weight[mask] * weight[mask] - ], - axis=0 - ) - self._total_weight_var_poisson[mask] += weight[mask] + # reset driz output arrays: + driz.out_img[:, :] = driz.fillval + driz.out_wht[:, :] = 0.0 - if self._check_var_array(data_model, 'var_flat'): - data = self.get_model_array(data_model, "var_flat") - data = np.sqrt(data) - driz_flat.add_image(data, **add_image_kwargs) - var = driz_flat.out_img + driz.add_image(data, **add_image_kwargs) + var = driz.out_img np.square(var, out=var) mask = (var > 0) & weight_mask # Add the inverse of the resampled variance to a running sum. # Update only pixels (in the running sum) with valid new values: - self._var_flat_sum[mask] = np.nansum( + var_sum[mask] = np.nansum( [ - self._var_flat_sum[mask], + var_sum[mask], var[mask] * weight[mask] * weight[mask] ], axis=0 ) - self._total_weight_var_flat[mask] += weight[mask] - - def final_post_processing(self): - pass + t_var_weight[mask] += weight[mask] - def _finish_variance_processing(self): + def finalize_variance_processing(self): # We now have a sum of the weighted resampled variances. # Divide by the total weights, squared, and set in the output model. # Zero weight and missing values are NaN in the output. @@ -920,15 +1095,16 @@ def _finish_variance_processing(self): warnings.filterwarnings("ignore", "invalid value*", RuntimeWarning) warnings.filterwarnings("ignore", "divide by zero*", RuntimeWarning) - odt = self.get_model_array(self._output_model, "data").dtype - # readout noise np.reciprocal(self._var_rnoise_sum, out=self._var_rnoise_sum) - self.set_model_array( - self._output_model, - "var_rnoise", - self._var_rnoise_sum.astype(dtype=odt) - ) + if self._accumulate and self._output_model["var_rnoise"]: + self._output_model["var_rnoise"] += self._var_rnoise_sum.astype( + dtype=self.output_array_types["var_rnoise"] + ) + else: + self._output_model["var_rnoise"] = self._var_rnoise_sum.astype( + dtype=self.output_array_types["var_rnoise"] + ) # Poisson noise for _ in range(2): @@ -937,11 +1113,15 @@ def _finish_variance_processing(self): self._total_weight_var_poisson, out=self._var_poisson_sum ) - self.set_model_array( - self._output_model, - "var_poisson", - self._var_poisson_sum.astype(dtype=odt) - ) + + if self._accumulate and self._output_model["var_poisson"]: + self._output_model["var_poisson"] += self._var_rnoise_sum.astype( + dtype=self.output_array_types["var_poisson"] + ) + else: + self._output_model["var_poisson"] = self._var_rnoise_sum.astype( + dtype=self.output_array_types["var_poisson"] + ) # flat's noise for _ in range(2): @@ -950,420 +1130,42 @@ def _finish_variance_processing(self): self._total_weight_var_flat, out=self._var_flat_sum ) - self.set_model_array( - self._output_model, - "var_flat", - self._var_flat_sum.astype(dtype=odt) - ) - - # compute total error: - vars = np.array( - [ - self._var_rnoise_sum, - self._var_poisson_sum, - self._var_flat_sum, - ] - ) - all_nan_mask = np.any(np.isnan(vars), axis=0) - - err = np.sqrt(np.nansum(vars, axis=0)).astype(dtype=odt) - err[all_nan_mask] = np.nan - self.set_model_array(self._output_model, "err", err) - self.set_model_array(self._output_model, "var_rnoise", self._var_rnoise_sum) - self.set_model_array(self._output_model, "var_poisson", self._var_poisson_sum) - self.set_model_array(self._output_model, "var_flat", self._var_flat_sum) - - del vars - del self._var_rnoise_sum - del self._var_poisson_sum - del self._var_flat_sum - # del self._total_weight_var_rnoise - del self._total_weight_var_poisson - del self._total_weight_var_flat - - def run(self): - """Resample and coadd many inputs to a single output. - - Used for stage 3 resampling - """ - - # TODO: repetiveness of code below should be compactified via using - # getattr as in orig code and maybe making an alternative method to - # the original resample_variance_array - ninputs = len(self._input_models) - - do_accum = ( - self._accum and - ( - self._output_model or - (self._output_filename and Path(self._output_filename).is_file()) - ) - ) - - if do_accum and self._output_model is None: - self._output_model = self.open_model(self._output_filename) - - # get old data: - data = self.get_model_array(self._output_model, "data") - wht = self.get_model_array(self._output_model, "wht") - if self._enable_ctx: - ctx = self.get_model_array(self._output_model, "con") - else: - ctx = None - - t_exptime = self.get_model_attr_value( - self._output_model, - "exptime" - ) - # TODO: we need something to store total number of images that - # have been used to create the resampled output, something - # similar to output_model.meta.resample.pointings. - # For now I will call it "ncoadds" - ncoadds = self.get_model_attr_value( - self._output_model, - "ncoadds" - ) - self.accum_output_arrays = True - - else: - ncoadds = 0 - data = None - wht = None - ctx = None - t_exptime = 0.0 - self.accum_output_arrays = False - - driz_data = Drizzle( - kernel=self.kernel, - fillval=self.fillval, - out_shape=self._output_array_shape, - out_img=data, - out_wht=wht, - out_ctx=ctx, - exptime=t_exptime, - begin_ctx_id=ncoadds, - max_ctx_id=ncoadds + ninputs, - ) - - if self._enable_err: - self._setup_variance_data() - - log.info("Resampling science data") - - # loop over only science exposures in the ModelLibrary - # sci_indices = self._input_models.ind_asn_type("science") - with self._input_models: - for model in self._input_models: - # model = self._input_models.borrow(idx) - - try: - if self.get_model_attr_value(model, "exptype").upper() != "SCIENCE": - self._input_models.shelve(model, modify=False) - continue - except AttributeError: - pass - - in_data = self.get_model_array(model, "data") - - attrs = self.get_model_meta( - model, - [ - "wcs", - "pixelarea_steradians", - "filename", - "level", - "subtracted", - "exposure_time", - ] - ) - - # Check that input models are 2D images - if in_data.ndim != 2: - raise RuntimeError( - f"Input {attrs['filename']} is not a 2D image." - ) - - input_pixflux_area = attrs["pixelarea_steradians"] - imwcs = attrs["wcs"] - if (input_pixflux_area and - 'SPECTRAL' not in imwcs.output_frame.axes_type): - imwcs.array_shape = in_data.shape - input_pixel_area = _compute_image_pixel_area(imwcs) - if input_pixel_area is None: - raise ValueError( - "Unable to compute input pixel area from WCS of input " - f"image {repr(attrs['filename'])}." - ) - iscale = np.sqrt(input_pixflux_area / input_pixel_area) - else: - iscale = 1.0 - # TODO: should weight_type=None here? - in_wht = self.build_driz_weight( - model, - weight_type=self.weight_type, - good_bits=self.good_bits + if self._accumulate and self._output_model["var_flat"]: + self._output_model["var_flat"] += self._var_rnoise_sum.astype( + dtype=self.output_array_types["var_flat"] ) - - # apply sky subtraction - blevel = attrs["level"] - if not attrs["subtracted"] and blevel is not None: - in_data = in_data - blevel - - xmin, xmax, ymin, ymax = _resample_range( - in_data.shape, - imwcs.bounding_box + else: + self._output_model["var_flat"] = self._var_rnoise_sum.astype( + dtype=self.output_array_types["var_flat"] ) - pixmap = calc_pixmap(wcs_from=imwcs, wcs_to=self._output_wcs) - - add_image_kwargs = { - 'exptime': attrs["exposure_time"], - 'pixmap': pixmap, - 'scale': iscale, - 'weight_map': in_wht, - 'wht_scale': 1.0, - 'pixfrac': self.pixfrac, - 'in_units': 'cps', # TODO: get units from data model - 'xmin': xmin, - 'xmax': xmax, - 'ymin': ymin, - 'ymax': ymax, - } - - driz_data.add_image(in_data, **add_image_kwargs) - - if self._enable_err: - self._resample_variance_data(model, None, add_image_kwargs) - - self._input_models.shelve(model, modify=False) - - # TODO: see what to do about original update_exposure_times() - - return self.create_output_model(driz_data) - - -class ResampleSingle(ResampleBase): - """ - This is the controlling routine for the resampling process. - - Notes - ----- - This routine performs the following operations:: - - 1. Extracts parameter settings from input model, such as pixfrac, - weight type, exposure time (if relevant), and kernel, and merges - them with any user-provided values. - 2. Creates output WCS based on input images and define mapping function - between all input arrays and the output array. - 3. Updates output data model with output arrays from drizzle, including - a record of metadata from all input models. - """ - - def __init__(self, input_models, - pixfrac=1.0, kernel="square", fillval=0.0, wht_type="ivm", - good_bits=0, output_wcs=None, wcs_pars=None, - in_memory=True, allowed_memory=None): - """ - Parameters - ---------- - input_models : list of objects - list of data models, one for each input image - - output : DataModel, str - filename for output - - kwargs : dict - Other parameters. - - .. note:: - ``output_shape`` is in the ``x, y`` order. - - .. note:: - ``in_memory`` controls whether or not the resampled - array from ``resample_many_to_many()`` - should be kept in memory or written out to disk and - deleted from memory. Default value is `True` to keep - all products in memory. - - """ - super().__init__( - input_models, - pixfrac=pixfrac, - kernel=kernel, - fillval=fillval, - wht_type=wht_type, - good_bits=good_bits, - output_wcs=output_wcs, - wcs_pars=wcs_pars, - in_memory=in_memory, - allowed_memory=allowed_memory, + # free arrays: + del self._var_rnoise_sum + del self._var_poisson_sum + del self._var_flat_sum + # del self._total_weight_var_rnoise + del self._total_weight_var_poisson + del self._total_weight_var_flat + + def compute_errors(self): + """ Computes total error of the resampled image. """ + vars = np.array( + [ + self._output_model["var_rnoise"], + self._output_model["var_poisson"], + self._output_model["var_flat"], + ] ) + all_nan_mask = np.any(np.isnan(vars), axis=0) - def build_output_name_from_input_name(self, input_file_name): - """ Form output file name from input image name """ - indx = input_file_name.rfind('.') - output_type = input_file_name[indx:] - output_root = '_'.join( - input_file_name.replace(output_type, '').split('_')[:-1] - ) - output_file_name = f'{output_root}_outlier_i2d{output_type}' - return output_file_name - - def _create_output_template_model(self): - # this probably needs to be an abstract class. - # also this is mostly needed for "single" drizzle. - self._template_output_model = self.new_model( - copy_meta_from=self._first_model_meta, - ) - pix_area = self._output_pixel_scale**2 - self.set_model_meta( - self._template_output_model, - { - "wcs": deepcopy(self._output_wcs), - "pixelarea_steradians": pix_area, - "pixelarea_arcsecsq": pix_area * np.rad2deg(3600)**2, - } + err = np.sqrt(np.nansum(vars, axis=0)).astype( + dtype=self.output_array_types["err"] ) + del vars + err[all_nan_mask] = np.nan - def create_output_model_single(self, file_name, resample_results): - # this probably needs to be an abstract class - output_model = deepcopy(self._template_output_model) - self.set_model_array(output_model, "data", resample_results.out_img) - if self.in_memory: - return output_model - else: - self.write_model(output_model, file_name, overwrite=True) - self.close_model(output_model) - log.info(f"Saved resampled model to {file_name}") - return file_name - - def run(self): - """Resample many inputs to many outputs where outputs have a common frame. - - Coadd only different detectors of the same exposure, i.e. map NRCA5 and - NRCB5 onto the same output image, as they image different areas of the - sky. - - Used for outlier detection - """ - output_models = [] - - for exposure_indices in self._input_models.group_indices.values(): - - driz = Drizzle( - kernel=self.kernel, - fillval=self.fillval, - out_shape=self._output_array_shape, - max_ctx_id=0 - ) - - log.info(f"{len(exposure_indices)} exposures to drizzle together") - - exptime = None - - meta_fields = [ - "wcs", - "pixelarea_steradians", - "filename", - "level", - "subtracted", - ] - - for idx in exposure_indices: - - with self._input_models: - model = self._input_models.borrow(idx) - - in_data = self.get_model_array(model, "data") - - if exptime is None: - attrs = self.get_model_meta( - model, - meta_fields + ["exposure_time", "filename"] - ) - else: - attrs = self.get_model_meta(model, meta_fields) - - # Check that input models are 2D images - if in_data.ndim != 2: - raise RuntimeError( - f"Input {attrs['filename']} is not a 2D image." - ) - - input_pixflux_area = attrs["pixelarea_steradians"] - imwcs = attrs["wcs"] - - if exptime is None: - exptime = attrs["exposure_time"] - # Determine output file type from input exposure filenames - # Use this for defining the output filename - output_filename = self.build_output_name_from_input_name( - attrs["filename"] - ) - - # compute image intensity correction due to the difference - # between where in the input image - # img.meta.photometry.pixelarea_steradians was computed and - # the average input pixel area. - if (input_pixflux_area and - 'SPECTRAL' not in imwcs.output_frame.axes_type): - imwcs.array_shape = in_data.shape - input_pixel_area = _compute_image_pixel_area(imwcs) - if input_pixel_area is None: - raise ValueError( - "Unable to compute input pixel area from WCS " - f"of input image {repr(attrs['filename'])}." - ) - iscale = np.sqrt(input_pixflux_area / input_pixel_area) - else: - iscale = 1.0 - - # TODO: should weight_type=None here? - in_wht = self.build_driz_weight( - model, - weight_type=self.weight_type, - good_bits=self.good_bits - ) - - # apply sky subtraction - blevel = attrs["level"] - if not attrs["subtracted"] and blevel is not None: - in_data = in_data - blevel - - xmin, xmax, ymin, ymax = _resample_range( - in_data.shape, - imwcs.bounding_box - ) - - pixmap = calc_pixmap(wcs_from=imwcs, wcs_to=self._output_wcs) - - driz.add_image( - in_data, - exptime=exptime, - pixmap=pixmap, - scale=iscale, - weight_map=in_wht, - wht_scale=1.0, - pixfrac=self.pixfrac, - in_units='cps', # TODO: get units from data model - xmin=xmin, - xmax=xmax, - ymin=ymin, - ymax=ymax, - ) - - self._input_models.shelve(model, idx, modify=False) - del in_data - - output_models.append( - self.create_output_model_single( - output_filename, - driz - ) - ) - - return output_models + self._output_model["err"] = err def _get_boundary_points(xmin, xmax, ymin, ymax, dx=None, dy=None, shrink=0): @@ -1492,3 +1294,10 @@ def _compute_image_pixel_area(wcs): pix_area = sky_area / image_area return pix_area + + +def _get_model_name(model_info): + model_name = model_info["filename"] + if model_name is None or not model_name.strip(): + model_name = "Unknown" + return model_name diff --git a/src/stcal/resample/utils.py b/src/stcal/resample/utils.py index c9623b1b0..71d32b123 100644 --- a/src/stcal/resample/utils.py +++ b/src/stcal/resample/utils.py @@ -1,11 +1,28 @@ +import os +from pathlib import Path, PurePath + import numpy as np from astropy.nddata.bitmask import interpret_bit_flags __all__ = [ - "build_mask", "get_tmeasure", + "build_mask", "build_output_model_name", "get_tmeasure", "bytes2human" ] +def build_output_model_name(input_filename_list): + fnames = {f for f in input_filename_list if f is not None} + + if not fnames: + return "resampled_data_{resample_suffix}{resample_file_ext}" + + # TODO: maybe remove ending suffix for single file names? + prefix = os.path.commonprefix( + [PurePath(f).stem.strip('_- ') for f in fnames] + ) + + return prefix + "{resample_suffix}{resample_file_ext}" + + def build_mask(dqarr, bitvalue, flag_name_map=None): """Build a bit mask from an input DQ array and a bitvalue flag @@ -26,10 +43,47 @@ def get_tmeasure(model): Returns a tuple of (exptime, is_measurement_time) """ try: - tmeasure = model.meta.exposure.measurement_time - except AttributeError: - return model.meta.exposure.exposure_time, False + tmeasure = model["measurement_time"] + except KeyError: + return model["exposure_time"], False if tmeasure is None: - return model.meta.exposure.exposure_time, False + return model["exposure_time"], False else: return tmeasure, True + + +# FIXME: temporarily copied here to avoid this import: +# from stdatamodels.jwst.library.basic_utils import bytes2human +def bytes2human(n): + """Convert bytes to human-readable format + + Taken from the `psutil` library which references + http://code.activestate.com/recipes/578019 + + Parameters + ---------- + n : int + Number to convert + + Returns + ------- + readable : str + A string with units attached. + + Examples + -------- + >>> bytes2human(10000) + '9.8K' + + >>> bytes2human(100001221) + '95.4M' + """ + symbols = ('K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y') + prefix = {} + for i, s in enumerate(symbols): + prefix[s] = 1 << (i + 1) * 10 + for s in reversed(symbols): + if n >= prefix[s]: + value = float(n) / prefix[s] + return '%.1f%s' % (value, s) + return "%sB" % n \ No newline at end of file