From a642578a2477fbaaef78d5ae40ed578964c2cf10 Mon Sep 17 00:00:00 2001 From: chris-langfield Date: Tue, 16 Jul 2024 18:53:30 -0400 Subject: [PATCH] pandas fixes and return channels --- src/ibldsp/waveform_extraction.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/ibldsp/waveform_extraction.py b/src/ibldsp/waveform_extraction.py index 692ac9e..3aca533 100644 --- a/src/ibldsp/waveform_extraction.py +++ b/src/ibldsp/waveform_extraction.py @@ -504,11 +504,12 @@ def __init__( assert self.channels_fp.exists(), "waveforms.channels.npz file missing!" # ingest parquet table - self.table = pd.read_parquet(self.table_fp) + self.table = pd.read_parquet(self.table_fp).reset_index(drop=True).drop(columns=["index"]) self.num_labels = self.table["cluster"].nunique() self.labels = np.array(self.table["cluster"].unique()) self.total_wfs = sum(~self.table["peak_channel"].isna()) self.table["wf_number"] = np.tile(np.arange(self.max_wf), self.num_labels) + self.table["linear_index"] = np.arange(len(self.table)) traces_shape = (self.num_labels, max_wf, num_channels, spike_length_samples) templates_shape = (self.num_labels, num_channels, spike_length_samples) @@ -517,7 +518,14 @@ def __init__( self.templates = np.memmap(self.templates_fp, dtype=np.float32, shape=templates_shape) self.channels = np.load(self.channels_fp, allow_pickle="True")["channels"] - def load_waveforms(self, labels=None, indices=None, return_info=True): + def __repr__(self): + s1 = f"WaveformsLoader with {self.total_wfs} waveforms in {self.wfs_dtype} from {self.num_labels} labels.\n" + s2 = f"Data path: {self.data_dir}\n" + s3 = f"{self.spike_length_samples} samples, {self.num_channels} channels, {self.max_wf} max waveforms per label\n" + + return s1 + s2 + s3 + + def load_waveforms(self, labels=None, indices=None, return_info=True, flatten=True): if labels is None: labels = self.labels @@ -526,10 +534,14 @@ def load_waveforms(self, labels=None, indices=None, return_info=True): wfs = self.traces[np.array(labels)][:, indices, :, :] + if flatten: + wfs = wfs.reshape(-1, self.num_channels, self.spike_length_samples).astype(np.float32) + if return_info: _table = self.table[self.table["cluster"].isin(labels)].copy() - _table = _table[_table["wf_number"].isin(indices)] - return wfs, _table + _table = _table[_table["wf_number"].isin(indices)].reset_index(drop=True) + channels = self.channels[_table["linear_index"].to_numpy()].astype(int) + return wfs, _table, channels return wfs @@ -539,7 +551,8 @@ def random_waveforms( num_random_labels=None, num_random_waveforms=None, return_info=True, - seed=None + seed=None, + flatten=True ): rg = np.random.default_rng(seed=seed) @@ -555,6 +568,8 @@ def random_waveforms( wfs = self.traces[np.array(labels)][:, :, :, :] + +