Skip to content

Commit

Permalink
Minor cleanups with enum values
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamJamieson committed Oct 4, 2023
1 parent 8580147 commit a7dc16b
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 18 deletions.
3 changes: 2 additions & 1 deletion src/stcal/ramp_fitting/ols_cas22/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._fit_ramps import fit_ramps
from ._core import Parameter, Variance, Diff

__all__ = ['fit_ramps']
__all__ = ['fit_ramps', 'Parameter', 'Variance', 'Diff']
13 changes: 12 additions & 1 deletion src/stcal/ramp_fitting/ols_cas22/_core.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,22 @@ cdef struct Thresh:
float constant


cdef enum Diff:
cpdef enum Diff:
single = 0
double = 1


cpdef enum Parameter:
intercept = 0
slope = 1


cpdef enum Variance:
read_var = 0
poisson_var = 1
total_var = 2


cdef float threshold(Thresh thresh, float slope)
cdef float get_power(float s)
cdef deque[stack[RampIndex]] init_ramps(int[:, :] dq)
Expand Down
10 changes: 5 additions & 5 deletions src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ from libcpp.deque cimport deque
cimport cython

from stcal.ramp_fitting.ols_cas22._core cimport (
RampFits, RampIndex, Thresh, read_data, init_ramps)
RampFits, RampIndex, Thresh, read_data, init_ramps, Parameter, Variance)
from stcal.ramp_fitting.ols_cas22._fixed cimport make_fixed, Fixed
from stcal.ramp_fitting.ols_cas22._pixel cimport make_pixel

Expand Down Expand Up @@ -80,11 +80,11 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants,
fit = make_pixel(fixed, read_noise[index],
resultants[:, index]).fit_ramps(pixel_ramps[index])

parameters[index, 1] = fit.average.slope
parameters[index, Parameter.slope] = fit.average.slope

variances[index, 0] = fit.average.read_var
variances[index, 1] = fit.average.poisson_var
variances[index, 2] = fit.average.read_var + fit.average.poisson_var
variances[index, Variance.read_var] = fit.average.read_var
variances[index, Variance.poisson_var] = fit.average.poisson_var
variances[index, Variance.total_var] = fit.average.read_var + fit.average.poisson_var

ramp_fits.push_back(fit)

Expand Down
22 changes: 11 additions & 11 deletions tests/test_jump_cas22.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from stcal.ramp_fitting.ols_cas22._wrappers import init_ramps
from stcal.ramp_fitting.ols_cas22._wrappers import run_threshold, make_fixed, make_pixel, fit_ramp

from stcal.ramp_fitting.ols_cas22 import fit_ramps
from stcal.ramp_fitting.ols_cas22 import fit_ramps, Parameter, Variance, Diff


RNG = np.random.default_rng(619)
Expand Down Expand Up @@ -128,8 +128,8 @@ def test_make_fixed(ramp_data, use_jump):
# These are computed via vectorized operations in the main code, here we
# check using item-by-item operations
if use_jump:
single_gen = zip(fixed['t_bar_diff'][0], fixed['recip'][0], fixed['slope_var'][0])
double_gen = zip(fixed['t_bar_diff'][1], fixed['recip'][1], fixed['slope_var'][1])
single_gen = zip(fixed['t_bar_diff'][Diff.single], fixed['recip'][Diff.single], fixed['slope_var'][Diff.single])
double_gen = zip(fixed['t_bar_diff'][Diff.double], fixed['recip'][Diff.double], fixed['slope_var'][Diff.double])

for index, (t_bar_1, recip_1, slope_var_1) in enumerate(single_gen):
assert t_bar_1 == t_bar[index + 1] - t_bar[index]
Expand Down Expand Up @@ -210,8 +210,8 @@ def test_make_pixel(pixel_data, use_jump):
# These are computed via vectorized operations in the main code, here we
# check using item-by-item operations
if use_jump:
single_gen = zip(pixel['delta'][0], pixel['sigma'][0])
double_gen = zip(pixel['delta'][1], pixel['sigma'][1])
single_gen = zip(pixel['delta'][Diff.single], pixel['sigma'][Diff.single])
double_gen = zip(pixel['delta'][Diff.double], pixel['sigma'][Diff.double])

for index, (delta_1, sigma_1) in enumerate(single_gen):
assert delta_1 == (resultants[index + 1] - resultants[index]) / (t_bar[index + 1] - t_bar[index])
Expand Down Expand Up @@ -260,12 +260,12 @@ def test_fit_ramps_array_outputs(detector_data, use_jump):
fits, parameters, variances = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump)

for fit, par, var in zip(fits, parameters, variances):
assert par[0] == 0
assert par[1] == fit['average']['slope']
assert par[Parameter.intercept] == 0
assert par[Parameter.slope] == fit['average']['slope']

assert var[0] == fit['average']['read_var']
assert var[1] == fit['average']['poisson_var']
assert var[2] == np.float32(fit['average']['read_var'] + fit['average']['poisson_var'])
assert var[Variance.read_var] == fit['average']['read_var']
assert var[Variance.poisson_var] == fit['average']['poisson_var']
assert var[Variance.total_var] == np.float32(fit['average']['read_var'] + fit['average']['poisson_var'])


@pytest.mark.parametrize("use_jump", [True, False])
Expand Down Expand Up @@ -381,7 +381,7 @@ def test_find_jumps(jump_data):
# There is no way to detect a jump if it is in the very first read
# The very first pixel in this case has a jump in the first read
assert len(fit['jumps']) == 0
assert jump[0]
assert jump[0] # sanity check that the jump is in the first resultant still
assert not np.all(jump[1:])

# Test that the correct index was recorded
Expand Down

0 comments on commit a7dc16b

Please sign in to comment.