Skip to content

Commit

Permalink
Extract wf (#22)
Browse files Browse the repository at this point in the history
* extract wfs

* default radius 200 not 160

* extract_wfs test

* WIP :)

* flake

---------

Co-authored-by: Christopher Langfield <[email protected]>
  • Loading branch information
oliche and chris-langfield authored Feb 27, 2024
1 parent 68211b0 commit 38a6a83
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/ibldsp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,21 @@ def rms(x, axis=-1):
return np.sqrt(np.mean(x ** 2, axis=axis))


def make_channel_index(geom, radius=200., pad_val=384):
neighbors = scipy.spatial.distance.squareform(scipy.spatial.distance.pdist(geom)) < radius
n_nbors = np.max(np.sum(neighbors, 0))

nc = geom.shape[0]
if pad_val is None:
pad_val = nc
channel_idx = np.full((nc, n_nbors), pad_val, dtype=int)
for c in range(nc):
ch_idx = np.flatnonzero(neighbors[c, :])
channel_idx[c, :ch_idx.shape[0]] = ch_idx

return channel_idx


class WindowGenerator(object):
"""
`wg = WindowGenerator(ns, nswin, overlap)`
Expand Down
29 changes: 29 additions & 0 deletions src/ibldsp/waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,3 +548,32 @@ def compute_spike_features(arr_in, fs=30000, recovery_duration_ms=0.16, return_p
return df, arr_peak_real
else:
return df


def extract_wfs_array(arr, df, channel_neighbors, trough_offset=42, spike_length_samples=121):
"""
Extract waveforms at specified samples and peak channels
as a stack.
:param arr: Array of traces. (samples, channels)
:param df: df containing "sample" and "peak_channel" columns.
:param channel_neighbors: Channel neighbor matrix (384x384)
:param trough_offset: Number of samples to include before peak.
(defaults to 42)
:param spike_length_samples: Total length of wf in samples.
(defaults to 121)
"""
nwf = len(df)

# Get channel indices
cind = channel_neighbors[df["peak_channel"].to_numpy()]

# Get sample indices
sind = df["sample"].to_numpy()[:, np.newaxis] + (np.arange(spike_length_samples) - trough_offset)
nchan = cind.shape[1]

wfs = np.zeros((nwf, spike_length_samples, nchan), arr.dtype)
for i in range(nwf):
wfs[i, :, :] = arr[sind[i], :][:, cind[i]]

return wfs.swapaxes(1, 2)
49 changes: 49 additions & 0 deletions src/tests/unit/cpu/test_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import numpy as np
import pandas as pd

import ibldsp.utils as utils
import ibldsp.waveforms as waveforms
from neurowaveforms.model import generate_waveform
from neuropixel import trace_header


def make_array_peak_through_tip():
Expand Down Expand Up @@ -166,3 +168,50 @@ def test_weights_all_channels():
def test_generate_waveforms():
wav = generate_waveform()
assert wav.shape == (121, 40)


def test_extract_waveforms():
# create sample array with 10 point wfs at different
# channel locations
trough_offset = 42
ns = 1000
nc = 384
samples = np.arange(100, ns, 100)
channels = np.arange(12, 384, 45)

arr = np.zeros((ns, nc + 1), np.float32)
for i in range(9):
s, c = samples[i], channels[i]
arr[s, c] = float(i + 1)
arr[:, -1] = np.nan

df = pd.DataFrame({"sample": samples, "peak_channel": channels})
# generate channel neighbor matrix for NP1, default radius 200um
geom_dict = trace_header(version=1)
geom = np.c_[geom_dict["x"], geom_dict["y"]]
channel_neighbors = utils.make_channel_index(geom, radius=200.)
# radius = 200um, 38 chans
num_channels = 38
wfs = waveforms.extract_wfs_array(arr, df, channel_neighbors)

# first wf is a special case: it's at the top of the probe so the center
# index is the actual channel index, and the rest of the wf has been padded
# with NaNs
assert wfs[0, channels[0], trough_offset] == 1.
assert np.all(np.isnan(wfs[0, num_channels // 2 + channels[0] + 1:, :]))

for i in range(1, 8):
# center channel depends on odd/even of channel
if channels[i] % 2 == 0:
centered_channel_idx = 18
else:
centered_channel_idx = 19
assert wfs[i, centered_channel_idx, trough_offset] == float(i + 1)

# last wf is a special case analogous to the first wf, but at the bottom
# of the probe
if channels[-1] % 2 == 0:
centered_channel_idx = 18
else:
centered_channel_idx = 19
assert wfs[-1, centered_channel_idx, trough_offset] == 9.

0 comments on commit 38a6a83

Please sign in to comment.