diff --git a/src/ibldsp/waveform_extraction.py b/src/ibldsp/waveform_extraction.py index b6c2e0c..0d2f84c 100644 --- a/src/ibldsp/waveform_extraction.py +++ b/src/ibldsp/waveform_extraction.py @@ -176,7 +176,7 @@ def write_wfs_chunk( trough_offset, spike_length_samples, reader_kwargs, - preprocess, + preprocess_steps, ): """ Parallel job to extract waveforms from chunk `i_chunk` of a recording `sr` and @@ -204,37 +204,36 @@ def write_wfs_chunk( s0 - offset:s1 + spike_length_samples - trough_offset, :-my_sr.nsync ] - if not preprocess: - wfs_mmap[wf_flat["index"], :, :] = extract_wfs_array( - snip, df, channel_neighbors, add_nan_trace=True - )[0] - return + if "butterworth" in preprocess_steps: + butter_kwargs = {"N": 3, "Wn": 300 / my_sr.fs * 2, "btype": "highpass"} + sos = scipy.signal.butter(**butter_kwargs, output="sos") + snip = scipy.signal.sosfiltfilt(sos, snip.T).T + + if "phase_shift" in preprocess_steps: + snip = fshift(snip, geom_dict["sample_shift"], axis=0) + + if "bad_channel_interpolation" in preprocess_steps: + snip = interpolate_bad_channels( + snip, + channel_labels, + geom_dict["x"], + geom_dict["y"], + ) + + if "car" in preprocess_steps: + k_kwargs = { + "ntr_pad": 60, + "ntr_tap": 0, + "lagc": int(my_sr.fs / 10), + "butter_kwargs": {"N": 3, "Wn": 0.01, "btype": "highpass"}, + } + car_func = lambda dat: car(dat, **k_kwargs) # noqa: E731 + snip = car_func(snip) - # create filters - butter_kwargs = {"N": 3, "Wn": 300 / my_sr.fs * 2, "btype": "highpass"} - sos = scipy.signal.butter(**butter_kwargs, output="sos") - k_kwargs = { - "ntr_pad": 60, - "ntr_tap": 0, - "lagc": int(my_sr.fs / 10), - "butter_kwargs": {"N": 3, "Wn": 0.01, "btype": "highpass"}, - } - car_func = lambda dat: car(dat, **k_kwargs) # noqa: E731 - - snip0 = interpolate_bad_channels( - fshift( - scipy.signal.sosfiltfilt(sos, snip.T), geom_dict["sample_shift"], axis=1 - ), - channel_labels, - geom_dict["x"], - geom_dict["y"], - ) - # car - snip1 = np.full((my_sr.nc, snip0.shape[1]), np.nan) - snip1[:-1, :] = car_func(snip0) wfs_mmap[wf_flat["index"], :, :] = extract_wfs_array( - snip1.T, df, channel_neighbors + snip, df, channel_neighbors, add_nan_trace=True )[0] + wfs_mmap.flush() @@ -253,7 +252,7 @@ def extract_waveforms_cbin( reader_kwargs={}, n_jobs=None, wfs_dtype=np.float32, - preprocess=False, + preprocess_steps=[], ): """ Given a bin file and locations of spikes, extract waveforms for each unit, compute @@ -303,6 +302,16 @@ def extract_waveforms_cbin( """ n_jobs = n_jobs or int(cpu_count() / 2) + assert set(preprocess_steps).issubset( + { + "phase_shift", + "bad_channel_interpolation", + "butterworth", + "car", + "k_filt" + } + ) + sr = spikeglx.Reader(bin_file, **reader_kwargs) if h is None: h = sr.geometry @@ -367,7 +376,7 @@ def extract_waveforms_cbin( trough_offset, spike_length_samples, reader_kwargs, - preprocess + preprocess_steps, ) for i in range(num_chunks) ) diff --git a/src/tests/unit/cpu/test_waveforms.py b/src/tests/unit/cpu/test_waveforms.py index 2223c41..2fbad4f 100644 --- a/src/tests/unit/cpu/test_waveforms.py +++ b/src/tests/unit/cpu/test_waveforms.py @@ -354,7 +354,8 @@ def test_extract_waveforms_bin(self): self.spike_channels, reader_kwargs={"ns": self.ns, "nc": self.nc, "nsync": 1, "dtype": "float32"}, max_wf=self.max_wf, - h=trace_header() + h=trace_header(), + preprocess_steps=[], ) templates = np.load(self.tmpdir.joinpath("waveforms.templates.npy")) waveforms = np.load(self.tmpdir.joinpath("waveforms.traces.npy"))