diff --git a/src/stcal/ramp_fitting/ols_cas22/__init__.py b/src/stcal/ramp_fitting/ols_cas22/__init__.py index 4a5480d52..8b2a69abb 100644 --- a/src/stcal/ramp_fitting/ols_cas22/__init__.py +++ b/src/stcal/ramp_fitting/ols_cas22/__init__.py @@ -1,4 +1,4 @@ from ._fit_ramps import fit_ramps -from ._core import Parameter, Variance, Diff +from ._core import Parameter, Variance, Diff, RampJumpDQ -__all__ = ['fit_ramps', 'Parameter', 'Variance', 'Diff'] +__all__ = ['fit_ramps', 'Parameter', 'Variance', 'Diff', 'RampJumpDQ'] diff --git a/src/stcal/ramp_fitting/ols_cas22/_core.pxd b/src/stcal/ramp_fitting/ols_cas22/_core.pxd index 8b5494e33..9f14c7571 100644 --- a/src/stcal/ramp_fitting/ols_cas22/_core.pxd +++ b/src/stcal/ramp_fitting/ols_cas22/_core.pxd @@ -48,6 +48,10 @@ cpdef enum Variance: total_var = 2 +cpdef enum RampJumpDQ: + JUMP_DET = 4 + + cdef float threshold(Thresh thresh, float slope) cdef float get_power(float s) cdef deque[stack[RampIndex]] init_ramps(int[:, :] dq) diff --git a/src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx b/src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx index 32031f7fa..65f40d904 100644 --- a/src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx +++ b/src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx @@ -8,7 +8,7 @@ cimport cython from stcal.ramp_fitting.ols_cas22._core cimport (RampFits, RampIndex, Thresh, metadata_from_read_pattern, init_ramps, - Parameter, Variance) + Parameter, Variance, RampJumpDQ) from stcal.ramp_fitting.ols_cas22._fixed cimport fixed_values_from_metadata, FixedValues from stcal.ramp_fitting.ols_cas22._pixel cimport make_pixel @@ -85,6 +85,9 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants, cdef np.ndarray[float, ndim=2] parameters = np.zeros((n_pixels, 2), dtype=np.float32) cdef np.ndarray[float, ndim=2] variances = np.zeros((n_pixels, 3), dtype=np.float32) + # Copy the dq array so we can modify it without fear + cdef np.ndarray[int, ndim=2] fit_dq = dq.copy() + # Perform all of the fits cdef RampFits fit cdef int index @@ -99,6 +102,9 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants, variances[index, Variance.poisson_var] = fit.average.poisson_var variances[index, Variance.total_var] = fit.average.read_var + fit.average.poisson_var + for jump in fit.jumps: + fit_dq[jump, index] = RampJumpDQ.JUMP_DET + ramp_fits.push_back(fit) - return ramp_fits, parameters, variances + return ramp_fits, parameters, variances, fit_dq diff --git a/src/stcal/ramp_fitting/ols_cas22_fit.py b/src/stcal/ramp_fitting/ols_cas22_fit.py index df3a8b4e3..94d382425 100644 --- a/src/stcal/ramp_fitting/ols_cas22_fit.py +++ b/src/stcal/ramp_fitting/ols_cas22_fit.py @@ -111,7 +111,7 @@ def fit_ramps_casertano( dq = dq.reshape(orig_shape + (1,)) read_noise = read_noise.reshape(orig_shape[1:] + (1,)) - _, parameters, variances = ols_cas22.fit_ramps( + _, parameters, variances, _ = ols_cas22.fit_ramps( resultants.reshape(resultants.shape[0], -1), dq.reshape(resultants.shape[0], -1), read_noise.reshape(-1), diff --git a/tests/test_jump_cas22.py b/tests/test_jump_cas22.py index 9c10c4c32..cf34a989f 100644 --- a/tests/test_jump_cas22.py +++ b/tests/test_jump_cas22.py @@ -6,7 +6,7 @@ from stcal.ramp_fitting.ols_cas22._wrappers import init_ramps from stcal.ramp_fitting.ols_cas22._wrappers import run_threshold, fixed_values_from_metadata, make_pixel -from stcal.ramp_fitting.ols_cas22 import fit_ramps, Parameter, Variance, Diff +from stcal.ramp_fitting.ols_cas22 import fit_ramps, Parameter, Variance, Diff, RampJumpDQ RNG = np.random.default_rng(619) @@ -331,7 +331,7 @@ def test_fit_ramps_array_outputs(detector_data, use_jump): resultants, read_noise, read_pattern = detector_data dq = np.zeros(resultants.shape, dtype=np.int32) - fits, parameters, variances = fit_ramps( + fits, parameters, variances, _ = fit_ramps( resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump ) @@ -356,7 +356,7 @@ def test_fit_ramps_no_dq(detector_data, use_jump): resultants, read_noise, read_pattern = detector_data dq = np.zeros(resultants.shape, dtype=np.int32) - fits, _, _ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump) + fits, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump) assert len(fits) == N_PIXELS # sanity check that a fit is output for each pixel # Check that the chi2 for the resulting fit relative to the assumed flux is ~1 @@ -387,7 +387,7 @@ def test_fit_ramps_dq(detector_data, use_jump): # i.e., we can make a measurement from them. okay = np.sum((dq[1:, :] == 0) & (dq[:-1, :] == 0), axis=0) != 0 - fits, _, _ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump) + fits, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump) assert len(fits) == N_PIXELS # sanity check that a fit is output for each pixel chi2 = 0 @@ -465,7 +465,7 @@ def test_find_jumps(jump_data): resultants, read_noise, read_pattern, jump_reads, jump_resultants = jump_data dq = np.zeros(resultants.shape, dtype=np.int32) - fits, _, _ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True) + fits, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True) assert len(fits) == len(jump_reads) # sanity check that a fit/jump is set for every pixel chi2 = 0 @@ -524,11 +524,28 @@ def test_override_default_threshold(jump_data): resultants, read_noise, read_pattern, jump_reads, jump_resultants = jump_data dq = np.zeros(resultants.shape, dtype=np.int32) - _, standard, _ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True) - _, override, _ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True, + _, standard, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True) + _, override, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True, intercept=0, constant=0) # All this is intended to do is show that with all other things being equal passing non-default # threshold parameters changes the results. assert (standard != override).any() + + +def test_jump_dq_set(jump_data): + # Check the DQ flag value to start + assert RampJumpDQ.JUMP_DET == 2**2 + + resultants, read_noise, read_pattern, jump_reads, jump_resultants = jump_data + dq = np.zeros(resultants.shape, dtype=np.int32) + + fits, *_, fit_dq = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True) + + for fit, pixel_dq in zip(fits, fit_dq.transpose()): + # Check that all jumps found get marked + assert (pixel_dq[fit['jumps']] == RampJumpDQ.JUMP_DET).all() + + # Check that dq flags for jumps are only set if the jump is marked + assert set(np.where(pixel_dq == RampJumpDQ.JUMP_DET)[0]) == set(fit['jumps'])