Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modify wiggle plot to allow filling of positive and/or negative peaks #42

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 47 additions & 8 deletions src/ibldsp/waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import scipy

from ibldsp.utils import parabolic_max
from ibldsp.fourier import fshift

Expand Down Expand Up @@ -231,16 +233,24 @@ def find_tip_trough(arr_peak, arr_peak_real, df):
return df, arr_peak


def plot_wiggle(wav, fs=1, ax=None, scalar=0.3, clip=1.5, **axkwargs):
def plot_wiggle(wav, fs=1, ax=None, scale=0.3, clip=10, fill_sign=-1, plot_kwargs=None, fill_kwargs=None):
"""
Displays a multi-trace waveform in a wiggle traces with negative
amplitudes filled
:param wav: (nchannels, nsamples)
:param axkwargs: keyword arguments to feed to ax.set()
:return:
:param fs: sampling rate
:param ax: axis to plot on
:param scale: waveform amplitude that will be displayed as one inter-trace: if scale = 20e-6 one intertrace will be 20uV
:param clip: maximum value for the traces
:param fill_sign: -1 for negative (default for spikes), 1 for positive
:param plot_kwargs: kwargs for the line plot
:param fill_kwargs: kwargs for the fill
:return: axis
"""
if ax is None:
fig, ax = plt.subplots()
plot_kwargs = {'color': 'k', 'linewidth': 0.5} | (plot_kwargs or {})
fill_kwargs = {'color': 'k', 'aa': True} | (fill_kwargs or {})
nc, ns = wav.shape
vals = np.c_[wav, wav[:, :1] * np.nan].ravel() # flat view of the 2d array.
vect = np.arange(vals.size).astype(
Expand All @@ -255,22 +265,51 @@ def plot_wiggle(wav, fs=1, ax=None, scalar=0.3, clip=1.5, **axkwargs):
m = (y2 - y1) / (x2 - x1)
c = y1 - m * x1
# tack these values onto the end of the existing data
x = np.hstack([vals, np.zeros_like(c)]) * scalar
x = np.hstack([vals, np.zeros_like(c)]) / scale
x = np.maximum(np.minimum(x, clip), -clip)
y = np.hstack([vect, c])
# resort the data
order = np.argsort(y)
# shift from amplitudes to plotting coordinates
x_shift, y = y[order].__divmod__(ns + 1)
ax.plot(y / fs, x[order] + x_shift + 1, 'k', linewidth=.5)
x[x > 0] = np.nan
print(plot_kwargs)
ax.plot(y / fs, x[order] + x_shift + 1, **plot_kwargs)
if fill_sign < 0:
x[x > 0] = np.nan
else:
x[x < 0] = np.nan
x = x[order] + x_shift + 1
ax.fill(y / fs, x, 'k', aa=True)
ax.set(xlim=[0, ns / fs], ylim=[0, nc], xlabel='sample', ylabel='trace')
ax.fill(y / fs, x, **fill_kwargs)
ax.set(xlim=[0, ns / fs], ylim=[0, nc])
plt.tight_layout()
return ax


def double_wiggle(wav, fs=1, ax=None, colors=None, **kwargs):
"""
Double trouble: this wiggle colours both the negative and the postive values
:param wav: (nchannels, nsamples)
:param fs: sampling rate
:param ax: axis to plot on
:param scale: scale factor for the traces
:param clip: maximum value for the traces
:param fill_sign: -1 for negative (default for spikes), 1 for positive
:param plot_kwargs: kwargs for the line plot
:param fill_kwargs: kwargs for the fill
:return:
"""
if colors is None:
cmap = 'PuOr'
_cmap = mpl.colormaps.get_cmap(cmap)
colors = _cmap(np.linspace(0, 1, 256))
colors = [colors[50], colors[-50]]
if ax is None:
fig, ax = plt.subplots()
plot_wiggle(wav, fs=fs / 1e3, ax=ax, plot_kwargs={'linewidth': 0}, fill_kwargs={'color': colors[0]}, **kwargs)
plot_wiggle(wav, fs=fs / 1e3, ax=ax, fill_sign=1, plot_kwargs={'linewidth': 0.5}, fill_kwargs={'color': colors[1]}, **kwargs)
return ax


def plot_peaktiptrough(df, arr, ax, nth_wav=0, plot_grey=True, fs=30000):
# Time axix
nech, ntr = arr[nth_wav].shape
Expand Down
18 changes: 14 additions & 4 deletions src/tests/unit/cpu/test_waveforms.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from pathlib import Path
import shutil
import tempfile
import unittest

import numpy as np
import pandas as pd
import tempfile
import shutil
import matplotlib.pyplot as plt
import scipy

import ibldsp.utils as utils
import ibldsp.waveforms as waveforms
import ibldsp.waveform_extraction as waveform_extraction
from neurowaveforms.model import generate_waveform
from neuropixel import trace_header
from ibldsp.fourier import fshift
import scipy

import unittest

TEST_PATH = Path(__file__).parent.joinpath("fixtures")

Expand Down Expand Up @@ -391,3 +392,12 @@ def test_extract_waveforms_bin(self):
assert np.allclose(np.nan_to_num(waveforms[1, [5, 6, 7]]), np.nan_to_num(wfs[1]))
# right channels
assert np.all(channels == self.chan_map[info.peak_channel.astype(int).to_numpy()])


def test_wiggle():
wav = generate_waveform()
wav = wav / np.max(np.abs(wav)) * 120 * 1e-6
fig, ax = plt.subplots(1, 2)
waveforms.plot_wiggle(wav, scale=40 * 1e-6, ax=ax[0])
waveforms.double_wiggle(wav, scale=40 * 1e-6, fs=30_000, ax=ax[1])
plt.close('all')
Loading