-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add numpy alternative to FE using torchaudio #26339
Changes from 5 commits
f36232b
fbc40d3
5ad48e0
73a4c06
f06db42
608644b
605ed14
b16b5a9
8af2313
5e98476
4aa7100
d2e2714
c18ee1a
4dd6207
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,12 +16,14 @@ | |
Feature extractor class for Audio Spectrogram Transformer. | ||
""" | ||
|
||
import copy | ||
from typing import List, Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
import torchaudio.compliance.kaldi as ta_kaldi | ||
|
||
from ...audio_utils import mel_filter_bank, spectrogram, window_function | ||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor | ||
from ...feature_extraction_utils import BatchFeature | ||
from ...utils import TensorType, logging | ||
|
@@ -58,6 +60,9 @@ class ASTFeatureExtractor(SequenceFeatureExtractor): | |
by default. | ||
return_attention_mask (`bool`, *optional*, defaults to `False`): | ||
Whether or not [`~ASTFeatureExtractor.__call__`] should return `attention_mask`. | ||
use_torchaudio (`bool`, *optional*, defaults to `True`): | ||
Whether or not to use torchaudio implementation of mel-filter banks. If `False`, use a numpy porting of | ||
torchaudio mel-filter banks implementation. | ||
""" | ||
|
||
model_input_names = ["input_values", "attention_mask"] | ||
|
@@ -73,6 +78,7 @@ def __init__( | |
mean=-4.2677393, | ||
std=4.5689974, | ||
return_attention_mask=False, | ||
use_torchaudio=True, | ||
**kwargs, | ||
): | ||
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) | ||
|
@@ -83,6 +89,22 @@ def __init__( | |
self.std = std | ||
self.return_attention_mask = return_attention_mask | ||
|
||
self.use_torchaudio = use_torchaudio | ||
if not use_torchaudio: | ||
mel_filters = mel_filter_bank( | ||
num_frequency_bins=256, | ||
num_mel_filters=self.num_mel_bins, | ||
min_frequency=20, | ||
max_frequency=sampling_rate // 2, | ||
sampling_rate=sampling_rate, | ||
norm=None, | ||
mel_scale="kaldi", | ||
triangularize_in_mel_space=True, | ||
) | ||
|
||
self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0))) | ||
self.window = window_function(400, "hann", periodic=False) | ||
|
||
def _extract_fbank_features( | ||
self, | ||
waveform: np.ndarray, | ||
|
@@ -93,17 +115,32 @@ def _extract_fbank_features( | |
and hence the waveform should not be normalized before feature extraction. | ||
""" | ||
# waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers | ||
waveform = torch.from_numpy(waveform).unsqueeze(0) | ||
fbank = ta_kaldi.fbank( | ||
waveform, | ||
htk_compat=True, | ||
sample_frequency=self.sampling_rate, | ||
use_energy=False, | ||
window_type="hanning", | ||
num_mel_bins=self.num_mel_bins, | ||
dither=0.0, | ||
frame_shift=10, | ||
) | ||
if self.use_torchaudio: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think there's a need to have the if_torchaudio_is_available():
# do legacy code
else:
# do numpy code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm also fine with removing the legacy torchaudio code altogether. I know this makes the feature extraction quite a bit slower, but I think this is fine to remove the extra dependencies to bring these models in-line with the rest of the audio library. Personally, I would favour this approach over supporting both methods for feature extraction (torchaudio and numpy). IMO having both methods convolutes the code quite a lot, which is something we want to avoid. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fine with me to remove the previous code, it won’t be performance wise backward compatible 🫠 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's go with the first option here then? Decorate with |
||
waveform = torch.from_numpy(waveform).unsqueeze(0) | ||
fbank = ta_kaldi.fbank( | ||
waveform, | ||
sample_frequency=self.sampling_rate, | ||
window_type="hanning", | ||
num_mel_bins=self.num_mel_bins, | ||
) | ||
else: | ||
waveform = np.squeeze(waveform) | ||
fbank = spectrogram( | ||
waveform, | ||
self.window, | ||
frame_length=400, | ||
hop_length=160, | ||
fft_length=512, | ||
power=2.0, | ||
center=False, | ||
preemphasis=0.97, | ||
mel_filters=self.mel_filters, | ||
log_mel="log", | ||
mel_floor=1.192092955078125e-07, | ||
remove_dc_offset=True, | ||
).T | ||
|
||
fbank = torch.from_numpy(fbank) | ||
|
||
n_frames = fbank.shape[0] | ||
difference = max_length - n_frames | ||
|
@@ -198,3 +235,16 @@ def __call__( | |
padded_inputs = padded_inputs.convert_to_tensors(return_tensors) | ||
|
||
return padded_inputs | ||
|
||
def to_dict(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this method strictly necessary? If it is, shouldn't it go in the base There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. which method do you mean? to_dict? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it's about time we add this to the base feature class! |
||
""" | ||
Serializes this instance to a Python dictionary. Returns: | ||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. | ||
""" | ||
output = copy.deepcopy(self.__dict__) | ||
output["feature_extractor_type"] = self.__class__.__name__ | ||
if "mel_filters" in output: | ||
del output["mel_filters"] | ||
if "window" in output: | ||
del output["window"] | ||
return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also took the opportunity to remove some unnecessary parameters here