Skip to content

Commit

Permalink
Merge pull request #31 from int-brain-lab/extract_wf
Browse files Browse the repository at this point in the history
Updated waveform extraction
  • Loading branch information
chris-langfield authored Apr 18, 2024
2 parents 0bff850 + 3b5ed81 commit e84a1d7
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 91 deletions.
2 changes: 2 additions & 0 deletions release_notes.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# 0.10.0
## 0.10.3 2024-04-18
- Patch fixing memory leaks for `waveform_extraction` module.
## 0.10.2 2024-04-10
- Add `waveform_extraction` module to `ibldsp`. This includes the `extract_wfs_array` and `extract_wfs_cbin` methods.
- Add code for performing subsample shifts of waveforms.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setuptools.setup(
name="ibl-neuropixel",
version="0.10.2",
version="0.10.3",
author="The International Brain Laboratory",
description="Collection of tools for Neuropixel 1.0 and 2.0 probes data",
long_description=long_description,
Expand Down
247 changes: 157 additions & 90 deletions src/ibldsp/waveform_extraction.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import logging

import scipy
import pandas as pd
import numpy as np
from numpy.lib.format import open_memmap
import neuropixel
import spikeglx

from joblib import Parallel, delayed, cpu_count

import neuropixel
import spikeglx
from ibldsp.voltage import detect_bad_channels, interpolate_bad_channels, car
from ibldsp.fourier import fshift
from ibldsp.utils import make_channel_index

logger = logging.getLogger(__name__)


def extract_wfs_array(
arr,
Expand Down Expand Up @@ -83,7 +86,12 @@ def _get_channel_labels(sr, num_snippets=20, verbose=True):
if verbose:
from tqdm import trange

start = (np.linspace(100, int(sr.rl) - 100, num_snippets) * sr.fs).astype(int)
# for most of recordings we take 100 secs left and right but account for recordings smaller
buffer_left_right = np.minimum(100, sr.rl * 0.03)
start = (
np.linspace(buffer_left_right, int(sr.rl) - buffer_left_right, num_snippets)
* sr.fs
).astype(int)
end = start + int(sr.fs)

_channel_labels = np.zeros((384, num_snippets), int)
Expand All @@ -101,10 +109,9 @@ def _get_channel_labels(sr, num_snippets=20, verbose=True):

def _make_wfs_table(
sr,
spike_times,
spike_samples,
spike_clusters,
spike_channels,
chunksize_t=10,
max_wf=256,
trough_offset=42,
spike_length_samples=128,
Expand All @@ -118,8 +125,8 @@ def _make_wfs_table(
"""
# exclude spikes without a buffer on either end
# of recording
allowed_idx = (spike_times > trough_offset) & (
spike_times < sr.ns - (spike_length_samples - trough_offset)
allowed_idx = (spike_samples > trough_offset) & (
spike_samples < sr.ns - (spike_length_samples - trough_offset)
)

rng = np.random.default_rng(seed=2024) # numpy 1.23.5
Expand All @@ -136,7 +143,7 @@ def _make_wfs_table(
nspikes = u_spikeidx.shape[0]
unit_nspikes[i] = nspikes
# uniformly select up to 500 spikes
u_wf_idx = rng.choice(u_spikeidx, min(max_wf, nspikes))
u_wf_idx = rng.choice(u_spikeidx, min(max_wf, nspikes), replace=False)
unit_wf_idx[u, : min(max_wf, nspikes)] = u_wf_idx

# all wf indices in order
Expand All @@ -145,13 +152,12 @@ def _make_wfs_table(
wf_idx = wf_idx[np.nonzero(wf_idx)[0][0]:]

# get sample times, clusters, channels

wf_flat = pd.DataFrame(
{
"indices": np.arange(wf_idx.shape[0]),
"samples": spike_times[wf_idx].astype(int),
"clusters": spike_clusters[wf_idx].astype(int),
"channels": spike_channels[wf_idx].astype(int),
"index": np.arange(wf_idx.shape[0]),
"sample": spike_samples[wf_idx].astype(int),
"cluster": spike_clusters[wf_idx].astype(int),
"peak_channel": spike_channels[wf_idx].astype(int),
}
)

Expand All @@ -176,6 +182,9 @@ def write_wfs_chunk(
Parallel job to extract waveforms from chunk `i_chunk` of a recording `sr` and
write them to the correct spot in the output .npy file `wfs_fn`.
"""
if len(wf_flat) == 0:
return

my_sr = spikeglx.Reader(cbin)
s0, s1 = sr_sl

Expand All @@ -197,13 +206,13 @@ def write_wfs_chunk(
else:
offset = trough_offset

sample = wf_flat["samples"].astype(int) + offset - i_chunk * chunksize_samples
peak_channel = wf_flat["channels"]
sample = wf_flat["sample"].astype(int) + offset - i_chunk * chunksize_samples
peak_channel = wf_flat["peak_channel"]

df = pd.DataFrame({"sample": sample, "peak_channel": peak_channel})

snip = my_sr[
s0 - offset: s1 + spike_length_samples - trough_offset, : -my_sr.nsync
s0 - offset:s1 + spike_length_samples - trough_offset, :-my_sr.nsync
]
snip0 = interpolate_bad_channels(
fshift(
Expand All @@ -216,98 +225,109 @@ def write_wfs_chunk(
# car
snip1 = np.full((my_sr.nc, snip0.shape[1]), np.nan)
snip1[:-1, :] = car_func(snip0)
wfs_mmap[wf_flat["indices"], :, :] = extract_wfs_array(
wfs_mmap[wf_flat["index"], :, :] = extract_wfs_array(
snip1.T, df, channel_neighbors
)[0]
wfs_mmap.flush()


def extract_wfs_cbin(
cbin_file,
output_file,
spike_times,
output_dir,
spike_samples,
spike_clusters,
spike_channels,
h=None,
wf_extract_params=None,
nprocesses=None,
channel_labels=None,
max_wf=256,
trough_offset=42,
spike_length_samples=128,
chunksize_samples=int(3000),
n_jobs=None,
):
"""
Given a cbin file and locations of spikes, extract waveforms for each unit, compute
the templates, and save to `output_file`.
If `output_file=Path("/path/to/example_clusters.npy")`, this array will be of shape
`(num_units, max_wf, nc, spike_length_samples)` where by default `max_wf=256, nc=40,
spike_length_samples=128`.
The file "path/to/example_clusters_templates.npy" will also be generated, of shape
`(num_units, nc, spike_length_samples)`, where the median across waveforms is taken
for each unit.
The parquet file "path/to/example_clusters.pqt" contains the samples and max channels
of each waveform, indexed by unit.
the templates, and save the results in `output_path`. The waveforms come from chunks
of raw data which are phase-corrected to account for the ADC, high-pass filtered in
time with an order 3 Butterworth filter with a 300Hz cutoff, and a common-average
reference procedure is applied in the spatial dimension.
The following files will be generated:
- waveforms.traces.npy: `(num_units, max_wf, nc, spike_length_samples)`
This file contains the lightly processed waveforms indexed by cluster in the first
dimension. By default `max_wf=256, nc=40, spike_length_samples=128`.
- waveforms.templates.npy: `(num_units, nc, spike_length_samples)`
This file contains the median across individual waveforms for each unit.
- waveforms.channels.npz: `(num_units * max_wf, nc)`
The i'th row contains the ordered indices of the `nc`-channel neighborhood used
to extract the i'th waveform. A NaN means the waveform is missing because the
unit it was supposed to come from has less than `max_wf` spikes total in the
recording.
- waveforms.table.pqt: `num_units * max_wf` rows
For each waveform, gives the absolute sample number from the recording (i.e.
where to find it in `spikes.samples`), peak channel, cluster, and linear index.
A row of -1s implies that the waveform is missing because the unit is was supposed
to come from has less than `max_wf` spikes total.
"""
if h is None:
h = neuropixel.trace_header()

if wf_extract_params is None:
wf_extract_params = {
"max_wf": 256,
"trough_offset": 42,
"spike_length_samples": 128,
"chunksize_t": 10,
}

output_path = output_file.parent

max_wf = wf_extract_params["max_wf"]
trough_offset = wf_extract_params["trough_offset"]
spike_length_samples = wf_extract_params["spike_length_samples"]
chunksize_t = wf_extract_params["chunksize_t"]
n_jobs = n_jobs or int(cpu_count() / 2)

sr = spikeglx.Reader(cbin_file)
chunksize_samples = chunksize_t * 30_000
s0_arr = np.arange(0, sr.ns, chunksize_samples)
s1_arr = s0_arr + chunksize_samples
s1_arr[-1] = sr.ns

# selects spikes from throughout the recording for each unit
wf_flat, unit_ids = _make_wfs_table(
sr, spike_times, spike_clusters, spike_channels, **wf_extract_params
sr,
spike_samples,
spike_clusters,
spike_channels,
max_wf,
trough_offset,
spike_length_samples,
)
num_chunks = s0_arr.shape[0]
print(f"Chunk size: {chunksize_t}")
print(f"Num chunks: {num_chunks}")

print("Running channel detection")
channel_labels = _get_channel_labels(sr)
logger.info(f"Chunk size samples: {chunksize_samples}")
logger.info(f"Num chunks: {num_chunks}")

logger.info("Running channel detection")
if channel_labels is None:
channel_labels = _get_channel_labels(sr)

nwf = wf_flat["samples"].shape[0]
nwf = len(wf_flat)
nu = unit_ids.shape[0]
print(f"Extracting {nwf} waveforms from {nu} units")
logger.info(f"Extracting {nwf} waveforms from {nu} units")

# get channel geometry
geom = np.c_[h["x"], h["y"]]
channel_neighbors = make_channel_index(geom)
nc = channel_neighbors.shape[1]

fn = output_path.joinpath("_wf_extract_intermediate.npy")
# this intermediate memmap is written to in parallel
# the waveforms are ordered only by their chronological position
# in the recording, as we are reading them in time chunks
int_fn = output_dir.joinpath("_wf_extract_intermediate.npy")
wfs = open_memmap(
fn, mode="w+", shape=(nwf, nc, spike_length_samples), dtype=np.float32
int_fn, mode="w+", shape=(nwf, nc, spike_length_samples), dtype=np.float32
)

slices = [
slice(
*(np.searchsorted(wf_flat["samples"], [s0_arr[i], s1_arr[i]]).astype(int))
)
slice(*(np.searchsorted(wf_flat["sample"], [s0_arr[i], s1_arr[i]]).astype(int)))
for i in range(num_chunks)
]

nprocesses = nprocesses or int(cpu_count() - cpu_count() / 4)
_ = Parallel(n_jobs=nprocesses)(
_ = Parallel(n_jobs=n_jobs)(
delayed(write_wfs_chunk)(
i,
cbin_file,
fn,
int_fn,
wfs.shape,
h,
channel_labels,
Expand All @@ -321,34 +341,81 @@ def extract_wfs_cbin(
for i in range(num_chunks)
)

wfs = open_memmap(
fn, mode="r+", shape=(nwf, nc, spike_length_samples), dtype=np.float32
)
# bookkeeping
wfs_by_unit = np.full(
(nu, max_wf, nc, spike_length_samples), np.nan, dtype=np.float16
# output files
traces_fn = output_dir.joinpath("waveforms.traces.npy")
templates_fn = output_dir.joinpath("waveforms.templates.npy")
table_fn = output_dir.joinpath("waveforms.table.pqt")
channels_fn = output_dir.joinpath("waveforms.channels.npz")

## rearrange and save traces by unit
# store medians across waveforms
wfs_templates = np.full((nu, nc, spike_length_samples), np.nan, dtype=np.float32)
# create waveform output file (~2-3 GB)
traces_by_unit = open_memmap(
traces_fn,
mode="w+",
shape=(nu, max_wf, nc, spike_length_samples),
dtype=np.float16,
)
wfs_medians = np.full((nu, nc, spike_length_samples), np.nan, dtype=np.float32)
print("Computing templates")
for i, u in enumerate(unit_ids):
_wfs_unit = wfs[wf_flat["clusters"] == u]
nwf_u = _wfs_unit.shape[0]
wfs_by_unit[i, : min(max_wf, nwf_u), :, :] = _wfs_unit.astype(np.float16)
wfs_medians[i, :, :] = np.nanmedian(_wfs_unit, axis=0)
logger.info("Writing to output files")

df = pd.DataFrame(
for i, u in enumerate(unit_ids):
idx = np.where(wf_flat["cluster"] == u)[0]
nwf_u = idx.shape[0]
# reopening these memmaps on each iteration
# forces Python to clean up each large array it loads
# and prevent a memory leak
wfs = open_memmap(
int_fn, mode="r+", shape=(nwf, nc, spike_length_samples), dtype=np.float32
)
traces_by_unit = open_memmap(
traces_fn,
mode="r+",
shape=(nu, max_wf, nc, spike_length_samples),
dtype=np.float16,
)
# write up to 256 waveforms and leave the rest of dimensions 1-3 as NaNs
traces_by_unit[i, : min(max_wf, nwf_u), :, :] = wfs[idx].astype(np.float16)
traces_by_unit.flush()
# populate this array in memory as it's 256x smaller
wfs_templates[i, :, :] = np.nanmedian(wfs[idx], axis=0)

# cleanup intermediate file
int_fn.unlink()

# save templates
np.save(templates_fn, wfs_templates)

# add in dummy rows and order by unit, and then sample
unit_counts = wf_flat.groupby("cluster")["sample"].count().reset_index(name="count")
unit_counts["missing"] = 256 - unit_counts["count"]
missing_wf = unit_counts[unit_counts["missing"] > 0]
total_missing = sum(missing_wf.missing)
extra_rows = pd.DataFrame(
{
"sample": wf_flat["samples"],
"peak_channel": wf_flat["channels"],
"cluster": wf_flat["clusters"],
"sample": [np.nan] * total_missing,
"peak_channel": [np.nan] * total_missing,
"index": [np.nan] * total_missing,
"cluster": sum(
[[row["cluster"]] * row["missing"] for _, row in missing_wf.iterrows()],
[],
),
}
)
df = df.sort_values(["cluster", "sample"]).set_index(["cluster", "sample"])

np.save(output_file, wfs_by_unit)
# medians
avg_file = output_file.parent.joinpath(output_file.stem + "_templates.npy")
np.save(avg_file, wfs_medians)
df.to_parquet(output_file.with_suffix(".pqt"))

fn.unlink()
save_df = pd.concat([wf_flat, extra_rows])
# now the waveforms are arranged by cluster, and then in time
# these match dimensions 0 and 1 of waveforms.traces.npy
save_df.sort_values(["cluster", "sample"], inplace=True)
save_df.to_parquet(table_fn)

# save channel map for each waveform
# these values are now reordered so that they match the pqt
# and the traces file
peak_channel = np.nan_to_num(save_df["peak_channel"].to_numpy(), nan=-1).astype(
np.int16
)
dummy_idx = np.where(peak_channel >= 0)[0]
# leave "missing" waveforms as -1 since we can't have NaN with int dtype
chan_map = np.ones((max_wf * nu, nc), np.int16) * -1
chan_map[dummy_idx] = channel_neighbors[peak_channel[dummy_idx].astype(int)]
np.savez(channels_fn, channels=chan_map)

0 comments on commit e84a1d7

Please sign in to comment.