Skip to content

Commit

Permalink
Clean up style issues
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamJamieson committed Oct 4, 2023
1 parent 1f72ae2 commit a8a6ab6
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 37 deletions.
20 changes: 0 additions & 20 deletions src/stcal/ramp_fitting/ols_cas22/_wrappers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -143,23 +143,3 @@ def make_pixel(np.ndarray[float, ndim=1] resultants,
read_noise=pixel.read_noise,
delta=delta,
sigma=sigma)


def fit_ramp(np.ndarray[float, ndim=1] resultants,
np.ndarray[float, ndim=1] t_bar,
np.ndarray[float, ndim=1] tau,
np.ndarray[int, ndim=1] n_reads,
float read_noise,
int start,
int end):

cdef DerivedData data = DerivedData(t_bar, tau, n_reads)
cdef Thresh threshold = Thresh(0, 1)
cdef Fixed fixed = c_make_fixed(data, threshold, False)

cdef Pixel pixel = c_make_pixel(fixed, read_noise, resultants)
cdef RampIndex ramp_index = RampIndex(start, end)

cdef RampFit ramp_fit = pixel.fit_ramp(ramp_index)

return ramp_fit
26 changes: 19 additions & 7 deletions src/stcal/ramp_fitting/ols_cas22_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,22 @@
import numpy as np

from . import ols_cas22
from .ols_cas22_util import ma_table_to_tau, ma_table_to_tbar, read_pattern_to_ma_table, ma_table_to_read_pattern


def fit_ramps_casertano(resultants, dq, read_noise, read_time, ma_table=None, read_pattern=None, use_jump=False):
from .ols_cas22_util import (
ma_table_to_tau,
ma_table_to_tbar,
ma_table_to_read_pattern
)


