diff --git a/release_notes.md b/release_notes.md index 0253a63..07f5073 100644 --- a/release_notes.md +++ b/release_notes.md @@ -1,4 +1,7 @@ # 0.10.0 +## 0.10.2 2024-04-10 +- Add `waveform_extraction` module to `ibldsp`. This includes the `extract_wfs_array` and `extract_wfs_cbin` methods. +- Add code for performing subsample shifts of waveforms. ## 0.10.1 2024-03-19 - ensure compatibility with spikeglx 202309 metadata coordinates ## 0.10.0 2024-03-14 diff --git a/setup.py b/setup.py index 9f1667a..edb0434 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setuptools.setup( name="ibl-neuropixel", - version="0.10.1", + version="0.10.2", author="The International Brain Laboratory", description="Collection of tools for Neuropixel 1.0 and 2.0 probes data", long_description=long_description, diff --git a/src/ibldsp/utils.py b/src/ibldsp/utils.py index 55962bf..bf24e20 100644 --- a/src/ibldsp/utils.py +++ b/src/ibldsp/utils.py @@ -215,8 +215,17 @@ def rms(x, axis=-1): def make_channel_index(geom, radius=200.0, pad_val=384): + """ + Given a neuropixels geometry dict `geom`, returns an array with nc rows + where the i'th row contains the channel ids that fall within `radius` um + of channel i. The number of columns is the maximum number of neighbors a + channel can have and will depend on the geometry and the radius chosen. + + For channels at the edges of the probe which have less than the maximum possible + number of neighbors, the remaining indices in the row are filled with `pad_val`. + """ neighbors = ( - scipy.spatial.distance.squareform(scipy.spatial.distance.pdist(geom)) < radius + scipy.spatial.distance.squareform(scipy.spatial.distance.pdist(geom)) <= radius ) n_nbors = np.max(np.sum(neighbors, 0)) diff --git a/src/ibldsp/waveform_extraction.py b/src/ibldsp/waveform_extraction.py new file mode 100644 index 0000000..2c0e66d --- /dev/null +++ b/src/ibldsp/waveform_extraction.py @@ -0,0 +1,354 @@ +import scipy +import pandas as pd +import numpy as np +from numpy.lib.format import open_memmap +import neuropixel +import spikeglx + +from joblib import Parallel, delayed, cpu_count + +from ibldsp.voltage import detect_bad_channels, interpolate_bad_channels, car +from ibldsp.fourier import fshift +from ibldsp.utils import make_channel_index + + +def extract_wfs_array( + arr, + df, + channel_neighbors, + trough_offset=42, + spike_length_samples=128, + add_nan_trace=False, + verbose=False, +): + """ + Extract waveforms at specified samples and peak channels + as a stack. + + :param arr: Array of traces. (samples, channels). The last trace of the array should be a + row of non-data NaNs. If this has not been added set the `add_nan_trace` flag. + :param df: df containing "sample" and "peak_channel" columns. + :param channel_neighbors: Channel neighbor matrix (384x384) + :param trough_offset: Number of samples to include before peak. + (defaults to 42) + :param spike_length_samples: Total length of wf in samples. + (defaults to 128) + :param add_nan_trace: Whether to add a row of `NaN`s as the last trace. + (If False, code assumes this has already been added) + """ + # This is to do fast index assignment to assign missing channels (out of the probe) to NaN + if add_nan_trace: + newcol = np.empty((arr.shape[0], 1)) + newcol[:] = np.nan + arr = np.hstack([arr, newcol]) + + # check that the spike window is included in the recording: + last_idx = df["sample"].iloc[-1] + assert ( + last_idx + (spike_length_samples - trough_offset) < arr.shape[0] + ), f"Spike index {last_idx} extends past end of recording ({arr.shape[0]} samples)." + + nwf = len(df) + + # Get channel indices + cind = channel_neighbors[df["peak_channel"].to_numpy()] + + # Get sample indices + sind = df["sample"].to_numpy()[:, np.newaxis] + ( + np.arange(spike_length_samples) - trough_offset + ) + nchan = cind.shape[1] + + wfs = np.zeros((nwf, spike_length_samples, nchan), arr.dtype) + fun = range + if verbose: + try: + from tqdm import trange + + fun = trange + except ImportError: + pass + for i in fun(nwf): + wfs[i, :, :] = arr[sind[i], :][:, cind[i]] + + return wfs.swapaxes(1, 2), cind, trough_offset + + +def _get_channel_labels(sr, num_snippets=20, verbose=True): + """ + Given a spikeglx Reader object, samples `num_snippets` 1-second + segments of the recording and returns the median channel labels + across the segments as an array of size (nc,). + """ + if verbose: + from tqdm import trange + + start = (np.linspace(100, int(sr.rl) - 100, num_snippets) * sr.fs).astype(int) + end = start + int(sr.fs) + + _channel_labels = np.zeros((384, num_snippets), int) + + for i in trange(num_snippets): + s0 = start[i] + s1 = end[i] + arr = sr[s0:s1, : -sr.nsync].T + _channel_labels[:, i] = detect_bad_channels(arr, fs=30_000)[0] + + channel_labels = scipy.stats.mode(_channel_labels, axis=1, keepdims=True)[0].T + + return channel_labels + + +def _make_wfs_table( + sr, + spike_times, + spike_clusters, + spike_channels, + chunksize_t=10, + max_wf=256, + trough_offset=42, + spike_length_samples=128, +): + """ + Given a recording `sr` and spike detections, pick up to `max_wf` + waveforms uniformly for each unit and return their times, peak channels, + and unit assignments. + + :return: wf_flat, unit_ids Dataframe of waveform information and unit ids. + """ + # exclude spikes without a buffer on either end + # of recording + allowed_idx = (spike_times > trough_offset) & ( + spike_times < sr.ns - (spike_length_samples - trough_offset) + ) + + rng = np.random.default_rng(seed=2024) # numpy 1.23.5 + + unit_ids = np.unique(spike_clusters) + nu = unit_ids.shape[0] + + # this array contains the (up to) max_wf *indices* of the wfs + # we are going to extract for that unit + unit_wf_idx = np.zeros((nu, max_wf), int) + unit_nspikes = np.zeros(nu, int) + for i, u in enumerate(unit_ids): + u_spikeidx = np.where((spike_clusters == u) & allowed_idx)[0] + nspikes = u_spikeidx.shape[0] + unit_nspikes[i] = nspikes + # uniformly select up to 500 spikes + u_wf_idx = rng.choice(u_spikeidx, min(max_wf, nspikes)) + unit_wf_idx[u, : min(max_wf, nspikes)] = u_wf_idx + + # all wf indices in order + wf_idx = np.sort(unit_wf_idx.flatten()) + # remove initial zeros + wf_idx = wf_idx[np.nonzero(wf_idx)[0][0]:] + + # get sample times, clusters, channels + + wf_flat = pd.DataFrame( + { + "indices": np.arange(wf_idx.shape[0]), + "samples": spike_times[wf_idx].astype(int), + "clusters": spike_clusters[wf_idx].astype(int), + "channels": spike_channels[wf_idx].astype(int), + } + ) + + return wf_flat, unit_ids + + +def write_wfs_chunk( + i_chunk, + cbin, + wfs_fn, + mmap_shape, + geom_dict, + channel_labels, + channel_neighbors, + wf_flat, + sr_sl, + chunksize_samples, + trough_offset, + spike_length_samples, +): + """ + Parallel job to extract waveforms from chunk `i_chunk` of a recording `sr` and + write them to the correct spot in the output .npy file `wfs_fn`. + """ + my_sr = spikeglx.Reader(cbin) + s0, s1 = sr_sl + + wfs_mmap = open_memmap(wfs_fn, shape=mmap_shape, mode="r+", dtype=np.float32) + + # create filters + butter_kwargs = {"N": 3, "Wn": 300 / my_sr.fs * 2, "btype": "highpass"} + sos = scipy.signal.butter(**butter_kwargs, output="sos") + k_kwargs = { + "ntr_pad": 60, + "ntr_tap": 0, + "lagc": int(my_sr.fs / 10), + "butter_kwargs": {"N": 3, "Wn": 0.01, "btype": "highpass"}, + } + car_func = lambda dat: car(dat, **k_kwargs) # noqa: E731 + + if i_chunk == 0: + offset = 0 + else: + offset = trough_offset + + sample = wf_flat["samples"].astype(int) + offset - i_chunk * chunksize_samples + peak_channel = wf_flat["channels"] + + df = pd.DataFrame({"sample": sample, "peak_channel": peak_channel}) + + snip = my_sr[ + s0 - offset: s1 + spike_length_samples - trough_offset, : -my_sr.nsync + ] + snip0 = interpolate_bad_channels( + fshift( + scipy.signal.sosfiltfilt(sos, snip.T), geom_dict["sample_shift"], axis=1 + ), + channel_labels, + geom_dict["x"], + geom_dict["y"], + ) + # car + snip1 = np.full((my_sr.nc, snip0.shape[1]), np.nan) + snip1[:-1, :] = car_func(snip0) + wfs_mmap[wf_flat["indices"], :, :] = extract_wfs_array( + snip1.T, df, channel_neighbors + )[0] + wfs_mmap.flush() + + +def extract_wfs_cbin( + cbin_file, + output_file, + spike_times, + spike_clusters, + spike_channels, + h=None, + wf_extract_params=None, + nprocesses=None, +): + """ + Given a cbin file and locations of spikes, extract waveforms for each unit, compute + the templates, and save to `output_file`. + + If `output_file=Path("/path/to/example_clusters.npy")`, this array will be of shape + `(num_units, max_wf, nc, spike_length_samples)` where by default `max_wf=256, nc=40, + spike_length_samples=128`. + + The file "path/to/example_clusters_templates.npy" will also be generated, of shape + `(num_units, nc, spike_length_samples)`, where the median across waveforms is taken + for each unit. + + The parquet file "path/to/example_clusters.pqt" contains the samples and max channels + of each waveform, indexed by unit. + """ + if h is None: + h = neuropixel.trace_header() + + if wf_extract_params is None: + wf_extract_params = { + "max_wf": 256, + "trough_offset": 42, + "spike_length_samples": 128, + "chunksize_t": 10, + } + + output_path = output_file.parent + + max_wf = wf_extract_params["max_wf"] + trough_offset = wf_extract_params["trough_offset"] + spike_length_samples = wf_extract_params["spike_length_samples"] + chunksize_t = wf_extract_params["chunksize_t"] + + sr = spikeglx.Reader(cbin_file) + chunksize_samples = chunksize_t * 30_000 + s0_arr = np.arange(0, sr.ns, chunksize_samples) + s1_arr = s0_arr + chunksize_samples + s1_arr[-1] = sr.ns + + wf_flat, unit_ids = _make_wfs_table( + sr, spike_times, spike_clusters, spike_channels, **wf_extract_params + ) + num_chunks = s0_arr.shape[0] + print(f"Chunk size: {chunksize_t}") + print(f"Num chunks: {num_chunks}") + + print("Running channel detection") + channel_labels = _get_channel_labels(sr) + + nwf = wf_flat["samples"].shape[0] + nu = unit_ids.shape[0] + print(f"Extracting {nwf} waveforms from {nu} units") + + # get channel geometry + geom = np.c_[h["x"], h["y"]] + channel_neighbors = make_channel_index(geom) + nc = channel_neighbors.shape[1] + + fn = output_path.joinpath("_wf_extract_intermediate.npy") + wfs = open_memmap( + fn, mode="w+", shape=(nwf, nc, spike_length_samples), dtype=np.float32 + ) + + slices = [ + slice( + *(np.searchsorted(wf_flat["samples"], [s0_arr[i], s1_arr[i]]).astype(int)) + ) + for i in range(num_chunks) + ] + + nprocesses = nprocesses or int(cpu_count() - cpu_count() / 4) + _ = Parallel(n_jobs=nprocesses)( + delayed(write_wfs_chunk)( + i, + cbin_file, + fn, + wfs.shape, + h, + channel_labels, + channel_neighbors, + wf_flat.iloc[slices[i]], + (s0_arr[i], s1_arr[i]), + chunksize_samples, + trough_offset, + spike_length_samples, + ) + for i in range(num_chunks) + ) + + wfs = open_memmap( + fn, mode="r+", shape=(nwf, nc, spike_length_samples), dtype=np.float32 + ) + # bookkeeping + wfs_by_unit = np.full( + (nu, max_wf, nc, spike_length_samples), np.nan, dtype=np.float16 + ) + wfs_medians = np.full((nu, nc, spike_length_samples), np.nan, dtype=np.float32) + print("Computing templates") + for i, u in enumerate(unit_ids): + _wfs_unit = wfs[wf_flat["clusters"] == u] + nwf_u = _wfs_unit.shape[0] + wfs_by_unit[i, : min(max_wf, nwf_u), :, :] = _wfs_unit.astype(np.float16) + wfs_medians[i, :, :] = np.nanmedian(_wfs_unit, axis=0) + + df = pd.DataFrame( + { + "sample": wf_flat["samples"], + "peak_channel": wf_flat["channels"], + "cluster": wf_flat["clusters"], + } + ) + df = df.sort_values(["cluster", "sample"]).set_index(["cluster", "sample"]) + + np.save(output_file, wfs_by_unit) + # medians + avg_file = output_file.parent.joinpath(output_file.stem + "_templates.npy") + np.save(avg_file, wfs_medians) + df.to_parquet(output_file.with_suffix(".pqt")) + + fn.unlink() diff --git a/src/ibldsp/waveforms.py b/src/ibldsp/waveforms.py index d08c5a5..b05882d 100644 --- a/src/ibldsp/waveforms.py +++ b/src/ibldsp/waveforms.py @@ -6,6 +6,9 @@ import numpy as np import pandas as pd import matplotlib.pyplot as plt +import scipy +from ibldsp.utils import parabolic_max +from ibldsp.fourier import fshift def _validate_arr_in(arr_in): @@ -228,7 +231,7 @@ def find_tip_trough(arr_peak, arr_peak_real, df): return df, arr_peak -def plot_wiggle(wav, ax=None, scalar=0.3, clip=1.5, **axkwargs): +def plot_wiggle(wav, fs=1, ax=None, scalar=0.3, clip=1.5, **axkwargs): """ Displays a multi-trace waveform in a wiggle traces with negative amplitudes filled @@ -259,11 +262,11 @@ def plot_wiggle(wav, ax=None, scalar=0.3, clip=1.5, **axkwargs): order = np.argsort(y) # shift from amplitudes to plotting coordinates x_shift, y = y[order].__divmod__(ns + 1) - ax.plot(y, x[order] + x_shift + 1, "k", linewidth=0.5) + ax.plot(y / fs, x[order] + x_shift + 1, 'k', linewidth=.5) x[x > 0] = np.nan x = x[order] + x_shift + 1 - ax.fill(y, x, "k", aa=True) - ax.set(xlim=[0, ns], ylim=[0, nc], xlabel="sample", ylabel="trace") + ax.fill(y / fs, x, 'k', aa=True) + ax.set(xlim=[0, ns / fs], ylim=[0, nc], xlabel='sample', ylabel='trace') plt.tight_layout() return ax @@ -609,61 +612,149 @@ def compute_spike_features( return df -def extract_wfs_array( - arr, - df, - channel_neighbors, - trough_offset=42, - spike_length_samples=121, - add_nan_trace=False, -): - """ - Extract waveforms at specified samples and peak channels - as a stack. - - :param arr: Array of traces. (samples, channels). The last trace of the array should be a - row of non-data NaNs. If this has not been added set the `add_nan_trace` flag. - :param df: df containing "sample" and "peak_channel" columns. - :param channel_neighbors: Channel neighbor matrix (384x384) - :param trough_offset: Number of samples to include before peak. - (defaults to 42) - :param spike_length_samples: Total length of wf in samples. - (defaults to 121) - :param add_nan_trace: Whether to add a row of `NaN`s as the last trace. - (If False, code assumes this has already been added) - """ - # This is to do fast index assignment to assign missing channels (out of the probe) to NaN - if add_nan_trace: - newcol = np.empty((arr.shape[0], 1)) - newcol[:] = np.nan - arr = np.hstack([arr, newcol]) - - # check that the spike window is included in the recording: - last_idx = df["sample"].iloc[-1] - assert ( - last_idx + (spike_length_samples - trough_offset) < arr.shape[0] - ), f"Spike index {last_idx} extends past end of recording ({arr.shape[0]} samples)." - - nwf = len(df) - - # Get channel indices - cind = channel_neighbors[df["peak_channel"].to_numpy()] - - # Get sample indices - sind = df["sample"].to_numpy()[:, np.newaxis] + ( - np.arange(spike_length_samples) - trough_offset - ) - nchan = cind.shape[1] - - wfs = np.zeros((nwf, spike_length_samples, nchan), arr.dtype) - - try: - from tqdm import trange +def wave_shift_corrmax(spike, spike2): + ''' + Shift in time (sub-sample) the spike2 onto the spike + (For residual subtraction, typically, the spike2 would be the template) + :param spike: 1D array of float (e.g. on peak channel); same size as spike2 + :param spike2: 1D array of float + :return: spike_resync: 1D array of float, shift_computed: in time sample (e.g. -4.03) + ''' + # Numpy implementation of correlation centers it in the middle at np.floor(len_sample/2) + assert spike.shape[0] == spike2.shape[0] + sig_len = spike.shape[0] + c = scipy.signal.correlate(spike, spike2, mode='same') + ipeak, maxi = parabolic_max(c) + shift_computed = (ipeak - np.floor(sig_len / 2)) * -1 + spike_resync = fshift(spike2, -shift_computed) + return spike_resync, shift_computed + +# ------------------------------------------------------------- +# Functions to fit the phase slope, and find the relationship between phase slope and sample shift + + +def line_fit(x, a, b): # function to fit a line and get the slope out + return a * x + b + + +def get_apf_from2spikes(spike, spike2, fs): + fscale = np.fft.rfftfreq(spike.size, 1 / fs) + C = np.fft.rfft(spike) * np.conj(np.fft.rfft(spike2)) + + # Take the phase for freq at high amplitude, and compute slope + amp = np.abs(C) + phase = np.unwrap(np.angle(C)) + return amp, phase, fscale + + +def get_phase_slope(amp, phase, fscale, q=90): + # Take 90 percentile of distribution to find high amplitude freq + thresh_amp = np.percentile(amp, q) + indx_highamp = np.where(amp >= thresh_amp)[0] + # Perform linear fit to get the slope + popt, _ = scipy.optimize.curve_fit(line_fit, xdata=fscale[indx_highamp], ydata=phase[indx_highamp]) + a, b = popt + return a, b + + +def fit_phaseshift(phase_slopes, sample_shifts): + # Get parameters for the phase slope / sample shift curve + popt, _ = scipy.optimize.curve_fit(line_fit, xdata=sample_shifts, ydata=phase_slopes) + a, b = popt + return a, b + + +def get_phase_from_fit(sample_shifts, a, b): + # phases = line_fit(np.abs(sample_shifts), a, b) * np.sign(sample_shifts) + phases = line_fit(sample_shifts, a, b) + return phases + + +def get_shift_from_fit(phases, a, b): + # Invert the line function: x = (y-b)/a + sample_shifts = (phases - b) / a + return sample_shifts + + +def get_spike_slopeparams(spike, fs, num_estim=50): + sample_shifts = np.linspace(-1, 1, num=num_estim) + phase_slopes = np.empty(shape=sample_shifts.shape) + + for i_shift, sample_shift in enumerate(sample_shifts): + spike2 = fshift(spike, sample_shift) + # Get amplitude, phase, fscale + amp, phase, fscale = get_apf_from2spikes(spike, spike2, fs) + # Perform linear fit to get the slope + a, b = get_phase_slope(amp, phase, fscale) + phase_slopes[i_shift] = a + + a_pslope, b_pslope = fit_phaseshift(phase_slopes, sample_shifts) + return a_pslope, b_pslope, sample_shifts, phase_slopes + + +def wave_shift_phase(spike, spike2, fs, a_pos=None, b_pos=None): + ''' + Resynch spike2 onto spike using the phase spectrum's slope + (this work perfectly in theory, but does not work well with raw daw sampled at 30kHz!) + ''' + # Get template parameters if not passed in + if a_pos is None or b_pos is None: + a_pos, b_pos, _, _ = get_spike_slopeparams(spike, fs) + # Get amplitude, phase, fscale + amp, phase, fscale = get_apf_from2spikes(spike, spike2, fs) + # Perform linear fit to get the slope + a, b = get_phase_slope(amp, phase, fscale) + phase_slope = a + # Get sample shift + sample_shift = get_shift_from_fit(phase_slope, a_pos, b_pos) + + # Resynch in time given phase slope + spike_resync = fshift(spike2, -sample_shift) # Use negative to re-synch + return spike_resync, sample_shift + +# End of functions +# ------------------------------------------------------------- + + +def shift_waveform(wf_cluster): + ''' + :param wf_cluster: # A matrix of spike waveforms per cluster (N spike, trace, time) + :return: wf_out (same shape as waveform cluster): A matrix with the waveforms shifted in time + ''' + # Take first the average as template to compute shift on + wfs_avg = np.nanmedian(wf_cluster, axis=0) + # Find the peak channel from template + template = np.transpose(wfs_avg.copy()) # wfs_avg is 2D (trace, time) -> transpose: (time, trace) + arr_temp = np.expand_dims(template, axis=0) # 3D dimension have to be (wav, time, trace) -> add 1 dimension (ax=0) + df_temp = find_peak(arr_temp) + spike_template = arr_temp[:, :, df_temp['peak_trace_idx'][0]] # Take template at peak trace + spike_template = np.squeeze(spike_template) + + # Take the raw spikes at that channel + # Create df for all spikes + ''' + Note: took the party here to NOT recompute the peak channel of each waveform, but to reuse the one from the + template — this is because the function to find the peak assumes the waveform has been denoised + and uses the maximum amplitude value --> which here would lead to failures in the case of collision + ''' + df = pd.DataFrame() + df['peak_trace_idx'] = [df_temp['peak_trace_idx'][0]] * wf_cluster.shape[0] - fun = trange - except ImportError: - fun = range - for i in fun(nwf): - wfs[i, :, :] = arr[sind[i], :][:, cind[i]] + # Per waveform, keep only trace that contains the peak + arr_in = np.swapaxes(wf_cluster, axis1=1, axis2=2) # wfs size (wav, trace, time) -> swap (wav, time, trace) + arr_peak_real = get_array_peak(arr_in, df) - return wfs.swapaxes(1, 2), cind, trough_offset + # Resynch 1 spike with 1 template (using only peak channel) ; Apply shift to all wav traces + wf_out = np.zeros(wf_cluster.shape) + shift_applied = np.zeros(wf_cluster.shape[0]) + for i_spike in range(0, wf_cluster.shape[0]): + # Raw spike at peak channel + spike_raw = arr_peak_real[i_spike, :] + # Resynch + spike_template_resynch, shift_computed = wave_shift_corrmax(spike_raw, spike_template) + # Apply shift to all traces at once + wfs_avg_resync = fshift(wf_cluster[i_spike, :, :], shift_computed) + wf_out[i_spike, :, :] = wfs_avg_resync + shift_applied[i_spike] = shift_computed + + return wf_out, shift_applied diff --git a/src/neurowaveforms/model.py b/src/neurowaveforms/model.py index 6e43e93..0dd4bfc 100644 --- a/src/neurowaveforms/model.py +++ b/src/neurowaveforms/model.py @@ -2,9 +2,7 @@ from ibldsp.fourier import fshift -def generate_waveform( - spike=None, sxy=None, wxy=None, fs=30000, vertical_velocity_mps=3 -): +def generate_waveform(spike=None, sxy=None, wxy=None, fs=30000, vertical_velocity_mps=3, decay_exponent=3.0): """ Generate a waveform from a spike and a set of coordinates :param spike: the single trace spike waveform @@ -12,6 +10,7 @@ def generate_waveform( :param wxy: ntraces by 3 np.array containing the generated traces coordinates :param fs: sampling frequency :param vertical_velocity_mps: vertical velocity of the spike in m/s + :param decay_exponent: the spike amplitude is scaled down :return: the generated waveform ns by ntraces """ # spike coordinates @@ -195,6 +194,6 @@ def generate_waveform( r = np.sqrt(np.sum(np.square(sxy - wxy), axis=1)) sample_shift = (wxy[:, 1] - np.mean(wxy[:, 1])) / 1e6 * vertical_velocity_mps * fs # shperical divergence - wav = spike * 1 / (r[..., np.newaxis] + 50) ** 2 - wav = fshift(wav, sample_shift, axis=-1).T + wav = (spike * 1 / (r[..., np.newaxis] + 50) ** decay_exponent) + wav = fshift(wav, sample_shift, axis=-1) return wav diff --git a/src/tests/unit/cpu/test_waveforms.py b/src/tests/unit/cpu/test_waveforms.py index 0eb0bec..a59ce82 100644 --- a/src/tests/unit/cpu/test_waveforms.py +++ b/src/tests/unit/cpu/test_waveforms.py @@ -5,8 +5,11 @@ import ibldsp.utils as utils import ibldsp.waveforms as waveforms +import ibldsp.waveform_extraction as waveform_extraction from neurowaveforms.model import generate_waveform from neuropixel import trace_header +from ibldsp.fourier import fshift +import scipy import unittest @@ -160,7 +163,7 @@ def test_weights_all_channels(): def test_generate_waveforms(): wav = generate_waveform() - assert wav.shape == (121, 40) + assert wav.shape == (40, 121) class TestWaveformExtractor(unittest.TestCase): @@ -184,10 +187,10 @@ class TestWaveformExtractor(unittest.TestCase): geom = np.c_[geom_dict["x"], geom_dict["y"]] channel_neighbors = utils.make_channel_index(geom, radius=200.0) # radius = 200um, 38 chans - num_channels = 38 + num_channels = 40 def test_extract_waveforms(self): - wfs, _, _ = waveforms.extract_wfs_array( + wfs, _, _ = waveform_extraction.extract_wfs_array( self.arr, self.df, self.channel_neighbors ) @@ -199,21 +202,22 @@ def test_extract_waveforms(self): np.isnan(wfs[0, self.num_channels // 2 + self.channels[0] + 1:, :]) ) - for i in range(1, 8): + for i in range(1, 9): + print(i) # center channel depends on odd/even of channel if self.channels[i] % 2 == 0: - centered_channel_idx = 18 - else: centered_channel_idx = 19 + else: + centered_channel_idx = 20 assert wfs[i, centered_channel_idx, self.trough_offset] == float(i + 1) # last wf is a special case analogous to the first wf, but at the bottom # of the probe if self.channels[-1] % 2 == 0: - centered_channel_idx = 18 - else: centered_channel_idx = 19 - assert wfs[-1, centered_channel_idx, self.trough_offset] == 9.0 + else: + centered_channel_idx = 20 + assert wfs[-1, centered_channel_idx, self.trough_offset] == 9. def test_spike_window(self): # check that we have an error when the last spike window @@ -221,14 +225,73 @@ def test_spike_window(self): df = self.df.copy() df["sample"].iloc[-1] = 996 with self.assertRaisesRegex(AssertionError, "extends"): - _ = waveforms.extract_wfs_array(self.arr, df, self.channel_neighbors) + _ = waveform_extraction.extract_wfs_array(self.arr, df, self.channel_neighbors) def test_nan_channel(self): # test that if user does not fill last column with NaNs # the user can set the flag and the result will be the same arr = self.arr.copy()[:, :-1] - wfs = waveforms.extract_wfs_array(self.arr, self.df, self.channel_neighbors) - wfs_nan = waveforms.extract_wfs_array( + wfs = waveform_extraction.extract_wfs_array(self.arr, self.df, self.channel_neighbors) + wfs_nan = waveform_extraction.extract_wfs_array( arr, self.df, self.channel_neighbors, add_nan_trace=True ) np.testing.assert_equal(wfs, wfs_nan) + + def test_wave_shift_corrmax(self): + sample_shifts = [4.43, -1.0] + sig_lens = [100, 101] + for sample_shift in sample_shifts: + for sig_len in sig_lens: + spike = scipy.signal.morlet2(sig_len, 8.0, 2.0) + spike = -np.fft.irfft(np.fft.rfft(spike) * np.exp(1j * 45 / 180 * np.pi)) + + spike2 = fshift(spike, sample_shift) + spike3, shift_computed = waveforms.wave_shift_corrmax(spike, spike2) + + np.testing.assert_equal(sample_shift, np.around(shift_computed, decimals=2)) + + def test_wave_shift_phase(self): + fs = 30000 + # Resynch in time spike2 onto spike + sample_shift_original = 0.323 + spike = scipy.signal.morlet2(100, 8.5, 2.0) + spike = -np.fft.irfft(np.fft.rfft(spike) * np.exp(1j * 45 / 180 * np.pi)) + spike = np.append(spike, np.zeros((1, 25))) + spike2 = fshift(spike, sample_shift_original) + # Resynch + spike_resync, sample_shift_applied = waveforms.wave_shift_phase(spike, spike2, fs) + np.testing.assert_equal(sample_shift_original, np.around(sample_shift_applied, decimals=3)) + + def test_wave_shift_waveform(self): + sample_shift_original = 15.32 + # Create peak channel spike + spike_peak = scipy.signal.morlet2(100, 8.5, 2.0) # 100 time samples + spike_peak = -np.fft.irfft(np.fft.rfft(spike_peak) * np.exp(1j * 45 / 180 * np.pi)) + # Create other channel spikes + spike_oth = spike_peak * 0.3 + # Create shifted spike + spike_peak2 = fshift(spike_peak, sample_shift_original) + spike_oth2 = fshift(spike_oth, sample_shift_original) + + # Create matrix N=512 wavs: 511 spikes the same, 1 with shifted (one spike will have 2 channels) + wav_normal = np.stack([spike_peak, spike_oth]) # size (trace, time) : (2, 100) + wav_shifted = np.stack([spike_peak2, spike_oth2]) + + n_wav = 511 + wav_rep = np.repeat(wav_normal[:, :, np.newaxis], n_wav, axis=2) + wav_all = np.dstack((wav_rep, wav_shifted)) # size (trace, time, N spike) : (2, 100, 512) + + # Change axis to (N spike, trace, time) : (512, 2, 100) + wav_cluster = np.swapaxes(wav_all, axis1=1, axis2=2) # (2, 512, 100) + wav_cluster = np.swapaxes(wav_cluster, axis1=1, axis2=0) + # The last wav (-1) has the shift after all this swapping - checked visually by plotting below + ''' + import matplotlib.pyplot as plt + fig, axs = plt.subplots(2, 1) + axs[0].imshow(-np.flipud(wav_cluster[0, :, :]), cmap="Grays") + axs[1].imshow(-np.flipud(wav_cluster[-1, :, :]), cmap="Grays") + ''' + wav_out, shift_applied = waveforms.shift_waveform(wav_cluster) + # Test last waveform shift applied is minus the original shift, and the rest 511 waveforms are 0 + np.testing.assert_equal(-sample_shift_original, np.around(shift_applied[-1], decimals=2)) + np.testing.assert_equal(np.zeros(n_wav), np.abs(np.around(shift_applied[0:-1], decimals=2)))