Skip to content

Commit

Permalink
docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-langfield committed Aug 1, 2024
1 parent 9db80a5 commit 1bdf8e3
Showing 1 changed file with 63 additions and 2 deletions.
65 changes: 63 additions & 2 deletions src/ibldsp/waveform_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,34 @@ def extract_wfs_cbin(

class WaveformsLoader:

"""
Interface to the output of `extract_wfs_cbin`. Requires the following four files to
exist in `data_dir`:
- waveforms.traces.npy: `(num_units, max_wf, nc, spike_length_samples)`
This file contains the lightly processed waveforms indexed by cluster in the first
dimension. By default `max_wf=256, nc=40, spike_length_samples=128`.
- waveforms.templates.npy: `(num_units, nc, spike_length_samples)`
This file contains the median across individual waveforms for each unit.
- waveforms.channels.npz: `(num_units * max_wf, nc)`
The i'th row contains the ordered indices of the `nc`-channel neighborhood used
to extract the i'th waveform. A NaN means the waveform is missing because the
unit it was supposed to come from has less than `max_wf` spikes total in the
recording.
- waveforms.table.pqt: `num_units * max_wf` rows
For each waveform, gives the absolute sample number from the recording (i.e.
where to find it in `spikes.samples`), peak channel, cluster, and linear index.
A row of -1s implies that the waveform is missing because the unit is was supposed
to come from has less than `max_wf` spikes total.
WaveformsLoader.load_waveforms() and random_waveforms() allow selection of a subset of
waveforms.
"""

def __init__(
self,
data_dir,
Expand Down Expand Up @@ -528,8 +556,28 @@ def __repr__(self):

return s1 + s2 + s3

def load_waveforms(self, labels=None, indices=None, return_info=True, flatten=False):
@property
def wf_counts(self):
"""
pandas Series containing number of (non-NaN) waveforms for each label.
"""
return self.table.groupby("cluster").count()["sample"].rename("num_wfs")

def load_waveforms(self, labels=None, indices=None, return_info=True, flatten=False):
"""
Returns a specified subset of waveforms from the dataset.
:param labels: (list, NumPy array) Label ids (usually clusters) from which to get waveforms.
:param indices: (list, NumPy array) Waveform indices to grab for each waveform.
Can be 1D in which case the same indices are returned for each waveform, or
2D with first dimension == len(labels) to return a specific set of indices for
each waveform.
:param return_info: If True, returns waveforms, table, channels, where table is a DF containing
information about the waveforms returned, and channels is the channel map for each waveform.
:param flatten: If True, returns all waveforms stacked along dimension zero, otherwise returns
array of shape (num_labels, num_indices_per_label, num_channels, spike_length_samples)
"""
if labels is None:
labels = self.labels
if indices is None:
Expand Down Expand Up @@ -577,7 +625,20 @@ def random_waveforms(
seed=None,
flatten=False
):

"""
Returns a random subset of waveforms from the dataset.
:param labels: (list, NumPy array) Label ids (usually clusters) from which to get waveforms.
If None, 10 random labels are chosen.
:param num_random_labels: If set, this number of random labels are chosen
:param num_random_waveforms: If set, this number of random waveforms are chosen for each label.
If None, 10 random waveforms are chosen for each label.
:param return_info: If True, returns waveforms, table, channels, where table is a DF containing
information about the waveforms returned, and channels is the channel map for each waveform.
:param flatten: If True, returns all waveforms stacked along dimension zero, otherwise returns
array of shape (num_labels, num_indices_per_label, num_channels, spike_length_samples)
"""
rg = np.random.default_rng(seed=seed)

if labels is None:
Expand Down

0 comments on commit 1bdf8e3

Please sign in to comment.