From 63fc63effd358d95cbae4acc7fca3196eb1009a8 Mon Sep 17 00:00:00 2001 From: owinter Date: Mon, 2 Sep 2024 15:28:40 +0100 Subject: [PATCH 1/2] modify wiggle plot to allow filling of positive and/or negative peaks --- src/ibldsp/waveforms.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/ibldsp/waveforms.py b/src/ibldsp/waveforms.py index b05882d..094b9c1 100644 --- a/src/ibldsp/waveforms.py +++ b/src/ibldsp/waveforms.py @@ -231,16 +231,24 @@ def find_tip_trough(arr_peak, arr_peak_real, df): return df, arr_peak -def plot_wiggle(wav, fs=1, ax=None, scalar=0.3, clip=1.5, **axkwargs): +def plot_wiggle(wav, fs=1, ax=None, scalar=0.3, clip=10, fill_sign=-1, plot_kwargs=None, fill_kwargs=None): """ Displays a multi-trace waveform in a wiggle traces with negative amplitudes filled :param wav: (nchannels, nsamples) - :param axkwargs: keyword arguments to feed to ax.set() + :param fs: sampling rate + :param ax: axis to plot on + :param scalar: scale factor for the traces + :param clip: maximum value for the traces + :param fill_sign: -1 for negative (default for spikes), 1 for positive + :param plot_kwargs: kwargs for the line plot + :param fill_kwargs: kwargs for the fill :return: """ if ax is None: fig, ax = plt.subplots() + plot_kwargs = {'color': 'k', 'linewidth': 0.5} | (plot_kwargs or {}) + fill_kwargs = {'color': 'k', 'aa': True} | (fill_kwargs or {}) nc, ns = wav.shape vals = np.c_[wav, wav[:, :1] * np.nan].ravel() # flat view of the 2d array. vect = np.arange(vals.size).astype( @@ -262,11 +270,15 @@ def plot_wiggle(wav, fs=1, ax=None, scalar=0.3, clip=1.5, **axkwargs): order = np.argsort(y) # shift from amplitudes to plotting coordinates x_shift, y = y[order].__divmod__(ns + 1) - ax.plot(y / fs, x[order] + x_shift + 1, 'k', linewidth=.5) - x[x > 0] = np.nan + print(plot_kwargs) + ax.plot(y / fs, x[order] + x_shift + 1, **plot_kwargs) + if fill_sign < 0: + x[x > 0] = np.nan + else: + x[x < 0] = np.nan x = x[order] + x_shift + 1 - ax.fill(y / fs, x, 'k', aa=True) - ax.set(xlim=[0, ns / fs], ylim=[0, nc], xlabel='sample', ylabel='trace') + ax.fill(y / fs, x, **fill_kwargs) + ax.set(xlim=[0, ns / fs], ylim=[0, nc]) plt.tight_layout() return ax From 6bf1f5c797dc95e1b3f615ef6fbf0b68301a8dbd Mon Sep 17 00:00:00 2001 From: owinter Date: Tue, 10 Sep 2024 13:30:47 +0100 Subject: [PATCH 2/2] double wiggle: double trouble. And tests. --- src/ibldsp/waveforms.py | 35 ++++++++++++++++++++++++---- src/tests/unit/cpu/test_waveforms.py | 18 ++++++++++---- 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/src/ibldsp/waveforms.py b/src/ibldsp/waveforms.py index 094b9c1..edae543 100644 --- a/src/ibldsp/waveforms.py +++ b/src/ibldsp/waveforms.py @@ -6,7 +6,9 @@ import numpy as np import pandas as pd import matplotlib.pyplot as plt +import matplotlib as mpl import scipy + from ibldsp.utils import parabolic_max from ibldsp.fourier import fshift @@ -231,19 +233,19 @@ def find_tip_trough(arr_peak, arr_peak_real, df): return df, arr_peak -def plot_wiggle(wav, fs=1, ax=None, scalar=0.3, clip=10, fill_sign=-1, plot_kwargs=None, fill_kwargs=None): +def plot_wiggle(wav, fs=1, ax=None, scale=0.3, clip=10, fill_sign=-1, plot_kwargs=None, fill_kwargs=None): """ Displays a multi-trace waveform in a wiggle traces with negative amplitudes filled :param wav: (nchannels, nsamples) :param fs: sampling rate :param ax: axis to plot on - :param scalar: scale factor for the traces + :param scale: waveform amplitude that will be displayed as one inter-trace: if scale = 20e-6 one intertrace will be 20uV :param clip: maximum value for the traces :param fill_sign: -1 for negative (default for spikes), 1 for positive :param plot_kwargs: kwargs for the line plot :param fill_kwargs: kwargs for the fill - :return: + :return: axis """ if ax is None: fig, ax = plt.subplots() @@ -263,7 +265,7 @@ def plot_wiggle(wav, fs=1, ax=None, scalar=0.3, clip=10, fill_sign=-1, plot_kwar m = (y2 - y1) / (x2 - x1) c = y1 - m * x1 # tack these values onto the end of the existing data - x = np.hstack([vals, np.zeros_like(c)]) * scalar + x = np.hstack([vals, np.zeros_like(c)]) / scale x = np.maximum(np.minimum(x, clip), -clip) y = np.hstack([vect, c]) # resort the data @@ -283,6 +285,31 @@ def plot_wiggle(wav, fs=1, ax=None, scalar=0.3, clip=10, fill_sign=-1, plot_kwar return ax +def double_wiggle(wav, fs=1, ax=None, colors=None, **kwargs): + """ + Double trouble: this wiggle colours both the negative and the postive values + :param wav: (nchannels, nsamples) + :param fs: sampling rate + :param ax: axis to plot on + :param scale: scale factor for the traces + :param clip: maximum value for the traces + :param fill_sign: -1 for negative (default for spikes), 1 for positive + :param plot_kwargs: kwargs for the line plot + :param fill_kwargs: kwargs for the fill + :return: + """ + if colors is None: + cmap = 'PuOr' + _cmap = mpl.colormaps.get_cmap(cmap) + colors = _cmap(np.linspace(0, 1, 256)) + colors = [colors[50], colors[-50]] + if ax is None: + fig, ax = plt.subplots() + plot_wiggle(wav, fs=fs / 1e3, ax=ax, plot_kwargs={'linewidth': 0}, fill_kwargs={'color': colors[0]}, **kwargs) + plot_wiggle(wav, fs=fs / 1e3, ax=ax, fill_sign=1, plot_kwargs={'linewidth': 0.5}, fill_kwargs={'color': colors[1]}, **kwargs) + return ax + + def plot_peaktiptrough(df, arr, ax, nth_wav=0, plot_grey=True, fs=30000): # Time axix nech, ntr = arr[nth_wav].shape diff --git a/src/tests/unit/cpu/test_waveforms.py b/src/tests/unit/cpu/test_waveforms.py index e97b397..765cefd 100644 --- a/src/tests/unit/cpu/test_waveforms.py +++ b/src/tests/unit/cpu/test_waveforms.py @@ -1,9 +1,12 @@ from pathlib import Path +import shutil +import tempfile +import unittest import numpy as np import pandas as pd -import tempfile -import shutil +import matplotlib.pyplot as plt +import scipy import ibldsp.utils as utils import ibldsp.waveforms as waveforms @@ -11,9 +14,7 @@ from neurowaveforms.model import generate_waveform from neuropixel import trace_header from ibldsp.fourier import fshift -import scipy -import unittest TEST_PATH = Path(__file__).parent.joinpath("fixtures") @@ -391,3 +392,12 @@ def test_extract_waveforms_bin(self): assert np.allclose(np.nan_to_num(waveforms[1, [5, 6, 7]]), np.nan_to_num(wfs[1])) # right channels assert np.all(channels == self.chan_map[info.peak_channel.astype(int).to_numpy()]) + + +def test_wiggle(): + wav = generate_waveform() + wav = wav / np.max(np.abs(wav)) * 120 * 1e-6 + fig, ax = plt.subplots(1, 2) + waveforms.plot_wiggle(wav, scale=40 * 1e-6, ax=ax[0]) + waveforms.double_wiggle(wav, scale=40 * 1e-6, fs=30_000, ax=ax[1]) + plt.close('all')