Skip to content

Commit

Permalink
wf extract spruce up
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-langfield committed Jul 2, 2024
1 parent 145ec6e commit 935149d
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 27 deletions.
83 changes: 58 additions & 25 deletions src/ibldsp/waveform_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _make_wfs_table(
unit_nspikes[i] = nspikes
# uniformly select up to 500 spikes
u_wf_idx = rng.choice(u_spikeidx, min(max_wf, nspikes), replace=False)
unit_wf_idx[u, : min(max_wf, nspikes)] = u_wf_idx
unit_wf_idx[i, : min(max_wf, nspikes)] = u_wf_idx

# all wf indices in order
wf_idx = np.sort(unit_wf_idx.flatten())
Expand Down Expand Up @@ -175,6 +175,8 @@ def write_wfs_chunk(
chunksize_samples,
trough_offset,
spike_length_samples,
reader_kwargs,
preprocess,
):
"""
Parallel job to extract waveforms from chunk `i_chunk` of a recording `sr` and
Expand All @@ -183,22 +185,11 @@ def write_wfs_chunk(
if len(wf_flat) == 0:
return

my_sr = spikeglx.Reader(cbin)
my_sr = spikeglx.Reader(cbin, **reader_kwargs)
s0, s1 = sr_sl

wfs_mmap = open_memmap(wfs_fn, shape=mmap_shape, mode="r+", dtype=np.float32)

# create filters
butter_kwargs = {"N": 3, "Wn": 300 / my_sr.fs * 2, "btype": "highpass"}
sos = scipy.signal.butter(**butter_kwargs, output="sos")
k_kwargs = {
"ntr_pad": 60,
"ntr_tap": 0,
"lagc": int(my_sr.fs / 10),
"butter_kwargs": {"N": 3, "Wn": 0.01, "btype": "highpass"},
}
car_func = lambda dat: car(dat, **k_kwargs) # noqa: E731

if i_chunk == 0:
offset = 0
else:
Expand All @@ -212,6 +203,24 @@ def write_wfs_chunk(
snip = my_sr[
s0 - offset:s1 + spike_length_samples - trough_offset, :-my_sr.nsync
]

if not preprocess:
wfs_mmap[wf_flat["index"], :, :] = extract_wfs_array(
snip, df, channel_neighbors, add_nan_trace=True
)[0]
return

# create filters
butter_kwargs = {"N": 3, "Wn": 300 / my_sr.fs * 2, "btype": "highpass"}
sos = scipy.signal.butter(**butter_kwargs, output="sos")
k_kwargs = {
"ntr_pad": 60,
"ntr_tap": 0,
"lagc": int(my_sr.fs / 10),
"butter_kwargs": {"N": 3, "Wn": 0.01, "btype": "highpass"},
}
car_func = lambda dat: car(dat, **k_kwargs) # noqa: E731

snip0 = interpolate_bad_channels(
fshift(
scipy.signal.sosfiltfilt(sos, snip.T), geom_dict["sample_shift"], axis=1
Expand All @@ -229,8 +238,8 @@ def write_wfs_chunk(
wfs_mmap.flush()


def extract_wfs_cbin(
cbin_file,
def extract_waveforms_cbin(
bin_file,
output_dir,
spike_samples,
spike_clusters,
Expand All @@ -241,13 +250,16 @@ def extract_wfs_cbin(
trough_offset=42,
spike_length_samples=128,
chunksize_samples=int(3000),
reader_kwargs={},
n_jobs=None,
wfs_dtype=np.float32,
preprocess=False,
):
"""
Given a cbin file and locations of spikes, extract waveforms for each unit, compute
the templates, and save the results in `output_path`. The waveforms come from chunks
of raw data which are phase-corrected to account for the ADC, high-pass filtered in
time with an order 3 Butterworth filter with a 300Hz cutoff, and a common-average
Given a bin file and locations of spikes, extract waveforms for each unit, compute
the templates, and save the results in `output_path`. If preprocess=True, the waveforms
come from chunks of raw data which are phase-corrected to account for the ADC, high-pass
filtered in time with an order 3 Butterworth filter with a 300Hz cutoff, and a common-average
reference procedure is applied in the spatial dimension.
The following files will be generated:
Expand All @@ -269,10 +281,29 @@ def extract_wfs_cbin(
where to find it in `spikes.samples`), peak channel, cluster, and linear index.
A row of -1s implies that the waveform is missing because the unit is was supposed
to come from has less than `max_wf` spikes total.
Parameters:
:param bin_file: Path to cbin or bin file to be read by spikeglx.Reader
:param output_dir: Folder where waveform extraction files will be saved
:param spike_samples: Spike times in samples
:param spike_clusters: Spike cluster labels
:param spike_channels: Peak channel around which to extract waveform for each spike
:param h: Geometry header file for probe (default: NP1)
:param channel_labels: Array of channel labels used for bad channel interpolation
(0: good, 1: dead, 2: noisy, 3: out of brain). If not set and preprocess=True,
channel detection will be run in this function.
:param max_wf: Max number of waveforms to extract per cluster (default: 256)
:param trough_offset: Location of peak in spike, in samples (default: 42)
:param spike_length_samples: Number of samples to extract per spike (default: 128)
:param chunksize_samples: Length of chunk to process at a time in samples (default: 3000)
:param reader_kwargs: Kwargs to pass to spikeglx.Reader()
:param n_jobs: Number of parallel jobs to run. By default it will use 3/4 of available CPUs.
:param wfs_dtype: Data type of raw waveforms saved (default np.float32)
:param preprocess: Whether to preprocess the data
"""
n_jobs = n_jobs or int(cpu_count() / 2)

sr = spikeglx.Reader(cbin_file)
sr = spikeglx.Reader(bin_file, **reader_kwargs)
if h is None:
h = sr.geometry

Expand Down Expand Up @@ -324,7 +355,7 @@ def extract_wfs_cbin(
_ = Parallel(n_jobs=n_jobs)(
delayed(write_wfs_chunk)(
i,
cbin_file,
bin_file,
int_fn,
wfs.shape,
h,
Expand All @@ -335,6 +366,8 @@ def extract_wfs_cbin(
chunksize_samples,
trough_offset,
spike_length_samples,
reader_kwargs,
preprocess
)
for i in range(num_chunks)
)
Expand All @@ -353,7 +386,7 @@ def extract_wfs_cbin(
traces_fn,
mode="w+",
shape=(nu, max_wf, nc, spike_length_samples),
dtype=np.float16,
dtype=wfs_dtype,
)
logger.info("Writing to output files")

Expand All @@ -370,10 +403,10 @@ def extract_wfs_cbin(
traces_fn,
mode="r+",
shape=(nu, max_wf, nc, spike_length_samples),
dtype=np.float16,
dtype=wfs_dtype,
)
# write up to 256 waveforms and leave the rest of dimensions 1-3 as NaNs
traces_by_unit[i, : min(max_wf, nwf_u), :, :] = wfs[idx].astype(np.float16)
traces_by_unit[i, : min(max_wf, nwf_u), :, :] = wfs[idx].astype(wfs_dtype)
traces_by_unit.flush()
# populate this array in memory as it's 256x smaller
wfs_templates[i, :, :] = np.nanmedian(wfs[idx], axis=0)
Expand All @@ -386,7 +419,7 @@ def extract_wfs_cbin(

# add in dummy rows and order by unit, and then sample
unit_counts = wf_flat.groupby("cluster")["sample"].count().reset_index(name="count")
unit_counts["missing"] = 256 - unit_counts["count"]
unit_counts["missing"] = max_wf - unit_counts["count"]
missing_wf = unit_counts[unit_counts["missing"] > 0]
total_missing = sum(missing_wf.missing)
extra_rows = pd.DataFrame(
Expand Down
73 changes: 71 additions & 2 deletions src/tests/unit/cpu/test_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy as np
import pandas as pd
import tempfile
import shutil

import ibldsp.utils as utils
import ibldsp.waveforms as waveforms
Expand All @@ -13,6 +15,8 @@

import unittest

TEST_PATH = Path(__file__).parent.joinpath("fixtures")


def make_array_peak_through_tip():
arr = np.array(
Expand Down Expand Up @@ -166,7 +170,7 @@ def test_generate_waveforms():
assert wav.shape == (40, 121)


class TestWaveformExtractor(unittest.TestCase):
class TestWaveformExtractorArray(unittest.TestCase):
# create sample array with 10 point wfs at different
# channel locations
trough_offset = 42
Expand All @@ -189,7 +193,7 @@ class TestWaveformExtractor(unittest.TestCase):
# radius = 200um, 38 chans
num_channels = 40

def test_extract_waveforms(self):
def test_extract_waveforms_array(self):
wfs, _, _ = waveform_extraction.extract_wfs_array(
self.arr, self.df, self.channel_neighbors
)
Expand Down Expand Up @@ -295,3 +299,68 @@ def test_wave_shift_waveform(self):
# Test last waveform shift applied is minus the original shift, and the rest 511 waveforms are 0
np.testing.assert_equal(-sample_shift_original, np.around(shift_applied[-1], decimals=2))
np.testing.assert_equal(np.zeros(n_wav), np.abs(np.around(shift_applied[0:-1], decimals=2)))


class TestWaveformExtractorBin(unittest.TestCase):
ns = 38502
nc = 385
n_clusters = 2
ns_extract = 128
max_wf = 25

# 2 clusters
spike_samples = np.repeat(np.arange(0, ns, 1600), 2) # 50 spikes
spike_channels = np.tile(np.array([100, 368]), 25)
spike_clusters = np.tile(np.array([1, 2]), 25)

def setUp(self):
self.workdir = TEST_PATH
self.tmpdir = Path(tempfile.gettempdir()) / "test_wfs"
self.tmpdir.mkdir(exist_ok=True)
self.bin_file = self.tmpdir.joinpath("wfs_test.bin")
data = np.tile(np.arange(0, 385), (1, self.ns)).astype(np.float32)
data.tofile(self.bin_file)

def tearDown(self):
shutil.rmtree(self.tmpdir)

def _ground_truth_values(self):
h = trace_header()
geom = np.c_[h["x"], h["y"]]
chan_map = utils.make_channel_index(geom)
nc_extract = chan_map.shape[1]
gt_templates = np.ones((self.n_clusters, nc_extract, self.ns_extract), np.float32) * np.nan
gt_waveforms = np.ones((self.n_clusters, self.max_wf, nc_extract, self.ns_extract), np.float32) * np.nan

c0_chans = chan_map[100].astype(np.float32)
gt_templates[0, :, :] = np.tile(c0_chans, (self.ns_extract, 1)).T
gt_waveforms[0, :self.max_wf - 1, :, :] = np.tile(c0_chans, (self.max_wf - 1, self.ns_extract, 1)).swapaxes(1, 2)

c1_chans = chan_map[368].astype(np.float32)
c1_chans[c1_chans == 384] = np.nan
gt_templates[1, :, :] = np.tile(c1_chans, (self.ns_extract, 1)).T
gt_waveforms[1, :self.max_wf - 1, :, :] = np.tile(c1_chans, (self.max_wf - 1, self.ns_extract, 1)).swapaxes(1, 2)

return gt_templates, gt_waveforms

def test_extract_waveforms_bin(self):
waveform_extraction.extract_waveforms_cbin(
self.bin_file,
self.tmpdir,
self.spike_samples,
self.spike_clusters,
self.spike_channels,
reader_kwargs={"ns": self.ns, "nc": self.nc, "nsync": 1, "dtype": "float32"},
max_wf=self.max_wf,
h=trace_header()
)
templates = np.load(self.tmpdir.joinpath("waveforms.templates.npy"))
waveforms = np.load(self.tmpdir.joinpath("waveforms.traces.npy"))

for u in [0, 1]:
assert np.allclose(np.nan_to_num(templates[u]), np.nanmedian(waveforms[u], axis=0))

gt_templates, gt_waveforms = self._ground_truth_values()

assert np.allclose(np.nan_to_num(gt_templates), np.nan_to_num(templates))
assert np.allclose(np.nan_to_num(gt_waveforms), np.nan_to_num(waveforms))

0 comments on commit 935149d

Please sign in to comment.