Skip to content

Commit

Permalink
fix numpy errors in waveforms tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-langfield committed Jul 2, 2024
1 parent 935149d commit edcc883
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/tests/unit/cpu/test_ibldsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def test_timestamps(self):

class TestParabolicMax(unittest.TestCase):
# expected values
maxi = np.array([np.NaN, 0, 3.04166667, 3.04166667, 5, 5])
ipeak = np.array([np.NaN, 0, 5.166667, 2.166667, 0, 7])
maxi = np.array([np.nan, 0, 3.04166667, 3.04166667, 5, 5])
ipeak = np.array([np.nan, 0, 5.166667, 2.166667, 0, 7])
# input
x = np.array(
[
[0, 0, 0, 0, 0, np.NaN, 0, 0], # some NaNs
[0, 0, 0, 0, 0, np.nan, 0, 0], # some nans
[0, 0, 0, 0, 0, 0, 0, 0], # all flat
[0, 0, 0, 0, 1, 3, 2, 0],
[0, 1, 3, 2, 0, 0, 0, 0],
Expand Down
8 changes: 5 additions & 3 deletions src/tests/unit/cpu/test_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ def test_wave_shift_corrmax(self):
for sample_shift in sample_shifts:
for sig_len in sig_lens:
spike = scipy.signal.morlet2(sig_len, 8.0, 2.0)
spike = -np.fft.irfft(np.fft.rfft(spike) * np.exp(1j * 45 / 180 * np.pi))
spike = -np.fft.irfft(np.fft.rfft(np.real(spike)) * np.exp(1j * 45 / 180 * np.pi))
import pdb
pdb.set_trace

spike2 = fshift(spike, sample_shift)
spike3, shift_computed = waveforms.wave_shift_corrmax(spike, spike2)
Expand All @@ -259,7 +261,7 @@ def test_wave_shift_phase(self):
# Resynch in time spike2 onto spike
sample_shift_original = 0.323
spike = scipy.signal.morlet2(100, 8.5, 2.0)
spike = -np.fft.irfft(np.fft.rfft(spike) * np.exp(1j * 45 / 180 * np.pi))
spike = -np.fft.irfft(np.fft.rfft(np.real(spike)) * np.exp(1j * 45 / 180 * np.pi))
spike = np.append(spike, np.zeros((1, 25)))
spike2 = fshift(spike, sample_shift_original)
# Resynch
Expand All @@ -270,7 +272,7 @@ def test_wave_shift_waveform(self):
sample_shift_original = 15.32
# Create peak channel spike
spike_peak = scipy.signal.morlet2(100, 8.5, 2.0) # 100 time samples
spike_peak = -np.fft.irfft(np.fft.rfft(spike_peak) * np.exp(1j * 45 / 180 * np.pi))
spike_peak = -np.fft.irfft(np.fft.rfft(np.real(spike_peak)) * np.exp(1j * 45 / 180 * np.pi))
# Create other channel spikes
spike_oth = spike_peak * 0.3
# Create shifted spike
Expand Down

0 comments on commit edcc883

Please sign in to comment.