Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support whisper long-form generation #469

Merged
merged 8 commits into from
Aug 14, 2024
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 82 additions & 18 deletions comps/asr/whisper/whisper_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class WhisperModel:
"""Convert audio to text."""

def __init__(self, model_name_or_path="openai/whisper-small", language="english", device="cpu"):
def __init__(self, model_name_or_path="openai/whisper-small", language="english", device="cpu", hpu_max_len=8192):
if device == "hpu":
# Explicitly link HPU with Torch
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
Expand All @@ -31,12 +31,11 @@ def __init__(self, model_name_or_path="openai/whisper-small", language="english"
self.model.eval()

self.language = language
self.hpu_max_len = hpu_max_len

if device == "hpu":
# do hpu graph warmup with a long enough input audio
# whisper has a receptive field of 30 seconds
# here we select a relatively long audio (~15 sec) to quickly warmup
self._warmup_whisper_hpu_graph("https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav")
self._warmup_whisper_hpu_graph("https://github.com/Spycsh/assets/raw/main/ljspeech_60s_audio.wav")
self._warmup_whisper_hpu_graph("https://github.com/Spycsh/assets/raw/main/ljspeech_30s_audio.wav")

def _audiosegment_to_librosawav(self, audiosegment):
# https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentget_array_of_samples
Expand All @@ -59,11 +58,43 @@ def _warmup_whisper_hpu_graph(self, url):
print("[ASR] warmup...")
waveform = AudioSegment.from_file("warmup.wav").set_frame_rate(16000)
waveform = self._audiosegment_to_librosawav(waveform)
# pylint: disable=E1101
inputs = self.processor.feature_extractor(
waveform, return_tensors="pt", sampling_rate=16_000
).input_features.to(self.device)
_ = self.model.generate(inputs, language="chinese")

processed_inputs = self.processor(
waveform,
return_tensors="pt",
truncation=False,
padding="longest",
return_attention_mask=True,
sampling_rate=16000,
)

if processed_inputs.input_features.shape[-1] < 3000:
# short-form
processed_inputs = self.processor(
waveform,
return_tensors="pt",
sampling_rate=16000,
)
else:
processed_inputs["input_features"] = torch.nn.functional.pad(
processed_inputs.input_features,
(0, self.hpu_max_len - processed_inputs.input_features.size(-1)),
value=-1.5,
)
processed_inputs["attention_mask"] = torch.nn.functional.pad(
processed_inputs.attention_mask,
(0, self.hpu_max_len + 1 - processed_inputs.attention_mask.size(-1)),
value=0,
)

_ = self.model.generate(
**(
processed_inputs.to(
self.device,
)
),
language=self.language,
)

def audio2text(self, audio_path):
"""Convert audio to text.
Expand All @@ -80,11 +111,41 @@ def audio2text(self, audio_path):
audio_dataset = Dataset.from_dict({"audio": [audio_path]}).cast_column("audio", Audio(sampling_rate=16000))
waveform = audio_dataset[0]["audio"]["array"]

# pylint: disable=E1101
inputs = self.processor.feature_extractor(
waveform, return_tensors="pt", sampling_rate=16_000
).input_features.to(self.device)
predicted_ids = self.model.generate(inputs, language=self.language)
processed_inputs = self.processor(
waveform,
return_tensors="pt",
truncation=False,
padding="longest",
return_attention_mask=True,
sampling_rate=16000,
)
if processed_inputs.input_features.shape[-1] < 3000:
# short-form
processed_inputs = self.processor(
waveform,
return_tensors="pt",
sampling_rate=16000,
)
else:
processed_inputs["input_features"] = torch.nn.functional.pad(
processed_inputs.input_features,
(0, self.hpu_max_len - processed_inputs.input_features.size(-1)),
value=-1.5,
)
processed_inputs["attention_mask"] = torch.nn.functional.pad(
processed_inputs.attention_mask,
(0, self.hpu_max_len + 1 - processed_inputs.attention_mask.size(-1)),
value=0,
)

predicted_ids = self.model.generate(
**(
processed_inputs.to(
self.device,
)
),
language=self.language,
)
# pylint: disable=E1101
result = self.processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
if self.language in ["chinese", "mandarin"]:
Expand All @@ -96,20 +157,23 @@ def audio2text(self, audio_path):


if __name__ == "__main__":
asr = WhisperModel(language="english")
asr = WhisperModel(model_name_or_path="openai/whisper-small", language="english", device="cpu")

# Test multilanguage asr
asr.language = "chinese"
urllib.request.urlretrieve(
"https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav",
"sample.wav",
)
asr.language = "chinese"
text = asr.audio2text("sample.wav")

asr.language = "english"
urllib.request.urlretrieve(
"https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav",
"sample.wav",
)
text = asr.audio2text("sample.wav")

os.remove("sample.wav")
for i in [5, 10, 30, 60]:
urllib.request.urlretrieve(f"https://github.com/Spycsh/assets/raw/main/ljspeech_{i}s_audio.wav", "sample.wav")
text = asr.audio2text("sample.wav")
Loading