Skip to content

Commit

Permalink
Clean up imports
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamJamieson committed Nov 21, 2023
1 parent d54e74b commit 194f72c
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 31 deletions.
19 changes: 6 additions & 13 deletions src/stcal/ramp_fitting/_ols_cas22/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
from enum import Enum

import numpy as np

from ._fit import JUMP_DET, FixedOffsets, Parameter, PixelOffsets, Variance, fit_ramps


class DefaultThreshold(Enum):
INTERCEPT = np.float32(5.5)
CONSTANT = np.float32(1 / 3)

"""
This subpackage exists to hold the Cython implementation of the OLS cas22 algorithm
This subpackage is private, and should not be imported directly by users. Instead,
import from stcal.ramp_fitting.ols_cas22.
"""
from ._fit import FixedOffsets, Parameter, PixelOffsets, Variance, fit_ramps

__all__ = [
"fit_ramps",
"Parameter",
"Variance",
"PixelOffsets",
"FixedOffsets",
"JUMP_DET",
"DefaultThreshold",
]
31 changes: 20 additions & 11 deletions src/stcal/ramp_fitting/ols_cas22.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,21 @@
So the routines in these packages construct these different matrices, store
them, and interpolate between them for different different fluxes and ratios.
"""
from enum import Enum
from typing import NamedTuple

import numpy as np
from astropy import units as u

from . import _ols_cas22
from ._ols_cas22 import FixedOffsets, Parameter, PixelOffsets, Variance
from ._ols_cas22 import fit_ramps as _fit_ramps

__all__ = ["fit_ramps", "Parameter", "Variance", "DefaultThreshold", "RampFitOutputs"]


class DefaultThreshold(Enum):
INTERCEPT = np.float32(5.5)
CONSTANT = np.float32(1 / 3)


class RampFitOutputs(NamedTuple):
Expand Down Expand Up @@ -65,16 +74,16 @@ class RampFitOutputs(NamedTuple):
dq: np.ndarray


def fit_ramps_casertano(
def fit_ramps(
resultants,
dq,
read_noise,
read_time,
read_pattern,
use_jump=False,
*,
threshold_intercept=_ols_cas22.DefaultThreshold.INTERCEPT.value,
threshold_constant=_ols_cas22.DefaultThreshold.CONSTANT.value,
threshold_intercept=DefaultThreshold.INTERCEPT.value,
threshold_constant=DefaultThreshold.CONSTANT.value,
include_diagnostic=False,
):
"""Fit ramps following Casertano+2022, including averaging partial ramps.
Expand Down Expand Up @@ -154,25 +163,25 @@ def fit_ramps_casertano(

# Pre-allocate the output arrays
n_pixels = np.prod(resultants.shape[1:])
parameters = np.empty((n_pixels, _ols_cas22.Parameter.n_param), dtype=np.float32)
variances = np.empty((n_pixels, _ols_cas22.Variance.n_var), dtype=np.float32)
parameters = np.empty((n_pixels, Parameter.n_param), dtype=np.float32)
variances = np.empty((n_pixels, Variance.n_var), dtype=np.float32)

# Pre-allocate the working memory arrays
# This prevents bouncing to and from cython for this allocation, which
# is slower than just doing it all in python to start.
t_bar, tau, n_reads = _create_metadata(read_pattern, read_time)
if use_jump:
single_pixel = np.empty((_ols_cas22.PixelOffsets.n_pixel_offsets, n_resultants - 1), dtype=np.float32)
double_pixel = np.empty((_ols_cas22.PixelOffsets.n_pixel_offsets, n_resultants - 2), dtype=np.float32)
single_fixed = np.empty((_ols_cas22.FixedOffsets.n_fixed_offsets, n_resultants - 1), dtype=np.float32)
double_fixed = np.empty((_ols_cas22.FixedOffsets.n_fixed_offsets, n_resultants - 2), dtype=np.float32)
single_pixel = np.empty((PixelOffsets.n_pixel_offsets, n_resultants - 1), dtype=np.float32)
double_pixel = np.empty((PixelOffsets.n_pixel_offsets, n_resultants - 2), dtype=np.float32)
single_fixed = np.empty((FixedOffsets.n_fixed_offsets, n_resultants - 1), dtype=np.float32)
double_fixed = np.empty((FixedOffsets.n_fixed_offsets, n_resultants - 2), dtype=np.float32)
else:
single_pixel = np.empty((0, 0), dtype=np.float32)
double_pixel = np.empty((0, 0), dtype=np.float32)
single_fixed = np.empty((0, 0), dtype=np.float32)
double_fixed = np.empty((0, 0), dtype=np.float32)

_ols_cas22.fit_ramps(
_fit_ramps(
resultants.reshape(resultants.shape[0], -1),
dq.reshape(resultants.shape[0], -1),
read_noise.reshape(-1),
Expand Down
9 changes: 3 additions & 6 deletions tests/test_jump_cas22.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@
import pytest
from numpy.testing import assert_allclose

from stcal.ramp_fitting._ols_cas22 import JUMP_DET, DefaultThreshold, fit_ramps
from stcal.ramp_fitting._ols_cas22 import FixedOffsets, Parameter, PixelOffsets, Variance, fit_ramps
from stcal.ramp_fitting._ols_cas22._fit import (
FixedOffsets,
Parameter,
PixelOffsets,
Variance,
JUMP_DET,
_fill_fixed_values,
_fill_pixel_values,
_init_ramps,
)
from stcal.ramp_fitting.ols_cas22 import _create_metadata
from stcal.ramp_fitting.ols_cas22 import DefaultThreshold, _create_metadata

# Purposefully set a fixed seed so that the tests in this module are deterministic
RNG = np.random.default_rng(619)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ramp_fitting_cas22.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_simulated_ramps(use_unit, use_dq):
bad = RNG.uniform(size=resultants.shape) > 0.7
dq |= bad

output = ramp.fit_ramps_casertano(
output = ramp.fit_ramps(
resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, threshold_constant=0, threshold_intercept=0
) # set the threshold parameters
# to demo the interface. This
Expand Down

0 comments on commit 194f72c

Please sign in to comment.