Skip to content

Commit

Permalink
[Flax Examples] Seq2Seq ASR Fine-Tuning Script (#21764)
Browse files Browse the repository at this point in the history
* from seq2seq speech

* [Flax] Example script for speech seq2seq

* tests and fixes

* make style

* fix: label padding tokens

* fix: label padding tokens over list

* update ln names for Whisper

* try datasets iter loader

* create readme and append results

* style

* make style

* adjust lr

* use pt dataloader

* make fast

* pin gen max len

* finish

* add pt to requirements for test

* fix pt -> torch

* add accelerate
  • Loading branch information
sanchit-gandhi authored Sep 29, 2023
1 parent 3911774 commit 68e85fc
Show file tree
Hide file tree
Showing 5 changed files with 967 additions and 1 deletion.
4 changes: 3 additions & 1 deletion examples/flax/_tests_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ nltk
rouge-score
seqeval
tensorboard
evaluate >= 0.2.0
evaluate >= 0.2.0
torch
accelerate
68 changes: 68 additions & 0 deletions examples/flax/speech-recognition/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
<!---
Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

# Automatic Speech Recognition - Flax Examples

## Sequence to Sequence

The script [`run_flax_speech_recognition_seq2seq.py`](https://github.com/huggingface/transformers/blob/main/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py)
can be used to fine-tune any [Flax Speech Sequence-to-Sequence Model](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.FlaxAutoModelForSpeechSeq2Seq)
for automatic speech recognition on one of the [official speech recognition datasets](https://huggingface.co/datasets?task_ids=task_ids:automatic-speech-recognition)
or a custom dataset. This includes the Whisper model from OpenAI, or a warm-started Speech-Encoder-Decoder Model,
an example for which is included below.

### Whisper Model

We can load all components of the Whisper model directly from the pretrained checkpoint, including the pretrained model
weights, feature extractor and tokenizer. We simply have to specify the id of fine-tuning dataset and the necessary
training hyperparameters.

The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint
on the Hindi subset of the [Common Voice 13](https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0) dataset.
Note that before running this script you must accept the dataset's [terms of use](https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0)
and register your Hugging Face Hub token on your device by running `huggingface-hub login`.

```bash
python run_flax_speech_recognition_seq2seq.py \
--model_name_or_path="openai/whisper-small" \
--dataset_name="mozilla-foundation/common_voice_13_0" \
--dataset_config_name="hi" \
--language="hindi" \
--train_split_name="train+validation" \
--eval_split_name="test" \
--output_dir="./whisper-small-hi-flax" \
--per_device_train_batch_size="16" \
--per_device_eval_batch_size="16" \
--num_train_epochs="10" \
--learning_rate="1e-4" \
--warmup_steps="500" \
--logging_steps="25" \
--generation_max_length="40" \
--preprocessing_num_workers="32" \
--dataloader_num_workers="32" \
--max_duration_in_seconds="30" \
--text_column_name="sentence" \
--overwrite_output_dir \
--do_train \
--do_eval \
--predict_with_generate \
--push_to_hub \
--use_auth_token
```

On a TPU v4-8, training should take approximately 25 minutes, with a final cross-entropy loss of 0.02 and word error
rate of **34%**. See the checkpoint [sanchit-gandhi/whisper-small-hi-flax](https://huggingface.co/sanchit-gandhi/whisper-small-hi-flax)
for an example training run.
8 changes: 8 additions & 0 deletions examples/flax/speech-recognition/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
datasets[audio]>=2.14.0
jax>=0.3.6
jaxlib>=0.3.6
flax>=0.4.1
optax>=0.0.8
torch>=1.9.0
jiwer
evaluate
Loading

0 comments on commit 68e85fc

Please sign in to comment.