diff --git a/src/stcal/ramp_fitting/ols_cas22/__init__.py b/src/stcal/ramp_fitting/ols_cas22/__init__.py index a5d0f6f89..4a5480d52 100644 --- a/src/stcal/ramp_fitting/ols_cas22/__init__.py +++ b/src/stcal/ramp_fitting/ols_cas22/__init__.py @@ -1,3 +1,4 @@ from ._fit_ramps import fit_ramps +from ._core import Parameter, Variance, Diff -__all__ = ['fit_ramps'] +__all__ = ['fit_ramps', 'Parameter', 'Variance', 'Diff'] diff --git a/src/stcal/ramp_fitting/ols_cas22/_core.pxd b/src/stcal/ramp_fitting/ols_cas22/_core.pxd index 731f7c587..554ab00fd 100644 --- a/src/stcal/ramp_fitting/ols_cas22/_core.pxd +++ b/src/stcal/ramp_fitting/ols_cas22/_core.pxd @@ -32,11 +32,22 @@ cdef struct Thresh: float constant -cdef enum Diff: +cpdef enum Diff: single = 0 double = 1 +cpdef enum Parameter: + intercept = 0 + slope = 1 + + +cpdef enum Variance: + read_var = 0 + poisson_var = 1 + total_var = 2 + + cdef float threshold(Thresh thresh, float slope) cdef float get_power(float s) cdef deque[stack[RampIndex]] init_ramps(int[:, :] dq) diff --git a/src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx b/src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx index 46eab337a..415220373 100644 --- a/src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx +++ b/src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx @@ -7,7 +7,7 @@ from libcpp.deque cimport deque cimport cython from stcal.ramp_fitting.ols_cas22._core cimport ( - RampFits, RampIndex, Thresh, read_data, init_ramps) + RampFits, RampIndex, Thresh, read_data, init_ramps, Parameter, Variance) from stcal.ramp_fitting.ols_cas22._fixed cimport make_fixed, Fixed from stcal.ramp_fitting.ols_cas22._pixel cimport make_pixel @@ -80,11 +80,11 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants, fit = make_pixel(fixed, read_noise[index], resultants[:, index]).fit_ramps(pixel_ramps[index]) - parameters[index, 1] = fit.average.slope + parameters[index, Parameter.slope] = fit.average.slope - variances[index, 0] = fit.average.read_var - variances[index, 1] = fit.average.poisson_var - variances[index, 2] = fit.average.read_var + fit.average.poisson_var + variances[index, Variance.read_var] = fit.average.read_var + variances[index, Variance.poisson_var] = fit.average.poisson_var + variances[index, Variance.total_var] = fit.average.read_var + fit.average.poisson_var ramp_fits.push_back(fit) diff --git a/tests/test_jump_cas22.py b/tests/test_jump_cas22.py index 5849e2c32..bdb3570aa 100644 --- a/tests/test_jump_cas22.py +++ b/tests/test_jump_cas22.py @@ -6,7 +6,7 @@ from stcal.ramp_fitting.ols_cas22._wrappers import init_ramps from stcal.ramp_fitting.ols_cas22._wrappers import run_threshold, make_fixed, make_pixel, fit_ramp -from stcal.ramp_fitting.ols_cas22 import fit_ramps +from stcal.ramp_fitting.ols_cas22 import fit_ramps, Parameter, Variance, Diff RNG = np.random.default_rng(619) @@ -128,8 +128,8 @@ def test_make_fixed(ramp_data, use_jump): # These are computed via vectorized operations in the main code, here we # check using item-by-item operations if use_jump: - single_gen = zip(fixed['t_bar_diff'][0], fixed['recip'][0], fixed['slope_var'][0]) - double_gen = zip(fixed['t_bar_diff'][1], fixed['recip'][1], fixed['slope_var'][1]) + single_gen = zip(fixed['t_bar_diff'][Diff.single], fixed['recip'][Diff.single], fixed['slope_var'][Diff.single]) + double_gen = zip(fixed['t_bar_diff'][Diff.double], fixed['recip'][Diff.double], fixed['slope_var'][Diff.double]) for index, (t_bar_1, recip_1, slope_var_1) in enumerate(single_gen): assert t_bar_1 == t_bar[index + 1] - t_bar[index] @@ -210,8 +210,8 @@ def test_make_pixel(pixel_data, use_jump): # These are computed via vectorized operations in the main code, here we # check using item-by-item operations if use_jump: - single_gen = zip(pixel['delta'][0], pixel['sigma'][0]) - double_gen = zip(pixel['delta'][1], pixel['sigma'][1]) + single_gen = zip(pixel['delta'][Diff.single], pixel['sigma'][Diff.single]) + double_gen = zip(pixel['delta'][Diff.double], pixel['sigma'][Diff.double]) for index, (delta_1, sigma_1) in enumerate(single_gen): assert delta_1 == (resultants[index + 1] - resultants[index]) / (t_bar[index + 1] - t_bar[index]) @@ -260,12 +260,12 @@ def test_fit_ramps_array_outputs(detector_data, use_jump): fits, parameters, variances = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump) for fit, par, var in zip(fits, parameters, variances): - assert par[0] == 0 - assert par[1] == fit['average']['slope'] + assert par[Parameter.intercept] == 0 + assert par[Parameter.slope] == fit['average']['slope'] - assert var[0] == fit['average']['read_var'] - assert var[1] == fit['average']['poisson_var'] - assert var[2] == np.float32(fit['average']['read_var'] + fit['average']['poisson_var']) + assert var[Variance.read_var] == fit['average']['read_var'] + assert var[Variance.poisson_var] == fit['average']['poisson_var'] + assert var[Variance.total_var] == np.float32(fit['average']['read_var'] + fit['average']['poisson_var']) @pytest.mark.parametrize("use_jump", [True, False]) @@ -381,7 +381,7 @@ def test_find_jumps(jump_data): # There is no way to detect a jump if it is in the very first read # The very first pixel in this case has a jump in the first read assert len(fit['jumps']) == 0 - assert jump[0] + assert jump[0] # sanity check that the jump is in the first resultant still assert not np.all(jump[1:]) # Test that the correct index was recorded