-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move outlier detection utility functions from jwst to stcal #270
Changes from all commits
c284f74
ac93830
741edb3
546fd07
17631e5
d9007b6
a957d99
d7936bf
a57d618
6080156
252ee7c
63b84f0
fd738a8
fc59880
f3d34c2
4ebbeec
20207b1
61aeee8
573b5d9
aa64f1d
2112773
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
Description | ||
============ | ||
|
||
This sub-package contains functions useful for outlier detection. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
.. _outlier_detection: | ||
|
||
======================= | ||
Outlier Detection Utils | ||
======================= | ||
|
||
.. toctree:: | ||
:maxdepth: 2 | ||
|
||
description.rst | ||
|
||
.. automodapi:: stcal.outlier_detection.utils |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,4 @@ Package Index | |
ramp_fitting/index.rst | ||
alignment/index.rst | ||
tweakreg/index.rst | ||
outlier_detection/index.rst |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,339 @@ | ||
""" | ||
Utility functions for outlier detection routines | ||
""" | ||
import warnings | ||
|
||
import numpy as np | ||
from astropy.stats import sigma_clip | ||
from drizzle.cdrizzle import tblot | ||
from scipy import ndimage | ||
from skimage.util import view_as_windows | ||
import gwcs | ||
|
||
from stcal.alignment.util import wcs_bbox_from_shape | ||
|
||
import logging | ||
log = logging.getLogger(__name__) | ||
log.setLevel(logging.DEBUG) | ||
|
||
|
||
__all__ = [ | ||
"medfilt", | ||
"compute_weight_threshold", | ||
"flag_crs", | ||
"flag_resampled_crs", | ||
"gwcs_blot", | ||
"calc_gwcs_pixmap", | ||
"reproject", | ||
] | ||
|
||
|
||
def medfilt(arr, kern_size): | ||
""" | ||
scipy.signal.medfilt (and many other median filters) have undefined behavior | ||
for nan inputs. See: https://github.com/scipy/scipy/issues/4800 | ||
|
||
Parameters | ||
---------- | ||
arr : numpy.ndarray | ||
The input array | ||
|
||
kern_size : list of int | ||
List of kernel dimensions, length must be equal to arr.ndim. | ||
|
||
Returns | ||
------- | ||
filtered_arr : numpy.ndarray | ||
Input array median filtered with a kernel of size kern_size | ||
""" | ||
padded = np.pad(arr, [[k // 2] for k in kern_size]) | ||
windows = view_as_windows(padded, kern_size, np.ones(len(kern_size), dtype='int')) | ||
return np.nanmedian(windows, axis=np.arange(-len(kern_size), 0)) | ||
|
||
|
||
def compute_weight_threshold(weight, maskpt): | ||
''' | ||
Compute the weight threshold for a single image or cube. | ||
|
||
Parameters | ||
---------- | ||
weight : numpy.ndarray | ||
The weight array | ||
|
||
maskpt : float | ||
The percentage of the mean weight to use as a threshold for masking. | ||
|
||
Returns | ||
------- | ||
float | ||
The weight threshold for this integration. | ||
''' | ||
# necessary in order to assure that mask gets applied correctly | ||
if hasattr(weight, '_mask'): | ||
del weight._mask | ||
mask_zero_weight = np.equal(weight, 0.) | ||
mask_nans = np.isnan(weight) | ||
# Combine the masks | ||
weight_masked = np.ma.array(weight, mask=np.logical_or( | ||
mask_zero_weight, mask_nans)) | ||
# Sigma-clip the unmasked data | ||
weight_masked = sigma_clip(weight_masked, sigma=3, maxiters=5) | ||
mean_weight = np.mean(weight_masked) | ||
# Mask pixels where weight falls below maskpt percent | ||
weight_threshold = mean_weight * maskpt | ||
return weight_threshold | ||
|
||
|
||
def _abs_deriv(array): | ||
""" | ||
Do not use this function. | ||
|
||
Take the absolute derivative of a numpy array. | ||
|
||
This function assumes off-edge pixel values are 0 | ||
and leads to erroneous derivative values and should | ||
likely not be used. | ||
""" | ||
tmp = np.zeros(array.shape, dtype=np.float64) | ||
out = np.zeros(array.shape, dtype=np.float64) | ||
|
||
tmp[1:, :] = array[:-1, :] | ||
tmp, out = _absolute_subtract(array, tmp, out) | ||
tmp[:-1, :] = array[1:, :] | ||
tmp, out = _absolute_subtract(array, tmp, out) | ||
|
||
tmp[:, 1:] = array[:, :-1] | ||
tmp, out = _absolute_subtract(array, tmp, out) | ||
tmp[:, :-1] = array[:, 1:] | ||
tmp, out = _absolute_subtract(array, tmp, out) | ||
|
||
return out | ||
|
||
|
||
def _absolute_subtract(array, tmp, out): | ||
""" | ||
Do not use this function. | ||
|
||
A helper function for _abs_deriv. | ||
""" | ||
tmp = np.abs(array - tmp) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing docstring. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a brief docstring to this function (see #270 (comment) for more details). I attempted to add more but given the number of lines of code and the array "switch-a-roo" that this function is performing I think reading the code is the best bet for trying to understand what's happening. Also this function will be removed if the candidate code for fixing spacetelescope/jwst#8636 is acceptable. |
||
out = np.maximum(tmp, out) | ||
tmp = tmp * 0. | ||
return tmp, out | ||
|
||
|
||
def flag_crs( | ||
sci_data, | ||
sci_err, | ||
blot_data, | ||
snr, | ||
): | ||
""" | ||
Straightforward detection of outliers for non-dithered data since | ||
sci_err includes all noise sources (photon, read, and flat for baseline). | ||
|
||
Parameters | ||
---------- | ||
sci_data : numpy.ndarray | ||
"Science" data possibly containing outliers. | ||
|
||
sci_err : numpy.ndarray | ||
Error estimates for sci_data. | ||
|
||
blot_data : numpy.ndarray | ||
Reference data used to detect outliers. | ||
|
||
snr : float | ||
Signal-to-noise ratio used during detection. | ||
|
||
Returns | ||
------- | ||
cr_mask : numpy.ndarray | ||
Boolean array where outliers (CRs) are true. | ||
""" | ||
return np.greater(np.abs(sci_data - blot_data), snr * np.nan_to_num(sci_err)) | ||
|
||
|
||
def flag_resampled_crs( | ||
sci_data, | ||
sci_err, | ||
blot_data, | ||
snr1, | ||
snr2, | ||
scale1, | ||
scale2, | ||
backg, | ||
): | ||
""" | ||
Detect outliers (CRs) using resampled reference data. | ||
|
||
Parameters | ||
---------- | ||
|
||
sci_data : numpy.ndarray | ||
"Science" data possibly containing outliers | ||
|
||
sci_err : numpy.ndarray | ||
Error estimates for sci_data | ||
|
||
blot_data : numpy.ndarray | ||
Reference data used to detect outliers. | ||
|
||
snr1 : float | ||
Signal-to-noise ratio threshold used prior to smoothing. | ||
|
||
snr2 : float | ||
Signal-to-noise ratio threshold used after smoothing. | ||
|
||
scale1 : float | ||
Scale used prior to smoothing. | ||
|
||
scale2 : float | ||
Scale used after smoothing. | ||
|
||
backg : float | ||
Scalar background to subtract from the difference. | ||
|
||
Returns | ||
------- | ||
cr_mask : numpy.ndarray | ||
boolean array where outliers (CRs) are true | ||
""" | ||
err_data = np.nan_to_num(sci_err) | ||
|
||
blot_deriv = _abs_deriv(blot_data) | ||
diff_noise = np.abs(sci_data - blot_data - backg) | ||
|
||
# Create a boolean mask based on a scaled version of | ||
# the derivative image (dealing with interpolating issues?) | ||
# and the standard n*sigma above the noise | ||
threshold1 = scale1 * blot_deriv + snr1 * err_data | ||
mask1 = np.greater(diff_noise, threshold1) | ||
|
||
# Smooth the boolean mask with a 3x3 boxcar kernel | ||
kernel = np.ones((3, 3), dtype=int) | ||
mask1_smoothed = ndimage.convolve(mask1, kernel, mode='nearest') | ||
|
||
# Create a 2nd boolean mask based on the 2nd set of | ||
# scale and threshold values | ||
threshold2 = scale2 * blot_deriv + snr2 * err_data | ||
mask2 = np.greater(diff_noise, threshold2) | ||
|
||
# Final boolean mask | ||
return mask1_smoothed & mask2 | ||
|
||
|
||
def gwcs_blot(median_data, median_wcs, blot_shape, blot_wcs, pix_ratio): | ||
""" | ||
Resample the median data to recreate an input image based on | ||
the blot wcs. | ||
|
||
Parameters | ||
---------- | ||
median_data : numpy.ndarray | ||
The data to blot. | ||
|
||
median_wcs : gwcs.wcs.WCS | ||
The wcs for the median data. | ||
|
||
blot_shape : list of int | ||
The target blot data shape. | ||
|
||
blot_wcs : gwcs.wcs.WCS | ||
The target/blotted wcs. | ||
|
||
pix_ratio : float | ||
Pixel ratio. | ||
|
||
Returns | ||
------- | ||
blotted : numpy.ndarray | ||
The blotted median data. | ||
|
||
blot_img : datamodel | ||
Datamodel containing header and WCS to define the 'blotted' image | ||
""" | ||
# Compute the mapping between the input and output pixel coordinates | ||
pixmap = calc_gwcs_pixmap(blot_wcs, median_wcs, blot_shape) | ||
log.debug("Pixmap shape: {}".format(pixmap[:, :, 0].shape)) | ||
log.debug("Sci shape: {}".format(blot_shape)) | ||
log.info('Blotting {} <-- {}'.format(blot_shape, median_data.shape)) | ||
|
||
outsci = np.zeros(blot_shape, dtype=np.float32) | ||
|
||
# Currently tblot cannot handle nans in the pixmap, so we need to give some | ||
# other value. -1 is not optimal and may have side effects. But this is | ||
# what we've been doing up until now, so more investigation is needed | ||
# before a change is made. Preferably, fix tblot in drizzle. | ||
pixmap[np.isnan(pixmap)] = -1 | ||
tblot(median_data, pixmap, outsci, scale=pix_ratio, kscale=1.0, | ||
interp='linear', exptime=1.0, misval=0.0, sinscl=1.0) | ||
|
||
return outsci | ||
|
||
|
||
def calc_gwcs_pixmap(in_wcs, out_wcs, in_shape): | ||
emolter marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Return a pixel grid map from input frame to output frame. | ||
|
||
Parameters | ||
---------- | ||
in_wcs : gwcs.wcs.WCS | ||
Input/source wcs. | ||
|
||
out_wcs : gwcs.wcs.WCS | ||
Output/projected wcs. | ||
|
||
in_shape : list of int | ||
Input shape used to compute the input bounding box. | ||
|
||
Returns | ||
------- | ||
pixmap : numpy.ndarray | ||
Computed pixmap. | ||
""" | ||
bb = wcs_bbox_from_shape(in_shape) | ||
log.debug("Bounding box from data shape: {}".format(bb)) | ||
|
||
grid = gwcs.wcstools.grid_from_bounding_box(bb) | ||
return np.dstack(reproject(in_wcs, out_wcs)(grid[0], grid[1])) | ||
|
||
|
||
def reproject(wcs1, wcs2): | ||
braingram marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Given two WCSs return a function which takes pixel | ||
coordinates in wcs1 and computes them in wcs2. | ||
|
||
It performs the forward transformation of ``wcs1`` followed by the | ||
inverse of ``wcs2``. | ||
|
||
Parameters | ||
---------- | ||
wcs1, wcs2 : gwcs.wcs.WCS | ||
WCS objects that have `pixel_to_world_values` and `world_to_pixel_values` | ||
methods. | ||
|
||
Returns | ||
------- | ||
_reproject : | ||
Function to compute the transformations. It takes x, y | ||
positions in ``wcs1`` and returns x, y positions in ``wcs2``. | ||
""" | ||
|
||
try: | ||
forward_transform = wcs1.pixel_to_world_values | ||
backward_transform = wcs2.world_to_pixel_values | ||
except AttributeError as err: | ||
raise TypeError("Input should be a WCS") from err | ||
|
||
def _reproject(x, y): | ||
sky = forward_transform(x, y) | ||
flat_sky = [] | ||
for axis in sky: | ||
flat_sky.append(axis.flatten()) | ||
det = backward_transform(*tuple(flat_sky)) | ||
det_reshaped = [] | ||
for axis in det: | ||
det_reshaped.append(axis.reshape(x.shape)) | ||
return tuple(det_reshaped) | ||
return _reproject |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added short docstrings in 2112773
This function (and
_absolute_subtract
) should likely not be used since they appear to introduce erroneous cr flags due to treating off-edge pixels as 0. However, fixing this bug would result in changes to regression test results and I think makes more sense in a separate PR (where the new results can be verified). There is a JP ticket for this spacetelescope/jwst#8636 and candidate new code here https://github.com/spacetelescope/jwst/pull/8635/filesFor the scope of this PR I added "Do not use this function" to the docstrings as moving this code to stcal did not introduce the bug.