From 6e5af6f912f33c77615ef11e1fb455ecd30b0af9 Mon Sep 17 00:00:00 2001 From: Tim Widrick Date: Sat, 20 Mar 2021 16:46:04 -0400 Subject: [PATCH] BUG: fix dsp.waterfall to use 'which=None' properly --- pyyeti/dsp.py | 12 +++++++----- tests/test_dsp.py | 10 ++++++++++ 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/pyyeti/dsp.py b/pyyeti/dsp.py index 8a04bea..821bd97 100644 --- a/pyyeti/dsp.py +++ b/pyyeti/dsp.py @@ -2062,10 +2062,9 @@ def waterfall( Setting `which` to None is not the same as setting it to 0. Using None means that the function only returns amplitudes, while a 0 indicates that the output of `func` - must be indexed by 0 to get the amplitudes. An example - might make this more clear: if the function has ``return - amps``, use ``which=None``; if the function has ``return - (amps,)``, use ``which=0``. + must be indexed by 0 to get the amplitudes. For example: if + the function has ``return amps``, use ``which=None``; if + the function has ``return (amps,)``, use ``which=0``. freq : integer or vector If integer, it is the index of the output of `func` @@ -2217,9 +2216,12 @@ def slicefunc(a): raise ValueError("`which` cannot be None when `freq` is an integer") freq = res[freq] flen = len(freq) + res_dtype = res[which].dtype else: flen = len(freq) - mp = np.zeros((flen, tlen), res[which].dtype) + res_dtype = res.dtype if which is None else res[which].dtype + + mp = np.zeros((flen, tlen), res_dtype) mp[:, 0] = res[which] for j in range(1, tlen): diff --git a/tests/test_dsp.py b/tests/test_dsp.py index 4de1e48..49c03e1 100644 --- a/tests/test_dsp.py +++ b/tests/test_dsp.py @@ -103,6 +103,16 @@ def func(s): mp, t, f = dsp.waterfall(sig, sr, 2, 0.5, func, which=0, freq=1) mp, t, f = dsp.waterfall(sig2, sr, 2, 0.5, func, which=0, freq=1) + + def func2(s): + return srs.srs(s, sr, frq, Q) + + mp2, t2, f2 = dsp.waterfall(sig2, sr, 2, 0.5, func2, which=None, freq=frq) + + assert np.allclose(mp, mp2) + assert np.allclose(t, t2) + assert np.allclose(f, f2) + assert_raises(ValueError, dsp.waterfall, sig, sr, 2, 0.5, func, which=None, freq=1) assert_raises( ValueError, dsp.waterfall, sig, sr, 2, 1.5, func, which=None, freq=frq