Skip to content

Commit

Permalink
JP-3562: Shower flagging enhancement (#248)
Browse files Browse the repository at this point in the history
* Update jump.py

* fix missing pdq info

* test

* updates

* Update test_jump.py

* Update jump.py

* Update jump.py

* fixed pdq

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* tests

* Update jump.py

* Update jump.py

* fixes

* updates

* Update jump.py

* Update jump.py

* Update twopoint_difference.py

* Update twopoint_difference.py

* Update twopoint_difference.py

* Update twopoint_difference.py

* skip flagged groups

* Update jump.py

* Update jump.py

* fixes

* updates

* Update jump.py

* Update jump.py

* Update jump.py

* tests

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* Update jump.py

* fix test

* update tests

* Update refpix dqvalue

* Update CHANGES.rst

* Update CHANGES.rst

* Update jump.py

* Update test_jump.py

* Update test_jump.py

* Update test_jump.py

---------

Co-authored-by: Howard Bushouse <[email protected]>
Co-authored-by: Zach Burnett <[email protected]>
  • Loading branch information
3 people authored Mar 21, 2024
1 parent 4b5ea2f commit 95f1643
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 84 deletions.
7 changes: 7 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ ramp_fitting
Bug Fixes
---------

jump
~~~~

- Updated the shower flagging code to mask reference pixels, require a minimum
number of groups to trigger the detection, and use all integrations to determine
the median value. [#248]

ramp_fitting
~~~~~~~~~~~~

Expand Down
137 changes: 103 additions & 34 deletions src/stcal/jump/jump.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import multiprocessing
import time
import warnings

import numpy as np
import cv2 as cv
Expand Down Expand Up @@ -320,9 +321,11 @@ def detect_jumps(
gdq, num_showers = find_faint_extended(
data,
gdq,
pdq,
readnoise_2d,
frames_per_group,
minimum_sigclip_groups,
dqflags,
snr_threshold=extend_snr_threshold,
min_shower_area=extend_min_area,
inner=extend_inner_radius,
Expand Down Expand Up @@ -353,9 +356,9 @@ def detect_jumps(
slices.insert(
i,
(
data[:, :, i * yinc : (i + 1) * yinc, :],
gdq[:, :, i * yinc : (i + 1) * yinc, :].copy(),
readnoise_2d[i * yinc : (i + 1) * yinc, :],
data[:, :, i * yinc: (i + 1) * yinc, :],
gdq[:, :, i * yinc: (i + 1) * yinc, :].copy(),
readnoise_2d[i * yinc: (i + 1) * yinc, :],
rejection_thresh,
three_grp_thresh,
four_grp_thresh,
Expand All @@ -380,9 +383,9 @@ def detect_jumps(
slices.insert(
n_slices - 1,
(
data[:, :, (n_slices - 1) * yinc : n_rows, :],
gdq[:, :, (n_slices - 1) * yinc : n_rows, :].copy() ,
readnoise_2d[(n_slices - 1) * yinc : n_rows, :],
data[:, :, (n_slices - 1) * yinc: n_rows, :],
gdq[:, :, (n_slices - 1) * yinc: n_rows, :].copy(),
readnoise_2d[(n_slices - 1) * yinc: n_rows, :],
rejection_thresh,
three_grp_thresh,
four_grp_thresh,
Expand Down Expand Up @@ -426,15 +429,15 @@ def detect_jumps(
stddev = np.zeros((nrows, ncols), dtype=np.float32)
for resultslice in real_result:
if len(real_result) == k + 1: # last result
gdq[:, :, k * yinc : n_rows, :] = resultslice[0]
gdq[:, :, k * yinc: n_rows, :] = resultslice[0]
if only_use_ints:
stddev[:, k * yinc : n_rows, :] = resultslice[4]
stddev[:, k * yinc: n_rows, :] = resultslice[4]
else:
stddev[k * yinc : n_rows, :] = resultslice[4]
stddev[k * yinc: n_rows, :] = resultslice[4]
else:
gdq[:, :, k * yinc : (k + 1) * yinc, :] = resultslice[0]
gdq[:, :, k * yinc: (k + 1) * yinc, :] = resultslice[0]
if only_use_ints:
stddev[:, k * yinc : (k + 1) * yinc, :] = resultslice[4]
stddev[:, k * yinc: (k + 1) * yinc, :] = resultslice[4]
else:
stddev[k * yinc : (k + 1) * yinc, :] = resultslice[4]
row_below_gdq[:, :, :] = resultslice[1]
Expand All @@ -456,9 +459,9 @@ def detect_jumps(
# remove redundant bits in pixels that have jump flagged but were
# already flagged as do_not_use or saturated.
gdq[gdq == np.bitwise_or(dqflags['DO_NOT_USE'], dqflags['JUMP_DET'])] = \
dqflags['DO_NOT_USE']
dqflags['DO_NOT_USE']
gdq[gdq == np.bitwise_or(dqflags['SATURATED'], dqflags['JUMP_DET'])] = \
dqflags['SATURATED']
dqflags['SATURATED']

# This is the flag that controls the flagging of snowballs.
if expand_large_events:
Expand All @@ -483,9 +486,11 @@ def detect_jumps(
gdq, num_showers = find_faint_extended(
data,
gdq,
pdq,
readnoise_2d,
frames_per_group,
minimum_sigclip_groups,
dqflags,
snr_threshold=extend_snr_threshold,
min_shower_area=extend_min_area,
inner=extend_inner_radius,
Expand Down Expand Up @@ -569,7 +574,7 @@ def flag_large_events(
subsequent integrations
Returns
-------
Nothing, gdq array is modified.
total Snowballs
"""
log.info("Flagging Snowballs")
Expand Down Expand Up @@ -877,19 +882,23 @@ def near_edge(jump, low_threshold, high_threshold):

def find_faint_extended(
indata,
gdq,
ingdq,
pdq,
readnoise_2d,
nframes,
minimum_sigclip_groups,
dqflags,
snr_threshold=1.3,
min_shower_area=40,
inner=1,
outer=2,
donotuse_flag = 1,
sat_flag=2,
jump_flag=4,
ellipse_expand=1.1,
num_grps_masked=25,
max_extended_radius=200,
min_diffs_for_shower=10,
):
"""
Parameters
Expand Down Expand Up @@ -934,47 +943,93 @@ def find_faint_extended(
Total number of showers detected.
"""
read_noise_2d_sqr = readnoise_2d**2
log.info("Flagging Showers")
refpix_flag = dqflags["REFERENCE_PIXEL"]
gdq = ingdq.copy()
data = indata.copy()
nints = data.shape[0]
ngrps = data.shape[1]
num_grps_donotuse = 0
for integ in range(nints):
for grp in range(ngrps):
if np.all(np.bitwise_and(gdq[integ, grp, :, :], donotuse_flag)):
num_grps_donotuse += 1
total_diffs = nints * (ngrps - 1) - num_grps_donotuse
if total_diffs < min_diffs_for_shower:
log.warning("Not enough differences for shower detections")
return ingdq, 0
read_noise_2 = readnoise_2d**2
jump_dnu_flag = jump_flag + donotuse_flag
sat_dnu_flag = sat_flag + donotuse_flag
data[gdq == jump_dnu_flag] = np.nan
data[gdq == sat_dnu_flag] = np.nan
data[gdq == sat_flag] = np.nan
data[gdq == 1] = np.nan
data[gdq == jump_flag] = np.nan
all_ellipses = []
data[gdq == donotuse_flag] = np.nan
refy, refx = np.where(pdq == refpix_flag)
gdq[:, :, refy, refx] = donotuse_flag
first_diffs = np.diff(data, axis=1)

all_ellipses = []

first_diffs_masked = np.ma.masked_array(first_diffs, mask=np.isnan(first_diffs))
nints = data.shape[0]
if nints > minimum_sigclip_groups:
mean, median, stddev = stats.sigma_clipped_stats(first_diffs_masked, sigma=5, axis=0)
else:
median_diffs = np.nanmedian(first_diffs_masked, axis=(0, 1))
sigma = np.sqrt(np.abs(median_diffs) + read_noise_2 / nframes)

warnings.filterwarnings("ignore")
for intg in range(nints):
# calculate sigma for each pixel
if nints <= minimum_sigclip_groups:
median_diffs = np.nanmedian(first_diffs_masked[intg], axis=0)
sigma = np.sqrt(np.abs(median_diffs) + read_noise_2d_sqr / nframes)
if nints < minimum_sigclip_groups:
# The difference from the median difference for each group
e_jump = first_diffs_masked[intg] - median_diffs[np.newaxis, :, :]
# SNR ratio of each diff.
ratio = np.abs(e_jump) / sigma[np.newaxis, :, :]
if intg > 0:
e_jump = first_diffs_masked[intg] - median_diffs[np.newaxis, :, :]

# SNR ratio of each diff.
ratio = np.abs(e_jump) / sigma[np.newaxis, :, :]
else:
median_diffs = np.nanmedian(first_diffs_masked[intg], axis=0)
sigma = np.sqrt(np.abs(median_diffs) + read_noise_2 / nframes)
# The difference from the median difference for each group
e_jump = first_diffs_masked[intg] - median_diffs[np.newaxis, :, :]
# SNR ratio of each diff.
ratio = np.abs(e_jump) / sigma[np.newaxis, :, :]
median_diffs = np.nanmedian(first_diffs_masked, axis=(0, 1))
sigma = np.sqrt(np.abs(median_diffs) + read_noise_2 / nframes)
# The convolution kernel creation
ring_2D_kernel = Ring2DKernel(inner, outer)
ngrps = data.shape[1]
for grp in range(1, ngrps):
if nints > minimum_sigclip_groups:
first_good_group = find_first_good_group(gdq[intg, :, :, :], donotuse_flag)
for grp in range(first_good_group + 1, ngrps):
if nints >= minimum_sigclip_groups:
median_diffs = median[grp - 1]
sigma = stddev[grp - 1]
# The difference from the median difference for each group
e_jump = first_diffs_masked[intg] - median_diffs[np.newaxis, :, :]
# SNR ratio of each diff.
ratio = np.abs(e_jump) / sigma[np.newaxis, :, :]
masked_ratio = ratio[grp - 1].copy()
jumpy, jumpx = np.where(gdq[intg, grp, :, :] == jump_flag)
# mask pix. that are already flagged as jump
masked_ratio[jumpy, jumpx] = np.nan

saty, satx = np.where(gdq[intg, grp, :, :] == sat_flag)
# mask pixels that are already flagged as jump
combined_pixel_mask = np.bitwise_or(gdq[intg, grp, :, :], pdq[:, :])
jump_pixels_array = np.bitwise_and(combined_pixel_mask, jump_flag)
jumpy, jumpx = np.where(jump_pixels_array == jump_flag)
masked_ratio[jumpy, jumpx] = np.nan

# mask pix. that are already flagged as sat.
# mask pixels that are already flagged as sat.
sat_pixels_array = np.bitwise_and(combined_pixel_mask, sat_flag)
saty, satx = np.where(sat_pixels_array == sat_flag)
masked_ratio[saty, satx] = np.nan
masked_smoothed_ratio = convolve(masked_ratio, ring_2D_kernel)

# mask pixels that are already flagged as do not use
dnu_pixels_array = np.bitwise_and(combined_pixel_mask, 1)
dnuy, dnux = np.where(dnu_pixels_array == 1)
masked_ratio[dnuy, dnux] = np.nan

masked_smoothed_ratio = convolve(masked_ratio.filled(np.nan), ring_2D_kernel)
# mask out the pixels that got refilled by the convolution
masked_smoothed_ratio[dnuy, dnux] = np.nan
nrows = ratio.shape[1]
ncols = ratio.shape[2]
extended_emission = np.zeros(shape=(nrows, ncols), dtype=np.uint8)
Expand Down Expand Up @@ -1032,7 +1087,10 @@ def find_faint_extended(
if len(ellipses) > 0:
# add all the showers for this integration to the list
all_ellipses.append([intg, grp, ellipses])
# Reset the warnings filter to its original state
warnings.resetwarnings()
total_showers = 0

if all_ellipses:
# Now we actually do the flagging of the pixels inside showers.
# This is deferred until all showers are detected. because the showers
Expand All @@ -1057,6 +1115,17 @@ def find_faint_extended(
)
return gdq, total_showers

def find_first_good_group(int_gdq, do_not_use):
ngrps = int_gdq.shape[0]
skip_grp = True
first_good_group = 0
for grp in range(ngrps):
mask = np.bitwise_and(int_gdq[grp], do_not_use)
skip_grp = np.all(mask)
if not skip_grp:
first_good_group = grp
break
return first_good_group

def calc_num_slices(n_rows, max_cores, max_available):
n_slices = 1
Expand Down
Loading

0 comments on commit 95f1643

Please sign in to comment.