From 24221fdceb4910e170dde278c2ca04b6e755ccda Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Tue, 14 Nov 2023 17:02:01 -0500 Subject: [PATCH] Update the jump detection unittest to check if jumps are correctly identified This checks if jumps are correctly identified, resolving some lingering test issues from #227 --- tests/test_jump_cas22.py | 75 +++++++++++++++------------------------- 1 file changed, 27 insertions(+), 48 deletions(-) diff --git a/tests/test_jump_cas22.py b/tests/test_jump_cas22.py index 18c19c965..1f991060d 100644 --- a/tests/test_jump_cas22.py +++ b/tests/test_jump_cas22.py @@ -22,8 +22,8 @@ # that the random process will "accidentally" generate a set of data, which # can trigger jump detection. This makes it easier to cleanly test jump # detection is doing what we expect. -FLUX = 100 -READ_NOISE = np.float32(5) +FLUX = 10 +READ_NOISE = np.float32(20) # Set a value for jumps which makes them obvious relative to the normal flux JUMP_VALUE = 1_000 @@ -32,7 +32,7 @@ # across all tests to make it easier to isolate the effects of something using # multiple tests. N_PIXELS = 100_000 -CHI2_TOL = 0.03 +CHI2_TOL = 0.3 GOOD_PROB = 0.7 @@ -353,8 +353,7 @@ def test_fit_ramps(detector_data, use_jump, use_dq): chi2 = 0 for fit, use in zip(output.fits, okay): - if not use_dq and not use_jump: - ##### The not use_jump makes this NOT test for false positives ##### + if not use_dq: # Check that the data generated does not generate any false positives # for jumps as this data is reused for `test_find_jumps` below. # This guarantees that all jumps detected in that test are the @@ -468,10 +467,6 @@ def test_find_jumps(jump_data): assert len(output.fits) == len(jump_reads) # sanity check that a fit/jump is set for every pixel chi2 = 0 - incorrect_too_few = 0 - incorrect_too_many = 0 - incorrect_does_not_capture = 0 - incorrect_other = 0 for fit, jump_index, resultant_index in zip(output.fits, jump_reads, jump_resultants): # Check that the jumps are detected correctly if jump_index == 0: @@ -485,51 +480,35 @@ def test_find_jumps(jump_data): assert fit["index"][0]["start"] == 0 assert fit["index"][0]["end"] == len(read_pattern) - 1 else: - # There should be a single jump detected; however, this results in - # two resultants being excluded. - if resultant_index not in fit["jumps"]: - incorrect_does_not_capture += 1 - continue - if len(fit["jumps"]) < 2: - incorrect_too_few += 1 - continue - if len(fit["jumps"]) > 2: - incorrect_too_many += 1 - continue - - # The two resultants excluded should be adjacent - jump_correct = [ - (jump in (resultant_index, resultant_index - 1, resultant_index + 1)) for jump in fit["jumps"] - ] - if not all(jump_correct): - incorrect_other += 1 - continue - - # Because we do not have a data set with no false positives, we cannot run the below - # # Test the correct ramp indexes are recorded - # ramp_indices = [] - # for ramp_index in fit['index']: - # # Note start/end of a ramp_index are inclusive meaning that end - # # is an index included in the ramp_index so the range is to end + 1 - # new_indices = list(range(ramp_index["start"], ramp_index["end"] + 1)) - - # # check that all the ramps are non-overlapping - # assert set(ramp_indices).isdisjoint(new_indices) - - # ramp_indices.extend(new_indices) - - # # check that no ramp_index is a jump - # assert set(ramp_indices).isdisjoint(fit['jumps']) - - # # check that all resultant indices are either in a ramp or listed as a jump - # assert set(ramp_indices).union(fit['jumps']) == set(range(len(read_pattern))) + # Check that the inserted jump is detected or if the jump occurs in the last resultant + # (there are some unresolved issues with this case) + assert resultant_index in fit["jumps"] or resultant_index == resultants.shape[0] - 1 + + # Here we map out all of the ramps and make sure they are non-overlapping and that + # they do not overlap with the identified jumps + ramp_indices = [] + for ramp_index in fit["index"]: + # Note start/end of a ramp_index are inclusive meaning that end + # is an index included in the ramp_index so the range is to end + 1 + new_indices = list(range(ramp_index["start"], ramp_index["end"] + 1)) + + # check that all the ramps are non-overlapping + assert set(ramp_indices).isdisjoint(new_indices) + + ramp_indices.extend(new_indices) + + # check that no ramp_index is a jump + assert set(ramp_indices).isdisjoint(fit["jumps"]) + + # check that all resultant indices are either in a ramp or listed as a jump + assert set(ramp_indices).union(fit["jumps"]) == set(range(len(read_pattern))) # Compute the chi2 for the fit and add it to a running "total chi2" total_var = fit["average"]["read_var"] + fit["average"]["poisson_var"] chi2 += (fit["average"]["slope"] - FLUX) ** 2 / total_var # Check that the average chi2 is ~1. - chi2 /= N_PIXELS - incorrect_too_few - incorrect_too_many - incorrect_does_not_capture - incorrect_other + chi2 /= N_PIXELS assert np.abs(chi2 - 1) < CHI2_TOL