Skip to content

Commit

Permalink
RCAL-702: Fix memory use issue for Cas22 jump detection (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
schlafly authored Nov 6, 2023
2 parents e404ec8 + c6f90be commit a4849f7
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 130 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ramp_fitting

- Refactor Casertano, et.al, 2022 uneven ramp fitting and incorporate the matching
jump detection algorithm into it. [#215]
- Fix memory issue with uneven ramp fitting [#226]

Changes to API
--------------
Expand Down
8 changes: 3 additions & 5 deletions src/stcal/ramp_fitting/ols_cas22/_core.pxd
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from libcpp.vector cimport vector
from libcpp.stack cimport stack
from libcpp.deque cimport deque


cdef struct RampIndex:
Expand All @@ -15,10 +13,10 @@ cdef struct RampFit:


cdef struct RampFits:
RampFit average
vector[int] jumps
vector[RampFit] fits
vector[RampIndex] index
vector[int] jumps
RampFit average


cdef struct ReadPatternMetadata:
Expand Down Expand Up @@ -54,5 +52,5 @@ cpdef enum RampJumpDQ:

cpdef float threshold(Thresh thresh, float slope)
cdef float get_power(float s)
cdef deque[stack[RampIndex]] init_ramps(int[:, :] dq)
cpdef vector[RampIndex] init_ramps(int[:, :] dq, int n_resultants, int index_pixel)
cpdef ReadPatternMetadata metadata_from_read_pattern(list[list[int]] read_pattern, float read_time)
122 changes: 40 additions & 82 deletions src/stcal/ramp_fitting/ols_cas22/_core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ Functions
- cpdef gives a python wrapper, but the python version of this method
is considered private, only to be used for testing
"""
from libcpp.stack cimport stack
from libcpp.deque cimport deque
from libcpp.vector cimport vector
from libc.math cimport log10

import numpy as np
Expand Down Expand Up @@ -134,7 +133,7 @@ cpdef inline float threshold(Thresh thresh, float slope):

@cython.boundscheck(False)
@cython.wraparound(False)
cdef inline deque[stack[RampIndex]] init_ramps(int[:, :] dq):
cpdef inline vector[RampIndex] init_ramps(int[:, :] dq, int n_resultants, int index_pixel):
"""
Create the initial ramp stack for each pixel
if dq[index_resultant, index_pixel] == 0, then the resultant is in a ramp
Expand All @@ -144,94 +143,53 @@ cdef inline deque[stack[RampIndex]] init_ramps(int[:, :] dq):
----------
dq : int[n_resultants, n_pixel]
DQ array
n_resultants : int
Number of resultants
index_pixel : int
The index of the pixel to find ramps for
Returns
-------
deque of stacks of RampIndex objects
- deque with entry for each pixel
Chosen to be deque because need element access to loop
- stack with entry for each ramp found (top of stack is last ramp found)
vector of RampIndex objects
- vector with entry for each ramp found (last entry is last ramp found)
- RampIndex with start and end indices of the ramp in the resultants
"""
cdef int n_pixel, n_resultants

n_resultants, n_pixel = np.array(dq).shape
cdef deque[stack[RampIndex]] pixel_ramps

cdef int index_resultant, index_pixel
cdef stack[RampIndex] ramps
cdef RampIndex ramp

for index_pixel in range(n_pixel):
ramps = stack[RampIndex]()

# Note: if start/end are -1, then no value has been assigned
# ramp.start == -1 means we have not started a ramp
# dq[index_resultant, index_pixel] == 0 means resultant is in ramp
ramp = RampIndex(-1, -1)
for index_resultant in range(n_resultants):
if ramp.start == -1:
# Looking for the start of a ramp
if dq[index_resultant, index_pixel] == 0:
# We have found the start of a ramp!
ramp.start = index_resultant
else:
# This is not the start of the ramp yet
continue
cdef vector[RampIndex] ramps = vector[RampIndex]()

# Note: if start/end are -1, then no value has been assigned
# ramp.start == -1 means we have not started a ramp
# dq[index_resultant, index_pixel] == 0 means resultant is in ramp
cdef RampIndex ramp = RampIndex(-1, -1)
for index_resultant in range(n_resultants):
if ramp.start == -1:
# Looking for the start of a ramp
if dq[index_resultant, index_pixel] == 0:
# We have found the start of a ramp!
ramp.start = index_resultant
else:
# Looking for the end of a ramp
if dq[index_resultant, index_pixel] == 0:
# This pixel is in the ramp do nothing
continue
else:
# This pixel is not in the ramp
# => index_resultant - 1 is the end of the ramp
ramp.end = index_resultant - 1

# Add completed ramp to stack and reset ramp
ramps.push(ramp)
ramp = RampIndex(-1, -1)

# Handle case where last resultant is in ramp (so no end has been set)
if ramp.start != -1 and ramp.end == -1:
# Last resultant is end of the ramp => set then add to stack
ramp.end = n_resultants - 1
ramps.push(ramp)

# Add ramp stack for pixel to list
pixel_ramps.push_back(ramps)
# This is not the start of the ramp yet
continue
else:
# Looking for the end of a ramp
if dq[index_resultant, index_pixel] == 0:
# This pixel is in the ramp do nothing
continue
else:
# This pixel is not in the ramp
# => index_resultant - 1 is the end of the ramp
ramp.end = index_resultant - 1

return pixel_ramps
# Add completed ramp to stack and reset ramp
ramps.push_back(ramp)
ramp = RampIndex(-1, -1)

# Handle case where last resultant is in ramp (so no end has been set)
if ramp.start != -1 and ramp.end == -1:
# Last resultant is end of the ramp => set then add to stack
ramp.end = n_resultants - 1
ramps.push_back(ramp)

def _init_ramps_list(np.ndarray[int, ndim=2] dq):
"""
This is a wrapper for init_ramps so that it can be fully inspected from pure
python. A cpdef cannot be used in that case becase a stack has no direct python
analog. Instead this function turns that stack into a list ordered in the same
order as the stack; meaning that, the first element of the list is the top of
the stack.
Note this function is for testing purposes only and so is marked as private
within this private module
"""
cdef deque[stack[RampIndex]] raw = init_ramps(dq)

# Have to turn deque and stack into python compatible objects
cdef RampIndex index
cdef stack[RampIndex] ramp
cdef list out = []
cdef list stack_out
for ramp in raw:
stack_out = []
while not ramp.empty():
index = ramp.top()
ramp.pop()
# So top of stack is first item of list
stack_out = [index] + stack_out

out.append(stack_out)

return out
return ramps


@cython.boundscheck(False)
Expand Down
28 changes: 14 additions & 14 deletions src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import numpy as np
cimport numpy as np
from libcpp cimport bool
from libcpp.stack cimport stack
from libcpp.list cimport list as cpp_list
from libcpp.deque cimport deque
cimport cython

from stcal.ramp_fitting.ols_cas22._core cimport (RampFits, RampIndex, Thresh,
Expand All @@ -12,7 +10,7 @@ from stcal.ramp_fitting.ols_cas22._core cimport (RampFits, RampIndex, Thresh,
from stcal.ramp_fitting.ols_cas22._fixed cimport fixed_values_from_metadata, FixedValues
from stcal.ramp_fitting.ols_cas22._pixel cimport make_pixel

from typing import NamedTuple
from typing import NamedTuple, Optional


# Fix the default Threshold values at compile time these values cannot be overridden
Expand All @@ -28,9 +26,6 @@ class RampFitOutputs(NamedTuple):
Attributes
----------
fits: list of RampFits
the raw ramp fit outputs, these are all structs which will get mapped to
python dictionaries.
parameters: np.ndarray[n_pixel, 2]
the slope and intercept for each pixel's ramp fit. see Parameter enum
for indexing indicating slope/intercept in the second dimension.
Expand All @@ -41,11 +36,14 @@ class RampFitOutputs(NamedTuple):
dq: np.ndarray[n_resultants, n_pixel]
the dq array, with additional flags set for jumps detected by the
jump detection algorithm.
fits: list of RampFits
the raw ramp fit outputs, these are all structs which will get mapped to
python dictionaries.
"""
fits: list
parameters: np.ndarray
variances: np.ndarray
dq: np.ndarray
fits: Optional[list] = None


@cython.boundscheck(False)
Expand All @@ -57,7 +55,8 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants,
list[list[int]] read_pattern,
bool use_jump=False,
float intercept=DefaultIntercept,
float constant=DefaultConstant):
float constant=DefaultConstant,
bool include_diagnostic=False):
"""Fit ramps using the Casertano+22 algorithm.
This implementation uses the Cas22 algorithm to fit ramps, where
ramps are fit between bad resultants marked by dq flags for each pixel
Expand All @@ -84,6 +83,8 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants,
The intercept value for the threshold function. Default=5.5
constant : float
The constant value for the threshold function. Default=1/3.0
include_diagnostic : bool
If True, include the raw ramp fits in the output. Default=False
Returns
-------
Expand All @@ -102,9 +103,6 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants,
Thresh(intercept, constant),
use_jump)

# Compute all the initial sets of ramps
cdef deque[stack[RampIndex]] pixel_ramps = init_ramps(dq)

# Use list because this might grow very large which would require constant
# reallocation. We don't need random access, and this gets cast to a python
# list in the end.
Expand All @@ -119,7 +117,7 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants,
for index in range(n_pixels):
# Fit all the ramps for the given pixel
fit = make_pixel(fixed, read_noise[index],
resultants[:, index]).fit_ramps(pixel_ramps[index])
resultants[:, index]).fit_ramps(init_ramps(dq, n_resultants, index), include_diagnostic)

parameters[index, Parameter.slope] = fit.average.slope

Expand All @@ -130,6 +128,8 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants,
for jump in fit.jumps:
dq[jump, index] = RampJumpDQ.JUMP_DET

ramp_fits.push_back(fit)
if include_diagnostic:
ramp_fits.push_back(fit)

return RampFitOutputs(ramp_fits, parameters, variances, dq)
# return RampFitOutputs(ramp_fits, parameters, variances, dq)
return RampFitOutputs(parameters, variances, dq, ramp_fits if include_diagnostic else None)
1 change: 1 addition & 0 deletions src/stcal/ramp_fitting/ols_cas22/_fixed.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Functions
import numpy as np
cimport numpy as np
cimport cython
from libcpp cimport bool

from stcal.ramp_fitting.ols_cas22._core cimport Thresh, ReadPatternMetadata, Diff
from stcal.ramp_fitting.ols_cas22._fixed cimport FixedValues
Expand Down
5 changes: 3 additions & 2 deletions src/stcal/ramp_fitting/ols_cas22/_pixel.pxd
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from libcpp.stack cimport stack
from libcpp cimport bool
from libcpp.vector cimport vector

from stcal.ramp_fitting.ols_cas22._core cimport RampFit, RampFits, RampIndex
from stcal.ramp_fitting.ols_cas22._fixed cimport FixedValues
Expand All @@ -17,7 +18,7 @@ cdef class Pixel:
cdef float correction(Pixel self, RampIndex ramp, float slope)
cdef float stat(Pixel self, float slope, RampIndex ramp, int index, int diff)
cdef float[:] stats(Pixel self, float slope, RampIndex ramp)
cdef RampFits fit_ramps(Pixel self, stack[RampIndex] ramps)
cdef RampFits fit_ramps(Pixel self, vector[RampIndex] ramps, bool include_diagnostic)


cpdef Pixel make_pixel(FixedValues fixed, float read_noise, float [:] resultants)
31 changes: 17 additions & 14 deletions src/stcal/ramp_fitting/ols_cas22/_pixel.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ Functions
- cpdef gives a python wrapper, but the python version of this method
is considered private, only to be used for testing
"""
from libcpp cimport bool
from libc.math cimport sqrt, fabs
from libcpp.vector cimport vector
from libcpp.stack cimport stack

import numpy as np
cimport numpy as np
Expand Down Expand Up @@ -330,15 +330,15 @@ cdef class Pixel:
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cdef inline RampFits fit_ramps(Pixel self, stack[RampIndex] ramps):
cdef inline RampFits fit_ramps(Pixel self, vector[RampIndex] ramps, bool include_diagnostic):
"""
Compute all the ramps for a single pixel using the Casertano+22 algorithm
with jump detection.
Parameters
----------
ramps : stack[RampIndex]
Stack of initial ramps to fit for a single pixel
ramps : vector[RampIndex]
Vector of initial ramps to fit for a single pixel
multiple ramps are possible due to dq flags
Returns
Expand All @@ -361,8 +361,8 @@ cdef class Pixel:
# Run while the stack is non-empty
while not ramps.empty():
# Remove top ramp of the stack to use
ramp = ramps.top()
ramps.pop()
ramp = ramps.back()
ramps.pop_back()

# Compute fit
ramp_fit = self.fit_ramp(ramp)
Expand Down Expand Up @@ -398,8 +398,9 @@ cdef class Pixel:
# consideration.
jump0 = np.argmax(stats) + ramp.start
jump1 = jump0 + 1
ramp_fits.jumps.push_back(jump0)
ramp_fits.jumps.push_back(jump1)
if include_diagnostic:
ramp_fits.jumps.push_back(jump0)
ramp_fits.jumps.push_back(jump1)

# The two resultant indicies need to be skipped, therefore
# the two
Expand All @@ -424,7 +425,7 @@ cdef class Pixel:
# important that we exclude it.
# Note that jump0 < ramp.start is not possible because
# the argmax is always >= 0
ramps.push(RampIndex(ramp.start, jump0 - 1))
ramps.push_back(RampIndex(ramp.start, jump0 - 1))

if jump1 < ramp.end:
# Note that if jump1 == ramp.end, we have detected a
Expand All @@ -438,16 +439,17 @@ cdef class Pixel:
# resultants which are not considered part of the ramp
# under consideration. Therefore, we have to exlude all
# of those values.
ramps.push(RampIndex(jump1 + 1, ramp.end))
ramps.push_back(RampIndex(jump1 + 1, ramp.end))

continue

# Add ramp_fit to ramp_fits if no jump detection or stats are less
# than threshold
# Note that ramps are computed backward in time meaning we need to
# reverse the order of the fits at the end
ramp_fits.fits.push_back(ramp_fit)
ramp_fits.index.push_back(ramp)
if include_diagnostic:
ramp_fits.fits.push_back(ramp_fit)
ramp_fits.index.push_back(ramp)

# Start computing the averages
# Note we do not do anything in the NaN case for degenerate ramps
Expand All @@ -462,8 +464,9 @@ cdef class Pixel:
ramp_fits.average.poisson_var += weight**2 * ramp_fit.poisson_var

# Reverse to order in time
ramp_fits.fits = ramp_fits.fits[::-1]
ramp_fits.index = ramp_fits.index[::-1]
if include_diagnostic:
ramp_fits.fits = ramp_fits.fits[::-1]
ramp_fits.index = ramp_fits.index[::-1]

# Finish computing averages
ramp_fits.average.slope /= total_weight if total_weight != 0 else 1
Expand Down
Loading

0 comments on commit a4849f7

Please sign in to comment.