diff --git a/changes/306.apichange.rst b/changes/306.apichange.rst new file mode 100644 index 00000000..d04413a5 --- /dev/null +++ b/changes/306.apichange.rst @@ -0,0 +1,2 @@ +Add maximum_shower_amplitude parameter to MIRI cosmic rays showers routine +to fix accidental flagging of bright science pixels. diff --git a/src/stcal/jump/jump.py b/src/stcal/jump/jump.py index 9eaecf66..229f8ed9 100644 --- a/src/stcal/jump/jump.py +++ b/src/stcal/jump/jump.py @@ -60,6 +60,7 @@ def detect_jumps( min_diffs_single_pass=10, mask_persist_grps_next_int=True, persist_grps_flagged=25, + max_shower_amplitude=12 ): """ This is the high-level controlling routine for the jump detection process. @@ -220,6 +221,8 @@ def detect_jumps( then all differences are processed at once. min_diffs_single_pass : int The minimum number of groups to switch to flagging all outliers in a single pass. + max_shower_amplitude : float + The maximum possible amplitude for flagged MIRI showers in DN/group Returns ------- @@ -298,46 +301,7 @@ def detect_jumps( dqflags['DO_NOT_USE'] gdq[gdq == np.bitwise_or(dqflags['SATURATED'], dqflags['JUMP_DET'])] = \ dqflags['SATURATED'] - # This is the flag that controls the flagging of snowballs. - if expand_large_events: - gdq, total_snowballs = flag_large_events( - gdq, - jump_flag, - sat_flag, - min_sat_area=min_sat_area, - min_jump_area=min_jump_area, - expand_factor=expand_factor, - sat_required_snowball=sat_required_snowball, - min_sat_radius_extend=min_sat_radius_extend, - edge_size=edge_size, - sat_expand=sat_expand, - max_extended_radius=max_extended_radius, - mask_persist_grps_next_int=mask_persist_grps_next_int, - persist_grps_flagged=persist_grps_flagged, - ) - log.info("Total snowballs = %i", total_snowballs) - number_extended_events = total_snowballs - if find_showers: - gdq, num_showers = find_faint_extended( - data, - gdq, - pdq, - readnoise_2d, - frames_per_group, - minimum_sigclip_groups, - dqflags, - snr_threshold=extend_snr_threshold, - min_shower_area=extend_min_area, - inner=extend_inner_radius, - outer=extend_outer_radius, - sat_flag=sat_flag, - jump_flag=jump_flag, - ellipse_expand=extend_ellipse_expand_ratio, - num_grps_masked=grps_masked_after_shower, - max_extended_radius=max_extended_radius, - ) - log.info("Total showers= %i", num_showers) - number_extended_events = num_showers + else: yinc = int(n_rows // n_slices) slices = [] @@ -463,46 +427,50 @@ def detect_jumps( gdq[gdq == np.bitwise_or(dqflags['SATURATED'], dqflags['JUMP_DET'])] = \ dqflags['SATURATED'] - # This is the flag that controls the flagging of snowballs. - if expand_large_events: - gdq, total_snowballs = flag_large_events( - gdq, - jump_flag, - sat_flag, - min_sat_area=min_sat_area, - min_jump_area=min_jump_area, - expand_factor=expand_factor, - sat_required_snowball=sat_required_snowball, - min_sat_radius_extend=min_sat_radius_extend, - edge_size=edge_size, - sat_expand=sat_expand, - max_extended_radius=max_extended_radius, - mask_persist_grps_next_int=mask_persist_grps_next_int, - persist_grps_flagged=persist_grps_flagged, - ) - log.info("Total snowballs = %i", total_snowballs) - number_extended_events = total_snowballs - if find_showers: - gdq, num_showers = find_faint_extended( - data, - gdq, - pdq, - readnoise_2d, - frames_per_group, - minimum_sigclip_groups, - dqflags, - snr_threshold=extend_snr_threshold, - min_shower_area=extend_min_area, - inner=extend_inner_radius, - outer=extend_outer_radius, - sat_flag=sat_flag, - jump_flag=jump_flag, - ellipse_expand=extend_ellipse_expand_ratio, - num_grps_masked=grps_masked_after_shower, - max_extended_radius=max_extended_radius, - ) - log.info("Total showers= %i", num_showers) - number_extended_events = num_showers + # Look for snowballs in near-IR data + if expand_large_events: + gdq, total_snowballs = flag_large_events( + gdq, + jump_flag, + sat_flag, + min_sat_area=min_sat_area, + min_jump_area=min_jump_area, + expand_factor=expand_factor, + sat_required_snowball=sat_required_snowball, + min_sat_radius_extend=min_sat_radius_extend, + edge_size=edge_size, + sat_expand=sat_expand, + max_extended_radius=max_extended_radius, + mask_persist_grps_next_int=mask_persist_grps_next_int, + persist_grps_flagged=persist_grps_flagged, + ) + log.info("Total snowballs = %i", total_snowballs) + number_extended_events = total_snowballs + + # Look for showers in mid-IR data + if find_showers: + gdq, num_showers = find_faint_extended( + data, + gdq, + pdq, + readnoise_2d, + frames_per_group, + minimum_sigclip_groups, + dqflags, + snr_threshold=extend_snr_threshold, + min_shower_area=extend_min_area, + inner=extend_inner_radius, + outer=extend_outer_radius, + sat_flag=sat_flag, + jump_flag=jump_flag, + ellipse_expand=extend_ellipse_expand_ratio, + num_grps_masked=grps_masked_after_shower, + max_extended_radius=max_extended_radius, + max_shower_amplitude=max_shower_amplitude + ) + log.info("Total showers= %i", num_showers) + number_extended_events = num_showers + elapsed = time.time() - start log.info("Total elapsed time = %g sec", elapsed) @@ -878,6 +846,7 @@ def near_edge(jump, low_threshold, high_threshold): ) +# MIRI cosmic ray showers code def find_faint_extended( indata, ingdq, @@ -897,6 +866,7 @@ def find_faint_extended( num_grps_masked=25, max_extended_radius=200, min_diffs_for_shower=10, + max_shower_amplitude=6, ): """ Parameters @@ -931,6 +901,8 @@ def find_faint_extended( The upper limit for the extension of saturation and jump minimum_sigclip_groups : int The minimum number of groups to use sigma clipping. + max_shower_amplitude : float + The maximum amplitude of shower artifacts to correct in DN/group Returns @@ -948,6 +920,7 @@ def find_faint_extended( nints = data.shape[0] ngrps = data.shape[1] num_grps_donotuse = 0 + for integ in range(nints): for grp in range(ngrps): if np.all(np.bitwise_and(gdq[integ, grp, :, :], donotuse_flag)): @@ -1028,6 +1001,8 @@ def find_faint_extended( masked_smoothed_ratio = convolve(masked_ratio.filled(np.nan), ring_2D_kernel) # mask out the pixels that got refilled by the convolution masked_smoothed_ratio[dnuy, dnux] = np.nan + masked_smoothed_ratio[saty, satx] = np.nan + masked_smoothed_ratio[jumpy, jumpx] = np.nan nrows = ratio.shape[1] ncols = ratio.shape[2] extended_emission = np.zeros(shape=(nrows, ncols), dtype=np.uint8) @@ -1111,6 +1086,41 @@ def find_faint_extended( num_grps_masked=num_grps_masked, max_extended_radius=max_extended_radius ) + + # Ensure that flagging showers didn't change final fluxes by more than the allowed amount + for intg in range(nints): + # Consider DO_NOT_USE, SATURATION, and JUMP_DET flags + invalid_flags = donotuse_flag | sat_flag | jump_flag + + # Approximate pre-shower rates + tempdata = indata[intg, :, :, :].copy() + # Ignore any groups flagged in the original gdq array + tempdata[ingdq[intg, :, :, :] & invalid_flags != 0] = np.nan + # Compute group differences + diff = np.diff(tempdata, axis=0) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning, message="All-NaN") + warnings.filterwarnings("ignore", category=RuntimeWarning, message="Mean of empty slice") + image1 = np.nanmean(diff, axis=0) + + # Approximate post-shower rates + tempdata = indata[intg, :, :, :].copy() + # Ignore any groups flagged in the shower gdq array + tempdata[gdq[intg, :, :, :] & invalid_flags != 0] = np.nan + # Compute group differences + diff = np.diff(tempdata, axis=0) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning, message="All-NaN") + warnings.filterwarnings("ignore", category=RuntimeWarning, message="Mean of empty slice") + image2 = np.nanmean(diff, axis=0) + + # Revert the group flags to the pre-shower flags for any pixels whose rates + # became NaN or changed by more than the amount reasonable for a real CR shower + # Note that max_shower_amplitude should now be in DN/group not DN/s + diff = np.abs(image1 - image2) + indx = np.where((np.isfinite(diff) == False) | (diff > max_shower_amplitude)) + gdq[intg, :, indx[0], indx[1]] = ingdq[intg, :, indx[0], indx[1]] + return gdq, total_showers def find_first_good_group(int_gdq, do_not_use): diff --git a/tests/test_jump.py b/tests/test_jump.py index c06a76a9..b841d48e 100644 --- a/tests/test_jump.py +++ b/tests/test_jump.py @@ -383,6 +383,7 @@ def test_find_faint_extended(tmp_path): jump_flag=4, ellipse_expand=1., num_grps_masked=1, + max_shower_amplitude=10 ) # Check that all the expected samples in group 2 are flagged as jump and # that they are not flagged outside @@ -405,36 +406,6 @@ def test_find_faint_extended(tmp_path): # Check that the flags are not applied in the 3rd group after the event assert np.all(gdq[0, 4, 12:22, 14:23]) == 0 - def test_find_faint_extended(): - nint, ngrps, ncols, nrows = 1, 66, 5, 5 - data = np.zeros(shape=(nint, ngrps, nrows, ncols), dtype=np.float32) - gdq = np.zeros_like(data, dtype=np.uint32) - pdq = np.zeros(shape=(nrows, ncols), dtype=np.uint32) - pdq[0, 0] = 1 - pdq[1, 1] = 2147483648 - # pdq = np.zeros(shape=(data.shape[2], data.shape[3]), dtype=np.uint8) - gain = 4 - readnoise = np.ones(shape=(nrows, ncols), dtype=np.float32) * 6.0 * gain - rng = np.random.default_rng(12345) - data[0, 1:, 14:20, 15:20] = 6 * gain * 6.0 * np.sqrt(2) - data = data + rng.normal(size=(nint, ngrps, nrows, ncols)) * readnoise - gdq, num_showers = find_faint_extended( - data, - gdq, - pdq, - readnoise * np.sqrt(2), - 1, - 100, - snr_threshold=3, - min_shower_area=10, - inner=1, - outer=2.6, - sat_flag=2, - jump_flag=4, - ellipse_expand=1.1, - num_grps_masked=0, - ) - # No shower is found because the event is identical in all ints def test_find_faint_extended_sigclip():