Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add numpy alternative to FE using torchaudio #26339

Merged
merged 14 commits into from
Nov 8, 2023
Merged
27 changes: 4 additions & 23 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
"models.audio_spectrogram_transformer": [
"AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"ASTConfig",
"ASTFeatureExtractor",
],
"models.auto": [
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
Expand Down Expand Up @@ -535,6 +536,7 @@
"models.speech_to_text": [
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Speech2TextConfig",
"Speech2TextFeatureExtractor",
"Speech2TextProcessor",
],
"models.speech_to_text_2": [
Expand Down Expand Up @@ -913,20 +915,6 @@
else:
_import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"]

# Speech-specific objects
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_speech_objects

_import_structure["utils.dummy_speech_objects"] = [
name for name in dir(dummy_speech_objects) if not name.startswith("_")
]
else:
_import_structure["models.audio_spectrogram_transformer"].append("ASTFeatureExtractor")
_import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor")

# Tensorflow-text-specific objects
try:
if not is_tensorflow_text_available():
Expand Down Expand Up @@ -4352,6 +4340,7 @@
from .models.audio_spectrogram_transformer import (
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
ASTConfig,
ASTFeatureExtractor,
)
from .models.auto import (
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
Expand Down Expand Up @@ -4722,6 +4711,7 @@
from .models.speech_to_text import (
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Speech2TextConfig,
Speech2TextFeatureExtractor,
Speech2TextProcessor,
)
from .models.speech_to_text_2 import (
Expand Down Expand Up @@ -5067,15 +5057,6 @@
else:
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer

try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_speech_objects import *
else:
from .models.audio_spectrogram_transformer import ASTFeatureExtractor
from .models.speech_to_text import Speech2TextFeatureExtractor

try:
if not is_tensorflow_text_available():
raise OptionalDependencyNotAvailable()
Expand Down
11 changes: 6 additions & 5 deletions src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,14 +584,15 @@ def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrain

def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.

Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance.
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
output["feature_extractor_type"] = self.__class__.__name__

if "mel_filters" in output:
del output["mel_filters"]
if "window" in output:
del output["window"]
return output

@classmethod
Comment on lines 585 to 598
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As requested I've modified to_dict directly in feature_extraction_utils.py

Expand Down
21 changes: 4 additions & 17 deletions src/transformers/models/audio_spectrogram_transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_speech_available, is_torch_available
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available


_import_structure = {
"configuration_audio_spectrogram_transformer": [
"AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"ASTConfig",
]
],
"feature_extraction_audio_spectrogram_transformer": ["ASTFeatureExtractor"],
}

try:
Expand All @@ -36,19 +37,13 @@
"ASTPreTrainedModel",
]

try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_audio_spectrogram_transformer"] = ["ASTFeatureExtractor"]

if TYPE_CHECKING:
from .configuration_audio_spectrogram_transformer import (
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
ASTConfig,
)
from .feature_extraction_audio_spectrogram_transformer import ASTFeatureExtractor

try:
if not is_torch_available():
Expand All @@ -63,14 +58,6 @@
ASTPreTrainedModel,
)

try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_audio_spectrogram_transformer import ASTFeatureExtractor


else:
import sys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@
from typing import List, Optional, Union

import numpy as np
import torch
import torchaudio.compliance.kaldi as ta_kaldi

from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, logging
from ...utils import TensorType, is_speech_available, is_torch_available, logging


if is_speech_available():
import torchaudio.compliance.kaldi as ta_kaldi

if is_torch_available():
Comment on lines +29 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In speech-to-text we bundle these imports into one:

if is_speech_available():
    import torchaudio.compliance.kaldi as ta_kaldi
    import torch

Should we do the same here since we can only use torch if torchaudio is available?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, torch is also used here even when torchaudio isn't used. I can maybe refactor the code to change that, but I'm not sure it's worth the time, WDYT ?

import torch


logger = logging.get_logger(__name__)
Expand All @@ -37,8 +43,8 @@ class ASTFeatureExtractor(SequenceFeatureExtractor):
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
most of the main methods. Users should refer to this superclass for more information regarding those methods.

This class extracts mel-filter bank features from raw speech using TorchAudio, pads/truncates them to a fixed
length and normalizes them using a mean and standard deviation.
This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy
otherwise, pads/truncates them to a fixed length and normalizes them using a mean and standard deviation.

Args:
feature_size (`int`, *optional*, defaults to 1):
Expand Down Expand Up @@ -83,6 +89,21 @@ def __init__(
self.std = std
self.return_attention_mask = return_attention_mask

if not is_speech_available():
mel_filters = mel_filter_bank(
num_frequency_bins=256,
num_mel_filters=self.num_mel_bins,
min_frequency=20,
max_frequency=sampling_rate // 2,
sampling_rate=sampling_rate,
norm=None,
mel_scale="kaldi",
triangularize_in_mel_space=True,
)

self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
self.window = window_function(400, "hann", periodic=False)

def _extract_fbank_features(
self,
waveform: np.ndarray,
Expand All @@ -93,17 +114,32 @@ def _extract_fbank_features(
and hence the waveform should not be normalized before feature extraction.
"""
# waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
waveform = torch.from_numpy(waveform).unsqueeze(0)
fbank = ta_kaldi.fbank(
waveform,
htk_compat=True,
sample_frequency=self.sampling_rate,
use_energy=False,
window_type="hanning",
num_mel_bins=self.num_mel_bins,
dither=0.0,
frame_shift=10,
)
Comment on lines -97 to -106
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also took the opportunity to remove some unnecessary parameters here

if is_speech_available():
waveform = torch.from_numpy(waveform).unsqueeze(0)
fbank = ta_kaldi.fbank(
waveform,
sample_frequency=self.sampling_rate,
window_type="hanning",
num_mel_bins=self.num_mel_bins,
)
else:
waveform = np.squeeze(waveform)
fbank = spectrogram(
waveform,
self.window,
frame_length=400,
hop_length=160,
fft_length=512,
power=2.0,
center=False,
preemphasis=0.97,
mel_filters=self.mel_filters,
log_mel="log",
mel_floor=1.192092955078125e-07,
remove_dc_offset=True,
).T

fbank = torch.from_numpy(fbank)

n_frames = fbank.shape[0]
difference = max_length - n_frames
Expand Down
19 changes: 2 additions & 17 deletions src/transformers/models/speech_to_text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_speech_available,
is_tf_available,
is_torch_available,
)


_import_structure = {
"configuration_speech_to_text": ["SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "Speech2TextConfig"],
"feature_extraction_speech_to_text": ["Speech2TextFeatureExtractor"],
"processing_speech_to_text": ["Speech2TextProcessor"],
}

Expand All @@ -36,14 +36,6 @@
else:
_import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"]

try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_speech_to_text"] = ["Speech2TextFeatureExtractor"]

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -73,6 +65,7 @@

if TYPE_CHECKING:
from .configuration_speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig
from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor
from .processing_speech_to_text import Speech2TextProcessor

try:
Expand All @@ -83,14 +76,6 @@
else:
from .tokenization_speech_to_text import Speech2TextTokenizer

try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@
from typing import List, Optional, Union

import numpy as np
import torch
import torchaudio.compliance.kaldi as ta_kaldi

from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging
from ...utils import PaddingStrategy, TensorType, is_speech_available, logging


if is_speech_available():
import torch
import torchaudio.compliance.kaldi as ta_kaldi

logger = logging.get_logger(__name__)


Expand All @@ -37,8 +40,8 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
This feature extractor inherits from [`Speech2TextFeatureExtractor`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.

This class extracts mel-filter bank features from raw speech using TorchAudio and applies utterance-level cepstral
mean and variance normalization to the extracted features.
This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy
otherwise, and applies utterance-level cepstral mean and variance normalization to the extracted features.

Args:
feature_size (`int`, *optional*, defaults to 80):
Expand Down Expand Up @@ -77,6 +80,21 @@ def __init__(
self.normalize_vars = normalize_vars
self.return_attention_mask = True

if not is_speech_available():
mel_filters = mel_filter_bank(
num_frequency_bins=256,
num_mel_filters=self.num_mel_bins,
min_frequency=20,
max_frequency=sampling_rate // 2,
sampling_rate=sampling_rate,
norm=None,
mel_scale="kaldi",
triangularize_in_mel_space=True,
)

self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
self.window = window_function(400, "povey", periodic=False)

def _extract_fbank_features(
self,
waveform: np.ndarray,
Expand All @@ -86,9 +104,27 @@ def _extract_fbank_features(
and hence the waveform should not be normalized before feature extraction.
"""
waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
waveform = torch.from_numpy(waveform).unsqueeze(0)
features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate)
return features.numpy()
if is_speech_available():
waveform = torch.from_numpy(waveform).unsqueeze(0)
features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate)
features = features.numpy()
else:
waveform = np.squeeze(waveform)
features = spectrogram(
waveform,
self.window,
frame_length=400,
hop_length=160,
fft_length=512,
power=2.0,
center=False,
preemphasis=0.97,
mel_filters=self.mel_filters,
log_mel="log",
mel_floor=1.192092955078125e-07,
remove_dc_offset=True,
).T
return features

@staticmethod
def utterance_cmvn(
Expand Down
Loading
Loading