Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamJamieson committed Nov 1, 2023
1 parent 17e96c8 commit 5770243
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 94 deletions.
6 changes: 3 additions & 3 deletions src/stcal/ramp_fitting/ols_cas22/_core.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ cdef struct RampFit:


cdef struct RampFits:
# vector[RampFit] fits
# vector[RampIndex] index
vector[int] jumps
RampFit average
vector[int] jumps
vector[RampFit] fits
vector[RampIndex] index


cdef struct ReadPatternMetadata:
Expand Down
87 changes: 24 additions & 63 deletions src/stcal/ramp_fitting/ols_cas22/_core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -193,69 +193,30 @@ cdef inline stack[RampIndex] init_ramps(int[:, :] dq, int n_resultants, int inde
return ramps


# @cython.boundscheck(False)
# @cython.wraparound(False)
# cdef inline deque[stack[RampIndex]] init_ramps(int[:, :] dq):
# """
# Create the initial ramp stack for each pixel
# if dq[index_resultant, index_pixel] == 0, then the resultant is in a ramp
# otherwise, the resultant is not in a ramp

# Parameters
# ----------
# dq : int[n_resultants, n_pixel]
# DQ array

# Returns
# -------
# deque of stacks of RampIndex objects
# - deque with entry for each pixel
# Chosen to be deque because need element access to loop
# - stack with entry for each ramp found (top of stack is last ramp found)
# - RampIndex with start and end indices of the ramp in the resultants
# """
# cdef int n_pixel, n_resultants

# n_resultants, n_pixel = np.array(dq).shape
# cdef deque[stack[RampIndex]] pixel_ramps

# cdef int index_pixel

# for index_pixel in range(n_pixel):
# # Add ramp stack for pixel to list
# pixel_ramps.push_back(_init_ramps_pixel(dq, n_resultants, index_pixel))

# return pixel_ramps


# def _init_ramps_list(np.ndarray[int, ndim=2] dq):
# """
# This is a wrapper for init_ramps so that it can be fully inspected from pure
# python. A cpdef cannot be used in that case becase a stack has no direct python
# analog. Instead this function turns that stack into a list ordered in the same
# order as the stack; meaning that, the first element of the list is the top of
# the stack.
# Note this function is for testing purposes only and so is marked as private
# within this private module
# """
# cdef deque[stack[RampIndex]] raw = init_ramps(dq)

# # Have to turn deque and stack into python compatible objects
# cdef RampIndex index
# cdef stack[RampIndex] ramp
# cdef list out = []
# cdef list stack_out
# for ramp in raw:
# stack_out = []
# while not ramp.empty():
# index = ramp.top()
# ramp.pop()
# # So top of stack is first item of list
# stack_out = [index] + stack_out

# out.append(stack_out)

# return out
def _init_ramps_list(np.ndarray[int, ndim=2] dq, int n_resultants, int index_pixel):
"""
This is a wrapper for init_ramps so that it can be fully inspected from pure
python. A cpdef cannot be used in that case becase a stack has no direct python
analog. Instead this function turns that stack into a list ordered in the same
order as the stack; meaning that, the first element of the list is the top of
the stack.
Note this function is for testing purposes only and so is marked as private
within this private module
"""
cdef stack[RampIndex] ramp = init_ramps(dq, n_resultants, index_pixel)

# Have to turn deque and stack into python compatible objects
cdef RampIndex index
cdef list out = []

out = []
while not ramp.empty():
index = ramp.top()
ramp.pop()
# So top of stack is first item of list
out = [index] + out

return out


@cython.boundscheck(False)
Expand Down
24 changes: 14 additions & 10 deletions src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ from stcal.ramp_fitting.ols_cas22._core cimport (RampFits, RampIndex, Thresh,
from stcal.ramp_fitting.ols_cas22._fixed cimport fixed_values_from_metadata, FixedValues
from stcal.ramp_fitting.ols_cas22._pixel cimport make_pixel

from typing import NamedTuple
from typing import NamedTuple, Optional


# Fix the default Threshold values at compile time these values cannot be overridden
Expand All @@ -26,9 +26,6 @@ class RampFitOutputs(NamedTuple):
Attributes
----------
fits: list of RampFits
the raw ramp fit outputs, these are all structs which will get mapped to
python dictionaries.
parameters: np.ndarray[n_pixel, 2]
the slope and intercept for each pixel's ramp fit. see Parameter enum
for indexing indicating slope/intercept in the second dimension.
Expand All @@ -39,11 +36,14 @@ class RampFitOutputs(NamedTuple):
dq: np.ndarray[n_resultants, n_pixel]
the dq array, with additional flags set for jumps detected by the
jump detection algorithm.
fits: list of RampFits
the raw ramp fit outputs, these are all structs which will get mapped to
python dictionaries.
"""
# fits: list
parameters: np.ndarray
variances: np.ndarray
dq: np.ndarray
fits: Optional[list] = None


@cython.boundscheck(False)
Expand All @@ -55,7 +55,8 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants,
list[list[int]] read_pattern,
bool use_jump=False,
float intercept=DefaultIntercept,
float constant=DefaultConstant):
float constant=DefaultConstant,
bool include_diagnostic=False):
"""Fit ramps using the Casertano+22 algorithm.
This implementation uses the Cas22 algorithm to fit ramps, where
ramps are fit between bad resultants marked by dq flags for each pixel
Expand All @@ -82,6 +83,8 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants,
The intercept value for the threshold function. Default=5.5
constant : float
The constant value for the threshold function. Default=1/3.0
include_diagnostic : bool
If True, include the raw ramp fits in the output. Default=False
Returns
-------
Expand All @@ -103,7 +106,7 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants,
# Use list because this might grow very large which would require constant
# reallocation. We don't need random access, and this gets cast to a python
# list in the end.
# cdef cpp_list[RampFits] ramp_fits
cdef cpp_list[RampFits] ramp_fits

cdef np.ndarray[float, ndim=2] parameters = np.zeros((n_pixels, 2), dtype=np.float32)
cdef np.ndarray[float, ndim=2] variances = np.zeros((n_pixels, 3), dtype=np.float32)
Expand All @@ -114,7 +117,7 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants,
for index in range(n_pixels):
# Fit all the ramps for the given pixel
fit = make_pixel(fixed, read_noise[index],
resultants[:, index]).fit_ramps(init_ramps(dq, n_resultants, index))
resultants[:, index]).fit_ramps(init_ramps(dq, n_resultants, index), include_diagnostic)

parameters[index, Parameter.slope] = fit.average.slope

Expand All @@ -125,7 +128,8 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants,
for jump in fit.jumps:
dq[jump, index] = RampJumpDQ.JUMP_DET

# ramp_fits.push_back(fit)
if include_diagnostic:
ramp_fits.push_back(fit)

# return RampFitOutputs(ramp_fits, parameters, variances, dq)
return RampFitOutputs(parameters, variances, dq)
return RampFitOutputs(parameters, variances, dq, ramp_fits if include_diagnostic else None)
1 change: 1 addition & 0 deletions src/stcal/ramp_fitting/ols_cas22/_fixed.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Functions
import numpy as np
cimport numpy as np
cimport cython
from libcpp cimport bool

from stcal.ramp_fitting.ols_cas22._core cimport Thresh, ReadPatternMetadata, Diff
from stcal.ramp_fitting.ols_cas22._fixed cimport FixedValues
Expand Down
3 changes: 2 additions & 1 deletion src/stcal/ramp_fitting/ols_cas22/_pixel.pxd
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from libcpp cimport bool
from libcpp.stack cimport stack

from stcal.ramp_fitting.ols_cas22._core cimport RampFit, RampFits, RampIndex
Expand All @@ -17,7 +18,7 @@ cdef class Pixel:
cdef float correction(Pixel self, RampIndex ramp, float slope)
cdef float stat(Pixel self, float slope, RampIndex ramp, int index, int diff)
cdef float[:] stats(Pixel self, float slope, RampIndex ramp)
cdef RampFits fit_ramps(Pixel self, stack[RampIndex] ramps)
cdef RampFits fit_ramps(Pixel self, stack[RampIndex] ramps, bool include_diagnostic)


cpdef Pixel make_pixel(FixedValues fixed, float read_noise, float [:] resultants)
18 changes: 11 additions & 7 deletions src/stcal/ramp_fitting/ols_cas22/_pixel.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Functions
- cpdef gives a python wrapper, but the python version of this method
is considered private, only to be used for testing
"""
from libcpp cimport bool
from libc.math cimport sqrt, fabs
from libcpp.vector cimport vector
from libcpp.stack cimport stack
Expand Down Expand Up @@ -330,7 +331,7 @@ cdef class Pixel:
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cdef inline RampFits fit_ramps(Pixel self, stack[RampIndex] ramps):
cdef inline RampFits fit_ramps(Pixel self, stack[RampIndex] ramps, bool include_diagnostic):
"""
Compute all the ramps for a single pixel using the Casertano+22 algorithm
with jump detection.
Expand Down Expand Up @@ -398,8 +399,9 @@ cdef class Pixel:
# consideration.
jump0 = np.argmax(stats) + ramp.start
jump1 = jump0 + 1
ramp_fits.jumps.push_back(jump0)
ramp_fits.jumps.push_back(jump1)
if include_diagnostic:
ramp_fits.jumps.push_back(jump0)
ramp_fits.jumps.push_back(jump1)

# The two resultant indicies need to be skipped, therefore
# the two
Expand Down Expand Up @@ -446,8 +448,9 @@ cdef class Pixel:
# than threshold
# Note that ramps are computed backward in time meaning we need to
# reverse the order of the fits at the end
# ramp_fits.fits.push_back(ramp_fit)
# ramp_fits.index.push_back(ramp)
if include_diagnostic:
ramp_fits.fits.push_back(ramp_fit)
ramp_fits.index.push_back(ramp)

# Start computing the averages
# Note we do not do anything in the NaN case for degenerate ramps
Expand All @@ -462,8 +465,9 @@ cdef class Pixel:
ramp_fits.average.poisson_var += weight**2 * ramp_fit.poisson_var

# Reverse to order in time
# ramp_fits.fits = ramp_fits.fits[::-1]
# ramp_fits.index = ramp_fits.index[::-1]
if include_diagnostic:
ramp_fits.fits = ramp_fits.fits[::-1]
ramp_fits.index = ramp_fits.index[::-1]

# Finish computing averages
ramp_fits.average.slope /= total_weight if total_weight != 0 else 1
Expand Down
18 changes: 13 additions & 5 deletions src/stcal/ramp_fitting/ols_cas22_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,19 @@ def fit_ramps_casertano(
Returns
-------
par : np.ndarray[..., 2] (float)
the best fit pedestal and slope for each pixel
var : np.ndarray[..., 3, 2, 2] (float)
the covariance matrix of par, for each of three noise terms:
the read noise, Poisson source noise, and total noise.
RampFitOutputs
parameters: np.ndarray[n_pixel, 2]
the slope and intercept for each pixel's ramp fit. see Parameter enum
for indexing indicating slope/intercept in the second dimension.
variances: np.ndarray[n_pixel, 3]
the read, poisson, and total variances for each pixel's ramp fit.
see Variance enum for indexing indicating read/poisson/total in the
second dimension.
dq: np.ndarray[n_resultants, n_pixel]
the dq array, with additional flags set for jumps detected by the
jump detection algorithm.
fits: always None, this is a hold over which can contain the diagnostic
fit information from the jump detection algorithm.
"""

# Trickery to avoid having to specify the defaults for the threshold
Expand Down
12 changes: 7 additions & 5 deletions tests/test_jump_cas22.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def test_init_ramps():
[0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1],
[0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1]], dtype=np.int32)

ramps = _init_ramps_list(dq)
n_resultants, n_pixels = dq.shape
ramps = [_init_ramps_list(dq, n_resultants, index_pixel) for index_pixel in range(n_pixels)]

assert len(ramps) == dq.shape[1] == 16

# Check that the ramps are correct
Expand Down Expand Up @@ -418,7 +420,7 @@ def test_fit_ramps(detector_data, use_jump, use_dq):
if not use_dq:
assert okay.all()

output = fit_ramps(resultants, dq, read_noise, READ_TIME, read_pattern, use_jump=use_jump)
output = fit_ramps(resultants, dq, read_noise, READ_TIME, read_pattern, use_jump=use_jump, include_diagnostic=True)
assert len(output.fits) == N_PIXELS # sanity check that a fit is output for each pixel

chi2 = 0
Expand Down Expand Up @@ -456,7 +458,7 @@ def test_fit_ramps_array_outputs(detector_data, use_jump):
resultants, read_noise, read_pattern = detector_data
dq = np.zeros(resultants.shape, dtype=np.int32)

output = fit_ramps(resultants, dq, read_noise, READ_TIME, read_pattern, use_jump=use_jump)
output = fit_ramps(resultants, dq, read_noise, READ_TIME, read_pattern, use_jump=use_jump, include_diagnostic=True)

for fit, par, var in zip(output.fits, output.parameters, output.variances):
assert par[Parameter.intercept] == 0
Expand Down Expand Up @@ -528,7 +530,7 @@ def test_find_jumps(jump_data):
resultants, read_noise, read_pattern, jump_reads, jump_resultants = jump_data
dq = np.zeros(resultants.shape, dtype=np.int32)

output = fit_ramps(resultants, dq, read_noise, READ_TIME, read_pattern, use_jump=True)
output = fit_ramps(resultants, dq, read_noise, READ_TIME, read_pattern, use_jump=True, include_diagnostic=True)
assert len(output.fits) == len(jump_reads) # sanity check that a fit/jump is set for every pixel

chi2 = 0
Expand Down Expand Up @@ -603,7 +605,7 @@ def test_jump_dq_set(jump_data):
resultants, read_noise, read_pattern, jump_reads, jump_resultants = jump_data
dq = np.zeros(resultants.shape, dtype=np.int32)

output = fit_ramps(resultants, dq, read_noise, READ_TIME, read_pattern, use_jump=True)
output = fit_ramps(resultants, dq, read_noise, READ_TIME, read_pattern, use_jump=True, include_diagnostic=True)

for fit, pixel_dq in zip(output.fits, output.dq.transpose()):
# Check that all jumps found get marked
Expand Down

0 comments on commit 5770243

Please sign in to comment.