diff --git a/src/ibldsp/waveform_extraction.py b/src/ibldsp/waveform_extraction.py index e17fd21..1b18f3c 100644 --- a/src/ibldsp/waveform_extraction.py +++ b/src/ibldsp/waveform_extraction.py @@ -477,6 +477,34 @@ def extract_wfs_cbin( class WaveformsLoader: + """ + Interface to the output of `extract_wfs_cbin`. Requires the following four files to + exist in `data_dir`: + + - waveforms.traces.npy: `(num_units, max_wf, nc, spike_length_samples)` + This file contains the lightly processed waveforms indexed by cluster in the first + dimension. By default `max_wf=256, nc=40, spike_length_samples=128`. + + - waveforms.templates.npy: `(num_units, nc, spike_length_samples)` + This file contains the median across individual waveforms for each unit. + + - waveforms.channels.npz: `(num_units * max_wf, nc)` + The i'th row contains the ordered indices of the `nc`-channel neighborhood used + to extract the i'th waveform. A NaN means the waveform is missing because the + unit it was supposed to come from has less than `max_wf` spikes total in the + recording. + + - waveforms.table.pqt: `num_units * max_wf` rows + For each waveform, gives the absolute sample number from the recording (i.e. + where to find it in `spikes.samples`), peak channel, cluster, and linear index. + A row of -1s implies that the waveform is missing because the unit is was supposed + to come from has less than `max_wf` spikes total. + + WaveformsLoader.load_waveforms() and random_waveforms() allow selection of a subset of + waveforms. + + """ + def __init__( self, data_dir, @@ -528,8 +556,28 @@ def __repr__(self): return s1 + s2 + s3 - def load_waveforms(self, labels=None, indices=None, return_info=True, flatten=False): + @property + def wf_counts(self): + """ + pandas Series containing number of (non-NaN) waveforms for each label. + """ + return self.table.groupby("cluster").count()["sample"].rename("num_wfs") + def load_waveforms(self, labels=None, indices=None, return_info=True, flatten=False): + """ + Returns a specified subset of waveforms from the dataset. + + :param labels: (list, NumPy array) Label ids (usually clusters) from which to get waveforms. + :param indices: (list, NumPy array) Waveform indices to grab for each waveform. + Can be 1D in which case the same indices are returned for each waveform, or + 2D with first dimension == len(labels) to return a specific set of indices for + each waveform. + :param return_info: If True, returns waveforms, table, channels, where table is a DF containing + information about the waveforms returned, and channels is the channel map for each waveform. + :param flatten: If True, returns all waveforms stacked along dimension zero, otherwise returns + array of shape (num_labels, num_indices_per_label, num_channels, spike_length_samples) + + """ if labels is None: labels = self.labels if indices is None: @@ -577,7 +625,20 @@ def random_waveforms( seed=None, flatten=False ): - + """ + Returns a random subset of waveforms from the dataset. + + :param labels: (list, NumPy array) Label ids (usually clusters) from which to get waveforms. + If None, 10 random labels are chosen. + :param num_random_labels: If set, this number of random labels are chosen + :param num_random_waveforms: If set, this number of random waveforms are chosen for each label. + If None, 10 random waveforms are chosen for each label. + :param return_info: If True, returns waveforms, table, channels, where table is a DF containing + information about the waveforms returned, and channels is the channel map for each waveform. + :param flatten: If True, returns all waveforms stacked along dimension zero, otherwise returns + array of shape (num_labels, num_indices_per_label, num_channels, spike_length_samples) + + """ rg = np.random.default_rng(seed=seed) if labels is None: