Skip to content

Commit

Permalink
More significant changes for performance.
Browse files Browse the repository at this point in the history
  • Loading branch information
Timothy Brandt committed Dec 18, 2024
1 parent fee894a commit 062ee80
Showing 1 changed file with 126 additions and 111 deletions.
237 changes: 126 additions & 111 deletions src/stcal/jump/twopoint_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import warnings
from astropy import stats
from astropy.utils.exceptions import AstropyUserWarning

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -128,12 +129,11 @@ def find_crs(
pixels above current row also to be flagged as a CR
"""
# copy data and group DQ array
# copy group DQ array. data array never gets modified.
dat = dataa
if copy_arrs:
dat = dataa.copy()
gdq = group_dq.copy()
else:
dat = dataa
gdq = group_dq
# Get data characteristics
nints, ngroups, nrows, ncols = dataa.shape
Expand Down Expand Up @@ -172,15 +172,17 @@ def find_crs(
dtype=np.float32)
return gdq, row_below_gdq, row_above_gdq, 0, dummy
else:
# set 'saturated' or 'do not use' pixels to nan in data
dat[np.where(np.bitwise_and(gdq, dnu_flag | sat_flag))] = np.nan
# boolean array for 'saturated' or 'do not use' pixels
set_to_nan = gdq & (dnu_flag | sat_flag) != 0

# calculate the differences between adjacent groups (first diffs)
# use mask on data, so the results will have sat/donotuse groups masked
first_diffs = np.diff(dat, axis=1)
first_diffs[set_to_nan[:, 1:] | set_to_nan[:, :-1]] = np.nan
del set_to_nan
first_diffs_finite = np.isfinite(first_diffs)

# calc. the median of first_diffs for each pixel along the group axis
first_diffs_masked = np.ma.masked_array(first_diffs, mask=np.isnan(first_diffs))
median_diffs = np.nanmedian(first_diffs.reshape((-1, nrows, ncols)), axis=0)
# calculate sigma for each pixel
sigma = np.sqrt(np.abs(median_diffs) + read_noise_2 / nframes)
Expand All @@ -196,23 +198,24 @@ def find_crs(
warnings.filterwarnings("ignore", ".*All-NaN slice encountered.*", RuntimeWarning)
warnings.filterwarnings("ignore", ".*Mean of empty slice.*", RuntimeWarning)
warnings.filterwarnings("ignore", ".*Degrees of freedom <= 0.*", RuntimeWarning)
warnings.filterwarnings("ignore", ".*Input data contains invalid values*", AstropyUserWarning)

Check warning on line 201 in src/stcal/jump/twopoint_difference.py

View check run for this annotation

Codecov / codecov/patch

src/stcal/jump/twopoint_difference.py#L201

Added line #L201 was not covered by tests

if only_use_ints:
mean, median, stddev = stats.sigma_clipped_stats(first_diffs_masked, sigma=normal_rej_thresh,
axis=0)
clipped_diffs = stats.sigma_clip(first_diffs_masked, sigma=normal_rej_thresh,
axis=0, masked=True)
clipped_diffs, alow, ahigh = stats.sigma_clip(first_diffs, sigma=normal_rej_thresh,

Check warning on line 204 in src/stcal/jump/twopoint_difference.py

View check run for this annotation

Codecov / codecov/patch

src/stcal/jump/twopoint_difference.py#L204

Added line #L204 was not covered by tests
axis=0, masked=True, return_bounds=True)
else:
mean, median, stddev = stats.sigma_clipped_stats(first_diffs_masked, sigma=normal_rej_thresh,
axis=(0, 1))
clipped_diffs = stats.sigma_clip(first_diffs_masked, sigma=normal_rej_thresh,
axis=(0, 1), masked=True)
jump_mask = np.logical_and(clipped_diffs.mask, np.logical_not(first_diffs_masked.mask))
jump_mask[np.bitwise_and(jump_mask, gdq[:, 1:, :, :] == sat_flag)] = False
jump_mask[np.bitwise_and(jump_mask, gdq[:, 1:, :, :] == dnu_flag)] = False
jump_mask[np.bitwise_and(jump_mask, gdq[:, 1:, :, :] == (dnu_flag | sat_flag))] = False
gdq[:, 1:, :, :] = np.bitwise_or(gdq[:, 1:, :, :], jump_mask *
np.uint8(dqflags["JUMP_DET"]))
clipped_diffs, alow, ahigh = stats.sigma_clip(first_diffs, sigma=normal_rej_thresh,

Check warning on line 207 in src/stcal/jump/twopoint_difference.py

View check run for this annotation

Codecov / codecov/patch

src/stcal/jump/twopoint_difference.py#L207

Added line #L207 was not covered by tests
axis=(0, 1), masked=True, return_bounds=True)

# get the standard deviation from the bounds of sigma clipping
stddev = 0.5*(ahigh - alow)/normal_rej_thresh

Check warning on line 211 in src/stcal/jump/twopoint_difference.py

View check run for this annotation

Codecov / codecov/patch

src/stcal/jump/twopoint_difference.py#L211

Added line #L211 was not covered by tests

jump_candidates = clipped_diffs.mask
sat_or_dnu_not_set = gdq[:, 1:] & (sat_flag | dnu_flag) == 0
jump_mask = jump_candidates & first_diffs_finite & sat_or_dnu_not_set
del clipped_diffs
gdq[:, 1:] |= jump_mask * np.uint8(jump_flag)

Check warning on line 217 in src/stcal/jump/twopoint_difference.py

View check run for this annotation

Codecov / codecov/patch

src/stcal/jump/twopoint_difference.py#L213-L217

Added lines #L213 - L217 were not covered by tests

# if grp is all jump set to do not use
for integ in range(nints):
for grp in range(ngrps):
Expand All @@ -230,14 +233,15 @@ def find_crs(
# compute 'ratio' for each group. this is the value that will be
# compared to 'threshold' to classify jumps. subtract the median of
# first_diffs from first_diffs, take the abs. value and divide by sigma.
ratio = (np.abs(first_diffs - median_diffs[np.newaxis, np.newaxis, :]) / sigma[np.newaxis, :]).astype(np.float32)
masked_ratio = np.ma.masked_greater(ratio, normal_rej_thresh)
del ratio
# The jump mask is the ratio greater than the threshold and the difference is usable
jump_mask = np.logical_and(masked_ratio.mask, np.logical_not(first_diffs_masked.mask))
gdq[:, 1:, :, :] = np.bitwise_or(gdq[:, 1:, :, :], jump_mask *
np.uint8(dqflags["JUMP_DET"]))
del masked_ratio
# The jump mask is the ratio greater than the threshold and the
# difference is usable. Loop over integrations to minimize the memory
# footprint.

jump_mask = np.zeros(first_diffs.shape, dtype=bool)
for i in range(nints):
jump_candidates = np.abs(first_diffs[i] - median_diffs[np.newaxis, :]) / sigma[np.newaxis, :] > normal_rej_thresh
jump_mask = jump_candidates & first_diffs_finite[i]
gdq[i, 1:] |= jump_mask * np.uint8(jump_flag)

else: # low number of diffs requires iterative flagging
# calculate the differences between adjacent groups (first diffs)
Expand Down Expand Up @@ -334,91 +338,64 @@ def find_crs(
gdq[:, 1:, all_crs_row[j], all_crs_col[j]],
dqflags["JUMP_DET"] * np.invert(pix_cr_mask),
)
cr_integ, cr_group, cr_row, cr_col = np.where(np.bitwise_and(gdq, jump_flag))
num_primary_crs = len(cr_group)
if flag_4_neighbors: # iterate over each 'jump' pixel
for j in range(len(cr_group)):
_i, _j, _k, _l = (cr_integ[j], cr_group[j] - 1, cr_row[j], cr_col[j])
ratio_this_pix = np.abs(first_diffs[_i, _j, _k, _l] - median_diffs[_k, _l])/sigma[_k, _l]

# Jumps must be in a certain range to have neighbors flagged
if (ratio_this_pix < max_jump_to_flag_neighbors) and (
ratio_this_pix > min_jump_to_flag_neighbors
):
integ = cr_integ[j]
group = cr_group[j]
row = cr_row[j]
col = cr_col[j]

# This section saves flagged neighbors that are above or
# below the current range of row. If this method
# running in a single process, the row above and below are
# not used. If it is running in multiprocessing mode, then
# the rows above and below need to be returned to
# find_jumps to use when it reconstructs the full group dq
# array from the slices.

# Only flag adjacent pixels if they do not already have the
# 'SATURATION' or 'DONOTUSE' flag set
if row != 0:
if (gdq[integ, group, row - 1, col] & sat_flag) == 0 and (
gdq[integ, group, row - 1, col] & dnu_flag
) == 0:
gdq[integ, group, row - 1, col] = np.bitwise_or(
gdq[integ, group, row - 1, col], jump_flag
)
else:
row_below_gdq[integ, cr_group[j], cr_col[j]] = jump_flag

if row != nrows - 1:
if (gdq[integ, group, row + 1, col] & sat_flag) == 0 and (
gdq[integ, group, row + 1, col] & dnu_flag
) == 0:
gdq[integ, group, row + 1, col] = np.bitwise_or(
gdq[integ, group, row + 1, col], jump_flag
)
else:
row_above_gdq[integ, cr_group[j], cr_col[j]] = jump_flag

# Here we are just checking that we don't flag neighbors of
# jumps that are off the detector.
if (
cr_col[j] != 0
and (gdq[integ, group, row, col - 1] & sat_flag) == 0
and (gdq[integ, group, row, col - 1] & dnu_flag) == 0
):
gdq[integ, group, row, col - 1] = np.bitwise_or(
gdq[integ, group, row, col - 1], jump_flag
)

if (
cr_col[j] != ncols - 1
and (gdq[integ, group, row, col + 1] & sat_flag) == 0
and (gdq[integ, group, row, col + 1] & dnu_flag) == 0
):
gdq[integ, group, row, col + 1] = np.bitwise_or(
gdq[integ, group, row, col + 1], jump_flag
)
num_primary_crs = np.sum(gdq & jump_flag == jump_flag)

# Flag the four neighbors using bitwise or, shifting the reference
# boolean flag on pixel right, then left, then up, then down.

if flag_4_neighbors:
for i in range(nints):
for j in range(ngroups - 1):

ratio = np.abs(first_diffs[i, j] - median_diffs)/sigma
jump_set = gdq[i, j + 1] & jump_flag != 0

flag = (ratio < max_jump_to_flag_neighbors) & \
(ratio > min_jump_to_flag_neighbors) & \
(jump_set)

flagsave = flag.copy()
flag[1:] |= flagsave[:-1]
flag[:-1] |= flagsave[1:]
flag[:, 1:] |= flagsave[:, :-1]
flag[:, :-1] |= flagsave[:, 1:]

sat_or_dnu_notset = gdq[i, j + 1] & (sat_flag | dnu_flag) == 0

gdq[i, j + 1][sat_or_dnu_notset & flag] |= jump_flag

row_below_gdq[i, j + 1] = np.uint8(jump_flag) * flag[0]
row_above_gdq[i, j + 1] = np.uint8(jump_flag) * flag[-1]

# flag n groups after jumps above the specified thresholds to
# account for the transient seen after ramp jumps. Again, use
# boolean arrays; the propagation happens in a separate function.

if after_jump_flag_n1 > 0 or after_jump_flag_n2 > 0:
for i in range(nints):

ejump = first_diffs[i] - median_diffs[np.newaxis, :]
jump_set = gdq[i] & jump_flag != 0

# flag n groups after jumps above the specified thresholds to account for
# the transient seen after ramp jumps
flag_e_threshold = [after_jump_flag_e1, after_jump_flag_e2]
flag_groups = [after_jump_flag_n1, after_jump_flag_n2]
for cthres, cgroup in zip(flag_e_threshold, flag_groups):
if cgroup > 0:
cr_intg, cr_group, cr_row, cr_col = np.where(np.bitwise_and(gdq, jump_flag))
for j in range(len(cr_group)):
intg = cr_intg[j]
group = cr_group[j]
row = cr_row[j]
col = cr_col[j]
ejump_this_pix = first_diffs[intg, group - 1, row, col] - median_diffs[row, col]
if ejump_this_pix >= cthres:
for kk in range(group, min(group + cgroup + 1, ngroups)):
if (gdq[intg, kk, row, col] & sat_flag) == 0 and (
gdq[intg, kk, row, col] & dnu_flag
) == 0:
gdq[intg, kk, row, col] = np.bitwise_or(gdq[intg, kk, row, col], jump_flag)
bigjump = np.zeros(jump_set.shape, dtype=bool)
verybigjump = np.zeros(jump_set.shape, dtype=bool)

bigjump[1:] = (ejump >= after_jump_flag_e1) & jump_set[1:]
verybigjump[1:] = (ejump >= after_jump_flag_e2) & jump_set[1:]

# Propagate flags forward

propagate_flags(bigjump, after_jump_flag_n1)
propagate_flags(verybigjump, after_jump_flag_n2)

# Set the flags for pixels after these jumps that are not
# already flagged as saturated or do not use.

sat_or_dnu_notset = gdq[i] & (sat_flag | dnu_flag) == 0

addflag = (bigjump | verybigjump) & sat_or_dnu_notset
gdq[i][addflag] |= jump_flag

if "stddev" in locals():
return gdq, row_below_gdq, row_above_gdq, num_primary_crs, stddev
Expand All @@ -430,6 +407,44 @@ def find_crs(
return gdq, row_below_gdq, row_above_gdq, num_primary_crs, dummy


def propagate_flags(boolean_flag, n_groups_flag):
"""Propagate a boolean flag array npix groups along the first axis.
If the number of groups to propagate is not too large, or if a
high percentage of pixels are flagged, use boolean or on the
array. Otherwise use np.where. In both cases operate on the
array in-place.
Parameters
----------
boolean_flag : 3D boolean array
Should be True where the flag is to be propagated.
n_groups_flag : int
Number of groups to propagate flags forward.
Returns
-------
None
"""
ngroups = boolean_flag.shape[0]
jmax = min(n_groups_flag, ngroups - 2)

# Option A: iteratively propagate all flags forward by one
# group at a time.
if (jmax <= 50 and jmax > 0) or np.mean(boolean_flag) > 1e-3:
for j in range(jmax):
boolean_flag[j + 1:] |= boolean_flag[j:-1]

# Option B: find the flags and propagate them individually.
elif jmax > 0:
igrp, icol, irow = np.where(boolean_flag)
for j in range(len(igrp)):
boolean_flag[igrp[j]:igrp[j] + n_groups_flag + 1, icol[j], irow[j]] = True

Check warning on line 443 in src/stcal/jump/twopoint_difference.py

View check run for this annotation

Codecov / codecov/patch

src/stcal/jump/twopoint_difference.py#L443

Added line #L443 was not covered by tests

return


def calc_med_first_diffs(in_first_diffs):
"""Calculate the median of `first diffs` along the group axis.
Expand Down

0 comments on commit 062ee80

Please sign in to comment.