diff --git a/src/stcal/ramp_fitting/ols_cas22/__init__.py b/src/stcal/ramp_fitting/ols_cas22/__init__.py index 8b2a69abb..d07f4ddff 100644 --- a/src/stcal/ramp_fitting/ols_cas22/__init__.py +++ b/src/stcal/ramp_fitting/ols_cas22/__init__.py @@ -1,4 +1,4 @@ -from ._fit_ramps import fit_ramps +from ._fit_ramps import fit_ramps, RampFitOutputs from ._core import Parameter, Variance, Diff, RampJumpDQ -__all__ = ['fit_ramps', 'Parameter', 'Variance', 'Diff', 'RampJumpDQ'] +__all__ = ['fit_ramps', 'RampFitOutputs', 'Parameter', 'Variance', 'Diff', 'RampJumpDQ'] diff --git a/src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx b/src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx index 65f40d904..8f0d132eb 100644 --- a/src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx +++ b/src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx @@ -12,12 +12,31 @@ from stcal.ramp_fitting.ols_cas22._core cimport (RampFits, RampIndex, Thresh, from stcal.ramp_fitting.ols_cas22._fixed cimport fixed_values_from_metadata, FixedValues from stcal.ramp_fitting.ols_cas22._pixel cimport make_pixel +from typing import NamedTuple + # Fix the default Threshold values at compile time these values cannot be overridden # dynamically at runtime. DEF DefaultIntercept = 5.5 DEF DefaultConstant = 1/3.0 +class RampFitOutputs(NamedTuple): + """ + Simple tuple wrapper for outputs from the ramp fitting algorithm + This clarifies the meaning of the outputs via naming them something + descriptive. + """ + + fits: list + parameters: np.ndarray + variances: np.ndarray + dq: np.ndarray + # def __init__(self, fits, parameters, variances, dq): + # self.fits = fits + # self.parameters = parameters + # self.variances = variances + # self.dq = dq + @cython.boundscheck(False) @cython.wraparound(False) @@ -107,4 +126,4 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants, ramp_fits.push_back(fit) - return ramp_fits, parameters, variances, fit_dq + return RampFitOutputs(ramp_fits, parameters, variances, fit_dq) diff --git a/src/stcal/ramp_fitting/ols_cas22_fit.py b/src/stcal/ramp_fitting/ols_cas22_fit.py index 94d382425..945797e87 100644 --- a/src/stcal/ramp_fitting/ols_cas22_fit.py +++ b/src/stcal/ramp_fitting/ols_cas22_fit.py @@ -111,7 +111,7 @@ def fit_ramps_casertano( dq = dq.reshape(orig_shape + (1,)) read_noise = read_noise.reshape(orig_shape[1:] + (1,)) - _, parameters, variances, _ = ols_cas22.fit_ramps( + output = ols_cas22.fit_ramps( resultants.reshape(resultants.shape[0], -1), dq.reshape(resultants.shape[0], -1), read_noise.reshape(-1), @@ -120,11 +120,13 @@ def fit_ramps_casertano( use_jump, **kwargs) + parameters = output.parameters + variances = output.variances if resultants.shape != orig_shape: - parameters = parameters[0] - variances = variances[0] + parameters = output.parameters[0] + variances = output.variances[0] if resultants_unit is not None: - parameters = parameters * resultants_unit + parameters = output.parameters * resultants_unit - return parameters, variances + return ols_cas22.RampFitOutputs(output.fits, parameters, variances, output.dq) diff --git a/tests/test_jump_cas22.py b/tests/test_jump_cas22.py index f372179af..70c699dac 100644 --- a/tests/test_jump_cas22.py +++ b/tests/test_jump_cas22.py @@ -349,11 +349,11 @@ def test_fit_ramps(detector_data, use_jump, use_dq): if not use_dq: assert okay.all() - fits, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump) - assert len(fits) == N_PIXELS # sanity check that a fit is output for each pixel + output = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump) + assert len(output.fits) == N_PIXELS # sanity check that a fit is output for each pixel chi2 = 0 - for fit, use in zip(fits, okay): + for fit, use in zip(output.fits, okay): if not use_dq: assert len(fit['fits']) == 1 # only one fit per pixel since no dq/jump in this case @@ -381,11 +381,9 @@ def test_fit_ramps_array_outputs(detector_data, use_jump): resultants, read_noise, read_pattern = detector_data dq = np.zeros(resultants.shape, dtype=np.int32) - fits, parameters, variances, _ = fit_ramps( - resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump - ) + output = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump) - for fit, par, var in zip(fits, parameters, variances): + for fit, par, var in zip(output.fits, output.parameters, output.variances): assert par[Parameter.intercept] == 0 assert par[Parameter.slope] == fit['average']['slope'] @@ -455,11 +453,11 @@ def test_find_jumps(jump_data): resultants, read_noise, read_pattern, jump_reads, jump_resultants = jump_data dq = np.zeros(resultants.shape, dtype=np.int32) - fits, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True) - assert len(fits) == len(jump_reads) # sanity check that a fit/jump is set for every pixel + output = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True) + assert len(output.fits) == len(jump_reads) # sanity check that a fit/jump is set for every pixel chi2 = 0 - for fit, jump_index, resultant_index in zip(fits, jump_reads, jump_resultants): + for fit, jump_index, resultant_index in zip(output.fits, jump_reads, jump_resultants): # Check that the jumps are detected correctly if jump_index == 0: @@ -514,13 +512,13 @@ def test_override_default_threshold(jump_data): resultants, read_noise, read_pattern, jump_reads, jump_resultants = jump_data dq = np.zeros(resultants.shape, dtype=np.int32) - _, standard, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True) - _, override, *_ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True, - intercept=0, constant=0) + standard = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True) + override = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True, + intercept=0, constant=0) # All this is intended to do is show that with all other things being equal passing non-default # threshold parameters changes the results. - assert (standard != override).any() + assert (standard.parameters != override.parameters).any() def test_jump_dq_set(jump_data): @@ -530,12 +528,11 @@ def test_jump_dq_set(jump_data): resultants, read_noise, read_pattern, jump_reads, jump_resultants = jump_data dq = np.zeros(resultants.shape, dtype=np.int32) - fits, *_, fit_dq = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True) + output = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True) - for fit, pixel_dq in zip(fits, fit_dq.transpose()): + for fit, pixel_dq in zip(output.fits, output.dq.transpose()): # Check that all jumps found get marked assert (pixel_dq[fit['jumps']] == RampJumpDQ.JUMP_DET).all() # Check that dq flags for jumps are only set if the jump is marked assert set(np.where(pixel_dq == RampJumpDQ.JUMP_DET)[0]) == set(fit['jumps']) - diff --git a/tests/test_ramp_fitting_cas22.py b/tests/test_ramp_fitting_cas22.py index 0ed481bd6..60c6ff0c3 100644 --- a/tests/test_ramp_fitting_cas22.py +++ b/tests/test_ramp_fitting_cas22.py @@ -2,7 +2,9 @@ """ Unit tests for ramp-fitting functions. """ +import astropy.units as u import numpy as np +import pytest from stcal.ramp_fitting import ols_cas22_fit as ramp @@ -13,23 +15,33 @@ ROMAN_READ_TIME = 3.04 -def test_simulated_ramps(): +@pytest.mark.parametrize("use_unit", [True, False]) +def test_simulated_ramps(use_unit): ntrial = 100000 read_pattern, flux, read_noise, resultants = simulate_many_ramps(ntrial=ntrial) + if use_unit: + resultants = resultants * u.electron + dq = np.zeros(resultants.shape, dtype=np.int32) read_noise = np.ones(resultants.shape[1], dtype=np.float32) * read_noise - par, var = ramp.fit_ramps_casertano( + output = ramp.fit_ramps_casertano( resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern) - chi2dof_slope = np.sum((par[:, 1] - flux)**2 / var[:, 2]) / ntrial + if use_unit: + assert output.parameters.unit == u.electron + parameters = output.parameters.value + else: + parameters = output.parameters + + chi2dof_slope = np.sum((parameters[:, 1] - flux)**2 / output.variances[:, 2]) / ntrial assert np.abs(chi2dof_slope - 1) < 0.03 # now let's mark a bunch of the ramps as compromised. bad = np.random.uniform(size=resultants.shape) > 0.7 dq |= bad - par, var = ramp.fit_ramps_casertano( + output = ramp.fit_ramps_casertano( resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, threshold_constant=0, threshold_intercept=0) # set the threshold parameters # to demo the interface. This @@ -42,10 +54,17 @@ def test_simulated_ramps(): # ramps passing the below criterion have at least two adjacent valid reads # i.e., we can make a measurement from them. m = np.sum((dq[1:, :] == 0) & (dq[:-1, :] == 0), axis=0) != 0 - chi2dof_slope = np.sum((par[m, 1] - flux)**2 / var[m, 2]) / np.sum(m) + + if use_unit: + assert output.parameters.unit == u.electron + parameters = output.parameters.value + else: + parameters = output.parameters + + chi2dof_slope = np.sum((parameters[m, 1] - flux)**2 / output.variances[m, 2]) / np.sum(m) assert np.abs(chi2dof_slope - 1) < 0.03 - assert np.all(par[~m, 1] == 0) - assert np.all(var[~m, 1] == 0) + assert np.all(parameters[~m, 1] == 0) + assert np.all(output.variances[~m, 1] == 0) # #########