Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WaveformsLoader #36

Merged
merged 6 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 233 additions & 0 deletions src/ibldsp/waveform_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import scipy
import pandas as pd
import numpy as np
from pathlib import Path
from numpy.lib.format import open_memmap
from joblib import Parallel, delayed, cpu_count

Expand Down Expand Up @@ -472,3 +473,235 @@ def extract_wfs_cbin(
chan_map = np.ones((max_wf * nu, nc), np.int16) * -1
chan_map[dummy_idx] = channel_neighbors[peak_channel[dummy_idx].astype(int)]
np.savez(channels_fn, channels=chan_map)


class WaveformsLoader:

"""
Interface to the output of `extract_wfs_cbin`. Requires the following four files to
exist in `data_dir`:

- waveforms.traces.npy: `(num_units, max_wf, nc, spike_length_samples)`
This file contains the lightly processed waveforms indexed by cluster in the first
dimension. By default `max_wf=256, nc=40, spike_length_samples=128`.

- waveforms.templates.npy: `(num_units, nc, spike_length_samples)`
This file contains the median across individual waveforms for each unit.

- waveforms.channels.npz: `(num_units * max_wf, nc)`
The i'th row contains the ordered indices of the `nc`-channel neighborhood used
to extract the i'th waveform. A NaN means the waveform is missing because the
unit it was supposed to come from has less than `max_wf` spikes total in the
recording.

- waveforms.table.pqt: `num_units * max_wf` rows
For each waveform, gives the absolute sample number from the recording (i.e.
where to find it in `spikes.samples`), peak channel, cluster, and linear index.
A row of -1s implies that the waveform is missing because the unit is was supposed
to come from has less than `max_wf` spikes total.

WaveformsLoader.load_waveforms() and random_waveforms() allow selection of a subset of
waveforms.

"""

def __init__(
self,
data_dir,
max_wf=256,
trough_offset=42,
spike_length_samples=128,
num_channels=40,
wfs_dtype=np.float32
):

self.data_dir = Path(data_dir)
self.max_wf = max_wf
self.trough_offset = trough_offset
self.spike_length_samples = spike_length_samples
self.num_channels = num_channels
self.wfs_dtype = wfs_dtype

self.traces_fp = self.data_dir.joinpath("waveforms.traces.npy")
self.templates_fp = self.data_dir.joinpath("waveforms.templates.npy")
self.table_fp = self.data_dir.joinpath("waveforms.table.pqt")
self.channels_fp = self.data_dir.joinpath("waveforms.channels.npz")

assert self.traces_fp.exists(), "waveforms.traces.npy file missing!"
assert self.templates_fp.exists(), "waveforms.templates.npy file missing!"
assert self.table_fp.exists(), "waveforms.table.pqt file missing!"
assert self.channels_fp.exists(), "waveforms.channels.npz file missing!"

# ingest parquet table
self.table = pd.read_parquet(self.table_fp).reset_index(drop=True).drop(columns=["index"])
self.table["sample"] = self.table["sample"].astype("Int64")
self.table["peak_channel"] = self.table["peak_channel"].astype("Int64")
self.num_labels = self.table["cluster"].nunique()
self.labels = np.array(self.table["cluster"].unique())
self.total_wfs = sum(~self.table["peak_channel"].isna())
self.table["wf_number"] = np.tile(np.arange(self.max_wf), self.num_labels)
self.table["linear_index"] = np.arange(len(self.table))

traces_shape = (self.num_labels, max_wf, num_channels, spike_length_samples)
templates_shape = (self.num_labels, num_channels, spike_length_samples)

self.traces = np.lib.format.open_memmap(self.traces_fp, dtype=wfs_dtype, shape=traces_shape)
self.templates = np.lib.format.open_memmap(self.templates_fp, dtype=np.float32, shape=templates_shape)
self.channels = np.load(self.channels_fp, allow_pickle="True")["channels"]

def __repr__(self):
s1 = f"WaveformsLoader with {self.total_wfs} waveforms in {self.wfs_dtype} from {self.num_labels} labels.\n"
s2 = f"Data path: {self.data_dir}\n"
s3 = f"{self.spike_length_samples} samples, {self.num_channels} channels, {self.max_wf} max waveforms per label\n"

return s1 + s2 + s3

@property
def wf_counts(self):
"""
pandas Series containing number of (non-NaN) waveforms for each label.
"""
return self.table.groupby("cluster").count()["sample"].rename("num_wfs")

def load_waveforms(self, labels=None, indices=None, return_info=True, flatten=False):
"""
Returns a specified subset of waveforms from the dataset.

:param labels: (list, NumPy array) Label ids (usually clusters) from which to get waveforms.
:param indices: (list, NumPy array) Waveform indices to grab for each waveform.
Can be 1D in which case the same indices are returned for each waveform, or
2D with first dimension == len(labels) to return a specific set of indices for
each waveform.
:param return_info: If True, returns waveforms, table, channels, where table is a DF containing
information about the waveforms returned, and channels is the channel map for each waveform.
:param flatten: If True, returns all waveforms stacked along dimension zero, otherwise returns
array of shape (num_labels, num_indices_per_label, num_channels, spike_length_samples)

"""
if labels is None:
labels = self.labels
if indices is None:
indices = np.arange(self.max_wf)

labels = np.array(labels)
label_idx = np.array([np.where(labels == label)[0][0] for label in labels])
indices = np.array(indices)

num_labels = labels.shape[0]

if indices.ndim == 1:
indices = np.tile(indices, (num_labels, 1))

wfs = self.traces[label_idx[:, None], indices].astype(np.float32)

if flatten:
wfs = wfs.reshape(-1, self.num_channels, self.spike_length_samples)

info = self.table[self.table["cluster"].isin(labels)].copy()
dfs = []
for i, l in enumerate(labels):
_idx = indices[i]
dfs.append(info[(info["wf_number"].isin(_idx)) & (info["cluster"] == l)])
info = pd.concat(dfs).reset_index(drop=True)

channels = self.channels[info["linear_index"].to_numpy()].astype(int)

n_nan = sum(info["sample"].isna())
if n_nan > 0:
logger.warning(f"{n_nan} NaN waveforms included in result.")
if return_info:
return wfs, info, channels

logger.info("Use return_info=True and check the table for details.")

return wfs

def random_waveforms(
self,
labels=None,
num_random_labels=None,
num_random_waveforms=None,
return_info=True,
seed=None,
flatten=False
):
"""
Returns a random subset of waveforms from the dataset.

:param labels: (list, NumPy array) Label ids (usually clusters) from which to get waveforms.
If None, 10 random labels are chosen.
:param num_random_labels: If set, this number of random labels are chosen
:param num_random_waveforms: If set, this number of random waveforms are chosen for each label.
If None, 10 random waveforms are chosen for each label.
:param return_info: If True, returns waveforms, table, channels, where table is a DF containing
information about the waveforms returned, and channels is the channel map for each waveform.
:param flatten: If True, returns all waveforms stacked along dimension zero, otherwise returns
array of shape (num_labels, num_indices_per_label, num_channels, spike_length_samples)

"""
rg = np.random.default_rng(seed=seed)

if labels is None:
if num_random_labels is None:
labels = rg.choice(self.labels, 10)
else:
labels = rg.choice(self.labels, num_random_labels, replace=False)
else:
assert num_random_labels is None, "labels and num_random_labels cannot both be set"

labels = np.array(labels)
label_idx = np.array([np.where(labels == label)[0][0] for label in labels])

num_labels = labels.shape[0]

if num_random_waveforms is None:
num_random_waveforms = 10

# now select random non-NaN indices for each label
indices = np.zeros((num_labels, num_random_waveforms), int)
for u, label in enumerate(labels):
_t = self.table[self.table["cluster"] == label]
nanidx = _t["sample"].isna()
valid = _t[~nanidx]

num_valid_waveforms = len(valid)
if num_valid_waveforms >= num_random_waveforms:
indices[u, :] = rg.choice(
valid.wf_number.to_numpy(),
num_random_waveforms,
replace=False
)
continue

num_nan_waveforms = num_random_waveforms - num_valid_waveforms
indices[u, :num_valid_waveforms] = rg.choice(
valid.wf_number.to_numpy(),
num_valid_waveforms,
replace=False
)

indices[u, num_valid_waveforms:] = np.arange(num_valid_waveforms, num_valid_waveforms + num_nan_waveforms)

wfs = self.traces[label_idx[:, None], indices].astype(np.float32)

if flatten:
wfs = wfs.reshape(-1, self.num_channels, self.spike_length_samples)

info = self.table[self.table["cluster"].isin(labels)].copy()
dfs = []
for i, l in enumerate(labels):
_idx = indices[i]
dfs.append(info[(info["wf_number"].isin(_idx)) & (info["cluster"] == l)])
info = pd.concat(dfs).reset_index(drop=True)

channels = self.channels[info["linear_index"].to_numpy()].astype(int)

n_nan = sum(info["sample"].isna())
if n_nan > 0:
logger.warning(f"{n_nan} NaN waveforms included in result.")
if return_info:
return wfs, info, channels

logger.info("Use return_info=True and check the table for details.")

return wfs
36 changes: 30 additions & 6 deletions src/tests/unit/cpu/test_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,22 +323,24 @@ def setUp(self):
data = np.tile(np.arange(0, 385), (1, self.ns)).astype(np.float32)
data.tofile(self.bin_file)

h = trace_header()
self.geom = np.c_[h["x"], h["y"]]
self.chan_map = utils.make_channel_index(self.geom)

def tearDown(self):
shutil.rmtree(self.tmpdir)

def _ground_truth_values(self):
h = trace_header()
geom = np.c_[h["x"], h["y"]]
chan_map = utils.make_channel_index(geom)
nc_extract = chan_map.shape[1]

nc_extract = self.chan_map.shape[1]
gt_templates = np.ones((self.n_clusters, nc_extract, self.ns_extract), np.float32) * np.nan
gt_waveforms = np.ones((self.n_clusters, self.max_wf, nc_extract, self.ns_extract), np.float32) * np.nan

c0_chans = chan_map[100].astype(np.float32)
c0_chans = self.chan_map[100].astype(np.float32)
gt_templates[0, :, :] = np.tile(c0_chans, (self.ns_extract, 1)).T
gt_waveforms[0, :self.max_wf - 1, :, :] = np.tile(c0_chans, (self.max_wf - 1, self.ns_extract, 1)).swapaxes(1, 2)

c1_chans = chan_map[368].astype(np.float32)
c1_chans = self.chan_map[368].astype(np.float32)
c1_chans[c1_chans == 384] = np.nan
gt_templates[1, :, :] = np.tile(c1_chans, (self.ns_extract, 1)).T
gt_waveforms[1, :self.max_wf - 1, :, :] = np.tile(c1_chans, (self.max_wf - 1, self.ns_extract, 1)).swapaxes(1, 2)
Expand Down Expand Up @@ -367,3 +369,25 @@ def test_extract_waveforms_bin(self):

assert np.allclose(np.nan_to_num(gt_templates), np.nan_to_num(templates))
assert np.allclose(np.nan_to_num(gt_waveforms), np.nan_to_num(waveforms))

wfl = waveform_extraction.WaveformsLoader(self.tmpdir, max_wf=self.max_wf)

wfs = wfl.load_waveforms(return_info=False)
assert np.allclose(np.nan_to_num(waveforms), np.nan_to_num(wfs))

labels = np.array([1, 2])
indices = np.arange(10)

wfs, info, channels = wfl.load_waveforms(labels=labels, indices=indices)
# right waveforms
assert np.allclose(np.nan_to_num(waveforms[:, :10]), np.nan_to_num(wfs))
# right channels
assert np.all(channels == self.chan_map[info.peak_channel.astype(int).to_numpy()])

wfs, info, channels = wfl.load_waveforms(labels=labels, indices=np.array([[1, 2, 3], [5, 6, 7]]))

# right waveforms
assert np.allclose(np.nan_to_num(waveforms[0, [1, 2, 3]]), np.nan_to_num(wfs[0]))
assert np.allclose(np.nan_to_num(waveforms[1, [5, 6, 7]]), np.nan_to_num(wfs[1]))
# right channels
assert np.all(channels == self.chan_map[info.peak_channel.astype(int).to_numpy()])
Loading