Skip to content

Commit

Permalink
add preproc options
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-langfield committed Jul 2, 2024
1 parent edcc883 commit 3641c1e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 32 deletions.
71 changes: 40 additions & 31 deletions src/ibldsp/waveform_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def write_wfs_chunk(
trough_offset,
spike_length_samples,
reader_kwargs,
preprocess,
preprocess_steps,
):
"""
Parallel job to extract waveforms from chunk `i_chunk` of a recording `sr` and
Expand Down Expand Up @@ -204,37 +204,36 @@ def write_wfs_chunk(
s0 - offset:s1 + spike_length_samples - trough_offset, :-my_sr.nsync
]

if not preprocess:
wfs_mmap[wf_flat["index"], :, :] = extract_wfs_array(
snip, df, channel_neighbors, add_nan_trace=True
)[0]
return
if "butterworth" in preprocess_steps:
butter_kwargs = {"N": 3, "Wn": 300 / my_sr.fs * 2, "btype": "highpass"}
sos = scipy.signal.butter(**butter_kwargs, output="sos")
snip = scipy.signal.sosfiltfilt(sos, snip.T).T

if "phase_shift" in preprocess_steps:
snip = fshift(snip, geom_dict["sample_shift"], axis=0)

if "bad_channel_interpolation" in preprocess_steps:
snip = interpolate_bad_channels(
snip,
channel_labels,
geom_dict["x"],
geom_dict["y"],
)

if "car" in preprocess_steps:
k_kwargs = {
"ntr_pad": 60,
"ntr_tap": 0,
"lagc": int(my_sr.fs / 10),
"butter_kwargs": {"N": 3, "Wn": 0.01, "btype": "highpass"},
}
car_func = lambda dat: car(dat, **k_kwargs) # noqa: E731
snip = car_func(snip)

# create filters
butter_kwargs = {"N": 3, "Wn": 300 / my_sr.fs * 2, "btype": "highpass"}
sos = scipy.signal.butter(**butter_kwargs, output="sos")
k_kwargs = {
"ntr_pad": 60,
"ntr_tap": 0,
"lagc": int(my_sr.fs / 10),
"butter_kwargs": {"N": 3, "Wn": 0.01, "btype": "highpass"},
}
car_func = lambda dat: car(dat, **k_kwargs) # noqa: E731

snip0 = interpolate_bad_channels(
fshift(
scipy.signal.sosfiltfilt(sos, snip.T), geom_dict["sample_shift"], axis=1
),
channel_labels,
geom_dict["x"],
geom_dict["y"],
)
# car
snip1 = np.full((my_sr.nc, snip0.shape[1]), np.nan)
snip1[:-1, :] = car_func(snip0)
wfs_mmap[wf_flat["index"], :, :] = extract_wfs_array(
snip1.T, df, channel_neighbors
snip, df, channel_neighbors, add_nan_trace=True
)[0]

wfs_mmap.flush()


Expand All @@ -253,7 +252,7 @@ def extract_waveforms_cbin(
reader_kwargs={},
n_jobs=None,
wfs_dtype=np.float32,
preprocess=False,
preprocess_steps=[],
):
"""
Given a bin file and locations of spikes, extract waveforms for each unit, compute
Expand Down Expand Up @@ -303,6 +302,16 @@ def extract_waveforms_cbin(
"""
n_jobs = n_jobs or int(cpu_count() / 2)

assert set(preprocess_steps).issubset(
{
"phase_shift",
"bad_channel_interpolation",
"butterworth",
"car",
"k_filt"
}
)

sr = spikeglx.Reader(bin_file, **reader_kwargs)
if h is None:
h = sr.geometry
Expand Down Expand Up @@ -367,7 +376,7 @@ def extract_waveforms_cbin(
trough_offset,
spike_length_samples,
reader_kwargs,
preprocess
preprocess_steps,
)
for i in range(num_chunks)
)
Expand Down
3 changes: 2 additions & 1 deletion src/tests/unit/cpu/test_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ def test_extract_waveforms_bin(self):
self.spike_channels,
reader_kwargs={"ns": self.ns, "nc": self.nc, "nsync": 1, "dtype": "float32"},
max_wf=self.max_wf,
h=trace_header()
h=trace_header(),
preprocess_steps=[],
)
templates = np.load(self.tmpdir.joinpath("waveforms.templates.npy"))
waveforms = np.load(self.tmpdir.joinpath("waveforms.traces.npy"))
Expand Down

0 comments on commit 3641c1e

Please sign in to comment.