Skip to content

Commit

Permalink
Add support for fleurs, librispeech and custom datasets (#218)
Browse files Browse the repository at this point in the history
* update: add tokenized_sentence initialization for transcript field

* uncomment out install_requires

* fix tokenized_sentence

* add BitsAndBytesConfig, model configurations

* add lang code mapping to compatible whisper lang codes

* fix fleurs_map import statement

* fix fleurs_map import statement

* refactor fleurs_to_whisper to whisper_model_prep module

* add transcription as possible key

* comment out install_requires

* update documentation

* add function to check and convert language code
  • Loading branch information
KevKibe authored Nov 20, 2024
1 parent 81e627c commit 9c37011
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 19 deletions.
5 changes: 3 additions & 2 deletions DOCS/gettingstarted.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@
``` py
# Set the parameters (refer to the 'Usage on VM' section for more details)
huggingface_token = " " # make sure token has write permissions
dataset_name = "mozilla-foundation/common_voice_16_1"
language_abbr= [ ] # Example `["ti", "yi"]`. see abbreviations here https://huggingface.co/datasets/mozilla-foundation/common_voice_16_1.
dataset_name = "mozilla-foundation/common_voice_16_1" # Also supports "google/fleurs" and "facebook/multilingual_librispeech".
# For custom datasets, ensure the text key is one of the following: "sentence", "transcript", or "transcription".
language_abbr= [ ] # Example `["af"]`. see specific dataset for language code.
model_id= "model-id" # Example openai/whisper-small, openai/whisper-medium
processing_task= "translate" # translate or transcribe
wandb_api_key = " "
Expand Down
2 changes: 1 addition & 1 deletion DOCS/home.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

# Features

- 🔧 **Fine-Tuning**: Fine-tune the [Whisper](https://huggingface.co/collections/openai/whisper-release-6501bba2cf999715fd953013) model on any audio dataset from Huggingface, e.g., [Mozilla's](https://huggingface.co/mozilla-foundation) Common Voice datasets.
- 🔧 **Fine-Tuning**: Fine-tune the [Whisper](https://huggingface.co/collections/openai/whisper-release-6501bba2cf999715fd953013) model on any audio dataset from Huggingface, e.g., [Mozilla's](https://huggingface.co/mozilla-foundation) Common Voice, [Fleurs](https://huggingface.co/datasets/google/fleurs), [LibriSpeech](https://huggingface.co/datasets/facebook/multilingual_librispeech), or your own custom private/public dataset etc

- 📊 **Metrics Monitoring**: View training run metrics on [Wandb](https://wandb.ai/).

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

## Features

- 🔧 **Fine-Tuning**: Fine-tune the [Whisper](https://huggingface.co/collections/openai/whisper-release-6501bba2cf999715fd953013) model on any audio dataset from Huggingface, e.g., [Mozilla's](https://huggingface.co/mozilla-foundation) Common Voice datasets.
- 🔧 **Fine-Tuning**: Fine-tune the [Whisper](https://huggingface.co/collections/openai/whisper-release-6501bba2cf999715fd953013) model on any audio dataset from Huggingface, e.g., [Mozilla's](https://huggingface.co/mozilla-foundation) Common Voice, [Fleurs](https://huggingface.co/datasets/google/fleurs), [LibriSpeech](https://huggingface.co/datasets/facebook/multilingual_librispeech), or your own custom private/public dataset etc

- 📊 **Metrics Monitoring**: View training run metrics on [Wandb](https://wandb.ai/).

Expand Down
5 changes: 2 additions & 3 deletions src/training/audio_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ def resampled_dataset(self, sample: Dict[str, Any]) -> dict:
sample["audio"]["sampling_rate"] = 16000

audio_features = self.feature_extractor(resampled_audio, sampling_rate=16000).input_features[0]

tokenized_sentence = self.tokenizer(sample["sentence"]).input_ids

text = sample.get("transcription", sample.get("transcript", sample.get("sentence")))
tokenized_sentence = self.tokenizer(text).input_ids
sample["input_features"] = audio_features
sample["labels"] = tokenized_sentence

Expand Down
2 changes: 1 addition & 1 deletion src/training/wandb_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def dataset_to_records(self, dataset) -> pd.DataFrame:
audio_data = item['input_features']
audio_duration = len(audio_data) / 16000
record["audio_with_spec"] = wandb.Html(self.record_to_html(item))
record["sentence"] = item["sentence"]
record["sentence"] = item.get("sentence", item.get("transcript", item.get("transcription")))
record["length"] = audio_duration
records.append(record)
records = pd.DataFrame(records)
Expand Down
103 changes: 92 additions & 11 deletions src/training/whisper_model_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
WhisperProcessor,
WhisperTokenizerFast,
WhisperTokenizer,

BitsAndBytesConfig
)
import torch
warnings.filterwarnings("ignore")
Expand Down Expand Up @@ -98,17 +98,18 @@ def initialize_model(self) -> WhisperForConditionalGeneration:
WhisperForConditionalGeneration: The configured Whisper model ready for conditional generation tasks.
"""
processor = self.initialize_processor()
whisper_lang_code = convert_language_code(self.language)
if self.use_peft:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = WhisperForConditionalGeneration.from_pretrained(
self.model_id,
load_in_8bit=True,
quantization_config=quantization_config,
device_map="auto",
)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task=self.processing_task)
# model.config.suppress_tokens = []
model.config.use_cache = True
model.generation_config.language = self.language if self.processing_task == "transcribe" else "en"
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.use_cache = False
model.generation_config.language = whisper_lang_code if self.processing_task == "transcribe" else "en"
model.generation_config.task = self.processing_task
model = prepare_model_for_kbit_training(model)
config = LoraConfig(
Expand All @@ -126,14 +127,94 @@ def initialize_model(self) -> WhisperForConditionalGeneration:
self.model_id,
low_cpu_mem_usage = True
)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task=self.processing_task)
# model.config.suppress_tokens = []
model.config.use_cache = True
model.generation_config.language = self.language if self.processing_task == "transcribe" else "en"
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.use_cache = False
model.generation_config.language = whisper_lang_code if self.processing_task == "transcribe" else "en"
model.generation_config.task = self.processing_task
model = model.to("cuda") if torch.cuda.is_available() else model
return model

# Define the mapping between FLEURS language codes and Whisper language tokens
fleurs_to_whisper = {
"af_za": "af", # Afrikaans
"am_et": "am", # Amharic
"ar_eg": "ar", # Arabic
"as_in": "as", # Assamese
"az_az": "az", # Azerbaijani
"be_by": "be", # Belarusian
"bg_bg": "bg", # Bulgarian
"bn_in": "bn", # Bengali
"bs_ba": "bs", # Bosnian
"ca_es": "ca", # Catalan
"cs_cz": "cs", # Czech
"cy_gb": "cy", # Welsh
"da_dk": "da", # Danish
"de_de": "de", # German
"el_gr": "el", # Greek
"en_us": "en", # English
"es_es": "es", # Spanish
"et_ee": "et", # Estonian
"fa_ir": "fa", # Persian
"fi_fi": "fi", # Finnish
"fr_fr": "fr", # French
"ga_ie": "ga", # Irish
"gl_es": "gl", # Galician
"gu_in": "gu", # Gujarati
"he_il": "he", # Hebrew
"hi_in": "hi", # Hindi
"hr_hr": "hr", # Croatian
"hu_hu": "hu", # Hungarian
"hy_am": "hy", # Armenian
"id_id": "id", # Indonesian
"is_is": "is", # Icelandic
"it_it": "it", # Italian
"ja_jp": "ja", # Japanese
"jv_id": "jv", # Javanese
"ka_ge": "ka", # Georgian
"kk_kz": "kk", # Kazakh
"km_kh": "km", # Khmer
"kn_in": "kn", # Kannada
"ko_kr": "ko", # Korean
"lo_la": "lo", # Lao
"lt_lt": "lt", # Lithuanian
"lv_lv": "lv", # Latvian
"mk_mk": "mk", # Macedonian
"ml_in": "ml", # Malayalam
"mn_mn": "mn", # Mongolian
"mr_in": "mr", # Marathi
"ms_my": "ms", # Malay
"my_mm": "my", # Burmese
"ne_np": "ne", # Nepali
"nl_nl": "nl", # Dutch
"no_no": "no", # Norwegian
"or_in": "or", # Odia
"pa_in": "pa", # Punjabi
"pl_pl": "pl", # Polish
"pt_br": "pt", # Portuguese
"ro_ro": "ro", # Romanian
"ru_ru": "ru", # Russian
"si_lk": "si", # Sinhala
"sk_sk": "sk", # Slovak
"sl_si": "sl", # Slovenian
"sq_al": "sq", # Albanian
"sr_rs": "sr", # Serbian
"sv_se": "sv", # Swedish
"sw_ke": "sw", # Swahili
"ta_in": "ta", # Tamil
"te_in": "te", # Telugu
"th_th": "th", # Thai
"tl_ph": "tl", # Filipino
"tr_tr": "tr", # Turkish
"uk_ua": "uk", # Ukrainian
"ur_pk": "ur", # Urdu
"vi_vn": "vi", # Vietnamese
"zh_cn": "zh", # Chinese
}
def convert_language_code(language_code):
if language_code in fleurs_to_whisper.values():
return language_code
return fleurs_to_whisper.get(language_code, None)

#################################

Expand Down

0 comments on commit 9c37011

Please sign in to comment.