def fit_ramps_casertano(
resultants,
dq,
read_noise,
read_time,
ma_table=None,
read_pattern=None,
use_jump=False
):
"""Fit ramps following Casertano+2022, including averaging partial ramps.
Ramps are broken where dq != 0, and fits are performed on each sub-ramp.
Expand Down Expand Up @@ -94,9 +106,9 @@ def fit_ramps_casertano(resultants, dq, read_noise, read_time, ma_table=None, re
orig_shape = resultants.shape
if len(resultants.shape) == 1:
# single ramp.
resultants = resultants.reshape(origshape + (1,))
dq = dq.reshape(origshape + (1,))
read_noise = read_noise.reshape(origshape[1:] + (1,))
resultants = resultants.reshape(orig_shape + (1,))
dq = dq.reshape(orig_shape + (1,))
read_noise = read_noise.reshape(orig_shape[1:] + (1,))

ramp_fits, parameters, variances = ols_cas22.fit_ramps(
resultants.reshape(resultants.shape[0], -1),
Expand Down
35 changes: 25 additions & 10 deletions tests/test_jump_cas22.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from stcal.ramp_fitting.ols_cas22._wrappers import read_data
from stcal.ramp_fitting.ols_cas22._wrappers import init_ramps
from stcal.ramp_fitting.ols_cas22._wrappers import run_threshold, make_fixed, make_pixel, fit_ramp
from stcal.ramp_fitting.ols_cas22._wrappers import run_threshold, make_fixed, make_pixel

from stcal.ramp_fitting.ols_cas22 import fit_ramps, Parameter, Variance, Diff

Expand Down Expand Up @@ -128,8 +128,16 @@ def test_make_fixed(ramp_data, use_jump):
# These are computed via vectorized operations in the main code, here we
# check using item-by-item operations
if use_jump:
single_gen = zip(fixed['t_bar_diff'][Diff.single], fixed['recip'][Diff.single], fixed['slope_var'][Diff.single])
double_gen = zip(fixed['t_bar_diff'][Diff.double], fixed['recip'][Diff.double], fixed['slope_var'][Diff.double])
single_gen = zip(
fixed['t_bar_diff'][Diff.single],
fixed['recip'][Diff.single],
fixed['slope_var'][Diff.single]
)
double_gen = zip(
fixed['t_bar_diff'][Diff.double],
fixed['recip'][Diff.double],
fixed['slope_var'][Diff.double]
)

for index, (t_bar_1, recip_1, slope_var_1) in enumerate(single_gen):
assert t_bar_1 == t_bar[index + 1] - t_bar[index]
Expand Down Expand Up @@ -167,7 +175,9 @@ def _generate_resultants(read_pattern, flux, read_noise, n_pixels=1):
# - Poisson process for the flux
# - Gaussian process for the read noise
ramp_value += RNG.poisson(flux * ROMAN_READ_TIME, size=n_pixels).astype(np.float32)
ramp_value += RNG.standard_normal(size=n_pixels, dtype=np.float32) * read_noise / np.sqrt(len(reads))
ramp_value += (
RNG.standard_normal(size=n_pixels, dtype=np.float32)* read_noise / np.sqrt(len(reads))
)

# Add to running total for the resultant
resultant_total += ramp_value
Expand Down Expand Up @@ -225,7 +235,9 @@ def test_make_pixel(pixel_data, use_jump):
assert np.isnan(delta_2)
assert np.isnan(sigma_2)
else:
assert delta_2 == (resultants[index + 2] - resultants[index]) / (t_bar[index + 2] - t_bar[index])
assert delta_2 == (
(resultants[index + 2] - resultants[index]) / (t_bar[index + 2] - t_bar[index])
)
assert sigma_2 == read_noise * (
np.float32(1 / n_reads[index + 2]) + np.float32(1 / n_reads[index])
)
Expand Down Expand Up @@ -257,15 +269,19 @@ def test_fit_ramps_array_outputs(detector_data, use_jump):
resultants, read_noise, read_pattern, n_pixels, flux = detector_data
dq = np.zeros(resultants.shape, dtype=np.int32)

fits, parameters, variances = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump)
fits, parameters, variances = fit_ramps(
resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump
)

for fit, par, var in zip(fits, parameters, variances):
assert par[Parameter.intercept] == 0
assert par[Parameter.slope] == fit['average']['slope']

assert var[Variance.read_var] == fit['average']['read_var']
assert var[Variance.poisson_var] == fit['average']['poisson_var']
assert var[Variance.total_var] == np.float32(fit['average']['read_var'] + fit['average']['poisson_var'])
assert var[Variance.total_var] == np.float32(
fit['average']['read_var'] + fit['average']['poisson_var']
)


@pytest.mark.parametrize("use_jump", [True, False])
Expand Down Expand Up @@ -302,7 +318,7 @@ def test_fit_ramps_dq(detector_data, use_jump):
up any jumps.
"""
resultants, read_noise, read_pattern, n_pixels, flux = detector_data
dq = np.zeros(resultants.shape, dtype=np.int32) + (RNG.uniform(size=resultants.shape) > 1).astype(np.int32)
dq = (RNG.uniform(size=resultants.shape) > 1).astype(np.int32)

# only use okay ramps
# ramps passing the below criterion have at least two adjacent valid reads
Expand Down Expand Up @@ -352,7 +368,7 @@ def jump_data():
# Start indicating a new resultant
jump_res += 1
jumps[jump_res, jump_index] = True

resultants[:, jump_index] = np.mean(read_values.reshape(shape), axis=1).astype(np.float32)

n_pixels = np.prod(shape)
Expand Down Expand Up @@ -424,6 +440,5 @@ def test_find_jumps(jump_data):
# "fairly close" to the expected value. This is purposely a loose check
# because the main purpose of this test is to verify that the jumps are
# being detected correctly, above.
chi2 = 0
for fit in fits:
assert_allclose(fit['average']['slope'], FLUX, rtol=3)

0 comments on commit a8a6ab6

Please sign in to comment.