diff --git a/src/ibldsp/utils.py b/src/ibldsp/utils.py index 0efa96c..4de7bff 100644 --- a/src/ibldsp/utils.py +++ b/src/ibldsp/utils.py @@ -203,6 +203,21 @@ def rms(x, axis=-1): return np.sqrt(np.mean(x ** 2, axis=axis)) +def make_channel_index(geom, radius=200., pad_val=384): + neighbors = scipy.spatial.distance.squareform(scipy.spatial.distance.pdist(geom)) < radius + n_nbors = np.max(np.sum(neighbors, 0)) + + nc = geom.shape[0] + if pad_val is None: + pad_val = nc + channel_idx = np.full((nc, n_nbors), pad_val, dtype=int) + for c in range(nc): + ch_idx = np.flatnonzero(neighbors[c, :]) + channel_idx[c, :ch_idx.shape[0]] = ch_idx + + return channel_idx + + class WindowGenerator(object): """ `wg = WindowGenerator(ns, nswin, overlap)` diff --git a/src/ibldsp/waveforms.py b/src/ibldsp/waveforms.py index 4eae3c1..60461df 100644 --- a/src/ibldsp/waveforms.py +++ b/src/ibldsp/waveforms.py @@ -548,3 +548,32 @@ def compute_spike_features(arr_in, fs=30000, recovery_duration_ms=0.16, return_p return df, arr_peak_real else: return df + + +def extract_wfs_array(arr, df, channel_neighbors, trough_offset=42, spike_length_samples=121): + """ + Extract waveforms at specified samples and peak channels + as a stack. + + :param arr: Array of traces. (samples, channels) + :param df: df containing "sample" and "peak_channel" columns. + :param channel_neighbors: Channel neighbor matrix (384x384) + :param trough_offset: Number of samples to include before peak. + (defaults to 42) + :param spike_length_samples: Total length of wf in samples. + (defaults to 121) + """ + nwf = len(df) + + # Get channel indices + cind = channel_neighbors[df["peak_channel"].to_numpy()] + + # Get sample indices + sind = df["sample"].to_numpy()[:, np.newaxis] + (np.arange(spike_length_samples) - trough_offset) + nchan = cind.shape[1] + + wfs = np.zeros((nwf, spike_length_samples, nchan), arr.dtype) + for i in range(nwf): + wfs[i, :, :] = arr[sind[i], :][:, cind[i]] + + return wfs.swapaxes(1, 2) diff --git a/src/tests/unit/cpu/test_waveforms.py b/src/tests/unit/cpu/test_waveforms.py index ded4996..4e62c02 100644 --- a/src/tests/unit/cpu/test_waveforms.py +++ b/src/tests/unit/cpu/test_waveforms.py @@ -3,8 +3,10 @@ import numpy as np import pandas as pd +import ibldsp.utils as utils import ibldsp.waveforms as waveforms from neurowaveforms.model import generate_waveform +from neuropixel import trace_header def make_array_peak_through_tip(): @@ -166,3 +168,50 @@ def test_weights_all_channels(): def test_generate_waveforms(): wav = generate_waveform() assert wav.shape == (121, 40) + + +def test_extract_waveforms(): + # create sample array with 10 point wfs at different + # channel locations + trough_offset = 42 + ns = 1000 + nc = 384 + samples = np.arange(100, ns, 100) + channels = np.arange(12, 384, 45) + + arr = np.zeros((ns, nc + 1), np.float32) + for i in range(9): + s, c = samples[i], channels[i] + arr[s, c] = float(i + 1) + arr[:, -1] = np.nan + + df = pd.DataFrame({"sample": samples, "peak_channel": channels}) + # generate channel neighbor matrix for NP1, default radius 200um + geom_dict = trace_header(version=1) + geom = np.c_[geom_dict["x"], geom_dict["y"]] + channel_neighbors = utils.make_channel_index(geom, radius=200.) + # radius = 200um, 38 chans + num_channels = 38 + wfs = waveforms.extract_wfs_array(arr, df, channel_neighbors) + + # first wf is a special case: it's at the top of the probe so the center + # index is the actual channel index, and the rest of the wf has been padded + # with NaNs + assert wfs[0, channels[0], trough_offset] == 1. + assert np.all(np.isnan(wfs[0, num_channels // 2 + channels[0] + 1:, :])) + + for i in range(1, 8): + # center channel depends on odd/even of channel + if channels[i] % 2 == 0: + centered_channel_idx = 18 + else: + centered_channel_idx = 19 + assert wfs[i, centered_channel_idx, trough_offset] == float(i + 1) + + # last wf is a special case analogous to the first wf, but at the bottom + # of the probe + if channels[-1] % 2 == 0: + centered_channel_idx = 18 + else: + centered_channel_idx = 19 + assert wfs[-1, centered_channel_idx, trough_offset] == 9.