From a8a6ab68dde9dcef3ec6087805021f973421ddef Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Wed, 4 Oct 2023 14:52:57 -0400 Subject: [PATCH] Clean up style issues --- .../ramp_fitting/ols_cas22/_wrappers.pyx | 20 ----------- src/stcal/ramp_fitting/ols_cas22_fit.py | 26 ++++++++++---- tests/test_jump_cas22.py | 35 +++++++++++++------ 3 files changed, 44 insertions(+), 37 deletions(-) diff --git a/src/stcal/ramp_fitting/ols_cas22/_wrappers.pyx b/src/stcal/ramp_fitting/ols_cas22/_wrappers.pyx index 7bdc6f93b..975010d6e 100644 --- a/src/stcal/ramp_fitting/ols_cas22/_wrappers.pyx +++ b/src/stcal/ramp_fitting/ols_cas22/_wrappers.pyx @@ -143,23 +143,3 @@ def make_pixel(np.ndarray[float, ndim=1] resultants, read_noise=pixel.read_noise, delta=delta, sigma=sigma) - - -def fit_ramp(np.ndarray[float, ndim=1] resultants, - np.ndarray[float, ndim=1] t_bar, - np.ndarray[float, ndim=1] tau, - np.ndarray[int, ndim=1] n_reads, - float read_noise, - int start, - int end): - - cdef DerivedData data = DerivedData(t_bar, tau, n_reads) - cdef Thresh threshold = Thresh(0, 1) - cdef Fixed fixed = c_make_fixed(data, threshold, False) - - cdef Pixel pixel = c_make_pixel(fixed, read_noise, resultants) - cdef RampIndex ramp_index = RampIndex(start, end) - - cdef RampFit ramp_fit = pixel.fit_ramp(ramp_index) - - return ramp_fit diff --git a/src/stcal/ramp_fitting/ols_cas22_fit.py b/src/stcal/ramp_fitting/ols_cas22_fit.py index 974c5d4e7..db6c2d5d3 100644 --- a/src/stcal/ramp_fitting/ols_cas22_fit.py +++ b/src/stcal/ramp_fitting/ols_cas22_fit.py @@ -33,10 +33,22 @@ import numpy as np from . import ols_cas22 -from .ols_cas22_util import ma_table_to_tau, ma_table_to_tbar, read_pattern_to_ma_table, ma_table_to_read_pattern - - -def fit_ramps_casertano(resultants, dq, read_noise, read_time, ma_table=None, read_pattern=None, use_jump=False): +from .ols_cas22_util import ( + ma_table_to_tau, + ma_table_to_tbar, + ma_table_to_read_pattern +) + + +def fit_ramps_casertano( + resultants, + dq, + read_noise, + read_time, + ma_table=None, + read_pattern=None, + use_jump=False +): """Fit ramps following Casertano+2022, including averaging partial ramps. Ramps are broken where dq != 0, and fits are performed on each sub-ramp. @@ -94,9 +106,9 @@ def fit_ramps_casertano(resultants, dq, read_noise, read_time, ma_table=None, re orig_shape = resultants.shape if len(resultants.shape) == 1: # single ramp. - resultants = resultants.reshape(origshape + (1,)) - dq = dq.reshape(origshape + (1,)) - read_noise = read_noise.reshape(origshape[1:] + (1,)) + resultants = resultants.reshape(orig_shape + (1,)) + dq = dq.reshape(orig_shape + (1,)) + read_noise = read_noise.reshape(orig_shape[1:] + (1,)) ramp_fits, parameters, variances = ols_cas22.fit_ramps( resultants.reshape(resultants.shape[0], -1), diff --git a/tests/test_jump_cas22.py b/tests/test_jump_cas22.py index bdb3570aa..3101fb928 100644 --- a/tests/test_jump_cas22.py +++ b/tests/test_jump_cas22.py @@ -4,7 +4,7 @@ from stcal.ramp_fitting.ols_cas22._wrappers import read_data from stcal.ramp_fitting.ols_cas22._wrappers import init_ramps -from stcal.ramp_fitting.ols_cas22._wrappers import run_threshold, make_fixed, make_pixel, fit_ramp +from stcal.ramp_fitting.ols_cas22._wrappers import run_threshold, make_fixed, make_pixel from stcal.ramp_fitting.ols_cas22 import fit_ramps, Parameter, Variance, Diff @@ -128,8 +128,16 @@ def test_make_fixed(ramp_data, use_jump): # These are computed via vectorized operations in the main code, here we # check using item-by-item operations if use_jump: - single_gen = zip(fixed['t_bar_diff'][Diff.single], fixed['recip'][Diff.single], fixed['slope_var'][Diff.single]) - double_gen = zip(fixed['t_bar_diff'][Diff.double], fixed['recip'][Diff.double], fixed['slope_var'][Diff.double]) + single_gen = zip( + fixed['t_bar_diff'][Diff.single], + fixed['recip'][Diff.single], + fixed['slope_var'][Diff.single] + ) + double_gen = zip( + fixed['t_bar_diff'][Diff.double], + fixed['recip'][Diff.double], + fixed['slope_var'][Diff.double] + ) for index, (t_bar_1, recip_1, slope_var_1) in enumerate(single_gen): assert t_bar_1 == t_bar[index + 1] - t_bar[index] @@ -167,7 +175,9 @@ def _generate_resultants(read_pattern, flux, read_noise, n_pixels=1): # - Poisson process for the flux # - Gaussian process for the read noise ramp_value += RNG.poisson(flux * ROMAN_READ_TIME, size=n_pixels).astype(np.float32) - ramp_value += RNG.standard_normal(size=n_pixels, dtype=np.float32) * read_noise / np.sqrt(len(reads)) + ramp_value += ( + RNG.standard_normal(size=n_pixels, dtype=np.float32)* read_noise / np.sqrt(len(reads)) + ) # Add to running total for the resultant resultant_total += ramp_value @@ -225,7 +235,9 @@ def test_make_pixel(pixel_data, use_jump): assert np.isnan(delta_2) assert np.isnan(sigma_2) else: - assert delta_2 == (resultants[index + 2] - resultants[index]) / (t_bar[index + 2] - t_bar[index]) + assert delta_2 == ( + (resultants[index + 2] - resultants[index]) / (t_bar[index + 2] - t_bar[index]) + ) assert sigma_2 == read_noise * ( np.float32(1 / n_reads[index + 2]) + np.float32(1 / n_reads[index]) ) @@ -257,7 +269,9 @@ def test_fit_ramps_array_outputs(detector_data, use_jump): resultants, read_noise, read_pattern, n_pixels, flux = detector_data dq = np.zeros(resultants.shape, dtype=np.int32) - fits, parameters, variances = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump) + fits, parameters, variances = fit_ramps( + resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump + ) for fit, par, var in zip(fits, parameters, variances): assert par[Parameter.intercept] == 0 @@ -265,7 +279,9 @@ def test_fit_ramps_array_outputs(detector_data, use_jump): assert var[Variance.read_var] == fit['average']['read_var'] assert var[Variance.poisson_var] == fit['average']['poisson_var'] - assert var[Variance.total_var] == np.float32(fit['average']['read_var'] + fit['average']['poisson_var']) + assert var[Variance.total_var] == np.float32( + fit['average']['read_var'] + fit['average']['poisson_var'] + ) @pytest.mark.parametrize("use_jump", [True, False]) @@ -302,7 +318,7 @@ def test_fit_ramps_dq(detector_data, use_jump): up any jumps. """ resultants, read_noise, read_pattern, n_pixels, flux = detector_data - dq = np.zeros(resultants.shape, dtype=np.int32) + (RNG.uniform(size=resultants.shape) > 1).astype(np.int32) + dq = (RNG.uniform(size=resultants.shape) > 1).astype(np.int32) # only use okay ramps # ramps passing the below criterion have at least two adjacent valid reads @@ -352,7 +368,7 @@ def jump_data(): # Start indicating a new resultant jump_res += 1 jumps[jump_res, jump_index] = True - + resultants[:, jump_index] = np.mean(read_values.reshape(shape), axis=1).astype(np.float32) n_pixels = np.prod(shape) @@ -424,6 +440,5 @@ def test_find_jumps(jump_data): # "fairly close" to the expected value. This is purposely a loose check # because the main purpose of this test is to verify that the jumps are # being detected correctly, above. - chi2 = 0 for fit in fits: assert_allclose(fit['average']['slope'], FLUX, rtol=3)