diff --git a/changes/310.general.rst b/changes/310.general.rst new file mode 100644 index 00000000..e20c4358 --- /dev/null +++ b/changes/310.general.rst @@ -0,0 +1 @@ +Move common parts of skymatch shared by both jwst and romancal into stcal. diff --git a/pyproject.toml b/pyproject.toml index 7c46ff17..aff95026 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "gwcs >=0.22.0", "tweakwcs >=0.8.8", "requests >=2.22", + "spherical-geometry>=1.2.22" ] dynamic = [ "version", diff --git a/src/stcal/skymatch/__init__.py b/src/stcal/skymatch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/stcal/skymatch/skyimage.py b/src/stcal/skymatch/skyimage.py new file mode 100644 index 00000000..155d0e38 --- /dev/null +++ b/src/stcal/skymatch/skyimage.py @@ -0,0 +1,1020 @@ +""" +The ``skyimage`` module contains algorithms that are used by +``skymatch`` to manage all of the information for footprints (image outlines) +on the sky as well as perform useful operations on these outlines such as +computing intersections and statistics in the overlap regions. +""" + +import abc +import tempfile + +import numpy as np +from gwcs import region +from spherical_geometry.polygon import SphericalPolygon # type: ignore[import-untyped] + +from .skystatistics import SkyStats # type: ignore[import-untyped] + +__all__ = [ + "SkyImage", + "SkyGroup", + "DataAccessor", + "NDArrayInMemoryAccessor", + "NDArrayMappedAccessor", +] + + +class DataAccessor(abc.ABC): + """Base class for all data accessors. Provides a common interface to + access data. + """ + + @abc.abstractmethod + def get_data(self): + pass + + @abc.abstractmethod + def set_data(self, data): + """Sets data. + + Parameters + ---------- + data : numpy.ndarray + Data array to be set. + + """ + pass + + @abc.abstractmethod + def get_data_shape(self): + pass + + +class NDArrayInMemoryAccessor(DataAccessor): + """Acessor for in-memory `numpy.ndarray` data.""" + + def __init__(self, data): + super().__init__() + self._data = data + + def get_data(self): + return self._data + + def set_data(self, data): + self._data = data + + def get_data_shape(self): + return np.shape(self._data) + + +class NDArrayMappedAccessor(DataAccessor): + """Data accessor for arrays stored in temporary files.""" + + def __init__( + self, data, tmpfile=None, prefix="tmp_skymatch_", suffix=".npy", tmpdir="" + ): + super().__init__() + if tmpfile is None: + self._close = True + self._tmp = tempfile.NamedTemporaryFile( + prefix=prefix, suffix=suffix, dir=tmpdir + ) + if not self._tmp: + raise RuntimeError("Unable to create temporary file.") + else: + # temp file managed by the caller + self._close = False + self._tmp = tmpfile + + self.set_data(data) + + def get_data(self): + self._tmp.seek(0) + return np.load(self._tmp) + + def set_data(self, data): + data = np.asanyarray(data) + self._data_shape = data.shape + self._tmp.seek(0) + np.save(self._tmp, data) + + def __del__(self): + if self._close: + self._tmp.close() + + def get_data_shape(self): + return self._data_shape + + +class SkyImage: + """ + Container that holds information about properties of a *single* + image such as: + + * image data; + * WCS of the chip image; + * bounding spherical polygon; + * id; + * pixel area; + * sky background value; + * sky statistics parameters; + * mask associated image data indicating "good" (1) data. + + """ + + def __init__( + self, + image, + wcs_fwd, + wcs_inv, + pix_area=1.0, + convf=1.0, + mask=None, + id=None, # noqa: A002 + skystat=None, + stepsize=None, + meta=None, + reduce_memory_usage=True, + ): + """Initializes the SkyImage object. + + Parameters + ---------- + image : numpy.ndarray, NDArrayDataAccessor + A 2D array of image data or a `NDArrayDataAccessor`. + + wcs_fwd : function + "forward" pixel-to-world transformation function. + + wcs_inv : function + "inverse" world-to-pixel transformation function. + + pix_area : float, optional + Average pixel's sky area. + + convf : float, optional + Conversion factor that when multiplied to `image` data converts + the data to "uniform" (across multiple images) surface + brightness units. + + .. note:: + + The functionality to support this conversion is not yet + implemented and at this moment `convf` is ignored. + + mask : numpy.ndarray, NDArrayDataAccessor + A 2D array or `NDArrayDataAccessor` of a 2D array that indicates + which pixels in the input `image` should be used for sky + computations (``1``) and which pixels should **not** be used + for sky computations (``0``). + + id : anything + The value of this parameter is simply stored within the `SkyImage` + object. While it can be of any type, it is preferable that `id` be + of a type with nice string representation. + + skystat : callable, None, optional + A callable object that takes a either a 2D image (2D + `numpy.ndarray`) or a list of pixel values (an Nx1 array) and + returns a tuple of two values: some statistics (e.g., mean, + median, etc.) and number of pixels/values from the input image + used in computing that statistics. + + When `skystat` is not set, `SkyImage` will use + :py:class:`~stcal.skymatch.skystatistics.SkyStats` object + to perform sky statistics on image data. + + stepsize : int, None, optional + Spacing between vertices of the image's bounding polygon. Default + value of `None` creates bounding polygons with four vertices + corresponding to the corners of the image. + + meta : dict, None, optional + A dictionary of various items to be stored within the `SkyImage` + object. + + reduce_memory_usage : bool, optional + Indicates whether to attempt to minimize memory usage by attaching + input ``image`` and/or ``mask`` `numpy.ndarray` arrays to + file-mapped accessor. This has no effect when input parameters + ``image`` and/or ``mask`` are already of `NDArrayDataAccessor` + objects. + + """ + self._image = None + self._mask = None + self._image_shape = None + self._mask_shape = None + self._reduce_memory_usage = reduce_memory_usage + + self.image = image + + self.convf = convf + self.meta = meta + self._id = id + self._pix_area = pix_area + + # WCS + self.wcs_fwd = wcs_fwd + self.wcs_inv = wcs_inv + + # initial sky value: + self._sky = 0.0 + self._sky_is_valid = False + + self.mask = mask + + # create spherical polygon bounding the image + if image is None or wcs_fwd is None or wcs_inv is None: + self._radec = [(np.array([]), np.array([]))] + self._polygon = SphericalPolygon([]) + self._poly_area = 0.0 + + else: + self.calc_bounding_polygon(stepsize) + + # set sky statistics function (NOTE: it must return statistics and + # the number of pixels used after clipping) + if skystat is None: + self.set_builtin_skystat() + else: + self.skystat = skystat + + @property + def mask(self): + """Set or get `SkyImage`'s ``mask`` data array or `None`.""" + if self._mask is None: + return None + else: + return self._mask.get_data() + + @mask.setter + def mask(self, mask): + if mask is None: + self._mask = None + self._mask_shape = None + + elif isinstance(mask, DataAccessor): + if self._image is None: + raise ValueError("'mask' must be None when 'image' is None") + + self._mask = mask + self._mask_shape = mask.get_data_shape() + + # check that mask has the same shape as image: + if self._mask_shape != self.image_shape: + raise ValueError("'mask' must have the same shape as 'image'.") + + else: + if self._image is None: + raise ValueError("'mask' must be None when 'image' is None") + + mask = np.asanyarray(mask, dtype=bool) + self._mask_shape = mask.shape + + # check that mask has the same shape as image: + if self._mask_shape != self.image_shape: + raise ValueError("'mask' must have the same shape as 'image'.") + + if self._mask is None: + if self._reduce_memory_usage: + self._mask = NDArrayMappedAccessor( + mask, prefix="tmp_skymatch_mask_" + ) + else: + self._mask = NDArrayInMemoryAccessor(mask) + else: + self._mask.set_data(mask) + + @property + def image(self): + """Set or get `SkyImage`'s ``image`` data array.""" + if self._image is None: + return None + else: + return self._image.get_data() + + @image.setter + def image(self, image): + if image is None: + self._image = None + self._image_shape = None + self.mask = None + + if isinstance(image, DataAccessor): + self._image = image + self._image_shape = image.get_data_shape() + + else: + image = np.asanyarray(image) + self._image_shape = image.shape + if self._image is None: + if self._reduce_memory_usage: + self._image = NDArrayMappedAccessor( + image, prefix="tmp_skymatch_image_" + ) + else: + self._image = NDArrayInMemoryAccessor(image) + else: + self._image.set_data(image) + + @property + def image_shape(self): + """Get `SkyImage`'s ``image`` data shape.""" + if self._image_shape is None and self._image is not None: + self._image_shape = self._image.get_data_shape() + return self._image_shape + + @property + def id(self): # noqa: A003 + """Set or get `SkyImage`'s `id`. + + While `id` can be of any type, it is preferable that `id` be + of a type with nice string representation. + + """ + return self._id + + @id.setter + def id(self, value): # noqa: A003 + self._id = value + + @property + def pix_area(self): + """Set or get mean pixel area.""" + return self._pix_area + + @pix_area.setter + def pix_area(self, pix_area): + self._pix_area = pix_area + + @property + def poly_area(self): + """Get bounding polygon area in srad units.""" + return self._poly_area + + @property + def sky(self): + """Sky background value. See `calc_sky` for more details.""" + return self._sky + + @sky.setter + def sky(self, sky): + self._sky = sky + + @property + def is_sky_valid(self): + """ + Indicates whether sky value was successfully computed. + Must be set externally. + """ + return self._sky_is_valid + + @is_sky_valid.setter + def is_sky_valid(self, valid): + self._sky_is_valid = valid + + @property + def radec(self): + """ + Get RA and DEC of the vertices of the bounding polygon as a + `~numpy.ndarray` of shape (N, 2) where N is the number of vertices + 1. + """ + return self._radec + + @property + def polygon(self): + """Get image's bounding polygon.""" + return self._polygon + + def intersection(self, skyimage): + """ + Compute intersection of this `SkyImage` object and another + `SkyImage`, `SkyGroup`, or + :py:class:`~spherical_geometry.polygon.SphericalPolygon` + object. + + Parameters + ---------- + skyimage : SkyImage, SkyGroup, SphericalPolygon + Another object that should be intersected with this `SkyImage`. + + Returns + ------- + polygon : SphericalPolygon + A :py:class:`~spherical_geometry.polygon.SphericalPolygon` that is + the intersection of this `SkyImage` and `skyimage`. + + """ + if isinstance(skyimage, (SkyImage, SkyGroup)): + other = skyimage.polygon + else: + other = skyimage + + pts1 = np.sort(list(self._polygon.points)[0], axis=0) + pts2 = np.sort(list(other.points)[0], axis=0) + if np.allclose(pts1, pts2, rtol=0, atol=5e-9): + intersect_poly = self._polygon.copy() + else: + intersect_poly = self._polygon.intersection(other) + return intersect_poly + + def calc_bounding_polygon(self, stepsize=None): + """Compute image's bounding polygon. + + Parameters + ---------- + stepsize : int, None, optional + Indicates the maximum separation between two adjacent vertices + of the bounding polygon along each side of the image. Corners + of the image are included automatically. If `stepsize` is `None`, + bounding polygon will contain only vertices of the image. + + """ + ny, nx = self.image_shape + + if stepsize is None: + nintx = 2 + ninty = 2 + else: + nintx = max(2, int(np.ceil((nx + 1.0) / stepsize))) + ninty = max(2, int(np.ceil((ny + 1.0) / stepsize))) + + xs = np.linspace(-0.5, nx - 0.5, nintx, dtype=float) + ys = np.linspace(-0.5, ny - 0.5, ninty, dtype=float)[1:-1] + nptx = xs.size + npty = ys.size + + npts = 2 * (nptx + npty) + + borderx = np.empty((npts + 1,), dtype=float) + bordery = np.empty((npts + 1,), dtype=float) + + # "bottom" points: + borderx[:nptx] = xs + bordery[:nptx] = -0.5 + # "right" + sl = np.s_[nptx : nptx + npty] + borderx[sl] = nx - 0.5 + bordery[sl] = ys + # "top" + sl = np.s_[nptx + npty : 2 * nptx + npty] + borderx[sl] = xs[::-1] + bordery[sl] = ny - 0.5 + # "left" + sl = np.s_[2 * nptx + npty : -1] + borderx[sl] = -0.5 + bordery[sl] = ys[::-1] + + # close polygon: + borderx[-1] = borderx[0] + bordery[-1] = bordery[0] + + ra, dec = self.wcs_fwd(borderx, bordery, with_bounding_box=False) + # TODO: for strange reasons, occasionally ra[0] != ra[-1] and/or + # dec[0] != dec[-1] (even though we close the polygon in the + # previous two lines). Then SphericalPolygon fails because + # points are not closed. Therefore we force it to be closed: + ra[-1] = ra[0] + dec[-1] = dec[0] + + self._radec = [(ra, dec)] + self._polygon = SphericalPolygon.from_radec(ra, dec) + self._poly_area = np.fabs(self._polygon.area()) + + @property + def skystat(self): + """Stores/retrieves a callable object that takes a either a 2D image + (2D `numpy.ndarray`) or a list of pixel values (an Nx1 array) and + returns a tuple of two values: some statistics + (e.g., mean, median, etc.) and number of pixels/values from the input + image used in computing that statistics. + + When `skystat` is not set, `SkyImage` will use + :py:class:`~stcal.skymatch.skystatistics.SkyStats` object + to perform sky statistics on image data. + + """ + return self._skystat + + @skystat.setter + def skystat(self, skystat): + self._skystat = skystat + + def set_builtin_skystat( + self, + skystat="median", + lower=None, + upper=None, + nclip=5, + lsigma=4.0, + usigma=4.0, + binwidth=0.1, + ): + """ + Replace already set `skystat` with a "built-in" version of a + statistics callable object used to measure sky background. + + See :py:class:`~stcal.skymatch.skystatistics.SkyStats` for the + parameter description. + + """ + self._skystat = SkyStats( + skystat=skystat, + lower=lower, + upper=upper, + nclip=nclip, + lsig=lsigma, + usig=usigma, + binwidth=binwidth, + ) + + def calc_sky(self, overlap=None, delta=True): + """ + Compute sky background value. + + Parameters + ---------- + overlap : SkyImage, SkyGroup, SphericalPolygon, list of tuples, \ +None, optional + Another `SkyImage`, `SkyGroup`, + :py:class:`spherical_geometry.polygons.SphericalPolygon`, or + a list of tuples of (RA, DEC) of vertices of a spherical + polygon. This parameter is used to indicate that sky statistics + should computed only in the region of intersection of *this* + image with the polygon indicated by `overlap`. When `overlap` is + `None`, sky statistics will be computed over the entire image. + + delta : bool, optional + Should this function return absolute sky value or the difference + between the computed value and the value of the sky stored in the + `sky` property. + + Returns + ------- + skyval : float, None + Computed sky value (absolute or relative to the `sky` attribute). + If there are no valid data to perform this computations (e.g., + because this image does not overlap with the image indicated by + `overlap`), `skyval` will be set to `None`. + + npix : int + Number of pixels used to compute sky statistics. + + polyarea : float + Area (in srad) of the polygon that bounds data used to compute + sky statistics. + + """ + if overlap is None: + if self._mask is None: + data = self.image + else: + data = self.image[self._mask.get_data()] + + polyarea = self.poly_area + + else: + fill_mask = np.zeros(self.image_shape, dtype=bool) + + if isinstance(overlap, SkyImage): + intersection = self.intersection(overlap) + polyarea = np.fabs(intersection.area()) + radec = list(intersection.to_radec()) + + elif isinstance(overlap, SkyGroup): + radec = [] + polyarea = 0.0 + for im in overlap: + intersection = self.intersection(im) + polyarea1 = np.fabs(intersection.area()) + if polyarea1 == 0.0: + continue + polyarea += polyarea1 + radec += list(intersection.to_radec()) + + elif isinstance(overlap, SphericalPolygon): + radec = [] + polyarea = 0.0 + for p in overlap._polygons: + intersection = self.intersection(SphericalPolygon([p])) + polyarea1 = np.fabs(intersection.area()) + if polyarea1 == 0.0: + continue + polyarea += polyarea1 + radec += list(intersection.to_radec()) + + else: # assume a list of (ra, dec) tuples: + radec = [] + polyarea = 0.0 + for r, d in overlap: + poly = SphericalPolygon.from_radec(r, d) + polyarea1 = np.fabs(poly.area()) + if polyarea1 == 0.0 or len(r) < 4: + continue + polyarea += polyarea1 + radec.append(self.intersection(poly).to_radec()) + + if polyarea == 0.0: + return None, 0, 0.0 + + for ra, dec in radec: + if len(ra) < 4: + continue + + # set pixels in 'fill_mask' that are inside a polygon to True: + x, y = self.wcs_inv(ra, dec) + poly_vert = list(zip(*[x, y])) + + polygon = region.Polygon(True, poly_vert) + fill_mask = polygon.scan(fill_mask) + + if self._mask is not None: + fill_mask &= self._mask.get_data() + + data = self.image[fill_mask] + + if data.size < 1: + return None, 0, 0.0 + + # Calculate sky + try: + skyval, npix = self._skystat(data) + except ValueError: + return None, 0, 0.0 + + if not np.isfinite(skyval): + return None, 0, 0.0 + + if delta: + skyval -= self._sky + + return skyval, npix, polyarea + + def _calc_sky_orig(self, overlap=None, delta=True): + """ + Compute sky background value. + + Parameters + ---------- + overlap : SkyImage, SkyGroup, SphericalPolygon, list of tuples, \ +None, optional + Another `SkyImage`, `SkyGroup`, + :py:class:`spherical_geometry.polygons.SphericalPolygon`, or + a list of tuples of (RA, DEC) of vertices of a spherical + polygon. This parameter is used to indicate that sky statistics + should computed only in the region of intersection of *this* + image with the polygon indicated by `overlap`. When `overlap` is + `None`, sky statistics will be computed over the entire image. + + delta : bool, optional + Should this function return absolute sky value or the difference + between the computed value and the value of the sky stored in the + `sky` property. + + Returns + ------- + skyval : float, None + Computed sky value (absolute or relative to the `sky` attribute). + If there are no valid data to perform this computations (e.g., + because this image does not overlap with the image indicated by + `overlap`), `skyval` will be set to `None`. + + npix : int + Number of pixels used to compute sky statistics. + + polyarea : float + Area (in srad) of the polygon that bounds data used to compute + sky statistics. + + """ + + if overlap is None: + if self._mask is None: + data = self.image + else: + data = self.image[self._mask.get_data()] + + polyarea = self.poly_area + + else: + fill_mask = np.zeros(self.image_shape, dtype=bool) + + if isinstance(overlap, (SkyImage, SkyGroup, SphericalPolygon)): + intersection = self.intersection(overlap) + polyarea = np.fabs(intersection.area()) + radec = intersection.to_radec() + + else: # assume a list of (ra, dec) tuples: + radec = [] + polyarea = 0.0 + for r, d in overlap: + poly = SphericalPolygon.from_radec(r, d) + polyarea1 = np.fabs(poly.area()) + if polyarea1 == 0.0 or len(r) < 4: + continue + polyarea += polyarea1 + radec.append(self.intersection(poly).to_radec()) + + if polyarea == 0.0: + return None, 0, 0.0 + + for ra, dec in radec: + if len(ra) < 4: + continue + + # set pixels in 'fill_mask' that are inside a polygon to True: + x, y = self.wcs_inv(ra, dec) + poly_vert = list(zip(*[x, y])) + + polygon = region.Polygon(True, poly_vert) + fill_mask = polygon.scan(fill_mask) + + if self._mask is not None: + fill_mask &= self._mask.get_data() + + data = self.image[fill_mask] + + if data.size < 1: + return None, 0, 0.0 + + # Calculate sky + try: + skyval, npix = self._skystat(data) + + except ValueError: + return None, 0, 0.0 + + if delta: + skyval -= self._sky + + return skyval, npix, polyarea + + def copy(self): + """ + Return a shallow copy of the `SkyImage` object. + """ + si = SkyImage( + image=None, + wcs_fwd=self.wcs_fwd, + wcs_inv=self.wcs_inv, + pix_area=self.pix_area, + convf=self.convf, + mask=None, + id=self.id, + stepsize=None, + meta=self.meta, + ) + + si._image = self._image + si._mask = self._mask + si._image_shape = self._image_shape + si._mask_shape = self._mask_shape + si._reduce_memory_usage = self._reduce_memory_usage + + si._radec = self._radec + si._polygon = self._polygon + si._poly_area = self._poly_area + si.sky = self.sky + return si + + +class SkyGroup: + """ + Holds multiple :py:class:`SkyImage` objects whose sky background values + must be adjusted together. + + `SkyGroup` provides methods for obtaining bounding polygon of the group + of :py:class:`SkyImage` objects and to compute sky value of the group. + + """ + + def __init__(self, images, id=None, sky=0.0): # noqa: A002 + if isinstance(images, SkyImage): + self._images = [images] + + elif hasattr(images, "__iter__"): + self._images = [] + for im in images: + if not isinstance(im, SkyImage): + raise TypeError( + "Each element of the 'images' parameter " + "must be an 'SkyImage' object." + ) + self._images.append(im) + + else: + raise TypeError( + "Parameter 'images' must be either a single " + "'SkyImage' object or a list of 'SkyImage' objects" + ) + + self._id = id + self._update_bounding_polygon() + self._sky = sky + for im in self._images: + im.sky += sky + + @property + def id(self): # noqa: A003 + """Set or get `SkyImage`'s `id`. + + While `id` can be of any type, it is preferable that `id` be + of a type with nice string representation. + + """ + return self._id + + @id.setter + def id(self, value): # noqa: A003 + self._id = value + + @property + def sky(self): + """Sky background value. See `calc_sky` for more details.""" + return self._sky + + @sky.setter + def sky(self, sky): + delta_sky = sky - self._sky + self._sky = sky + for im in self._images: + im.sky += delta_sky + + @property + def radec(self): + """ + Get RA and DEC of the vertices of the bounding polygon as a + `~numpy.ndarray` of shape (N, 2) where N is the number of vertices + 1. + + """ + return self._radec + + @property + def polygon(self): + """Get image's bounding polygon.""" + return self._polygon + + def intersection(self, skyimage): + """ + Compute intersection of this `SkyImage` object and another + `SkyImage`, `SkyGroup`, or + :py:class:`~spherical_geometry.polygon.SphericalPolygon` + object. + + Parameters + ---------- + skyimage : SkyImage, SkyGroup, SphericalPolygon + Another object that should be intersected with this `SkyImage`. + + Returns + ------- + intersect_poly : SphericalPolygon + A :py:class:`~spherical_geometry.polygon.SphericalPolygon` that is + the intersection of this `SkyImage` and `skyimage`. + + """ + if isinstance(skyimage, (SkyImage, SkyGroup)): + other = skyimage.polygon + else: + other = skyimage + + pts1 = np.sort(list(self._polygon.points)[0], axis=0) + pts2 = np.sort(list(other.points)[0], axis=0) + if np.allclose(pts1, pts2, rtol=0, atol=1e-8): + intersect_poly = self._polygon.copy() + else: + intersect_poly = self._polygon.intersection(other) + return intersect_poly + + def _update_bounding_polygon(self): + polygons = [im.polygon for im in self._images] + if len(polygons) == 0: + self._polygon = SphericalPolygon([]) + self._radec = [] + else: + self._polygon = SphericalPolygon.multi_union(polygons) + self._radec = list(self._polygon.to_radec()) + + def __len__(self): + return len(self._images) + + def __getitem__(self, idx): + return self._images[idx] + + def __setitem__(self, idx, value): + if not isinstance(value, SkyImage): + raise TypeError("Item must be of 'SkyImage' type") + value.sky += self._sky + self._images[idx] = value + self._update_bounding_polygon() + + def __delitem__(self, idx): + del self._images[idx] + if len(self._images) == 0: + self._sky = 0.0 + self._id = None + self._update_bounding_polygon() + + def __iter__(self): + yield from self._images + + def insert(self, idx, value): + """Inserts a `SkyImage` into the group.""" + if not isinstance(value, SkyImage): + raise TypeError("Item must be of 'SkyImage' type") + value.sky += self._sky + self._images.insert(idx, value) + self._update_bounding_polygon() + + def append(self, value): + """Appends a `SkyImage` to the group.""" + if not isinstance(value, SkyImage): + raise TypeError("Item must be of 'SkyImage' type") + value.sky += self._sky + self._images.append(value) + self._update_bounding_polygon() + + def calc_sky(self, overlap=None, delta=True): + """ + Compute sky background value. + + Parameters + ---------- + overlap : SkyImage, SkyGroup, SphericalPolygon, list of tuples, \ +None, optional + Another `SkyImage`, `SkyGroup`, + :py:class:`spherical_geometry.polygons.SphericalPolygon`, or + a list of tuples of (RA, DEC) of vertices of a spherical + polygon. This parameter is used to indicate that sky statistics + should computed only in the region of intersection of *this* + image with the polygon indicated by `overlap`. When `overlap` is + `None`, sky statistics will be computed over the entire image. + + delta : bool, optional + Should this function return absolute sky value or the difference + between the computed value and the value of the sky stored in the + `sky` property. + + Returns + ------- + skyval : float, None + Computed sky value (absolute or relative to the `sky` attribute). + If there are no valid data to perform this computations (e.g., + because this image does not overlap with the image indicated by + `overlap`), `skyval` will be set to `None`. + + npix : int + Number of pixels used to compute sky statistics. + + polyarea : float + Area (in srad) of the polygon that bounds data used to compute + sky statistics. + + """ + + if len(self._images) == 0: + return None, 0, 0.0 + + wght = 0 + area = 0.0 + + if overlap is None: + # compute minimum sky across all images in the group: + wsky = None + + for image in self._images: + # make sure all images have the same background: + image.background = self._sky + + sky, npix, imarea = image.calc_sky(overlap=None, delta=delta) + + if sky is None: + continue + + if wsky is None or wsky > sky: + wsky = sky + wght = npix + area = imarea + + return wsky, wght, area + + # compute weighted sky in various overlaps: + wsky = 0.0 + + for image in self._images: + # make sure all images have the same background: + image.background = self._sky + + sky, npix, area1 = image.calc_sky(overlap=overlap, delta=delta) + + area += area1 + + if sky is not None and npix > 0: + pix_area = npix * image.pix_area + wsky += sky * pix_area + wght += pix_area + + if wght == 0.0 or area == 0.0: + return None, wght, area + else: + return wsky / wght, wght, area diff --git a/src/stcal/skymatch/skymatch.py b/src/stcal/skymatch/skymatch.py new file mode 100644 index 00000000..d428fdaa --- /dev/null +++ b/src/stcal/skymatch/skymatch.py @@ -0,0 +1,539 @@ +""" +A module that provides functions for matching sky in overlapping images. +""" + +import logging +from datetime import datetime + +import numpy as np + +from .skyimage import SkyGroup, SkyImage + +__all__ = ["match"] + +__local_debug__ = True + +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + + +def match(images, skymethod="global+match", match_down=True, subtract=False): + """ + A function to compute and/or "equalize" sky background in input images. + + .. note:: + Sky matching ("equalization") is possible only for **overlapping** + images. + + Parameters + ---------- + images : list of SkyImage or SkyGroup + A list of :py:class:`~stcal.skymatch.skyimage.SkyImage` or + :py:class:`~stcal.skymatch.skyimage.SkyGroup` objects. + + skymethod : {'local', 'global+match', 'global', 'match'}, optional + Select the algorithm for sky computation: + + * **'local'** : compute sky background values of each input image or + group of images (members of the same "exposure"). A single sky value + is computed for each group of images. + + .. note:: + This setting is recommended when regions of overlap between images + are dominated by "pure" sky (as opposed to extended, diffuse + sources). + + * **'global'** : compute a common sky value for all input images and + groups of images. With this setting `local` will compute + sky values for each input image/group, find the minimum sky value, + and then it will set (and/or subtract) the sky value of each input image + to this minimum value. This method *may* be + useful when the input images have been already matched. + + * **'match'** : compute differences in sky values between images + and/or groups in (pair-wise) common sky regions. In this case + the computed sky values will be relative (delta) to the sky computed + in one of the input images whose sky value will be set to + (reported to be) 0. This setting will "equalize" sky values between + the images in large mosaics. However, this method is not recommended + when used in conjunction with + `astrodrizzle + `_ + because it computes relative sky values while `astrodrizzle` needs + "absolute" sky values for median image generation and CR rejection. + + * **'global+match'** : first use the **'match'** method to + equalize sky values between images and then find a minimum + "global" sky value amongst all input images. + + .. note:: + This is the *recommended* setting for images + containing diffuse sources (e.g., galaxies, nebulae) + covering significant parts of the image. + + match_down : bool, optional + Specifies whether the sky *differences* should be subtracted from + images with higher sky values (`match_down` = `True`) to match the + image with the lowest sky or sky differences should be added to the + images with lower sky values to match the sky of the image with the + highest sky value (`match_down` = `False`). + + .. note:: + This setting applies *only* when the `skymethod` parameter is + either `'match'` or `'global+match'`. + + subtract : bool (Default = False) + Subtract computed sky value from image data. + + + Raises + ------ + + TypeError + The `images` argument must be a Python list of + :py:class:`~stcal.skymatch.skyimage.SkyImage` and/or + :py:class:`~stcal.skymatch.skyimage.SkyGroup` objects. + + + Notes + ----- + + :py:func:`match` provides new algorithms for sky value computations + and enhances previously available algorithms used by, e.g., + `astrodrizzle + `_. + + Two new methods of sky subtraction have been introduced (compared to the + standard ``'local'``): ``'global'`` and ``'match'``, as well as a + combination of the two -- ``'global+match'``. + + - The ``'global'`` method computes the minimum sky value across *all* + input images and/or groups. That sky value is then considered to be + the background in all input images. + + - The ``'match'`` algorithm is somewhat similar to the traditional sky + subtraction method (`skymethod` = `'local'`) in the sense that it + measures the sky independently in input images (or groups). The major + differences are that, unlike the traditional method, + + #. ``'match'`` algorithm computes *relative* (delta) sky values with + regard to the sky in a reference image chosen from the input list + of images; *and* + + #. Sky statistics are computed only in the part of the image + that intersects other images. + + This makes the ``'match'`` sky computation algorithm particularly useful + for "equalizing" sky values in large mosaics in which one may have + only (at least) pair-wise intersection of images without having + a common intersection region (on the sky) in all images. + + The `'match'` method works in the following way: for each pair + of intersecting images, an equation is written that + requires that average surface brightness in the overlapping part of + the sky be equal in both images. The final system of equations is then + solved for unknown background levels. + + .. warning:: + + The current algorithm is not capable of detecting cases where some subsets + of intersecting images (from the input list of images) do not intersect + at all with other subsets of intersecting images (except for the simple + case when *single* images do not intersect any other images). In these + cases the algorithm will find equalizing sky values for each + intersecting subset of images and/or groups of images. + However since these subsets of images do not intersect each other, + sky will be matched only within each subset and the "inter-subset" + sky mismatch could be significant. + + Users are responsible for detecting such cases and adjusting processing + accordingly. + + - The ``'global+match'`` algorithm combines the ``'match'`` and + ``'global'`` methods in order to overcome the limitation of the + ``'match'`` method described in the note above: it uses the ``'global'`` + algorithm to find a baseline sky value common to all input images + and the ``'match'`` algorithm to "equalize" sky values in the mosaic. + Thus, the sky value of the "reference" image will be equal to the + baseline sky value (instead of 0 in ``'match'`` algorithm alone). + + **Remarks:** + * :py:func:`match` works directly on *geometrically distorted* + flat-fielded images thus avoiding the need to perform distortion + correction on the input images. + + Initially, the footprint of a chip in an image is approximated by a + 2D planar rectangle representing the borders of chip's distorted + image. After applying distortion model to this rectangle and + projecting it onto the celestial sphere, it is approximated by + spherical polygons. Footprints of exposures and mosaics are + computed as unions of such spherical polygons while overlaps + of image pairs are found by intersecting these spherical polygons. + + **Limitations and Discussions:** + Primary reason for introducing "sky match" algorithm was to try to + equalize the sky in large mosaics in which computation of the + "absolute" sky is difficult due to the presence of large diffuse + sources in the image. As discussed above, :py:func:`match` + accomplishes this by comparing "sky values" in a pair of images in the + overlap region (that is common to both images). Quite obviously the + quality of sky "matching" will depend on how well these "sky values" + can be estimated. We use quotation marks around *sky values* because + for some image "true" background may not be present at all and the + measured sky may be the surface brightness of large galaxy, nebula, etc. + + In the discussion below we will refer to parameter names in + :py:class:`~stcal.skymatch.skystatistics.SkyStats` and these + parameter names may differ from the parameters of the actual `skystat` + object passed to initializer of the + :py:class:`~stcal.skymatch.skyimage.SkyImage`. + + Here is a brief list of possible limitations/factors that can affect + the outcome of the matching (sky subtraction in general) algorithm: + + * Since sky subtraction is performed on *flat-fielded* but + *not distortion corrected* images, it is important to keep in mind + that flat-fielding is performed to obtain uniform surface brightness + and not flux. This distinction is important for images that have + not been distortion corrected. As a consequence, it is advisable that + point-like sources be masked through the user-supplied mask files. + Values different from zero in user-supplied masks indicate "good" data + pixels. Alternatively, one can use `upper` parameter to limit the use + of bright objects in sky computations. + + * Normally, distorted flat-fielded images contain cosmic rays. This + algorithm does not perform CR cleaning. A possible way of minimizing + the effect of the cosmic rays on sky computations is to use + clipping (`nclip` > 0) and/or set `upper` parameter to a value + larger than most of the sky background (or extended source) but + lower than the values of most CR pixels. + + * In general, clipping is a good way of eliminating "bad" pixels: + pixels affected by CR, hot/dead pixels, etc. However, for + images with complicated backgrounds (extended galaxies, nebulae, + etc.), affected by CR and noise, clipping process may mask different + pixels in different images. If variations in the background are + too strong, clipping may converge to different sky values in + different images even when factoring in the "true" difference + in the sky background between the two images. + + * In general images can have different "true" background values + (we could measure it if images were not affected by large diffuse + sources). However, arguments such as `lower` and `upper` will + apply to all images regardless of the intrinsic differences + in sky levels. + + """ + function_name = match.__name__ + + # Time it + runtime_begin = datetime.now() + + log.info(" ") + log.info(f"***** {__name__:s}.{function_name:s}() started on {runtime_begin}") + log.info(" ") + + # check sky method: + skymethod = skymethod.lower() + if skymethod not in ["local", "global", "match", "global+match"]: + raise ValueError( + "Unsupported 'skymethod'. Valid values are: " + "'local', 'global', 'match', or 'global+match'" + ) + do_match = "match" in skymethod + do_global = "global" in skymethod + show_old = subtract + + log.info(f"Sky computation method: '{skymethod}'") + if do_match: + log.info("Sky matching direction: {:s}".format("DOWN" if match_down else "UP")) + + log.info( + "Sky subtraction from image data: {:s}".format("ON" if subtract else "OFF") + ) + + # check that input file name is a list of either SkyImage or SkyGroup: + nimages = 0 + for img in images: + if isinstance(img, SkyImage): + nimages += 1 + elif isinstance(img, SkyGroup): + nimages += len(img) + else: + raise TypeError( + "Each element of the 'images' must be either a " + "'SkyImage' or a 'SkyGroup'" + ) + + if nimages == 0: + raise ValueError("Argument 'images' must contain at least one image") + + log.debug( + "Total number of images to be sky-subtracted and/or matched: {:d}".format( + nimages + ) + ) + + # Print conversion factors + log.debug(" ") + log.debug("---- Image data conversion factors:") + + for img in images: + img_type = "Image" if isinstance(img, SkyImage) else "Group" + + if img_type == "Group": + log.debug(f" * Group ID={img.id}. Conversion factors:") + for im in img: + log.debug( + " - Image ID={}. Conversion factor = {:G}".format( + im.id, im.convf + ) + ) + else: + log.debug(f" * Image ID={img.id}. Conversion factor = {img.convf:G}") + + # 1. Method: "match" (or "global+match"). + # Find sky "deltas" that will match sky across all + # (intersecting) images. + if do_match: + log.info(" ") + log.info("---- Computing differences in sky values in " "overlapping regions.") + + # find "optimum" sky changes: + sky_deltas = _find_optimum_sky_deltas(images, apply_sky=not subtract) + sky_good = np.isfinite(sky_deltas) + + if np.any(sky_good): + # match sky "Up" or "Down": + if match_down: + refsky = np.amin(sky_deltas[sky_good]) + else: + refsky = np.amax(sky_deltas[sky_good]) + sky_deltas[sky_good] -= refsky + + # convert to Python list and replace numpy.nan with None + sky_deltas = [skd if np.isfinite(skd) else None for skd in sky_deltas] + + _apply_sky(images, sky_deltas, False, subtract, show_old) + show_old = True + + # 2. Method: "local". Compute the minimum sky background + # value in each sky group/image. + # This is an improved (use of masks) replacement + # for the classical 'subtract' used by astrodrizzle. + # + # NOTE: incompatible with "match"-containing + # 'skymethod' modes. + # + # 3. Method: "global". Compute the minimum sky background + # value *across* *all* sky line members. + if do_global or not do_match: + log.info(" ") + if do_global: + minsky = None + log.info( + '---- Computing "global" sky - smallest sky value ' + "across *all* input images." + ) + else: + log.info("---- Sky values computed per image and/or image " "groups.") + + sky_deltas = [] + for img in images: + sky = img.calc_sky(delta=not subtract)[0] + sky_deltas.append(sky) + if do_global and (minsky is None or sky < minsky): + minsky = sky + + if do_global: + log.info(" ") + if minsky is None: + log.warning(' Unable to compute "global" sky value') + sky_deltas = len(sky_deltas) * [minsky] + log.info( + ' "Global" sky value correction: {} ' "[not converted]".format(minsky) + ) + + if do_match: + log.info(" ") + log.info("---- Final (match+global) sky for:") + + _apply_sky(images, sky_deltas, do_global, subtract, show_old) + + # log running time: + runtime_end = datetime.now() + log.info(" ") + log.info(f"***** {__name__:s}.{function_name:s}() ended on {runtime_end}") + log.info( + "***** {:s}.{:s}() TOTAL RUN TIME: {}".format( + __name__, function_name, runtime_end - runtime_begin + ) + ) + log.info(" ") + + +def _apply_sky(images, sky_deltas, do_global, do_skysub, show_old): + for img, sky in zip(images, sky_deltas): + is_group = not isinstance(img, SkyImage) + + if do_global: + if sky is None: + valid = img[0].is_sky_valid if is_group else img.is_sky_valid + sky = 0.0 + else: + valid = True + + else: + valid = sky is not None + if not valid: + log.warning( + " * {:s} ID={}: Unable to compute sky value".format( + "Group" if is_group else "Image", img.id + ) + ) + sky = 0.0 + + if is_group: + # apply sky change: + old_img_sky = [im.sky for im in img] + if do_skysub: + for im in img: + im._image.set_data(im._image.get_data() - sky) + img.sky += sky + new_img_sky = [im.sky for im in img] + + # log sky values: + log.info( + " * Group ID={}. Sky background of " + "component images:".format(img.id) + ) + + for im, old_sky, new_sky in zip(img, old_img_sky, new_img_sky): + c = 1.0 / im.convf + if show_old: + log.info( + " - Image ID={}. Sky background: {:G} " + "(old={:G}, delta={:G})".format( + im.id, c * new_sky, c * old_sky, c * sky + ) + ) + else: + log.info( + " - Image ID={}. Sky background: {:G}".format( + im.id, c * new_sky + ) + ) + + im.is_sky_valid = valid + + else: + # apply sky change: + old_sky = img.sky if img.sky is not None else 0 + if do_skysub: + img._image.set_data(img._image.get_data() - sky) + if img.sky is None: + img.sky = 0 + + img.sky += sky + new_sky = img.sky + + # log sky values: + c = 1.0 / img.convf + if show_old: + log.info( + " * Image ID={}. Sky background: {:G} " + "(old={:G}, delta={:G})".format( + img.id, c * new_sky, c * old_sky, c * sky + ) + ) + else: + log.info( + " * Image ID={}. Sky background: {:G}".format( + img.id, c * new_sky + ) + ) + + img.is_sky_valid = valid + + +def _overlap_matrix(images, apply_sky=True): + # TODO: to improve performance, the nested loops could be parallelized + # since _calc_sky() here can be called independently from previous steps. + ns = len(images) + A = np.zeros((ns, ns), dtype=float) + W = np.zeros((ns, ns), dtype=float) + for i in range(ns): + for j in range(i + 1, ns): + s1, w1, area1 = images[i].calc_sky(overlap=images[j], delta=apply_sky) + + s2, w2, area2 = images[j].calc_sky(overlap=images[i], delta=apply_sky) + if area1 == 0.0 or area2 == 0.0 or s1 is None or s2 is None: + continue + + A[j, i] = s1 + W[j, i] = w1 + A[i, j] = s2 + W[i, j] = w2 + + return A, W + + +def _find_optimum_sky_deltas(images, apply_sky=True): + ns = len(images) + A, W = _overlap_matrix(images, apply_sky=apply_sky) + + def is_valid(i, j): + return W[i, j] > 0 and W[j, i] > 0 + + # We need to know how many "non-trivial" (at least for now... - we will + # compute rank later) equations can be built so that we know the + # shape of the arrays that need to be created... + # NOTE: for now use only pairs that *both* have weights > 0 (but a + # different scenario when only one image has a valid weight can be + # considered): + neq = 0 + for i in range(ns): + for j in range(i + 1, ns): + if is_valid(i, j): + neq += 1 + + # average weights: + Wm = 0.5 * (W + W.T) + + # create arrays for coefficients and free terms: + K = np.zeros((neq, ns), dtype=float) + F = np.zeros(neq, dtype=float) + invalid = (ns) * [True] + + # now process intersections between the rest of the images: + ieq = 0 + for i in range(0, ns): + for j in range(i + 1, ns): + if is_valid(i, j): + K[ieq, i] = Wm[i, j] + K[ieq, j] = -Wm[i, j] + F[ieq] = Wm[i, j] * (A[j, i] - A[i, j]) + invalid[i] = False + invalid[j] = False + ieq += 1 + + try: + rank = np.linalg.matrix_rank(K, 1.0e-12) + except np.linalg.LinAlgError: + log.warning("Unable to compute sky: No valid data in common " "image areas") + deltas = np.full(ns, np.nan, dtype=float) + return deltas + + if rank < ns - 1: + log.warning(f"There are more unknown sky values ({ns}) to be solved for") + log.warning( + "than there are independent equations available " + "(matrix rank={}).".format(rank) + ) + log.warning("Sky matching (delta) values will be computed only for") + log.warning("a subset (or more independent subsets) of input images.") + invK = np.linalg.pinv(K, rcond=1.0e-12) + + deltas = np.dot(invK, F) + deltas[np.asarray(invalid, dtype=bool)] = np.nan + return deltas diff --git a/src/stcal/skymatch/skystatistics.py b/src/stcal/skymatch/skystatistics.py new file mode 100644 index 00000000..ffecd43b --- /dev/null +++ b/src/stcal/skymatch/skystatistics.py @@ -0,0 +1,136 @@ +""" +The `skystatistics` module provides statistics computation class used by +:py:func:`~stcal.skymatch.skymatch.match` +and :py:class:`~stcal.skymatch.skyimage.SkyImage`. +""" + +from copy import deepcopy + +# THIRD PARTY +from stsci.imagestats import ImageStats # type: ignore[import-untyped] + +__all__ = ["SkyStats"] + + +class SkyStats: + """ + This is a superclass build on top of + :py:class:`stsci.imagestats.ImageStats`. Compared to + :py:class:`stsci.imagestats.ImageStats`, `SkyStats` has + "persistent settings" in the sense that object's parameters need to be + set once and these settings will be applied to all subsequent + computations on different data. + """ + + def __init__( + self, + skystat="mean", + lower=None, + upper=None, + nclip=5, + lsig=4.0, + usig=4.0, + binwidth=0.1, + **kwargs, + ): + """Initializes the SkyStats object. + + Parameters + ----------- + skystat : {'mode', 'median', 'mode', 'midpt'}, optional + Sets the statistics that will be returned by `~SkyStats.calc_sky`. + The following statistics are supported: 'mean', 'mode', 'midpt', + and 'median'. First three statistics have the same meaning as in + `stsdas.toolbox.imgtools.gstatistics `_ + while 'median' will compute the median of the distribution. + + lower : float, None, optional + Lower limit of usable pixel values for computing the sky. + This value should be specified in the units of the input image(s). + + upper : float, None, optional + Upper limit of usable pixel values for computing the sky. + This value should be specified in the units of the input image(s). + + nclip : int, optional + A non-negative number of clipping iterations to use when computing + the sky value. + + lsig : float, optional + Lower clipping limit, in sigma, used when computing the sky value. + + usig : float, optional + Upper clipping limit, in sigma, used when computing the sky value. + + binwidth : float, optional + Bin width, in sigma, used to sample the distribution of pixel + brightness values in order to compute the sky background + statistics. + + kwargs : dict + A dictionary of optional arguments to be passed to `ImageStats`. + + """ + self.npix = None + self.skyval = None + + self._fields = f"npix,{skystat}" + + self._kwargs = deepcopy(kwargs) + if "fields" in self._kwargs: + del self._kwargs["fields"] + if "image" in self._kwargs: + del self._kwargs["image"] + self._kwargs["lower"] = lower + self._kwargs["upper"] = upper + self._kwargs["nclip"] = nclip + self._kwargs["lsig"] = lsig + self._kwargs["usig"] = usig + self._kwargs["binwidth"] = binwidth + + self._skystat = { + "mean": self._extract_mean, + "mode": self._extract_mode, + "median": self._extract_median, + "midpt": self._extract_midpt, + }[skystat] + + def _extract_mean(self, imstat): + return imstat.mean + + def _extract_median(self, imstat): + return imstat.median + + def _extract_mode(self, imstat): + return imstat.mode + + def _extract_midpt(self, imstat): + return imstat.midpt + + def calc_sky(self, data): + """Computes statistics on data. + + Parameters + ----------- + data : numpy.ndarray + A numpy array of values for which the statistics needs to be computed. + + Returns + -------- + statistics : tuple + A tuple of two values: (`skyvalue`, `npix`), where `skyvalue` + is the statistics specified by the `skystat` parameter during + the initialization of the `SkyStats` object and `npix` is the + number of pixels used in computing the statistics reported + in `skyvalue`. + + """ + imstat = ImageStats(image=data, fields=self._fields, **(self._kwargs)) + self.skyval = self._skystat(imstat) # dict or scalar + + self.npix = imstat.npix + return self.skyval, self.npix + + def __call__(self, data): + return self.calc_sky(data)