Skip to content

Commit

Permalink
Clean up ramp_fit outputs so they are named tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamJamieson committed Oct 6, 2023
1 parent c55eb6c commit dcded7a
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 32 deletions.
4 changes: 2 additions & 2 deletions src/stcal/ramp_fitting/ols_cas22/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._fit_ramps import fit_ramps
from ._fit_ramps import fit_ramps, RampFitOutputs
from ._core import Parameter, Variance, Diff, RampJumpDQ

__all__ = ['fit_ramps', 'Parameter', 'Variance', 'Diff', 'RampJumpDQ']
__all__ = ['fit_ramps', 'RampFitOutputs', 'Parameter', 'Variance', 'Diff', 'RampJumpDQ']
21 changes: 20 additions & 1 deletion src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,31 @@ from stcal.ramp_fitting.ols_cas22._core cimport (RampFits, RampIndex, Thresh,
from stcal.ramp_fitting.ols_cas22._fixed cimport fixed_values_from_metadata, FixedValues
from stcal.ramp_fitting.ols_cas22._pixel cimport make_pixel

from typing import NamedTuple


# Fix the default Threshold values at compile time these values cannot be overridden
# dynamically at runtime.
DEF DefaultIntercept = 5.5
DEF DefaultConstant = 1/3.0

class RampFitOutputs(NamedTuple):
"""
Simple tuple wrapper for outputs from the ramp fitting algorithm
This clarifies the meaning of the outputs via naming them something
descriptive.
"""

fits: list
parameters: np.ndarray
variances: np.ndarray
dq: np.ndarray
# def __init__(self, fits, parameters, variances, dq):
# self.fits = fits
# self.parameters = parameters
# self.variances = variances
# self.dq = dq


@cython.boundscheck(False)
@cython.wraparound(False)
Expand Down Expand Up @@ -107,4 +126,4 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants,

ramp_fits.push_back(fit)

return ramp_fits, parameters, variances, fit_dq
return RampFitOutputs(ramp_fits, parameters, variances, fit_dq)
12 changes: 7 additions & 5 deletions src/stcal/ramp_fitting/ols_cas22_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def fit_ramps_casertano(
dq = dq.reshape(orig_shape + (1,))
read_noise = read_noise.reshape(orig_shape[1:] + (1,))

Check warning on line 112 in src/stcal/ramp_fitting/ols_cas22_fit.py

View check run for this annotation

Codecov / codecov/patch

src/stcal/ramp_fitting/ols_cas22_fit.py#L110-L112

Added lines #L110 - L112 were not covered by tests

_, parameters, variances, _ = ols_cas22.fit_ramps(
output = ols_cas22.fit_ramps(
resultants.reshape(resultants.shape[0], -1),
dq.reshape(resultants.shape[0], -1),
read_noise.reshape(-1),
Expand All @@ -120,11 +120,13 @@ def fit_ramps_casertano(
use_jump,
**kwargs)

parameters = output.parameters
variances = output.variances
if resultants.shape != orig_shape:
parameters = parameters[0]
variances = variances[0]
parameters = output.parameters[0]
variances = output.variances[0]

Check warning on line 127 in src/stcal/ramp_fitting/ols_cas22_fit.py

View check run for this annotation

Codecov / codecov/patch

src/stcal/ramp_fitting/ols_cas22_fit.py#L126-L127

Added lines #L126 - L127 were not covered by tests

if resultants_unit is not None:
parameters = parameters * resultants_unit
parameters = output.parameters * resultants_unit

return parameters, variances
return ols_cas22.RampFitOutputs(output.fits, parameters, variances, output.dq)
31 changes: 14 additions & 17 deletions tests/test_jump_cas22.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,11 @@ def test_fit_ramps(detector_data, use_jump, use_dq):
if not use_dq:
assert okay.all()

fits, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump)
assert len(fits) == N_PIXELS # sanity check that a fit is output for each pixel
output = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump)
assert len(output.fits) == N_PIXELS # sanity check that a fit is output for each pixel

chi2 = 0
for fit, use in zip(fits, okay):
for fit, use in zip(output.fits, okay):
if not use_dq:
assert len(fit['fits']) == 1 # only one fit per pixel since no dq/jump in this case

Expand Down Expand Up @@ -381,11 +381,9 @@ def test_fit_ramps_array_outputs(detector_data, use_jump):
resultants, read_noise, read_pattern = detector_data
dq = np.zeros(resultants.shape, dtype=np.int32)

fits, parameters, variances, _ = fit_ramps(
resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump
)
output = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump)

for fit, par, var in zip(fits, parameters, variances):
for fit, par, var in zip(output.fits, output.parameters, output.variances):
assert par[Parameter.intercept] == 0
assert par[Parameter.slope] == fit['average']['slope']

Expand Down Expand Up @@ -455,11 +453,11 @@ def test_find_jumps(jump_data):
resultants, read_noise, read_pattern, jump_reads, jump_resultants = jump_data
dq = np.zeros(resultants.shape, dtype=np.int32)

fits, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True)
assert len(fits) == len(jump_reads) # sanity check that a fit/jump is set for every pixel
output = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True)
assert len(output.fits) == len(jump_reads) # sanity check that a fit/jump is set for every pixel

