-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add numpy alternative to FE using torchaudio #26339
Changes from 10 commits
f36232b
fbc40d3
5ad48e0
73a4c06
f06db42
608644b
605ed14
b16b5a9
8af2313
5e98476
4aa7100
d2e2714
c18ee1a
4dd6207
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,12 +19,18 @@ | |
from typing import List, Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
import torchaudio.compliance.kaldi as ta_kaldi | ||
|
||
from ...audio_utils import mel_filter_bank, spectrogram, window_function | ||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor | ||
from ...feature_extraction_utils import BatchFeature | ||
from ...utils import TensorType, logging | ||
from ...utils import TensorType, is_speech_available, is_torch_available, logging | ||
|
||
|
||
if is_speech_available(): | ||
import torchaudio.compliance.kaldi as ta_kaldi | ||
|
||
if is_torch_available(): | ||
Comment on lines
+29
to
+32
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In speech-to-text we bundle these imports into one: if is_speech_available():
import torchaudio.compliance.kaldi as ta_kaldi
import torch Should we do the same here since we can only use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, |
||
import torch | ||
|
||
|
||
logger = logging.get_logger(__name__) | ||
|
@@ -37,8 +43,8 @@ class ASTFeatureExtractor(SequenceFeatureExtractor): | |
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains | ||
most of the main methods. Users should refer to this superclass for more information regarding those methods. | ||
|
||
This class extracts mel-filter bank features from raw speech using TorchAudio, pads/truncates them to a fixed | ||
length and normalizes them using a mean and standard deviation. | ||
This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy | ||
otherwise, pads/truncates them to a fixed length and normalizes them using a mean and standard deviation. | ||
|
||
Args: | ||
feature_size (`int`, *optional*, defaults to 1): | ||
|
@@ -83,6 +89,21 @@ def __init__( | |
self.std = std | ||
self.return_attention_mask = return_attention_mask | ||
|
||
if not is_speech_available(): | ||
mel_filters = mel_filter_bank( | ||
num_frequency_bins=256, | ||
num_mel_filters=self.num_mel_bins, | ||
min_frequency=20, | ||
max_frequency=sampling_rate // 2, | ||
sampling_rate=sampling_rate, | ||
norm=None, | ||
mel_scale="kaldi", | ||
triangularize_in_mel_space=True, | ||
) | ||
|
||
self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0))) | ||
self.window = window_function(400, "hann", periodic=False) | ||
|
||
def _extract_fbank_features( | ||
self, | ||
waveform: np.ndarray, | ||
|
@@ -93,17 +114,32 @@ def _extract_fbank_features( | |
and hence the waveform should not be normalized before feature extraction. | ||
""" | ||
# waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers | ||
waveform = torch.from_numpy(waveform).unsqueeze(0) | ||
fbank = ta_kaldi.fbank( | ||
waveform, | ||
htk_compat=True, | ||
sample_frequency=self.sampling_rate, | ||
use_energy=False, | ||
window_type="hanning", | ||
num_mel_bins=self.num_mel_bins, | ||
dither=0.0, | ||
frame_shift=10, | ||
) | ||
Comment on lines
-97
to
-106
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also took the opportunity to remove some unnecessary parameters here |
||
if is_speech_available(): | ||
waveform = torch.from_numpy(waveform).unsqueeze(0) | ||
fbank = ta_kaldi.fbank( | ||
waveform, | ||
sample_frequency=self.sampling_rate, | ||
window_type="hanning", | ||
num_mel_bins=self.num_mel_bins, | ||
) | ||
else: | ||
waveform = np.squeeze(waveform) | ||
fbank = spectrogram( | ||
waveform, | ||
self.window, | ||
frame_length=400, | ||
hop_length=160, | ||
fft_length=512, | ||
power=2.0, | ||
center=False, | ||
preemphasis=0.97, | ||
mel_filters=self.mel_filters, | ||
log_mel="log", | ||
mel_floor=1.192092955078125e-07, | ||
remove_dc_offset=True, | ||
).T | ||
|
||
fbank = torch.from_numpy(fbank) | ||
|
||
n_frames = fbank.shape[0] | ||
difference = max_length - n_frames | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As requested I've modified
to_dict
directly infeature_extraction_utils.py