Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] 🚨 Fix whisper decoding 🚨 #34135

Merged
merged 75 commits into from
Dec 18, 2024

Conversation

eustlb
Copy link
Contributor

@eustlb eustlb commented Oct 13, 2024

What does this PR do?

This PR finalizes #30984 which enabled short-form (<=30sec) and long-form generation using temperature fallback. Indeed, the original OpenAI implementation uses the same decode_with_fallback for both short-form and long-form audio, while we used to have temperature fallback only for long-form audio.

It aims to solve issues and divergences with the original implementation:

  1. When decoding with timestamps, the Transformers implementation skips the last segment for short-form audio while the OpenAI implementation does not: it will use the sliding window strategy and go to the last generated timestamp (so indeed no differences between inferring short-form and long-form audio). This was not detected in the tests since test_tiny_timestamp_generation and test_large_timestamp_generation are incorrect.
  2. Miscalculation of the avg_logprobs that triggered temperature fallback when it should not.
  3. When decoding short-form / long-form audio, we returned:
    1. short-formdecoder_input_ids + predicted tokens (including the eos token)
    2. long-form only the predicted tokens (without eos token)
      Since short-form and long-form generation are now merged, we need a consistent way of returning outputs (at least for the tokens, we still need to differentiate for past_key_values see here). To be consistent with generate convention and since the tokenizer has the skip_special_tokens argument, I went for option 1.

🚨 Important changes 🚨

➡️ Previously:
• Short-form: Returned a ModelOutput or torch.LongTensor, including decoder input IDs and the EOS token ID.
• Long-form: Returned a Dict or torch.LongTensor, excluding decoder input IDs and the EOS token ID.

➡️ From now on:
Short-form and long-form generation are now treated identically, meaning output differentiation based on these modes is no longer applicable.

Decoder input IDs and EOS token IDs are never returned, except in two specific cases: when return_dict_in_generate=True and (return_timestamps=False or force_unique_generate_call=True).

In this case, the output will be a ModelOutput, which is the result of the underlying call to GenerationMixin’s generate. Indeed, return_timestamps=False ensures no seeking occurs; only a single call to generate is made. Therefore, this output includes both decoder input IDs and the EOS token ID.

Testing

Note

This PR reconciles our implementation with OpenAI's. It should therefore be tested with the updated and verified tests introduced in #34111

Evaluations

Let’s verify the effectiveness of this fix. We will evaluate both accuracy and inference speed. To do this, we will test on four short-form test sets and two long-form test sets (effectively filtering samples as ≤30 sec and >30 sec). For the short-form sets, we will compare results with the current main branch of Transformers on the first 100 samples. (Indeed, Issue #2, which involves miscalculation of the avg_logprobs, causes the full test set to be too slow.)

⚠️ Moreover, it was also tested with this PR merged.

As explained in this PR, we need to use a simple Whisper fork to ensure consistency with the same input features.

Important

TL;DR: We achieve perfect 1-to-1 matching in prediction results with greedy decoding for both short-form and long-form samples, validating that our implementation matches OpenAI’s original decoding algorithm. For short-form, this PR is ~5x faster than the current implementation (due to multiple incorrect temperature fallbacks).❗

Results 📊

→ short-form (first 100 samples)

Set 1: edinburghcstr/ami, config "ihm", split test[:100]
Set 2: distil-whisper/chime4, config "1-channel", split test[:100]
Set 3: google/fleurs, config "en_us", split test[:100]

Wandb results here!

set 1 - WER set 1 - RTFx set 1 - WER set 1 - RTFx set 1 - WER set 1 - RTFx
whisper-orig (script) 17.35 6.55 4.62 12.21 3.78 14.61
whisper-orig-fork (script) 19.4 6.45 4.55 12 3.78 14.12
this PR (script) 19.4 3.49 4.55 8.05 3.78 10.59
currrent main (script) 25.8 0.73 5.67 1.49 7.09 1.9

→ short-form (full test sets)

Set 1: edinburghcstr/ami, config "ihm", split test
Set 2: distil-whisper/chime4, config "1-channel", split test
Set 3: google/fleurs, config "en_us", split test

Wandb results here!

Set 1 - WER Set 1 - RTFx Set 2 - WER Set 2 - RTFx Set 3 - WER Set 3 - RTFx
whisper-orig (script) 16.22 7.86 10.74 11.45 4.21 13.79
whisper-orig-fork (script) 16.36 7.91 11.08 11.51 4.14 13.90
this PR (script) 16.36 4.18 11.08 7.51 4.14 10.67

→ long-form

Set 1: distil-whisper/tedlium-long-form, config "default", split test
Set 2: distil-whisper/meanwhile, config "default", split test

Wandb results here!

Set 1 - WER Set 1 - RTFx Set 2 - WER Set 2 - RTFx
whisper-orig (script) 172.15 4.98 264.66 4.09
whisper-orig-fork (script) 172.15 4.83 264.66 4.00
this PR (script) 172.15 4.88 264.66 3.98
current main (script) 172.15 4.92 264.66 3.98

@eustlb eustlb mentioned this pull request Oct 14, 2024
4 tasks
@eustlb
Copy link
Contributor Author

eustlb commented Oct 14, 2024

Even if we usually return the decoder_input_ids and eos_token with generate in Transformers, and this is why we have a skip_special_tokens=True option in the tokenizer, I think it is better here not to return context tokens. Indeed, this would imply a lot of ambiguities: since we actually have multiple calls to generate when doing sequential decoding, why would we include those tokens only for the first segment of the concatenated sequence of tokens and the last (for the eos token)? This would let users think that generation was indeed in one shot and that tokens are indeed the concatenated ones. Likewise, it would be even worse for the returned segments: some will have the context tokens, and some won't, depending on if the segment is the first of a new call to generate.

For these reasons, I think it is better to stick with OpenAI's choice: return only the generated tokens. Moreover, this is the way it is currently implemented in Transformers : long-form generation does not return context tokens. As a consequence, this also comes with the advantage of requiring fewer changes in the current codebase, reducing the potential for mistakes. WDYT @ylacombe? Also pinging @ArthurZucker here since you've worked on Whisper integration.

@eustlb
Copy link
Contributor Author

eustlb commented Oct 18, 2024

Correction:
After discussion, it has been decided to rather go for solution 2: return only the generated tokens and skip all the special tokens.
Pros: fewer changes, no ambiguities.
Cons: need to overwrite generic tests of GenerationMixin

@ArthurZucker
Copy link
Collaborator

Sounds good!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@eustlb eustlb changed the title [WIP] [Whisper] Fix whisper decoding [Whisper] Fix whisper decoding Nov 22, 2024
@eustlb
Copy link
Contributor Author

eustlb commented Dec 11, 2024

Thanks a lot @ylacombe for the review. I updated it based on your comments (see the updated PR comment) !

Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @eustlb for iterating! LGTM now!

src/transformers/models/whisper/generation_whisper.py Outdated Show resolved Hide resolved
src/transformers/models/whisper/generation_whisper.py Outdated Show resolved Hide resolved
src/transformers/models/whisper/generation_whisper.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THanks for removing the wrapper!

src/transformers/models/whisper/generation_whisper.py Outdated Show resolved Hide resolved
@eustlb
Copy link
Contributor Author

eustlb commented Dec 18, 2024

The four failing slow tests are expected! Merging 🤗

@eustlb eustlb merged commit da334bc into huggingface:main Dec 18, 2024
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants