-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Whisper] Fix whisper integration tests #34111
[Whisper] Fix whisper integration tests #34111
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @eustlb, let me know when you need a proper review.
In the meantime, I've left some formatting comments and a question: These are the expected results when generating with OpenAI code, right? Are these the results when doing padding up until 30s (like we do) or adding a 30s zero padded audio (as OpenAI does)?
gen_kwargs = { | ||
"return_timestamps": True, | ||
"no_speech_threshold": 0.6, # necessary to trigger no speech detection | ||
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), | ||
"compression_ratio_threshold": 1.35, | ||
"condition_on_prev_tokens": False, | ||
"logprob_threshold": -2.0, # necessary to avoid triggering temp fallback that will introduce randomness since we are comparing to openai EXTECTED_TEXT | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of these are already used by default, let's remove them to improve readability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the ones with numerical values here are not set by default (not in generate and not also in the model's generation_config.json). Concerning return_timestamps
and condition_on_prev
, I find it clearer to have them explicitly mentioned.
Indeed, these are the expected results generated with OAI code inferred on mel input features extracted through In Transformers, we have two inputs possibilities for Whisper:
case 1 To understand better, for an audio of 10secs: case 2 About the implementation in the simple whisper fork: With inferring OAI directly on the mel spectrogram (so either of exactly 3000 frames, either on more than 3000 frames), we ensure that each pass of the forward of OAI whisper and our gets the exact same mel spectrogram. This ensures that the expected result we have in the test are indeed results that should be expected given the same input mel with OAI implementation. Note: For tests that required batched inference which is not supported by OAI implementation, I simply run it sequentially to get the outputs |
Co-authored-by: Yoach Lacombe <[email protected]>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
As mentioned here, we won't return special tokens anymore with generate for Whisper. Let's adapt the tests a bit for that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR @eustlb, this work, although a bit time consuming, should have been done a long time ago! Congratulations on doing it so thoroughly.
Also, thanks for also providing code to reproduce every results. It'll greatly help future efforts on Whisper.
Most of the comments I've made are tiny, the PR looks great to me!
I think we might want to merge this at the same time or a bit before your other PR, but in the meantime, cc @ArthurZucker and @LysandreJik for a final review
@@ -2866,6 +2833,7 @@ def test_whisper_longform_single_batch_beam(self): | |||
"compression_ratio_threshold": 1.35, | |||
"condition_on_prev_tokens": True, | |||
"logprob_threshold": -1.0, | |||
"renormalize_logits": True, # necessary to match OAI beam search implementation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Let's not forget to mention this somewhere in the docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am thinking about setting it by default to True in Whisper's generate (and adding it to the doc this way) and remove it from here. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would make more sense to leave it like that here, and change it (setting it by default to True in Whisper's generate-and adding it to the doc this way) in this PR, since those changes are intended for it anyway.
self.assertListEqual(generated_ids.tolist()[0], generated_ids_forced.tolist()[0]) | ||
|
||
@slow | ||
def test_generate_with_prompt_ids_task_language(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great test here! 🔥
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the meticulous work!
Let's update to main and merge this just a tiny bit before the other PRs. Also, let's run the slow Whisper test on this PR, so that we can verify that your two other PRs are fixing these new tests
40b1883
to
6d9b762
Compare
6d9b762
to
5217b49
Compare
5217b49
to
abe4372
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this PR. The whisper
codebase evolved quite a lot since the first release and it's nice to freshen things up!
Very nice that you have a full reproducing recipe, this is something that was quite lacking from me, thanks for improving our port! 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the tests fails, we get no help, using torch.testing.assert_close
will give how close / far we are!
* fix test_tiny_timestamp_generation * fix test_large_timestamp_generation * fix test_whisper_shortform_single_batch_prev_cond * fix test_whisper_shortform_multi_batch_hard_prev_cond * return_timestamps necessary with long form * fix test_default_multilingual_transcription_long_form * fix test_tiny_token_timestamp_generation_longform * fix test_whisper_longform_multi_batch_hard * Update tests/models/whisper/test_modeling_whisper.py Co-authored-by: Yoach Lacombe <[email protected]> * fix typo * do not expect special tokens * fix test_whisper_longform_single_batch_beam * fix test_whisper_longform_multi_batch_hard_prev_cond * update test_whisper_longform_multi_batch_hard_prev_cond * update test_whisper_longform_multi_batch_hard_prev_cond * these tests does not make sense anymore * this test does not make sense anymore * make fixup * suggested nits * add test with forced_decoder_ids * this test does not make sense anymore * change assert for unittest test cases * make fixup * test with prompt_ids and task and language * fix unittest test case call * fix test_tiny_generation * fix test_tiny_en_generation * fix test_tiny_en_batched_generation * fix test_tiny_longform_timestamps_generation * fix test_tiny_timestamp_generation * fix test_large_generation * fix test_large_batched_generation * fix test_large_generation_multilingual * fix test_large_timestamp_generation * fix test_large_timestamp_generation * fix test_tiny_token_timestamp_generation_longform * fix test_tiny_en_batched_generation * make fixup * [run-slow] whisper --------- Co-authored-by: Yoach Lacombe <[email protected]>
* fix test_tiny_timestamp_generation * fix test_large_timestamp_generation * fix test_whisper_shortform_single_batch_prev_cond * fix test_whisper_shortform_multi_batch_hard_prev_cond * return_timestamps necessary with long form * fix test_default_multilingual_transcription_long_form * fix test_tiny_token_timestamp_generation_longform * fix test_whisper_longform_multi_batch_hard * Update tests/models/whisper/test_modeling_whisper.py Co-authored-by: Yoach Lacombe <[email protected]> * fix typo * do not expect special tokens * fix test_whisper_longform_single_batch_beam * fix test_whisper_longform_multi_batch_hard_prev_cond * update test_whisper_longform_multi_batch_hard_prev_cond * update test_whisper_longform_multi_batch_hard_prev_cond * these tests does not make sense anymore * this test does not make sense anymore * make fixup * suggested nits * add test with forced_decoder_ids * this test does not make sense anymore * change assert for unittest test cases * make fixup * test with prompt_ids and task and language * fix unittest test case call * fix test_tiny_generation * fix test_tiny_en_generation * fix test_tiny_en_batched_generation * fix test_tiny_longform_timestamps_generation * fix test_tiny_timestamp_generation * fix test_large_generation * fix test_large_batched_generation * fix test_large_generation_multilingual * fix test_large_timestamp_generation * fix test_large_timestamp_generation * fix test_tiny_token_timestamp_generation_longform * fix test_tiny_en_batched_generation * make fixup * [run-slow] whisper --------- Co-authored-by: Yoach Lacombe <[email protected]>
What does this PR do?
This PR fixes multiple errors in Whisper integration tests and expected outputs.
To compute the correct excepted outputs, it is necessary to work from a very simple fork of the original OpenAI Whisper implementation. Indeed, the extraction of the mel spectrogram in
WhisperFeatureExtractor
diverges slightly from OpenAI's one: we pad the audio array to 30sec/ to longest with 0.0s and then extract our spectrogram through batched STFT while OpenAI's one will add 30sec of 0.0s to the audio array (and not pad to 30sec). This way, the are sure that model inputs for our and OpenAI's implementations are exactly the same.With this, we can use the following protocol to compute the expected outputs for the tests:
Important
Code to reproduce the outputs for each of the verified tests can be found here.
Edit: some more details about why we work from a whisper fork
In Transformers, we have two inputs possibilities for Whisper:
case 1
With an audio <=30sec, the difference between our implementation and OAI is that we first pad to 30sec with 0.0s, then extract features and this will be the input to the model's forward, while OAI pads audio with adding 30sec 0.0s, extract features, slice the exact number of frames and then pads the mel spectrogram to 3000 frames with 0.0s.
To understand better, for an audio of 10secs:
Transformers: audio + 20sec of 0.0s → mel spectrogram of shape [80, 3000] where [2000:] frames are close but not exactly 0.0s
OAI: audio + 30sec of 0.0s → mel spectrogram of shape [80, 4000] → sliced to the duration of the audio (so until frame 1000) and then padded with 0.0s: [2000:] frames are exactly 0.0s.
case 2
No differences (other than numerical difference due to STFT implementation).
About the implementation in the simple whisper fork:
We just take the mel spectrogram and concat with 3000 frames of 0.0s. This emulates the 30sec of 0.0s added originally.
For case 1, the duration considered by OAI is 30sec (see this line) and therefore the audio segment that will be given to the forward is the exact mel input that was given.
For case 2, likewise the duration considered is the one of the given mel input.
With inferring OAI directly on the mel spectrogram (so either of exactly 3000 frames, either on more than 3000 frames), we ensure that each pass of the forward of OAI whisper and our gets the exact same mel spectrogram. This ensures that the expected result we have in the test are indeed results that should be expected given the same input mel with OAI implementation.
Note
For tests that required batched inference which is not supported by OAI implementation, I simply run it sequentially to get the outputs
TODO
Tests to be verified and eventually corrected
✅ for a correct test
❌ for an incorrect one
"all-segments"
has no equivalent in OAI implem