Skip to content

Commit

Permalink
seed
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-langfield committed Jul 9, 2024
1 parent 3273b14 commit d7f659e
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/ibldsp/waveform_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def _make_wfs_table(
max_wf=256,
trough_offset=42,
spike_length_samples=128,
seed=None
):
"""
Given a recording `sr` and spike detections, pick up to `max_wf`
Expand All @@ -127,7 +128,7 @@ def _make_wfs_table(
allowed_idx = (spike_samples > trough_offset) & (
spike_samples < sr.ns - (spike_length_samples - trough_offset)
)
rng = np.random.default_rng(seed=2024) # numpy 1.23.5
rng = np.random.default_rng(seed=seed) # numpy 1.23.5

unit_ids = np.unique(spike_clusters)
nu = unit_ids.shape[0]
Expand Down Expand Up @@ -220,7 +221,6 @@ def write_wfs_chunk(
geom_dict["y"],
)


k_kwargs = {
"ntr_pad": 60,
"ntr_tap": 0,
Expand All @@ -232,10 +232,9 @@ def write_wfs_chunk(
snip = car_func(snip)

if "kfilt" in preprocess_steps:
kfilt_func = lambda dat: kfilt(dat, **k_kwargs) # noqa: E731
kfilt_func = lambda dat: kfilt(dat, **k_kwargs) # noqa: E731
snip = kfilt_func(snip)


wfs_mmap[wf_flat["index"], :, :] = extract_wfs_array(
snip, df, channel_neighbors, add_nan_trace=True
)[0]
Expand All @@ -259,6 +258,7 @@ def extract_wfs_cbin(
n_jobs=None,
wfs_dtype=np.float32,
preprocess_steps=[],
seed=None
):
"""
Given a bin file and locations of spikes, extract waveforms for each unit, compute
Expand Down Expand Up @@ -339,6 +339,7 @@ def extract_wfs_cbin(
max_wf,
trough_offset,
spike_length_samples,
seed,
)
num_chunks = s0_arr.shape[0]

Expand All @@ -349,7 +350,7 @@ def extract_wfs_cbin(
logger.info("Running channel detection")
channel_labels = _get_channel_labels(sr)
else:
channel_labels = channel_labels or np.zeros(sr.nc-sr.nsync)
channel_labels = channel_labels or np.zeros(sr.nc - sr.nsync)

nwf = len(wf_flat)
nu = unit_ids.shape[0]
Expand Down

0 comments on commit d7f659e

Please sign in to comment.