-
Notifications
You must be signed in to change notification settings - Fork 27.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Flax Examples] Seq2Seq ASR Fine-Tuning Script (#21764)
* from seq2seq speech * [Flax] Example script for speech seq2seq * tests and fixes * make style * fix: label padding tokens * fix: label padding tokens over list * update ln names for Whisper * try datasets iter loader * create readme and append results * style * make style * adjust lr * use pt dataloader * make fast * pin gen max len * finish * add pt to requirements for test * fix pt -> torch * add accelerate
- Loading branch information
1 parent
3911774
commit 68e85fc
Showing
5 changed files
with
967 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,4 +5,6 @@ nltk | |
rouge-score | ||
seqeval | ||
tensorboard | ||
evaluate >= 0.2.0 | ||
evaluate >= 0.2.0 | ||
torch | ||
accelerate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
<!--- | ||
Copyright 2023 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
--> | ||
|
||
# Automatic Speech Recognition - Flax Examples | ||
|
||
## Sequence to Sequence | ||
|
||
The script [`run_flax_speech_recognition_seq2seq.py`](https://github.com/huggingface/transformers/blob/main/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py) | ||
can be used to fine-tune any [Flax Speech Sequence-to-Sequence Model](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.FlaxAutoModelForSpeechSeq2Seq) | ||
for automatic speech recognition on one of the [official speech recognition datasets](https://huggingface.co/datasets?task_ids=task_ids:automatic-speech-recognition) | ||
or a custom dataset. This includes the Whisper model from OpenAI, or a warm-started Speech-Encoder-Decoder Model, | ||
an example for which is included below. | ||
|
||
### Whisper Model | ||
|
||
We can load all components of the Whisper model directly from the pretrained checkpoint, including the pretrained model | ||
weights, feature extractor and tokenizer. We simply have to specify the id of fine-tuning dataset and the necessary | ||
training hyperparameters. | ||
|
||
The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint | ||
on the Hindi subset of the [Common Voice 13](https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0) dataset. | ||
Note that before running this script you must accept the dataset's [terms of use](https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0) | ||
and register your Hugging Face Hub token on your device by running `huggingface-hub login`. | ||
|
||
```bash | ||
python run_flax_speech_recognition_seq2seq.py \ | ||
--model_name_or_path="openai/whisper-small" \ | ||
--dataset_name="mozilla-foundation/common_voice_13_0" \ | ||
--dataset_config_name="hi" \ | ||
--language="hindi" \ | ||
--train_split_name="train+validation" \ | ||
--eval_split_name="test" \ | ||
--output_dir="./whisper-small-hi-flax" \ | ||
--per_device_train_batch_size="16" \ | ||
--per_device_eval_batch_size="16" \ | ||
--num_train_epochs="10" \ | ||
--learning_rate="1e-4" \ | ||
--warmup_steps="500" \ | ||
--logging_steps="25" \ | ||
--generation_max_length="40" \ | ||
--preprocessing_num_workers="32" \ | ||
--dataloader_num_workers="32" \ | ||
--max_duration_in_seconds="30" \ | ||
--text_column_name="sentence" \ | ||
--overwrite_output_dir \ | ||
--do_train \ | ||
--do_eval \ | ||
--predict_with_generate \ | ||
--push_to_hub \ | ||
--use_auth_token | ||
``` | ||
|
||
On a TPU v4-8, training should take approximately 25 minutes, with a final cross-entropy loss of 0.02 and word error | ||
rate of **34%**. See the checkpoint [sanchit-gandhi/whisper-small-hi-flax](https://huggingface.co/sanchit-gandhi/whisper-small-hi-flax) | ||
for an example training run. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
datasets[audio]>=2.14.0 | ||
jax>=0.3.6 | ||
jaxlib>=0.3.6 | ||
flax>=0.4.1 | ||
optax>=0.0.8 | ||
torch>=1.9.0 | ||
jiwer | ||
evaluate |
Oops, something went wrong.