Skip to content

Commit

Permalink
Move array building to cython
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamJamieson committed Oct 4, 2023
1 parent e79bb89 commit 2e28281
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 19 deletions.
18 changes: 15 additions & 3 deletions src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,23 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants,
# list in the end.
cdef cpp_list[RampFits] ramp_fits

cdef np.ndarray[float, ndim=2] parameters = np.zeros((n_pixels, 2), dtype=np.float32)
cdef np.ndarray[float, ndim=2] variances = np.zeros((n_pixels, 3), dtype=np.float32)

# Perform all of the fits
cdef RampFits fit
cdef int index
for index in range(n_pixels):
# Fit all the ramps for the given pixel
ramp_fits.push_back(make_pixel(fixed, read_noise[index],
resultants[:, index]).fit_ramps(pixel_ramps[index]))
fit = make_pixel(fixed, read_noise[index],
resultants[:, index]).fit_ramps(pixel_ramps[index])

parameters[index, 1] = fit.average.slope

variances[index, 0] = fit.average.read_var
variances[index, 1] = fit.average.poisson_var
variances[index, 2] = fit.average.read_var + fit.average.poisson_var

ramp_fits.push_back(fit)

return ramp_fits
return ramp_fits, parameters, variances
14 changes: 1 addition & 13 deletions src/stcal/ramp_fitting/ols_cas22_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,26 +98,14 @@ def fit_ramps_casertano(resultants, dq, read_noise, read_time, ma_table=None, re
dq = dq.reshape(origshape + (1,))
read_noise = read_noise.reshape(origshape[1:] + (1,))

ramp_fits = ols_cas22.fit_ramps(
ramp_fits, parameters, variances = ols_cas22.fit_ramps(
resultants.reshape(resultants.shape[0], -1),
dq.reshape(resultants.shape[0], -1),
read_noise.reshape(-1),
read_time,
read_pattern,
use_jump)

parameters = np.zeros((len(ramp_fits), 2), dtype=np.float32)
variances = np.zeros((len(ramp_fits), 3), dtype=np.float32)

# Extract the data request from the ramp fits
for index, ramp_fit in enumerate(ramp_fits):
parameters[index, 1] = ramp_fit['average']['slope']

variances[index, 0] = ramp_fit['average']['read_var']
variances[index, 1] = ramp_fit['average']['poisson_var']

variances[:, 2] = (variances[:, 0] + variances[:, 1]).astype(np.float32)

if resultants.shape != orig_shape:
parameters = parameters[0]
variances = variances[0]

Check warning on line 111 in src/stcal/ramp_fitting/ols_cas22_fit.py

View check run for this annotation

Codecov / codecov/patch

src/stcal/ramp_fitting/ols_cas22_fit.py#L110-L111

Added lines #L110 - L111 were not covered by tests
Expand Down
24 changes: 21 additions & 3 deletions tests/test_jump_cas22.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,24 @@ def detector_data(ramp_data):

return resultants, read_noise, read_pattern, N_PIXELS, FLUX

@pytest.mark.parametrize("use_jump", [True, False])
def test_fit_ramps_array_outputs(detector_data, use_jump):
"""
Test that the array outputs line up with the dictionary output
"""
resultants, read_noise, read_pattern, n_pixels, flux = detector_data
dq = np.zeros(resultants.shape, dtype=np.int32)

fits, parameters, variances = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump)

for fit, par, var in zip(fits, parameters, variances):
assert par[0] == 0
assert par[1] == fit['average']['slope']

assert var[0] == fit['average']['read_var']
assert var[1] == fit['average']['poisson_var']
assert var[2] == np.float32(fit['average']['read_var'] + fit['average']['poisson_var'])


@pytest.mark.parametrize("use_jump", [True, False])
def test_fit_ramps_no_dq(detector_data, use_jump):
Expand All @@ -260,7 +278,7 @@ def test_fit_ramps_no_dq(detector_data, use_jump):
resultants, read_noise, read_pattern, n_pixels, flux = detector_data
dq = np.zeros(resultants.shape, dtype=np.int32)

fits = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump)
fits, _, _ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump)
assert len(fits) == n_pixels # sanity check that a fit is output for each pixel

# Check that the chi2 for the resulting fit relative to the assumed flux is ~1
Expand Down Expand Up @@ -291,7 +309,7 @@ def test_fit_ramps_dq(detector_data, use_jump):
# i.e., we can make a measurement from them.
okay = np.sum((dq[1:, :] == 0) & (dq[:-1, :] == 0), axis=0) != 0

fits = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump)
fits, _, _ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump)
assert len(fits) == n_pixels # sanity check that a fit is output for each pixel

chi2 = 0
Expand Down Expand Up @@ -353,7 +371,7 @@ def test_find_jumps(jump_data):
resultants, read_noise, read_pattern, n_pixels, jumps = jump_data
dq = np.zeros(resultants.shape, dtype=np.int32)

fits = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True)
fits, _, _ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True)

# Check that all the jumps have been located per the algorithm's constraints
for index, (fit, jump) in enumerate(zip(fits, jumps)):
Expand Down

0 comments on commit 2e28281

Please sign in to comment.