From 378d6bd2c0e0e98847f4783e74c0c91e59a0a6ab Mon Sep 17 00:00:00 2001 From: Christopher Langfield Date: Mon, 8 Jan 2024 10:59:21 -0800 Subject: [PATCH] extract_wfs test --- src/tests/unit/cpu/test_waveforms.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/tests/unit/cpu/test_waveforms.py b/src/tests/unit/cpu/test_waveforms.py index ded4996..3ef4951 100644 --- a/src/tests/unit/cpu/test_waveforms.py +++ b/src/tests/unit/cpu/test_waveforms.py @@ -5,6 +5,7 @@ import ibldsp.waveforms as waveforms from neurowaveforms.model import generate_waveform +from neuropixel import dense_layout def make_array_peak_through_tip(): @@ -166,3 +167,26 @@ def test_weights_all_channels(): def test_generate_waveforms(): wav = generate_waveform() assert wav.shape == (121, 40) + + +def test_extract_waveforms(): + # create sample array with 8 point wfs at different + # channel locations + trough_offset = 42 + ns = 1000 + nc = 384 + samples = np.arange(100, ns - 100, 100) + channels = np.arange(4, 384, 50) + # default extract radius = 200um, 40 chans + centered_channel_idx = 19 + arr = np.zeros((ns, nc), np.float32) + for i in range(8): + s, c = samples[i], channels[i] + arr[s, c] = float(i + 1) + + df = pd.DataFrame({"sample": samples, "peak_channel": channels}) + + wfs = waveforms.extract_wfs_array(arr, df, h=dense_layout()) + + for i in range(8): + assert wfs[i, trough_offset, centered_channel_idx] == float(i + 1)