Skip to content

Commit

Permalink
ENH: Add STFT function to Function class (#620)
Browse files Browse the repository at this point in the history
Squash of the following commits:
* Add STFT function to Function class
* Added feature: Short-Time Fourier Transform function
* Added feature: Short-Time Fourier Transform function
* "Variable name changes in stft"
* "Variable and function name formatting"
* "Better Example"
* Add STFT function to Function class
* Added feature: Short-Time Fourier Transform function
* Added feature: Short-Time Fourier Transform function
* "Variable name changes in stft"
* "Variable and function name formatting"
* "Better Example"
* Fixed the doctest
* "Spectrogram example"
* small fixes to STFT function

---------

Signed-off-by: AdvaitChandorkar07 <[email protected]>
Co-authored-by: Gui-FernandesBR <[email protected]>
  • Loading branch information
AdvaitChandorkar07 and Gui-FernandesBR authored Aug 18, 2024
1 parent 8b4c14a commit e40eecc
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
"copybutton",
"cstride",
"csys",
"cumsum",
"datapoints",
"datetime",
"dcsys",
Expand Down Expand Up @@ -149,6 +150,7 @@
"IGRA",
"imageio",
"imread",
"imshow",
"intc",
"interp",
"Interquartile",
Expand Down Expand Up @@ -258,6 +260,7 @@
"SRTM",
"SRTMGL",
"Stano",
"STFT",
"subintervals",
"suptitle",
"ticklabel",
Expand Down
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ Attention: The newest changes should be on top -->

### Added

- ENH: Rocket Axis Definition [#635](https://github.com/RocketPy-Team/RocketPy/pull/635)
- ENH: Add STFT function to Function class [#620](https://github.com/RocketPy-Team/RocketPy/pull/620)
- ENH: Rocket Axis Definition [#635](https://github.com/RocketPy-Team/RocketPy/pull/635)

### Changed

Expand Down
136 changes: 136 additions & 0 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,142 @@ def to_frequency_domain(self, lower, upper, sampling_frequency, remove_dc=True):
extrapolation="zero",
)

def short_time_fft(
self,
lower,
upper,
sampling_frequency,
window_size,
step_size,
remove_dc=True,
only_positive=True,
):
r"""
Performs the Short-Time Fourier Transform (STFT) of the Function and
returns the result. The STFT is computed by applying the Fourier
transform to overlapping windows of the Function.
Parameters
----------
lower : float
Lower bound of the time range.
upper : float
Upper bound of the time range.
sampling_frequency : float
Sampling frequency at which to perform the Fourier transform.
window_size : float
Size of the window for the STFT, in seconds.
step_size : float
Step size for the window, in seconds.
remove_dc : bool, optional
If True, the DC component is removed from each window before
computing the Fourier transform.
only_positive: bool, optional
If True, only the positive frequencies are returned.
Returns
-------
list[Function]
A list of Functions, each representing the STFT of a window.
Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from rocketpy import Function
Generate a signal with varying frequency:
>>> T_x, N = 1 / 20 , 1000 # 20 Hz sampling rate for 50 s signal
>>> t_x = np.arange(N) * T_x # time indexes for signal
>>> f_i = 1 * np.arctan((t_x - t_x[N // 2]) / 2) + 5 # varying frequency
>>> signal = np.sin(2 * np.pi * np.cumsum(f_i) * T_x) # the signal
Create the Function object and perform the STFT:
>>> time_domain = Function(np.array([t_x, signal]).T)
>>> stft_result = time_domain.short_time_fft(
... lower=0,
... upper=50,
... sampling_frequency=95,
... window_size=2,
... step_size=0.5,
... )
Plot the spectrogram:
>>> Sx = np.abs([window[:, 1] for window in stft_result])
>>> t_lo, t_hi = t_x[0], t_x[-1]
>>> fig1, ax1 = plt.subplots(figsize=(10, 6))
>>> im1 = ax1.imshow(
... Sx.T,
... origin='lower',
... aspect='auto',
... extent=[t_lo, t_hi, 0, 50],
... cmap='viridis'
... )
>>> _ = ax1.set_title(rf"STFT (2$\,s$ Gaussian window, $\sigma_t=0.4\,$s)")
>>> _ = ax1.set(
... xlabel=f"Time $t$ in seconds",
... ylabel=f"Freq. $f$ in Hz)",
... xlim=(t_lo, t_hi)
... )
>>> _ = ax1.plot(t_x, f_i, 'r--', alpha=.5, label='$f_i(t)$')
>>> _ = fig1.colorbar(im1, label="Magnitude $|S_x(t, f)|$")
>>> # Shade areas where window slices stick out to the side
>>> for t0_, t1_ in [(t_lo, 1), (49, t_hi)]:
... _ = ax1.axvspan(t0_, t1_, color='w', linewidth=0, alpha=.2)
>>> # Mark signal borders with vertical line
>>> for t_ in [t_lo, t_hi]:
... _ = ax1.axvline(t_, color='y', linestyle='--', alpha=0.5)
>>> # Add legend and finalize plot
>>> _ = ax1.legend()
>>> fig1.tight_layout()
>>> # plt.show() # uncomment to show the plot
References
----------
Example adapted from the SciPy documentation:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.ShortTimeFFT.html
"""
# Get the time domain data
sampling_time_step = 1.0 / sampling_frequency
sampling_range = np.arange(lower, upper, sampling_time_step)
sampled_points = self(sampling_range)
samples_per_window = int(window_size * sampling_frequency)
samples_skipped_per_step = int(step_size * sampling_frequency)
stft_results = []

max_start = len(sampled_points) - samples_per_window + 1

for start in range(0, max_start, samples_skipped_per_step):
windowed_samples = sampled_points[start : start + samples_per_window]
if remove_dc:
windowed_samples -= np.mean(windowed_samples)
fourier_amplitude = np.abs(
np.fft.fft(windowed_samples) / (samples_per_window / 2)
)
fourier_frequencies = np.fft.fftfreq(samples_per_window, sampling_time_step)

# Filter to keep only positive frequencies if specified
if only_positive:
positive_indices = fourier_frequencies > 0
fourier_frequencies = fourier_frequencies[positive_indices]
fourier_amplitude = fourier_amplitude[positive_indices]

stft_results.append(
Function(
source=np.array([fourier_frequencies, fourier_amplitude]).T,
inputs="Frequency (Hz)",
outputs="Amplitude",
interpolation="linear",
extrapolation="zero",
)
)

return stft_results

def low_pass_filter(self, alpha, file_path=None):
"""Implements a low pass filter with a moving average filter. This does
not mutate the original Function object, but returns a new one with the
Expand Down

0 comments on commit e40eecc

Please sign in to comment.