Skip to content

Commit

Permalink
pandas fixes and return channels
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-langfield committed Jul 16, 2024
1 parent 4ac6f80 commit a642578
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions src/ibldsp/waveform_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,11 +504,12 @@ def __init__(
assert self.channels_fp.exists(), "waveforms.channels.npz file missing!"

# ingest parquet table
self.table = pd.read_parquet(self.table_fp)
self.table = pd.read_parquet(self.table_fp).reset_index(drop=True).drop(columns=["index"])
self.num_labels = self.table["cluster"].nunique()
self.labels = np.array(self.table["cluster"].unique())
self.total_wfs = sum(~self.table["peak_channel"].isna())
self.table["wf_number"] = np.tile(np.arange(self.max_wf), self.num_labels)
self.table["linear_index"] = np.arange(len(self.table))

traces_shape = (self.num_labels, max_wf, num_channels, spike_length_samples)
templates_shape = (self.num_labels, num_channels, spike_length_samples)
Expand All @@ -517,7 +518,14 @@ def __init__(
self.templates = np.memmap(self.templates_fp, dtype=np.float32, shape=templates_shape)
self.channels = np.load(self.channels_fp, allow_pickle="True")["channels"]

def load_waveforms(self, labels=None, indices=None, return_info=True):
def __repr__(self):
s1 = f"WaveformsLoader with {self.total_wfs} waveforms in {self.wfs_dtype} from {self.num_labels} labels.\n"
s2 = f"Data path: {self.data_dir}\n"
s3 = f"{self.spike_length_samples} samples, {self.num_channels} channels, {self.max_wf} max waveforms per label\n"

return s1 + s2 + s3

def load_waveforms(self, labels=None, indices=None, return_info=True, flatten=True):

if labels is None:
labels = self.labels
Expand All @@ -526,10 +534,14 @@ def load_waveforms(self, labels=None, indices=None, return_info=True):

wfs = self.traces[np.array(labels)][:, indices, :, :]

if flatten:
wfs = wfs.reshape(-1, self.num_channels, self.spike_length_samples).astype(np.float32)

if return_info:
_table = self.table[self.table["cluster"].isin(labels)].copy()
_table = _table[_table["wf_number"].isin(indices)]
return wfs, _table
_table = _table[_table["wf_number"].isin(indices)].reset_index(drop=True)
channels = self.channels[_table["linear_index"].to_numpy()].astype(int)
return wfs, _table, channels

return wfs

Expand All @@ -539,7 +551,8 @@ def random_waveforms(
num_random_labels=None,
num_random_waveforms=None,
return_info=True,
seed=None
seed=None,
flatten=True
):

rg = np.random.default_rng(seed=seed)
Expand All @@ -555,6 +568,8 @@ def random_waveforms(
wfs = self.traces[np.array(labels)][:, :, :, :]







Expand Down

0 comments on commit a642578

Please sign in to comment.