Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move outlier detection utility functions from jwst to stcal #270

Merged
merged 21 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ General
Changes to API
--------------

-
- Add ``outlier_detection`` submodule with ``utils`` included
from jwst. [#270]

Bug Fixes
---------
Expand Down
4 changes: 4 additions & 0 deletions docs/stcal/outlier_detection/description.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Description
============

This sub-package contains functions useful for outlier detection.
12 changes: 12 additions & 0 deletions docs/stcal/outlier_detection/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.. _outlier_detection:

=======================
Outlier Detection Utils
=======================

.. toctree::
:maxdepth: 2

description.rst

.. automodapi:: stcal.outlier_detection.utils
1 change: 1 addition & 0 deletions docs/stcal/package_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ Package Index
ramp_fitting/index.rst
alignment/index.rst
tweakreg/index.rst
outlier_detection/index.rst
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ classifiers = [
]
dependencies = [
"astropy >=5.0.4",
"drizzle>=1.15.0",
"scipy >=1.7.2",
"scikit-image>=0.19",
"numpy >=1.21.2",
"opencv-python-headless >=4.6.0.66",
"asdf >=2.15.0",
Expand Down Expand Up @@ -209,6 +211,7 @@ module = [
"stdatamodels.*",
"asdf.*",
"scipy.*",
"drizzle.*",
# don't complain about the installed c parts of this library
"stcal.ramp_fitting.ols_cas22._fit",
"stcal.ramp_fitting.ols_cas22._jump",
Expand Down
Empty file.
339 changes: 339 additions & 0 deletions src/stcal/outlier_detection/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,339 @@
"""
Utility functions for outlier detection routines
"""
import warnings

import numpy as np
from astropy.stats import sigma_clip
from drizzle.cdrizzle import tblot
from scipy import ndimage
from skimage.util import view_as_windows
import gwcs

from stcal.alignment.util import wcs_bbox_from_shape

import logging
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)


__all__ = [
"medfilt",
"compute_weight_threshold",
"flag_crs",
"flag_resampled_crs",
"gwcs_blot",
"calc_gwcs_pixmap",
"reproject",
]


def medfilt(arr, kern_size):
"""
scipy.signal.medfilt (and many other median filters) have undefined behavior
for nan inputs. See: https://github.com/scipy/scipy/issues/4800

Parameters
----------
arr : numpy.ndarray
The input array

kern_size : list of int
List of kernel dimensions, length must be equal to arr.ndim.

Returns
-------
filtered_arr : numpy.ndarray
Input array median filtered with a kernel of size kern_size
"""
padded = np.pad(arr, [[k // 2] for k in kern_size])
windows = view_as_windows(padded, kern_size, np.ones(len(kern_size), dtype='int'))
return np.nanmedian(windows, axis=np.arange(-len(kern_size), 0))


def compute_weight_threshold(weight, maskpt):
'''
Compute the weight threshold for a single image or cube.

Parameters
----------
weight : numpy.ndarray
The weight array

maskpt : float
The percentage of the mean weight to use as a threshold for masking.

Returns
-------
float
The weight threshold for this integration.
'''
# necessary in order to assure that mask gets applied correctly
if hasattr(weight, '_mask'):
del weight._mask

Check warning on line 73 in src/stcal/outlier_detection/utils.py

View check run for this annotation

Codecov / codecov/patch

src/stcal/outlier_detection/utils.py#L73

Added line #L73 was not covered by tests
mask_zero_weight = np.equal(weight, 0.)
mask_nans = np.isnan(weight)
# Combine the masks
weight_masked = np.ma.array(weight, mask=np.logical_or(
mask_zero_weight, mask_nans))
# Sigma-clip the unmasked data
weight_masked = sigma_clip(weight_masked, sigma=3, maxiters=5)
mean_weight = np.mean(weight_masked)
# Mask pixels where weight falls below maskpt percent
weight_threshold = mean_weight * maskpt
return weight_threshold


def _abs_deriv(array):
"""
Do not use this function.

Take the absolute derivative of a numpy array.

This function assumes off-edge pixel values are 0
and leads to erroneous derivative values and should
likely not be used.
"""
tmp = np.zeros(array.shape, dtype=np.float64)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing docstring.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added short docstrings in 2112773

This function (and _absolute_subtract) should likely not be used since they appear to introduce erroneous cr flags due to treating off-edge pixels as 0. However, fixing this bug would result in changes to regression test results and I think makes more sense in a separate PR (where the new results can be verified). There is a JP ticket for this spacetelescope/jwst#8636 and candidate new code here https://github.com/spacetelescope/jwst/pull/8635/files

For the scope of this PR I added "Do not use this function" to the docstrings as moving this code to stcal did not introduce the bug.

out = np.zeros(array.shape, dtype=np.float64)

tmp[1:, :] = array[:-1, :]
tmp, out = _absolute_subtract(array, tmp, out)
tmp[:-1, :] = array[1:, :]
tmp, out = _absolute_subtract(array, tmp, out)

tmp[:, 1:] = array[:, :-1]
tmp, out = _absolute_subtract(array, tmp, out)
tmp[:, :-1] = array[:, 1:]
tmp, out = _absolute_subtract(array, tmp, out)

return out


def _absolute_subtract(array, tmp, out):
"""
Do not use this function.

A helper function for _abs_deriv.
"""
tmp = np.abs(array - tmp)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing docstring.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a brief docstring to this function (see #270 (comment) for more details).

I attempted to add more but given the number of lines of code and the array "switch-a-roo" that this function is performing I think reading the code is the best bet for trying to understand what's happening.

Also this function will be removed if the candidate code for fixing spacetelescope/jwst#8636 is acceptable.

out = np.maximum(tmp, out)
tmp = tmp * 0.
return tmp, out


def flag_crs(
sci_data,
sci_err,
blot_data,
snr,
):
"""
Straightforward detection of outliers for non-dithered data since
sci_err includes all noise sources (photon, read, and flat for baseline).

Parameters
----------
sci_data : numpy.ndarray
"Science" data possibly containing outliers.

sci_err : numpy.ndarray
Error estimates for sci_data.

blot_data : numpy.ndarray
Reference data used to detect outliers.

snr : float
Signal-to-noise ratio used during detection.

Returns
-------
cr_mask : numpy.ndarray
Boolean array where outliers (CRs) are true.
"""
return np.greater(np.abs(sci_data - blot_data), snr * np.nan_to_num(sci_err))


def flag_resampled_crs(
sci_data,
sci_err,
blot_data,
snr1,
snr2,
scale1,
scale2,
backg,
):
"""
Detect outliers (CRs) using resampled reference data.

Parameters
----------

sci_data : numpy.ndarray
"Science" data possibly containing outliers

sci_err : numpy.ndarray
Error estimates for sci_data

blot_data : numpy.ndarray
Reference data used to detect outliers.

snr1 : float
Signal-to-noise ratio threshold used prior to smoothing.

snr2 : float
Signal-to-noise ratio threshold used after smoothing.

scale1 : float
Scale used prior to smoothing.

scale2 : float
Scale used after smoothing.

backg : float
Scalar background to subtract from the difference.

Returns
-------
cr_mask : numpy.ndarray
boolean array where outliers (CRs) are true
"""
err_data = np.nan_to_num(sci_err)

blot_deriv = _abs_deriv(blot_data)
diff_noise = np.abs(sci_data - blot_data - backg)

# Create a boolean mask based on a scaled version of
# the derivative image (dealing with interpolating issues?)
# and the standard n*sigma above the noise
threshold1 = scale1 * blot_deriv + snr1 * err_data
mask1 = np.greater(diff_noise, threshold1)

# Smooth the boolean mask with a 3x3 boxcar kernel
kernel = np.ones((3, 3), dtype=int)
mask1_smoothed = ndimage.convolve(mask1, kernel, mode='nearest')

# Create a 2nd boolean mask based on the 2nd set of
# scale and threshold values
threshold2 = scale2 * blot_deriv + snr2 * err_data
mask2 = np.greater(diff_noise, threshold2)

# Final boolean mask
return mask1_smoothed & mask2


def gwcs_blot(median_data, median_wcs, blot_shape, blot_wcs, pix_ratio):
"""
Resample the median data to recreate an input image based on
the blot wcs.

Parameters
----------
median_data : numpy.ndarray
The data to blot.

median_wcs : gwcs.wcs.WCS
The wcs for the median data.

blot_shape : list of int
The target blot data shape.

blot_wcs : gwcs.wcs.WCS
The target/blotted wcs.

pix_ratio : float
Pixel ratio.

Returns
-------
blotted : numpy.ndarray
The blotted median data.

blot_img : datamodel
Datamodel containing header and WCS to define the 'blotted' image
"""
# Compute the mapping between the input and output pixel coordinates
pixmap = calc_gwcs_pixmap(blot_wcs, median_wcs, blot_shape)
log.debug("Pixmap shape: {}".format(pixmap[:, :, 0].shape))
log.debug("Sci shape: {}".format(blot_shape))
log.info('Blotting {} <-- {}'.format(blot_shape, median_data.shape))

outsci = np.zeros(blot_shape, dtype=np.float32)

# Currently tblot cannot handle nans in the pixmap, so we need to give some
# other value. -1 is not optimal and may have side effects. But this is
# what we've been doing up until now, so more investigation is needed
# before a change is made. Preferably, fix tblot in drizzle.
pixmap[np.isnan(pixmap)] = -1
tblot(median_data, pixmap, outsci, scale=pix_ratio, kscale=1.0,
interp='linear', exptime=1.0, misval=0.0, sinscl=1.0)

return outsci


def calc_gwcs_pixmap(in_wcs, out_wcs, in_shape):
emolter marked this conversation as resolved.
Show resolved Hide resolved
"""
Return a pixel grid map from input frame to output frame.

Parameters
----------
in_wcs : gwcs.wcs.WCS
Input/source wcs.

out_wcs : gwcs.wcs.WCS
Output/projected wcs.

in_shape : list of int
Input shape used to compute the input bounding box.

Returns
-------
pixmap : numpy.ndarray
Computed pixmap.
"""
bb = wcs_bbox_from_shape(in_shape)
log.debug("Bounding box from data shape: {}".format(bb))

grid = gwcs.wcstools.grid_from_bounding_box(bb)
return np.dstack(reproject(in_wcs, out_wcs)(grid[0], grid[1]))


def reproject(wcs1, wcs2):
braingram marked this conversation as resolved.
Show resolved Hide resolved
"""
Given two WCSs return a function which takes pixel
coordinates in wcs1 and computes them in wcs2.

It performs the forward transformation of ``wcs1`` followed by the
inverse of ``wcs2``.

Parameters
----------
wcs1, wcs2 : gwcs.wcs.WCS
WCS objects that have `pixel_to_world_values` and `world_to_pixel_values`
methods.

Returns
-------
_reproject :
Function to compute the transformations. It takes x, y
positions in ``wcs1`` and returns x, y positions in ``wcs2``.
"""

try:
forward_transform = wcs1.pixel_to_world_values
backward_transform = wcs2.world_to_pixel_values
except AttributeError as err:
raise TypeError("Input should be a WCS") from err

Check warning on line 327 in src/stcal/outlier_detection/utils.py

View check run for this annotation

Codecov / codecov/patch

src/stcal/outlier_detection/utils.py#L326-L327

Added lines #L326 - L327 were not covered by tests

def _reproject(x, y):
sky = forward_transform(x, y)
flat_sky = []
for axis in sky:
flat_sky.append(axis.flatten())
det = backward_transform(*tuple(flat_sky))
det_reshaped = []
for axis in det:
det_reshaped.append(axis.reshape(x.shape))
return tuple(det_reshaped)
return _reproject
Empty file.
Loading
Loading