chi2 = 0
for fit, jump_index, resultant_index in zip(fits, jump_reads, jump_resultants):
for fit, jump_index, resultant_index in zip(output.fits, jump_reads, jump_resultants):

# Check that the jumps are detected correctly
if jump_index == 0:
Expand Down Expand Up @@ -514,13 +512,13 @@ def test_override_default_threshold(jump_data):
resultants, read_noise, read_pattern, jump_reads, jump_resultants = jump_data
dq = np.zeros(resultants.shape, dtype=np.int32)

_, standard, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True)
_, override, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True,
intercept=0, constant=0)
standard = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True)
override = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True,
intercept=0, constant=0)

# All this is intended to do is show that with all other things being equal passing non-default
# threshold parameters changes the results.
assert (standard != override).any()
assert (standard.parameters != override.parameters).any()


def test_jump_dq_set(jump_data):
Expand All @@ -530,12 +528,11 @@ def test_jump_dq_set(jump_data):
resultants, read_noise, read_pattern, jump_reads, jump_resultants = jump_data
dq = np.zeros(resultants.shape, dtype=np.int32)

fits, *_, fit_dq = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True)
output = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True)

for fit, pixel_dq in zip(fits, fit_dq.transpose()):
for fit, pixel_dq in zip(output.fits, output.dq.transpose()):
# Check that all jumps found get marked
assert (pixel_dq[fit['jumps']] == RampJumpDQ.JUMP_DET).all()

# Check that dq flags for jumps are only set if the jump is marked
assert set(np.where(pixel_dq == RampJumpDQ.JUMP_DET)[0]) == set(fit['jumps'])

33 changes: 26 additions & 7 deletions tests/test_ramp_fitting_cas22.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
"""
Unit tests for ramp-fitting functions.
"""
import astropy.units as u
import numpy as np
import pytest

from stcal.ramp_fitting import ols_cas22_fit as ramp

Expand All @@ -13,23 +15,33 @@
ROMAN_READ_TIME = 3.04


def test_simulated_ramps():
@pytest.mark.parametrize("use_unit", [True, False])
def test_simulated_ramps(use_unit):
ntrial = 100000
read_pattern, flux, read_noise, resultants = simulate_many_ramps(ntrial=ntrial)

if use_unit:
resultants = resultants * u.electron

dq = np.zeros(resultants.shape, dtype=np.int32)
read_noise = np.ones(resultants.shape[1], dtype=np.float32) * read_noise

par, var = ramp.fit_ramps_casertano(
output = ramp.fit_ramps_casertano(
resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern)

chi2dof_slope = np.sum((par[:, 1] - flux)**2 / var[:, 2]) / ntrial
if use_unit:
assert output.parameters.unit == u.electron
parameters = output.parameters.value
else:
parameters = output.parameters

chi2dof_slope = np.sum((parameters[:, 1] - flux)**2 / output.variances[:, 2]) / ntrial
assert np.abs(chi2dof_slope - 1) < 0.03

# now let's mark a bunch of the ramps as compromised.
bad = np.random.uniform(size=resultants.shape) > 0.7
dq |= bad
par, var = ramp.fit_ramps_casertano(
output = ramp.fit_ramps_casertano(
resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern,
threshold_constant=0, threshold_intercept=0) # set the threshold parameters
# to demo the interface. This
Expand All @@ -42,10 +54,17 @@ def test_simulated_ramps():
# ramps passing the below criterion have at least two adjacent valid reads
# i.e., we can make a measurement from them.
m = np.sum((dq[1:, :] == 0) & (dq[:-1, :] == 0), axis=0) != 0
chi2dof_slope = np.sum((par[m, 1] - flux)**2 / var[m, 2]) / np.sum(m)

if use_unit:
assert output.parameters.unit == u.electron
parameters = output.parameters.value
else:
parameters = output.parameters

chi2dof_slope = np.sum((parameters[m, 1] - flux)**2 / output.variances[m, 2]) / np.sum(m)
assert np.abs(chi2dof_slope - 1) < 0.03
assert np.all(par[~m, 1] == 0)
assert np.all(var[~m, 1] == 0)
assert np.all(parameters[~m, 1] == 0)
assert np.all(output.variances[~m, 1] == 0)


# #########
Expand Down

0 comments on commit dcded7a

Please sign in to comment.