Skip to content

Commit

Permalink
Add full jump detection test with ramp fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamJamieson committed Oct 5, 2023
1 parent e688185 commit 52025e2
Showing 1 changed file with 131 additions and 75 deletions.
206 changes: 131 additions & 75 deletions tests/test_jump_cas22.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,16 @@

@pytest.fixture(scope="module")
def base_ramp_data():
"""Basic data for simulating ramps for testing (not unpacked)"""
"""
Basic data for simulating ramps for testing (not unpacked)
Returns
-------
read_pattern : list[list[int]]
The example read pattern
metadata : dict
The metadata computed from the read pattern
"""
read_pattern = [
[1, 2, 3, 4],
[5],
Expand Down Expand Up @@ -100,12 +109,26 @@ def test_threshold():

@pytest.fixture(scope="module")
def ramp_data(base_ramp_data):
"""Unpacked data for simulating ramps for testing"""
t_bar = np.array(base_ramp_data[1]['t_bar'], dtype=np.float32)
tau = np.array(base_ramp_data[1]['tau'], dtype=np.float32)
n_reads = np.array(base_ramp_data[1]['n_reads'], dtype=np.int32)
"""
Unpacked metadata for simulating ramps for testing
Returns
-------
read_pattern:
The read pattern used for testing
t_bar:
The t_bar values for the read pattern
tau:
The tau values for the read pattern
n_reads:
The number of reads for the read pattern
"""
read_pattern, read_pattern_metadata = base_ramp_data
t_bar = np.array(read_pattern_metadata['t_bar'], dtype=np.float32)
tau = np.array(read_pattern_metadata['tau'], dtype=np.float32)
n_reads = np.array(read_pattern_metadata['n_reads'], dtype=np.int32)

yield base_ramp_data[0], t_bar, tau, n_reads
yield read_pattern, t_bar, tau, n_reads


@pytest.mark.parametrize("use_jump", [True, False])
Expand Down Expand Up @@ -169,7 +192,19 @@ def test_fixed_values_from_metadata(ramp_data, use_jump):


def _generate_resultants(read_pattern, n_pixels=1):
"""Generate a set of resultants for a pixel"""
"""
Generate a set of resultants for a pixel
Parameters:
read_pattern : list[list[int]]
The read pattern to use
n_pixels:
The number of pixels to generate resultants for. Default is 1.
Returns:
resultants
The resultants generated
"""
resultants = np.zeros((len(read_pattern), n_pixels), dtype=np.float32)

# Use Poisson process to simulate the accumulation of the ramp
Expand Down Expand Up @@ -200,7 +235,19 @@ def _generate_resultants(read_pattern, n_pixels=1):

@pytest.fixture(scope="module")
def pixel_data(ramp_data):
"""Create data for a single pixel"""
"""
Create data for a single pixel
Returns:
resultants
Resultants for a single pixel
t_bar:
The t_bar values for the read pattern used for the resultants
tau:
The tau values for the read pattern used for the resultants
n_reads:
The number of reads for the read pattern used for the resultants
"""

read_pattern, t_bar, tau, n_reads = ramp_data
resultants = _generate_resultants(read_pattern)
Expand Down Expand Up @@ -259,20 +306,29 @@ def detector_data(ramp_data):
"""
Generate a set of with no jumps data as if for a single detector as it
would be passed in by the supporting code.
Returns:
resultants
The resultants for a large number of pixels
read_noise:
The read noise vector for those pixels
read_pattern:
The read pattern used for the resultants
"""
read_pattern, *_ = ramp_data
read_noise = np.ones(N_PIXELS, dtype=np.float32) * READ_NOISE

resultants = _generate_resultants(read_pattern, n_pixels=N_PIXELS)

return resultants, read_noise, read_pattern, N_PIXELS
return resultants, read_noise, read_pattern

@pytest.mark.parametrize("use_jump", [True, False])
def test_fit_ramps_array_outputs(detector_data, use_jump):
"""
Test that the array outputs line up with the dictionary output
"""
resultants, read_noise, read_pattern, n_pixels = detector_data
resultants, read_noise, read_pattern = detector_data
dq = np.zeros(resultants.shape, dtype=np.int32)

fits, parameters, variances = fit_ramps(
Expand All @@ -297,11 +353,11 @@ def test_fit_ramps_no_dq(detector_data, use_jump):
Since no jumps are simulated in the data, jump detection shouldn't pick
up any jumps.
"""
resultants, read_noise, read_pattern, n_pixels = detector_data
resultants, read_noise, read_pattern = detector_data
dq = np.zeros(resultants.shape, dtype=np.int32)

fits, _, _ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump)
assert len(fits) == n_pixels # sanity check that a fit is output for each pixel
assert len(fits) == N_PIXELS # sanity check that a fit is output for each pixel

# Check that the chi2 for the resulting fit relative to the assumed flux is ~1
chi2 = 0
Expand All @@ -311,7 +367,7 @@ def test_fit_ramps_no_dq(detector_data, use_jump):
total_var = fit['average']['read_var'] + fit['average']['poisson_var']
chi2 += (fit['average']['slope'] - FLUX)**2 / total_var

chi2 /= n_pixels
chi2 /= N_PIXELS

assert np.abs(chi2 - 1) < CHI2_TOL

Expand All @@ -323,7 +379,7 @@ def test_fit_ramps_dq(detector_data, use_jump):
Since no jumps are simulated in the data, jump detection shouldn't pick
up any jumps.
"""
resultants, read_noise, read_pattern, n_pixels = detector_data
resultants, read_noise, read_pattern = detector_data
dq = (RNG.uniform(size=resultants.shape) > 1).astype(np.int32)

# only use okay ramps
Expand All @@ -332,7 +388,7 @@ def test_fit_ramps_dq(detector_data, use_jump):
okay = np.sum((dq[1:, :] == 0) & (dq[:-1, :] == 0), axis=0) != 0

fits, _, _ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=use_jump)
assert len(fits) == n_pixels # sanity check that a fit is output for each pixel
assert len(fits) == N_PIXELS # sanity check that a fit is output for each pixel

chi2 = 0
for fit, use in zip(fits, okay):
Expand All @@ -351,81 +407,82 @@ def test_fit_ramps_dq(detector_data, use_jump):


@pytest.fixture(scope="module")
def jump_data():
def jump_data(detector_data):
"""
Generate a set of data were jumps are simulated in each possible read.
- jumps should occur in read of same index as the pixel index.
Generate resultants with single jumps in them for testing jump detection.
Note this specifically checks that we can detect jumps in any read, meaning
it has an insurance check that a jump has been placed in every single
read position.
"""
resultants, read_noise, read_pattern = detector_data

# Choose read to place a single jump in for each pixel
num_reads = read_pattern[-1][-1]
jump_reads = RNG.integers(num_reads - 1, size=N_PIXELS)

# This shows that a jump as been placed in every single possible
# read position. Technically, this check can fail; however,
# N_PIXELS >> num_reads so it is very unlikely in practice since
# all reads are equally likely to be chosen for a jump.
# It is a good check that we can detect a jump occurring in any read except
# the first read.
assert set(jump_reads) == set(range(num_reads - 1))

# Fill out jump reads with jump values
jump_flux = np.zeros((num_reads, N_PIXELS), dtype=np.float32)
for index, jump in enumerate(jump_reads):
jump_flux[jump:, index] = JUMP_VALUE

# Average the reads into the resultants
jump_resultants = np.zeros(N_PIXELS, dtype=np.int32)
for index, reads in enumerate(read_pattern):
indices = np.array(reads) - 1
resultants[index, :] += np.mean(jump_flux[indices, :], axis=0)
for read in reads:
jump_resultants[np.where(jump_reads == read)] = index

# Generate a read pattern with 8 reads per resultant
shape = (8, 8)
read_pattern = np.arange(np.prod(shape)).reshape(shape).tolist()

resultants = np.zeros((len(read_pattern), np.prod(shape)), dtype=np.float32)
jumps = np.zeros((len(read_pattern), np.prod(shape)), dtype=bool)
jump_res = -1
for jump_index in range(np.prod(shape)):
read_values = np.zeros(np.prod(shape), dtype=np.float32)
for index in range(np.prod(shape)):
if index >= jump_index:
read_values[index] = JUMP_VALUE

if jump_index % shape[1] == 0:
# Start indicating a new resultant
jump_res += 1
jumps[jump_res, jump_index] = True

resultants[:, jump_index] = np.mean(read_values.reshape(shape), axis=1).astype(np.float32)

n_pixels = np.prod(shape)
read_noise = np.ones(n_pixels, dtype=np.float32) * READ_NOISE

# Add actual ramp data in addition to the jump data
resultants += _generate_resultants(read_pattern, n_pixels=n_pixels)

return resultants, read_noise, read_pattern, n_pixels, jumps.transpose()
return resultants, read_noise, read_pattern, jump_reads, jump_resultants


def test_find_jumps(jump_data):
"""
Check that we can locate all the jumps in a given ramp
Full unit tests to demonstrate that we can detect jumps in any read (except
the first one) and that we correctly remove these reads from the fit to recover
the correct FLUX/slope.
"""
resultants, read_noise, read_pattern, n_pixels, jumps = jump_data
resultants, read_noise, read_pattern, jump_reads, jump_resultants = jump_data
dq = np.zeros(resultants.shape, dtype=np.int32)

fits, _, _ = fit_ramps(resultants, dq, read_noise, ROMAN_READ_TIME, read_pattern, use_jump=True)
assert len(fits) == len(jump_reads) # sanity check that a fit/jump is set for every pixel

chi2 = 0
for fit, jump_index, resultant_index in zip(fits, jump_reads, jump_resultants):

# Check that all the jumps have been located per the algorithm's constraints
for index, (fit, jump) in enumerate(zip(fits, jumps)):
print(f"{index=}, {fit['jumps']=}, {jump=}")
# sanity check that only one jump should have been added
assert np.where(jump)[0].shape == (1,)
if index == 0:
# Check that the jumps are detected correctly
if jump_index == 0:
# There is no way to detect a jump if it is in the very first read
# The very first pixel in this case has a jump in the first read
assert len(fit['jumps']) == 0
assert jump[0] # sanity check that the jump is in the first resultant still
assert not np.all(jump[1:])
assert resultant_index == 0 # sanity check that the jump is indeed in the first resultant

# Test that the correct index was recorded
# Test the correct ramp_index was recorded:
assert len(fit['index']) == 1
assert fit['index'][0]['start'] == 0
assert fit['index'][0]['end'] == len(read_pattern) - 1
else:
# Select the single jump and check that it is recorded as a jump
assert np.where(jump)[0][0] in fit['jumps']

# In all cases here we have to exclude two resultants
# There should be a single jump detected; however, this results in
# two resultants being excluded.
assert len(fit['jumps']) == 2
assert resultant_index in fit['jumps']

# Test that all the jumps recorded are +/- 1 of the real jump
# This is due to the need to exclude two resultants
for jump_index in fit['jumps']:
assert jump[jump_index] or jump[jump_index - 1] or jump[jump_index + 1]
# The two resultants excluded should be adjacent
for jump in fit['jumps']:
assert jump == resultant_index or jump == resultant_index - 1 or jump == resultant_index + 1

# Test that the correct indexes are recorded
# Test the correct ramp indexes are recorded
ramp_indices = []
for ramp_index in fit["index"]:
for ramp_index in fit['index']:
# Note start/end of a ramp_index are inclusive meaning that end
# is an index included in the ramp_index so the range is to end + 1
new_indices = list(range(ramp_index["start"], ramp_index["end"] + 1))
Expand All @@ -438,14 +495,13 @@ def test_find_jumps(jump_data):
# check that no ramp_index is a jump
assert set(ramp_indices).isdisjoint(fit['jumps'])

# check that all resultant indicies are either in a ramp or listed as a jump
# check that all resultant indices are either in a ramp or listed as a jump
assert set(ramp_indices).union(fit['jumps']) == set(range(len(read_pattern)))

# Check that the slopes have been estimated reasonably well
# There are not that many pixels to test this against and many resultants
# have been thrown out due to the jumps. Thus we only check the slope is
# "fairly close" to the expected value. This is purposely a loose check
# because the main purpose of this test is to verify that the jumps are
# being detected correctly, above.
for fit in fits:
assert_allclose(fit['average']['slope'], FLUX, rtol=3)
# Compute the chi2 for the fit and add it to a running "total chi2"
total_var = fit['average']['read_var'] + fit['average']['poisson_var']
chi2 += (fit['average']['slope'] - FLUX)**2 / total_var

# Check that the average chi2 is ~1.
chi2 /= N_PIXELS
assert np.abs(chi2 - 1) < CHI2_TOL

0 comments on commit 52025e2

Please sign in to comment.