Skip to content

Commit

Permalink
Add ability to get output dq array marking jump resultants
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamJamieson committed Oct 6, 2023
1 parent cd74c64 commit d138d6a
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/stcal/ramp_fitting/ols_cas22/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._fit_ramps import fit_ramps
from ._core import Parameter, Variance, Diff
from ._core import Parameter, Variance, Diff, RampJumpDQ

__all__ = ['fit_ramps', 'Parameter', 'Variance', 'Diff']
__all__ = ['fit_ramps', 'Parameter', 'Variance', 'Diff', 'RampJumpDQ']
4 changes: 4 additions & 0 deletions src/stcal/ramp_fitting/ols_cas22/_core.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ cpdef enum Variance:
total_var = 2


cpdef enum RampJumpDQ:
JUMP_DET = 4


cdef float threshold(Thresh thresh, float slope)
cdef float get_power(float s)
cdef deque[stack[RampIndex]] init_ramps(int[:, :] dq)
Expand Down
10 changes: 8 additions & 2 deletions src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ cimport cython

from stcal.ramp_fitting.ols_cas22._core cimport (RampFits, RampIndex, Thresh,
metadata_from_read_pattern, init_ramps,
Parameter, Variance)
Parameter, Variance, RampJumpDQ)
from stcal.ramp_fitting.ols_cas22._fixed cimport fixed_values_from_metadata, FixedValues
from stcal.ramp_fitting.ols_cas22._pixel cimport make_pixel

Expand Down Expand Up @@ -85,6 +85,9 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants,
cdef np.ndarray[float, ndim=2] parameters = np.zeros((n_pixels, 2), dtype=np.float32)
cdef np.ndarray[float, ndim=2] variances = np.zeros((n_pixels, 3), dtype=np.float32)

# Copy the dq array so we can modify it without fear
cdef np.ndarray[int, ndim=2] fit_dq = dq.copy()

# Perform all of the fits
cdef RampFits fit
cdef int index
Expand All @@ -99,6 +102,9 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants,
variances[index, Variance.poisson_var] = fit.average.poisson_var
variances[index, Variance.total_var] = fit.average.read_var + fit.average.poisson_var

for jump in fit.jumps:
fit_dq[jump, index] = RampJumpDQ.JUMP_DET

ramp_fits.push_back(fit)

return ramp_fits, parameters, variances
return ramp_fits, parameters, variances, fit_dq
2 changes: 1 addition & 1 deletion src/stcal/ramp_fitting/ols_cas22_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def fit_ramps_casertano(
dq = dq.reshape(orig_shape + (1,))
read_noise = read_noise.reshape(orig_shape[1:] + (1,))

Check warning on line 112 in src/stcal/ramp_fitting/ols_cas22_fit.py

View check run for this annotation

Codecov / codecov/patch

src/stcal/ramp_fitting/ols_cas22_fit.py#L110-L112

Added lines #L110 - L112 were not covered by tests

_, parameters, variances = ols_cas22.fit_ramps(
_, parameters, variances, _ = ols_cas22.fit_ramps(
resultants.reshape(resultants.shape[0], -1),
dq.reshape(resultants.shape[0], -1),
read_noise.reshape(-1),
Expand Down
31 changes: 24 additions & 7 deletions tests/test_jump_cas22.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from stcal.ramp_fitting.ols_cas22._wrappers import init_ramps
from stcal.ramp_fitting.ols_cas22._wrappers import run_threshold, fixed_values_from_metadata, make_pixel

from stcal.ramp_fitting.ols_cas22 import fit_ramps, Parameter, Variance, Diff
from stcal.ramp_fitting.ols_cas22 import fit_ramps, Parameter, Variance, Diff, RampJumpDQ


RNG = np.random.default_rng(619)
Expand Down Expand Up @@ -331,7 +331,7 @@ def test_fit_ramps_array_outputs(detector_data, use_jump):
resultants, read_noise, read_pattern = detector_data
dq = np.zeros(resultants.shape, dtype=np.int32)

fits, parameters, variances = fit_ramps(
fits, parameters, variances, _ = fit_ramps(
resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump
)

Expand All @@ -356,7 +356,7 @@ def test_fit_ramps_no_dq(detector_data, use_jump):
resultants, read_noise, read_pattern = detector_data
dq = np.zeros(resultants.shape, dtype=np.int32)

fits, _, _ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump)
fits, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump)
assert len(fits) == N_PIXELS # sanity check that a fit is output for each pixel

# Check that the chi2 for the resulting fit relative to the assumed flux is ~1
Expand Down Expand Up @@ -387,7 +387,7 @@ def test_fit_ramps_dq(detector_data, use_jump):
# i.e., we can make a measurement from them.
okay = np.sum((dq[1:, :] == 0) & (dq[:-1, :] == 0), axis=0) != 0

fits, _, _ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump)
fits, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump)
assert len(fits) == N_PIXELS # sanity check that a fit is output for each pixel

chi2 = 0
Expand Down Expand Up @@ -465,7 +465,7 @@ def test_find_jumps(jump_data):
resultants, read_noise, read_pattern, jump_reads, jump_resultants = jump_data
dq = np.zeros(resultants.shape, dtype=np.int32)

fits, _, _ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True)
fits, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True)
assert len(fits) == len(jump_reads) # sanity check that a fit/jump is set for every pixel

chi2 = 0
Expand Down Expand Up @@ -524,11 +524,28 @@ def test_override_default_threshold(jump_data):
resultants, read_noise, read_pattern, jump_reads, jump_resultants = jump_data
dq = np.zeros(resultants.shape, dtype=np.int32)

_, standard, _ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True)
_, override, _ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True,
_, standard, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True)
_, override, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True,
intercept=0, constant=0)

# All this is intended to do is show that with all other things being equal passing non-default
# threshold parameters changes the results.
assert (standard != override).any()


def test_jump_dq_set(jump_data):
# Check the DQ flag value to start
assert RampJumpDQ.JUMP_DET == 2**2

resultants, read_noise, read_pattern, jump_reads, jump_resultants = jump_data
dq = np.zeros(resultants.shape, dtype=np.int32)

fits, *_, fit_dq = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True)

for fit, pixel_dq in zip(fits, fit_dq.transpose()):
# Check that all jumps found get marked
assert (pixel_dq[fit['jumps']] == RampJumpDQ.JUMP_DET).all()

# Check that dq flags for jumps are only set if the jump is marked
assert set(np.where(pixel_dq == RampJumpDQ.JUMP_DET)[0]) == set(fit['jumps'])

0 comments on commit d138d6a

Please sign in to comment.