Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-langfield committed Aug 1, 2024
1 parent 0553bc2 commit 9db80a5
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 51 deletions.
89 changes: 44 additions & 45 deletions src/ibldsp/waveform_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,18 +474,19 @@ def extract_wfs_cbin(
chan_map[dummy_idx] = channel_neighbors[peak_channel[dummy_idx].astype(int)]
np.savez(channels_fn, channels=chan_map)


class WaveformsLoader:

def __init__(
self,
data_dir,
max_wf=256,
trough_offset=42,
spike_length_samples=128,
num_channels=40,
wfs_dtype=np.float32
):
self,
data_dir,
max_wf=256,
trough_offset=42,
spike_length_samples=128,
num_channels=40,
wfs_dtype=np.float32
):

self.data_dir = Path(data_dir)
self.max_wf = max_wf
self.trough_offset = trough_offset
Expand All @@ -505,6 +506,8 @@ def __init__(

# ingest parquet table
self.table = pd.read_parquet(self.table_fp).reset_index(drop=True).drop(columns=["index"])
self.table["sample"] = self.table["sample"].astype("Int64")
self.table["peak_channel"] = self.table["peak_channel"].astype("Int64")
self.num_labels = self.table["cluster"].nunique()
self.labels = np.array(self.table["cluster"].unique())
self.total_wfs = sum(~self.table["peak_channel"].isna())
Expand All @@ -525,45 +528,56 @@ def __repr__(self):

return s1 + s2 + s3

def load_waveforms(self, labels=None, indices=None, return_info=True, flatten=True):
def load_waveforms(self, labels=None, indices=None, return_info=True, flatten=False):

if labels is None:
labels = self.labels
if indices is None:
indices = np.arange(self.max_wf)

labels = np.array(labels)
label_idx = np.array([np.where(labels == label)[0][0] for label in labels])
indices = np.array(indices)

wfs = self.traces[labels][indices, :, :].astype(np.float32)
num_labels = labels.shape[0]

if indices.ndim == 1:
indices = np.tile(indices, (num_labels, 1))

wfs = self.traces[label_idx[:, None], indices].astype(np.float32)

if flatten:
wfs = wfs.reshape(-1, self.num_channels, self.spike_length_samples)

info = self.table[self.table["cluster"].isin(labels)].copy()
info = info[info["wf_number"].isin(indices)].reset_index(drop=True)
dfs = []
for i, l in enumerate(labels):
_idx = indices[i]
dfs.append(info[(info["wf_number"].isin(_idx)) & (info["cluster"] == l)])
info = pd.concat(dfs).reset_index(drop=True)

channels = self.channels[info["linear_index"].to_numpy()].astype(int)

n_nan = sum(info["sample"].isna())
if n_nan > 0:
logger.warn(f"{n_nan} NaN waveforms included in result.")
logger.warning(f"{n_nan} NaN waveforms included in result.")
if return_info:
return wfs, info, channels

logger.info("Use return_info=True and check the table for details.")

return wfs

def random_waveforms(
self,
labels=None,
num_random_labels=None,
num_random_waveforms=None,
return_info=True,
seed=None,
flatten=True
flatten=False
):

rg = np.random.default_rng(seed=seed)

if labels is None:
Expand All @@ -575,6 +589,7 @@ def random_waveforms(
assert num_random_labels is None, "labels and num_random_labels cannot both be set"

labels = np.array(labels)
label_idx = np.array([np.where(labels == label)[0][0] for label in labels])

num_labels = labels.shape[0]

Expand All @@ -584,29 +599,29 @@ def random_waveforms(
# now select random non-NaN indices for each label
indices = np.zeros((num_labels, num_random_waveforms), int)
for u, label in enumerate(labels):
_t = self.table[self.table["cluster"]==label]
_t = self.table[self.table["cluster"] == label]
nanidx = _t["sample"].isna()
valid = _t[~nanidx]

num_valid_waveforms = len(valid)
if num_valid_waveforms >= num_random_waveforms:
indices[u, :] = rg.choice(
valid.wf_number.to_numpy(),
num_random_waveforms,
valid.wf_number.to_numpy(),
num_random_waveforms,
replace=False
)
continue

num_nan_waveforms = num_random_waveforms - num_valid_waveforms
indices[u, :num_valid_waveforms] = rg.choice(
valid.wf_number.to_numpy(),
num_valid_waveforms,
valid.wf_number.to_numpy(),
num_valid_waveforms,
replace=False
)

indices[u, num_valid_waveforms:] = np.arange(num_valid_waveforms, num_valid_waveforms + num_nan_waveforms)

wfs = self.traces[labels[:, None], indices].astype(np.float32)
wfs = self.traces[label_idx[:, None], indices].astype(np.float32)

if flatten:
wfs = wfs.reshape(-1, self.num_channels, self.spike_length_samples)
Expand All @@ -615,33 +630,17 @@ def random_waveforms(
dfs = []
for i, l in enumerate(labels):
_idx = indices[i]
dfs.append(info[(info["wf_number"].isin(_idx))&(info["cluster"]==l)])
info = pd.concat(dfs)
dfs.append(info[(info["wf_number"].isin(_idx)) & (info["cluster"] == l)])
info = pd.concat(dfs).reset_index(drop=True)

channels = self.channels[info["linear_index"].to_numpy()].astype(int)

n_nan = sum(info["sample"].isna())
if n_nan > 0:
logger.warn(f"{n_nan} NaN waveforms included in result.")
logger.warning(f"{n_nan} NaN waveforms included in result.")
if return_info:
return wfs, info, channels

logger.info("Use return_info=True and check the table for details.")

return wfs
















36 changes: 30 additions & 6 deletions src/tests/unit/cpu/test_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,22 +323,24 @@ def setUp(self):
data = np.tile(np.arange(0, 385), (1, self.ns)).astype(np.float32)
data.tofile(self.bin_file)

h = trace_header()
self.geom = np.c_[h["x"], h["y"]]
self.chan_map = utils.make_channel_index(self.geom)

def tearDown(self):
shutil.rmtree(self.tmpdir)

def _ground_truth_values(self):
h = trace_header()
geom = np.c_[h["x"], h["y"]]
chan_map = utils.make_channel_index(geom)
nc_extract = chan_map.shape[1]

nc_extract = self.chan_map.shape[1]
gt_templates = np.ones((self.n_clusters, nc_extract, self.ns_extract), np.float32) * np.nan
gt_waveforms = np.ones((self.n_clusters, self.max_wf, nc_extract, self.ns_extract), np.float32) * np.nan

c0_chans = chan_map[100].astype(np.float32)
c0_chans = self.chan_map[100].astype(np.float32)
gt_templates[0, :, :] = np.tile(c0_chans, (self.ns_extract, 1)).T
gt_waveforms[0, :self.max_wf - 1, :, :] = np.tile(c0_chans, (self.max_wf - 1, self.ns_extract, 1)).swapaxes(1, 2)

c1_chans = chan_map[368].astype(np.float32)
c1_chans = self.chan_map[368].astype(np.float32)
c1_chans[c1_chans == 384] = np.nan
gt_templates[1, :, :] = np.tile(c1_chans, (self.ns_extract, 1)).T
gt_waveforms[1, :self.max_wf - 1, :, :] = np.tile(c1_chans, (self.max_wf - 1, self.ns_extract, 1)).swapaxes(1, 2)
Expand Down Expand Up @@ -367,3 +369,25 @@ def test_extract_waveforms_bin(self):

assert np.allclose(np.nan_to_num(gt_templates), np.nan_to_num(templates))
assert np.allclose(np.nan_to_num(gt_waveforms), np.nan_to_num(waveforms))

wfl = waveform_extraction.WaveformsLoader(self.tmpdir, max_wf=self.max_wf)

wfs = wfl.load_waveforms(return_info=False)
assert np.allclose(np.nan_to_num(waveforms), np.nan_to_num(wfs))

labels = np.array([1, 2])
indices = np.arange(10)

wfs, info, channels = wfl.load_waveforms(labels=labels, indices=indices)
# right waveforms
assert np.allclose(np.nan_to_num(waveforms[:, :10]), np.nan_to_num(wfs))
# right channels
assert np.all(channels == self.chan_map[info.peak_channel.astype(int).to_numpy()])

wfs, info, channels = wfl.load_waveforms(labels=labels, indices=np.array([[1, 2, 3], [5, 6, 7]]))

# right waveforms
assert np.allclose(np.nan_to_num(waveforms[0, [1, 2, 3]]), np.nan_to_num(wfs[0]))
assert np.allclose(np.nan_to_num(waveforms[1, [5, 6, 7]]), np.nan_to_num(wfs[1]))
# right channels
assert np.all(channels == self.chan_map[info.peak_channel.astype(int).to_numpy()])

0 comments on commit 9db80a5

Please sign in to comment.