From ccbfe763c1b9ebd89e0c57c744b74442280732da Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Sun, 27 Oct 2024 15:52:38 +0900 Subject: [PATCH 01/41] Rename file --- app.py | 2 +- modules/translation/translation_base.py | 2 +- modules/whisper/{whisper_parameter.py => data_classes.py} | 0 modules/whisper/faster_whisper_inference.py | 2 +- modules/whisper/insanely_fast_whisper_inference.py | 2 +- modules/whisper/whisper_Inference.py | 2 +- modules/whisper/whisper_base.py | 2 +- tests/test_bgm_separation.py | 2 +- tests/test_diarization.py | 2 +- tests/test_transcription.py | 2 +- tests/test_vad.py | 2 +- 11 files changed, 10 insertions(+), 10 deletions(-) rename modules/whisper/{whisper_parameter.py => data_classes.py} (100%) diff --git a/app.py b/app.py index ea094c50..3fb9b22f 100644 --- a/app.py +++ b/app.py @@ -17,7 +17,7 @@ from modules.utils.cli_manager import str2bool from modules.utils.youtube_manager import get_ytmetas from modules.translation.deepl_api import DeepLAPI -from modules.whisper.whisper_parameter import * +from modules.whisper.data_classes import * class App: diff --git a/modules/translation/translation_base.py b/modules/translation/translation_base.py index abc7f44c..3dbf024e 100644 --- a/modules/translation/translation_base.py +++ b/modules/translation/translation_base.py @@ -6,7 +6,7 @@ from datetime import datetime import modules.translation.nllb_inference as nllb -from modules.whisper.whisper_parameter import * +from modules.whisper.data_classes import * from modules.utils.subtitle_manager import * from modules.utils.files_manager import load_yaml, save_yaml from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR diff --git a/modules/whisper/whisper_parameter.py b/modules/whisper/data_classes.py similarity index 100% rename from modules/whisper/whisper_parameter.py rename to modules/whisper/data_classes.py diff --git a/modules/whisper/faster_whisper_inference.py b/modules/whisper/faster_whisper_inference.py index f12fc01a..f9edbc93 100644 --- a/modules/whisper/faster_whisper_inference.py +++ b/modules/whisper/faster_whisper_inference.py @@ -12,7 +12,7 @@ from argparse import Namespace from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR) -from modules.whisper.whisper_parameter import * +from modules.whisper.data_classes import * from modules.whisper.whisper_base import WhisperBase diff --git a/modules/whisper/insanely_fast_whisper_inference.py b/modules/whisper/insanely_fast_whisper_inference.py index fe6f4fdb..11f94bdf 100644 --- a/modules/whisper/insanely_fast_whisper_inference.py +++ b/modules/whisper/insanely_fast_whisper_inference.py @@ -12,7 +12,7 @@ from argparse import Namespace from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR) -from modules.whisper.whisper_parameter import * +from modules.whisper.data_classes import * from modules.whisper.whisper_base import WhisperBase diff --git a/modules/whisper/whisper_Inference.py b/modules/whisper/whisper_Inference.py index f87fbe5d..16ec9645 100644 --- a/modules/whisper/whisper_Inference.py +++ b/modules/whisper/whisper_Inference.py @@ -9,7 +9,7 @@ from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR) from modules.whisper.whisper_base import WhisperBase -from modules.whisper.whisper_parameter import * +from modules.whisper.data_classes import * class WhisperInference(WhisperBase): diff --git a/modules/whisper/whisper_base.py b/modules/whisper/whisper_base.py index 51c87ddf..c8c08dbb 100644 --- a/modules/whisper/whisper_base.py +++ b/modules/whisper/whisper_base.py @@ -18,7 +18,7 @@ from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename from modules.utils.youtube_manager import get_ytdata, get_ytaudio from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml -from modules.whisper.whisper_parameter import * +from modules.whisper.data_classes import * from modules.diarize.diarizer import Diarizer from modules.vad.silero_vad import SileroVAD diff --git a/tests/test_bgm_separation.py b/tests/test_bgm_separation.py index cc4a6f80..a8178ea6 100644 --- a/tests/test_bgm_separation.py +++ b/tests/test_bgm_separation.py @@ -1,6 +1,6 @@ from modules.utils.paths import * from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.whisper_parameter import WhisperValues +from modules.whisper.data_classes import WhisperValues from test_config import * from test_transcription import download_file, test_transcribe diff --git a/tests/test_diarization.py b/tests/test_diarization.py index 54e7244e..2a4c77af 100644 --- a/tests/test_diarization.py +++ b/tests/test_diarization.py @@ -1,6 +1,6 @@ from modules.utils.paths import * from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.whisper_parameter import WhisperValues +from modules.whisper.data_classes import WhisperValues from test_config import * from test_transcription import download_file, test_transcribe diff --git a/tests/test_transcription.py b/tests/test_transcription.py index 4b5ab98f..1e1560e7 100644 --- a/tests/test_transcription.py +++ b/tests/test_transcription.py @@ -1,5 +1,5 @@ from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.whisper_parameter import WhisperValues +from modules.whisper.data_classes import WhisperValues from modules.utils.paths import WEBUI_DIR from test_config import * diff --git a/tests/test_vad.py b/tests/test_vad.py index 124a043d..d2a30df6 100644 --- a/tests/test_vad.py +++ b/tests/test_vad.py @@ -1,6 +1,6 @@ from modules.utils.paths import * from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.whisper_parameter import WhisperValues +from modules.whisper.data_classes import WhisperValues from test_config import * from test_transcription import download_file, test_transcribe From 5a3afa2820a82d0dcbaaef247639b66a26a356e1 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Sun, 27 Oct 2024 16:01:52 +0900 Subject: [PATCH 02/41] Rename model --- app.py | 2 +- modules/whisper/faster_whisper_inference.py | 2 +- modules/whisper/insanely_fast_whisper_inference.py | 2 +- modules/whisper/whisper_Inference.py | 2 +- modules/whisper/whisper_base.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/app.py b/app.py index 3fb9b22f..d9985ba1 100644 --- a/app.py +++ b/app.py @@ -196,7 +196,7 @@ def create_whisper_parameters(self): dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate]) return ( - WhisperParameters( + WhisperGradioComponents( model_size=dd_model, lang=dd_lang, is_translate=cb_translate, beam_size=nb_beam_size, log_prob_threshold=nb_log_prob_threshold, no_speech_threshold=nb_no_speech_threshold, compute_type=dd_compute_type, best_of=nb_best_of, patience=nb_patience, diff --git a/modules/whisper/faster_whisper_inference.py b/modules/whisper/faster_whisper_inference.py index f9edbc93..d3f52c9d 100644 --- a/modules/whisper/faster_whisper_inference.py +++ b/modules/whisper/faster_whisper_inference.py @@ -62,7 +62,7 @@ def transcribe(self, """ start_time = time.time() - params = WhisperParameters.as_value(*whisper_params) + params = WhisperGradioComponents.as_value(*whisper_params) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) diff --git a/modules/whisper/insanely_fast_whisper_inference.py b/modules/whisper/insanely_fast_whisper_inference.py index 11f94bdf..0baae0e3 100644 --- a/modules/whisper/insanely_fast_whisper_inference.py +++ b/modules/whisper/insanely_fast_whisper_inference.py @@ -61,7 +61,7 @@ def transcribe(self, elapsed time for transcription """ start_time = time.time() - params = WhisperParameters.as_value(*whisper_params) + params = WhisperGradioComponents.as_value(*whisper_params) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) diff --git a/modules/whisper/whisper_Inference.py b/modules/whisper/whisper_Inference.py index 16ec9645..14887355 100644 --- a/modules/whisper/whisper_Inference.py +++ b/modules/whisper/whisper_Inference.py @@ -51,7 +51,7 @@ def transcribe(self, elapsed time for transcription """ start_time = time.time() - params = WhisperParameters.as_value(*whisper_params) + params = WhisperGradioComponents.as_value(*whisper_params) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) diff --git a/modules/whisper/whisper_base.py b/modules/whisper/whisper_base.py index c8c08dbb..2ff5ae87 100644 --- a/modules/whisper/whisper_base.py +++ b/modules/whisper/whisper_base.py @@ -99,7 +99,7 @@ def run(self, elapsed_time: float elapsed time for running """ - params = WhisperParameters.as_value(*whisper_params) + params = WhisperGradioComponents.as_value(*whisper_params) self.cache_parameters( whisper_params=params, From 927f1ef9f2be59b657e59e9bf2abfb7547346578 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Sun, 27 Oct 2024 16:58:44 +0900 Subject: [PATCH 03/41] Rename model --- app.py | 2 +- modules/whisper/faster_whisper_inference.py | 2 +- modules/whisper/insanely_fast_whisper_inference.py | 2 +- modules/whisper/whisper_Inference.py | 2 +- modules/whisper/whisper_base.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/app.py b/app.py index d9985ba1..90f2e4ba 100644 --- a/app.py +++ b/app.py @@ -196,7 +196,7 @@ def create_whisper_parameters(self): dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate]) return ( - WhisperGradioComponents( + TranscriptionPipelineGradioComponents( model_size=dd_model, lang=dd_lang, is_translate=cb_translate, beam_size=nb_beam_size, log_prob_threshold=nb_log_prob_threshold, no_speech_threshold=nb_no_speech_threshold, compute_type=dd_compute_type, best_of=nb_best_of, patience=nb_patience, diff --git a/modules/whisper/faster_whisper_inference.py b/modules/whisper/faster_whisper_inference.py index d3f52c9d..94eb18f5 100644 --- a/modules/whisper/faster_whisper_inference.py +++ b/modules/whisper/faster_whisper_inference.py @@ -62,7 +62,7 @@ def transcribe(self, """ start_time = time.time() - params = WhisperGradioComponents.as_value(*whisper_params) + params = TranscriptionPipelineGradioComponents.as_value(*whisper_params) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) diff --git a/modules/whisper/insanely_fast_whisper_inference.py b/modules/whisper/insanely_fast_whisper_inference.py index 0baae0e3..c7476d10 100644 --- a/modules/whisper/insanely_fast_whisper_inference.py +++ b/modules/whisper/insanely_fast_whisper_inference.py @@ -61,7 +61,7 @@ def transcribe(self, elapsed time for transcription """ start_time = time.time() - params = WhisperGradioComponents.as_value(*whisper_params) + params = TranscriptionPipelineGradioComponents.as_value(*whisper_params) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) diff --git a/modules/whisper/whisper_Inference.py b/modules/whisper/whisper_Inference.py index 14887355..338e04b7 100644 --- a/modules/whisper/whisper_Inference.py +++ b/modules/whisper/whisper_Inference.py @@ -51,7 +51,7 @@ def transcribe(self, elapsed time for transcription """ start_time = time.time() - params = WhisperGradioComponents.as_value(*whisper_params) + params = TranscriptionPipelineGradioComponents.as_value(*whisper_params) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) diff --git a/modules/whisper/whisper_base.py b/modules/whisper/whisper_base.py index 2ff5ae87..c7708538 100644 --- a/modules/whisper/whisper_base.py +++ b/modules/whisper/whisper_base.py @@ -99,7 +99,7 @@ def run(self, elapsed_time: float elapsed time for running """ - params = WhisperGradioComponents.as_value(*whisper_params) + params = TranscriptionPipelineGradioComponents.as_value(*whisper_params) self.cache_parameters( whisper_params=params, From a88b526e18531d07302abaf82c1e98a3aca3a714 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Sun, 27 Oct 2024 17:11:01 +0900 Subject: [PATCH 04/41] Rename model --- modules/whisper/whisper_base.py | 2 +- tests/test_bgm_separation.py | 2 +- tests/test_diarization.py | 2 +- tests/test_transcription.py | 4 ++-- tests/test_vad.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/whisper/whisper_base.py b/modules/whisper/whisper_base.py index c7708538..e891ec73 100644 --- a/modules/whisper/whisper_base.py +++ b/modules/whisper/whisper_base.py @@ -514,7 +514,7 @@ def remove_input_files(file_paths: List[str]): @staticmethod def cache_parameters( - whisper_params: WhisperValues, + whisper_params: TranscriptionPipelineParams, add_timestamp: bool ): """cache parameters to the yaml file""" diff --git a/tests/test_bgm_separation.py b/tests/test_bgm_separation.py index a8178ea6..504be03f 100644 --- a/tests/test_bgm_separation.py +++ b/tests/test_bgm_separation.py @@ -1,6 +1,6 @@ from modules.utils.paths import * from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.data_classes import WhisperValues +from modules.whisper.data_classes import TranscriptionPipelineParams from test_config import * from test_transcription import download_file, test_transcribe diff --git a/tests/test_diarization.py b/tests/test_diarization.py index 2a4c77af..daf41475 100644 --- a/tests/test_diarization.py +++ b/tests/test_diarization.py @@ -1,6 +1,6 @@ from modules.utils.paths import * from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.data_classes import WhisperValues +from modules.whisper.data_classes import TranscriptionPipelineParams from test_config import * from test_transcription import download_file, test_transcribe diff --git a/tests/test_transcription.py b/tests/test_transcription.py index 1e1560e7..dc85c2df 100644 --- a/tests/test_transcription.py +++ b/tests/test_transcription.py @@ -1,5 +1,5 @@ from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.data_classes import WhisperValues +from modules.whisper.data_classes import TranscriptionPipelineParams from modules.utils.paths import WEBUI_DIR from test_config import * @@ -37,7 +37,7 @@ def test_transcribe( f"""Diarization Device: {whisper_inferencer.diarizer.device}""" ) - hparams = WhisperValues( + hparams = TranscriptionPipelineParams( model_size=TEST_WHISPER_MODEL, vad_filter=vad_filter, is_bgm_separate=bgm_separation, diff --git a/tests/test_vad.py b/tests/test_vad.py index d2a30df6..fd7fd6c9 100644 --- a/tests/test_vad.py +++ b/tests/test_vad.py @@ -1,6 +1,6 @@ from modules.utils.paths import * from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.data_classes import WhisperValues +from modules.whisper.data_classes import TranscriptionPipelineParams from test_config import * from test_transcription import download_file, test_transcribe From df6fc6fdb9255f827b63c2ded3c622ba0ca1adcc Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Sun, 27 Oct 2024 22:40:41 +0900 Subject: [PATCH 05/41] Refactor dataclasses --- modules/whisper/data_classes.py | 801 ++++++++++++++++++-------------- 1 file changed, 462 insertions(+), 339 deletions(-) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index 19115fc2..fd3d6210 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -1,371 +1,494 @@ -from dataclasses import dataclass, fields import gradio as gr -from typing import Optional, Dict +import torch +from typing import Optional, Dict, List +from pydantic import BaseModel, Field, field_validator +from gradio_i18n import Translate, gettext as _ +from enum import Enum import yaml from modules.utils.constants import AUTOMATIC_DETECTION -@dataclass -class WhisperParameters: - model_size: gr.Dropdown - lang: gr.Dropdown - is_translate: gr.Checkbox - beam_size: gr.Number - log_prob_threshold: gr.Number - no_speech_threshold: gr.Number - compute_type: gr.Dropdown - best_of: gr.Number - patience: gr.Number - condition_on_previous_text: gr.Checkbox - prompt_reset_on_temperature: gr.Slider - initial_prompt: gr.Textbox - temperature: gr.Slider - compression_ratio_threshold: gr.Number - vad_filter: gr.Checkbox - threshold: gr.Slider - min_speech_duration_ms: gr.Number - max_speech_duration_s: gr.Number - min_silence_duration_ms: gr.Number - speech_pad_ms: gr.Number - batch_size: gr.Number - is_diarize: gr.Checkbox - hf_token: gr.Textbox - diarization_device: gr.Dropdown - length_penalty: gr.Number - repetition_penalty: gr.Number - no_repeat_ngram_size: gr.Number - prefix: gr.Textbox - suppress_blank: gr.Checkbox - suppress_tokens: gr.Textbox - max_initial_timestamp: gr.Number - word_timestamps: gr.Checkbox - prepend_punctuations: gr.Textbox - append_punctuations: gr.Textbox - max_new_tokens: gr.Number - chunk_length: gr.Number - hallucination_silence_threshold: gr.Number - hotwords: gr.Textbox - language_detection_threshold: gr.Number - language_detection_segments: gr.Number - is_bgm_separate: gr.Checkbox - uvr_model_size: gr.Dropdown - uvr_device: gr.Dropdown - uvr_segment_size: gr.Number - uvr_save_file: gr.Checkbox - uvr_enable_offload: gr.Checkbox - """ - A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing. - This data class is used to mitigate the key-value problem between Gradio components and function parameters. - Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471 - See more about Gradio pre-processing: https://www.gradio.app/docs/components +class WhisperImpl(Enum): + WHISPER = "whisper" + FASTER_WHISPER = "faster-whisper" + INSANELY_FAST_WHISPER = "insanely_fast_whisper" - Attributes - ---------- - model_size: gr.Dropdown - Whisper model size. - - lang: gr.Dropdown - Source language of the file to transcribe. - - is_translate: gr.Checkbox - Boolean value that determines whether to translate to English. - It's Whisper's feature to translate speech from another language directly into English end-to-end. - - beam_size: gr.Number - Int value that is used for decoding option. - - log_prob_threshold: gr.Number - If the average log probability over sampled tokens is below this value, treat as failed. - - no_speech_threshold: gr.Number - If the no_speech probability is higher than this value AND - the average log probability over sampled tokens is below `log_prob_threshold`, - consider the segment as silent. - - compute_type: gr.Dropdown - compute type for transcription. - see more info : https://opennmt.net/CTranslate2/quantization.html - - best_of: gr.Number - Number of candidates when sampling with non-zero temperature. - - patience: gr.Number - Beam search patience factor. - - condition_on_previous_text: gr.Checkbox - if True, the previous output of the model is provided as a prompt for the next window; - disabling may make the text inconsistent across windows, but the model becomes less prone to - getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. - - initial_prompt: gr.Textbox - Optional text to provide as a prompt for the first window. This can be used to provide, or - "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns - to make it more likely to predict those word correctly. - - temperature: gr.Slider - Temperature for sampling. It can be a tuple of temperatures, - which will be successively used upon failures according to either - `compression_ratio_threshold` or `log_prob_threshold`. - - compression_ratio_threshold: gr.Number - If the gzip compression ratio is above this value, treat as failed - - vad_filter: gr.Checkbox - Enable the voice activity detection (VAD) to filter out parts of the audio - without speech. This step is using the Silero VAD model - https://github.com/snakers4/silero-vad. - - threshold: gr.Slider - This parameter is related with Silero VAD. Speech threshold. - Silero VAD outputs speech probabilities for each audio chunk, - probabilities ABOVE this value are considered as SPEECH. It is better to tune this - parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. - - min_speech_duration_ms: gr.Number - This parameter is related with Silero VAD. Final speech chunks shorter min_speech_duration_ms are thrown out. - - max_speech_duration_s: gr.Number - This parameter is related with Silero VAD. Maximum duration of speech chunks in seconds. Chunks longer - than max_speech_duration_s will be split at the timestamp of the last silence that - lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be - split aggressively just before max_speech_duration_s. - - min_silence_duration_ms: gr.Number - This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms - before separating it - - speech_pad_ms: gr.Number - This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side - - batch_size: gr.Number - This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe - - is_diarize: gr.Checkbox - This parameter is related with whisperx. Boolean value that determines whether to diarize or not. - - hf_token: gr.Textbox - This parameter is related with whisperx. Huggingface token is needed to download diarization models. - Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements - - diarization_device: gr.Dropdown - This parameter is related with whisperx. Device to run diarization model - - length_penalty: gr.Number - This parameter is related to faster-whisper. Exponential length penalty constant. - - repetition_penalty: gr.Number - This parameter is related to faster-whisper. Penalty applied to the score of previously generated tokens - (set > 1 to penalize). - no_repeat_ngram_size: gr.Number - This parameter is related to faster-whisper. Prevent repetitions of n-grams with this size (set 0 to disable). +class VadParams(BaseModel): + """Voice Activity Detection parameters""" + vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts") + threshold: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Speech threshold for Silero VAD. Probabilities above this value are considered speech" + ) + min_speech_duration_ms: int = Field( + default=250, + ge=0, + description="Final speech chunks shorter than this are discarded" + ) + max_speech_duration_s: float = Field( + default=float("inf"), + gt=0, + description="Maximum duration of speech chunks in seconds" + ) + min_silence_duration_ms: int = Field( + default=2000, + ge=0, + description="Minimum silence duration between speech chunks" + ) + speech_pad_ms: int = Field( + default=400, + ge=0, + description="Padding added to each side of speech chunks" + ) - prefix: gr.Textbox - This parameter is related to faster-whisper. Optional text to provide as a prefix for the first window. + def to_dict(self) -> Dict: + return self.model_dump() - suppress_blank: gr.Checkbox - This parameter is related to faster-whisper. Suppress blank outputs at the beginning of the sampling. + @classmethod + def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]: + defaults = defaults or {} + return [ + gr.Checkbox(label=_("Enable Silero VAD Filter"), value=defaults.get("vad_filter", cls.vad_filter), + interactive=True, + info=_("Enable this to transcribe only detected voice")), + gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", + value=defaults.get("threshold", cls.threshold), + info="Lower it to be more sensitive to small sounds."), + gr.Number(label="Minimum Speech Duration (ms)", precision=0, + value=defaults.get("min_speech_duration_ms", cls.min_speech_duration_ms), + info="Final speech chunks shorter than this time are thrown out"), + gr.Number(label="Maximum Speech Duration (s)", + value=defaults.get("max_speech_duration_s", cls.max_speech_duration_s), + info="Maximum duration of speech chunks in \"seconds\"."), + gr.Number(label="Minimum Silence Duration (ms)", precision=0, + value=defaults.get("min_silence_duration_ms", cls.min_silence_duration_ms), + info="In the end of each speech chunk wait for this time" + " before separating it"), + gr.Number(label="Speech Padding (ms)", precision=0, + value=defaults.get("speech_pad_ms", cls.speech_pad_ms), + info="Final speech chunks are padded by this time each side") + ] - suppress_tokens: gr.Textbox - This parameter is related to faster-whisper. List of token IDs to suppress. -1 will suppress a default set - of symbols as defined in the model config.json file. - max_initial_timestamp: gr.Number - This parameter is related to faster-whisper. The initial timestamp cannot be later than this. - word_timestamps: gr.Checkbox - This parameter is related to faster-whisper. Extract word-level timestamps using the cross-attention pattern - and dynamic time warping, and include the timestamps for each word in each segment. +class DiarizationParams(BaseModel): + """Speaker diarization parameters""" + is_diarize: bool = Field(default=False, description="Enable speaker diarization") + hf_token: str = Field( + default="", + description="Hugging Face token for downloading diarization models" + ) - prepend_punctuations: gr.Textbox - This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols - with the next word. + def to_dict(self) -> Dict: + return self.model_dump() - append_punctuations: gr.Textbox - This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols - with the previous word. + @classmethod + def to_gradio_inputs(cls, + defaults: Optional[Dict] = None, + available_devices: Optional[List] = None, + device: Optional[str] = None) -> List[gr.components.base.FormComponent]: + defaults = defaults or {} + return [ + gr.Checkbox( + label=_("Enable Diarization"), + value=defaults.get("is_diarize", cls.is_diarize), + info=_("Enable speaker diarization") + ), + gr.Textbox( + label=_("HuggingFace Token"), + value=defaults.get("hf_token", cls.hf_token), + info=_("This is only needed the first time you download the model") + ), + gr.Dropdown( + label=_("Device"), + choices=["cpu", "cuda"] if available_devices is None else available_devices, + value="cuda" if device is None else device, + info=_("Device to run diarization model") + ) + ] - max_new_tokens: gr.Number - This parameter is related to faster-whisper. Maximum number of new tokens to generate per-chunk. If not set, - the maximum will be set by the default max_length. - chunk_length: gr.Number - This parameter is related to faster-whisper and insanely-fast-whisper. The length of audio segments in seconds. - If it is not None, it will overwrite the default chunk_length of the FeatureExtractor. +class BGMSeparationParams(BaseModel): + """Background music separation parameters""" + is_separate_bgm: bool = Field(default=False, description="Enable background music separation") + model_size: str = Field( + default="UVR-MDX-NET-Inst_HQ_4", + description="UVR model size" + ) + segment_size: int = Field( + default=256, + gt=0, + description="Segment size for UVR model" + ) + save_file: bool = Field( + default=False, + description="Whether to save separated audio files" + ) + enable_offload: bool = Field( + default=True, + description="Offload UVR model after transcription" + ) - hallucination_silence_threshold: gr.Number - This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold - (in seconds) when a possible hallucination is detected. + def to_dict(self) -> Dict: + return self.model_dump() - hotwords: gr.Textbox - This parameter is related to faster-whisper. Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None. + @classmethod + def to_gradio_input(cls, + defaults: Optional[Dict] = None, + available_devices: Optional[List] = None, + device: Optional[str] = None, + available_models: Optional[List] = None) -> List[gr.components.base.FormComponent]: + defaults = defaults or {} + return [ + gr.Checkbox( + label=_("Enable Background Music Remover Filter"), + value=defaults.get("is_separate_bgm", cls.is_separate_bgm), + interactive=True, + info=_("Enabling this will remove background music") + ), + gr.Dropdown( + label=_("Device"), + choices=["cpu", "cuda"] if available_devices is None else available_devices, + value="cuda" if device is None else device, + info=_("Device to run UVR model") + ), + gr.Dropdown( + label=_("Model"), + choices=["UVR-MDX-NET-Inst_HQ_4", "UVR-MDX-NET-Inst_3"] if available_models is None else available_models, + value=defaults.get("model_size", cls.model_size), + info=_("UVR model size") + ), + gr.Number( + label="Segment Size", + value=defaults.get("segment_size", cls.segment_size), + precision=0, + info="Segment size for UVR model" + ), + gr.Checkbox( + label=_("Save separated files to output"), + value=defaults.get("save_file", cls.save_file), + info=_("Whether to save separated audio files") + ), + gr.Checkbox( + label=_("Offload sub model after removing background music"), + value=defaults.get("enable_offload", cls.enable_offload), + info=_("Offload UVR model after transcription") + ) + ] - language_detection_threshold: gr.Number - This parameter is related to faster-whisper. If the maximum probability of the language tokens is higher than this value, the language is detected. - language_detection_segments: gr.Number - This parameter is related to faster-whisper. Number of segments to consider for the language detection. - - is_separate_bgm: gr.Checkbox - This parameter is related to UVR. Boolean value that determines whether to separate bgm or not. - - uvr_model_size: gr.Dropdown - This parameter is related to UVR. UVR model size. - - uvr_device: gr.Dropdown - This parameter is related to UVR. Device to run UVR model. - - uvr_segment_size: gr.Number - This parameter is related to UVR. Segment size for UVR model. - - uvr_save_file: gr.Checkbox - This parameter is related to UVR. Boolean value that determines whether to save the file or not. - - uvr_enable_offload: gr.Checkbox - This parameter is related to UVR. Boolean value that determines whether to offload the UVR model or not - after each transcription. - """ +class WhisperParams(BaseModel): + """Whisper parameters""" + model_size: str = Field(default="large-v2", description="Whisper model size") + lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe") + is_translate: bool = Field(default=False, description="Translate speech to English end-to-end") + beam_size: int = Field(default=5, ge=1, description="Beam size for decoding") + log_prob_threshold: float = Field( + default=-1.0, + description="Threshold for average log probability of sampled tokens" + ) + no_speech_threshold: float = Field( + default=0.6, + ge=0.0, + le=1.0, + description="Threshold for detecting silence" + ) + compute_type: str = Field(default="float16", description="Computation type for transcription") + best_of: int = Field(default=5, ge=1, description="Number of candidates when sampling") + patience: float = Field(default=1.0, gt=0, description="Beam search patience factor") + condition_on_previous_text: bool = Field( + default=True, + description="Use previous output as prompt for next window" + ) + prompt_reset_on_temperature: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Temperature threshold for resetting prompt" + ) + initial_prompt: Optional[str] = Field(default=None, description="Initial prompt for first window") + temperature: float = Field( + default=0.0, + ge=0.0, + description="Temperature for sampling" + ) + compression_ratio_threshold: float = Field( + default=2.4, + gt=0, + description="Threshold for gzip compression ratio" + ) + batch_size: int = Field(default=24, gt=0, description="Batch size for processing") + length_penalty: float = Field(default=1.0, gt=0, description="Exponential length penalty") + repetition_penalty: float = Field(default=1.0, gt=0, description="Penalty for repeated tokens") + no_repeat_ngram_size: int = Field(default=0, ge=0, description="Size of n-grams to prevent repetition") + prefix: Optional[str] = Field(default=None, description="Prefix text for first window") + suppress_blank: bool = Field( + default=True, + description="Suppress blank outputs at start of sampling" + ) + suppress_tokens: Optional[str] = Field(default="[-1]", description="Token IDs to suppress") + max_initial_timestamp: float = Field( + default=0.0, + ge=0.0, + description="Maximum initial timestamp" + ) + word_timestamps: bool = Field(default=False, description="Extract word-level timestamps") + prepend_punctuations: Optional[str] = Field( + default="\"'“¿([{-", + description="Punctuations to merge with next word" + ) + append_punctuations: Optional[str] = Field( + default="\"'.。,,!!??::”)]}、", + description="Punctuations to merge with previous word" + ) + max_new_tokens: Optional[int] = Field(default=None, description="Maximum number of new tokens per chunk") + chunk_length: Optional[int] = Field(default=30, description="Length of audio segments in seconds") + hallucination_silence_threshold: Optional[float] = Field( + default=None, + description="Threshold for skipping silent periods in hallucination detection" + ) + hotwords: Optional[str] = Field(default=None, description="Hotwords/hint phrases for the model") + language_detection_threshold: Optional[float] = Field( + default=None, + description="Threshold for language detection probability" + ) + language_detection_segments: int = Field( + default=1, + gt=0, + description="Number of segments for language detection" + ) - def as_list(self) -> list: - """ - Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing. - See more about Gradio pre-processing: : https://www.gradio.app/docs/components + def to_dict(self): + return self.model_dump() - Returns - ---------- - A list of Gradio components - """ - return [getattr(self, f.name) for f in fields(self)] + @field_validator('lang') + def validate_lang(cls, v): + from modules.utils.constants import AUTOMATIC_DETECTION + return None if v == AUTOMATIC_DETECTION.unwrap() else v - @staticmethod - def as_value(*args) -> 'WhisperValues': - """ - To use Whisper parameters in function after Gradio post-processing. - See more about Gradio post-processing: : https://www.gradio.app/docs/components + @classmethod + def to_gradio_inputs(cls, + defaults: Optional[Dict] = None, + only_advanced: Optional[bool] = True, + whisper_type: Optional[WhisperImpl] = None): + defaults = {} if defaults is None else defaults + whisper_type = WhisperImpl.FASTER_WHISPER if whisper_type is None else whisper_type - Returns - ---------- - WhisperValues - Data class that has values of parameters - """ - return WhisperValues(*args) + inputs = [] + if not only_advanced: + inputs += [ + gr.Dropdown( + label="Model Size", + choices=["small", "medium", "large-v2"], + value=defaults.get("model_size", cls.model_size), + info="Whisper model size" + ), + gr.Textbox( + label="Language", + value=defaults.get("lang", cls.lang), + info="Source language of the file to transcribe" + ), + gr.Checkbox( + label="Translate to English", + value=defaults.get("is_translate", cls.is_translate), + info="Translate speech to English end-to-end" + ), + ] + inputs += [ + gr.Number( + label="Beam Size", + value=defaults.get("beam_size", cls.beam_size), + precision=0, + info="Beam size for decoding" + ), + gr.Number( + label="Log Probability Threshold", + value=defaults.get("log_prob_threshold", cls.log_prob_threshold), + info="Threshold for average log probability of sampled tokens" + ), + gr.Number( + label="No Speech Threshold", + value=defaults.get("no_speech_threshold", cls.no_speech_threshold), + info="Threshold for detecting silence" + ), + gr.Dropdown( + label="Compute Type", + choices=["float16", "int8", "int16"], + value=defaults.get("compute_type", cls.compute_type), + info="Computation type for transcription" + ), + gr.Number( + label="Best Of", + value=defaults.get("best_of", cls.best_of), + precision=0, + info="Number of candidates when sampling" + ), + gr.Number( + label="Patience", + value=defaults.get("patience", cls.patience), + info="Beam search patience factor" + ), + gr.Checkbox( + label="Condition On Previous Text", + value=defaults.get("condition_on_previous_text", cls.condition_on_previous_text), + info="Use previous output as prompt for next window" + ), + gr.Slider( + label="Prompt Reset On Temperature", + value=defaults.get("prompt_reset_on_temperature", cls.prompt_reset_on_temperature), + minimum=0, + maximum=1, + step=0.01, + info="Temperature threshold for resetting prompt" + ), + gr.Textbox( + label="Initial Prompt", + value=defaults.get("initial_prompt", cls.initial_prompt), + info="Initial prompt for first window" + ), + gr.Slider( + label="Temperature", + value=defaults.get("temperature", cls.temperature), + minimum=0.0, + step=0.01, + maximum=1.0, + info="Temperature for sampling" + ), + gr.Number( + label="Compression Ratio Threshold", + value=defaults.get("compression_ratio_threshold", cls.compression_ratio_threshold), + info="Threshold for gzip compression ratio" + ) + ] + if whisper_type == WhisperImpl.FASTER_WHISPER: + inputs += [ + gr.Number( + label="Length Penalty", + value=defaults.get("length_penalty", cls.length_penalty), + info="Exponential length penalty", + visible=whisper_type=="faster_whisper" + ), + gr.Number( + label="Repetition Penalty", + value=defaults.get("repetition_penalty", cls.repetition_penalty), + info="Penalty for repeated tokens" + ), + gr.Number( + label="No Repeat N-gram Size", + value=defaults.get("no_repeat_ngram_size", cls.no_repeat_ngram_size), + precision=0, + info="Size of n-grams to prevent repetition" + ), + gr.Textbox( + label="Prefix", + value=defaults.get("prefix", cls.prefix), + info="Prefix text for first window" + ), + gr.Checkbox( + label="Suppress Blank", + value=defaults.get("suppress_blank", cls.suppress_blank), + info="Suppress blank outputs at start of sampling" + ), + gr.Textbox( + label="Suppress Tokens", + value=defaults.get("suppress_tokens", cls.suppress_tokens), + info="Token IDs to suppress" + ), + gr.Number( + label="Max Initial Timestamp", + value=defaults.get("max_initial_timestamp", cls.max_initial_timestamp), + info="Maximum initial timestamp" + ), + gr.Checkbox( + label="Word Timestamps", + value=defaults.get("word_timestamps", cls.word_timestamps), + info="Extract word-level timestamps" + ), + gr.Textbox( + label="Prepend Punctuations", + value=defaults.get("prepend_punctuations", cls.prepend_punctuations), + info="Punctuations to merge with next word" + ), + gr.Textbox( + label="Append Punctuations", + value=defaults.get("append_punctuations", cls.append_punctuations), + info="Punctuations to merge with previous word" + ), + gr.Number( + label="Max New Tokens", + value=defaults.get("max_new_tokens", cls.max_new_tokens), + precision=0, + info="Maximum number of new tokens per chunk" + ), + gr.Number( + label="Chunk Length (s)", + value=defaults.get("chunk_length", cls.chunk_length), + precision=0, + info="Length of audio segments in seconds" + ), + gr.Number( + label="Hallucination Silence Threshold (sec)", + value=defaults.get("hallucination_silence_threshold", cls.hallucination_silence_threshold), + info="Threshold for skipping silent periods in hallucination detection" + ), + gr.Textbox( + label="Hotwords", + value=defaults.get("hotwords", cls.hotwords), + info="Hotwords/hint phrases for the model" + ), + gr.Number( + label="Language Detection Threshold", + value=defaults.get("language_detection_threshold", cls.language_detection_threshold), + info="Threshold for language detection probability" + ), + gr.Number( + label="Language Detection Segments", + value=defaults.get("language_detection_segments", cls.language_detection_segments), + precision=0, + info="Number of segments for language detection" + ) + ] -@dataclass -class WhisperValues: - model_size: str = "large-v2" - lang: Optional[str] = None - is_translate: bool = False - beam_size: int = 5 - log_prob_threshold: float = -1.0 - no_speech_threshold: float = 0.6 - compute_type: str = "float16" - best_of: int = 5 - patience: float = 1.0 - condition_on_previous_text: bool = True - prompt_reset_on_temperature: float = 0.5 - initial_prompt: Optional[str] = None - temperature: float = 0.0 - compression_ratio_threshold: float = 2.4 - vad_filter: bool = False - threshold: float = 0.5 - min_speech_duration_ms: int = 250 - max_speech_duration_s: float = float("inf") - min_silence_duration_ms: int = 2000 - speech_pad_ms: int = 400 - batch_size: int = 24 - is_diarize: bool = False - hf_token: str = "" - diarization_device: str = "cuda" - length_penalty: float = 1.0 - repetition_penalty: float = 1.0 - no_repeat_ngram_size: int = 0 - prefix: Optional[str] = None - suppress_blank: bool = True - suppress_tokens: Optional[str] = "[-1]" - max_initial_timestamp: float = 0.0 - word_timestamps: bool = False - prepend_punctuations: Optional[str] = "\"'“¿([{-" - append_punctuations: Optional[str] = "\"'.。,,!!??::”)]}、" - max_new_tokens: Optional[int] = None - chunk_length: Optional[int] = 30 - hallucination_silence_threshold: Optional[float] = None - hotwords: Optional[str] = None - language_detection_threshold: Optional[float] = None - language_detection_segments: int = 1 - is_bgm_separate: bool = False - uvr_model_size: str = "UVR-MDX-NET-Inst_HQ_4" - uvr_device: str = "cuda" - uvr_segment_size: int = 256 - uvr_save_file: bool = False - uvr_enable_offload: bool = True - """ - A data class to use Whisper parameters. - """ + if whisper_type == WhisperImpl.INSANELY_FAST_WHISPER: + inputs += [ + gr.Number( + label="Batch Size", + value=defaults.get("batch_size", cls.batch_size), + precision=0, + info="Batch size for processing", + visible=whisper_type == "insanely_fast_whisper" + ) + ] + return inputs - def to_yaml(self) -> Dict: + +class TranscriptionPipelineParams(BaseModel): + """Transcription pipeline parameters""" + whisper: WhisperParams = Field(default_factory=WhisperParams) + vad: VadParams = Field(default_factory=VadParams) + diarization: DiarizationParams = Field(default_factory=DiarizationParams) + bgm_separation: BGMSeparationParams = Field(default_factory=BGMSeparationParams) + + def to_dict(self) -> Dict: data = { - "whisper": { - "model_size": self.model_size, - "lang": AUTOMATIC_DETECTION.unwrap() if self.lang is None else self.lang, - "is_translate": self.is_translate, - "beam_size": self.beam_size, - "log_prob_threshold": self.log_prob_threshold, - "no_speech_threshold": self.no_speech_threshold, - "best_of": self.best_of, - "patience": self.patience, - "condition_on_previous_text": self.condition_on_previous_text, - "prompt_reset_on_temperature": self.prompt_reset_on_temperature, - "initial_prompt": None if not self.initial_prompt else self.initial_prompt, - "temperature": self.temperature, - "compression_ratio_threshold": self.compression_ratio_threshold, - "batch_size": self.batch_size, - "length_penalty": self.length_penalty, - "repetition_penalty": self.repetition_penalty, - "no_repeat_ngram_size": self.no_repeat_ngram_size, - "prefix": None if not self.prefix else self.prefix, - "suppress_blank": self.suppress_blank, - "suppress_tokens": self.suppress_tokens, - "max_initial_timestamp": self.max_initial_timestamp, - "word_timestamps": self.word_timestamps, - "prepend_punctuations": self.prepend_punctuations, - "append_punctuations": self.append_punctuations, - "max_new_tokens": self.max_new_tokens, - "chunk_length": self.chunk_length, - "hallucination_silence_threshold": self.hallucination_silence_threshold, - "hotwords": None if not self.hotwords else self.hotwords, - "language_detection_threshold": self.language_detection_threshold, - "language_detection_segments": self.language_detection_segments, - }, - "vad": { - "vad_filter": self.vad_filter, - "threshold": self.threshold, - "min_speech_duration_ms": self.min_speech_duration_ms, - "max_speech_duration_s": self.max_speech_duration_s, - "min_silence_duration_ms": self.min_silence_duration_ms, - "speech_pad_ms": self.speech_pad_ms, - }, - "diarization": { - "is_diarize": self.is_diarize, - "hf_token": self.hf_token - }, - "bgm_separation": { - "is_separate_bgm": self.is_bgm_separate, - "model_size": self.uvr_model_size, - "segment_size": self.uvr_segment_size, - "save_file": self.uvr_save_file, - "enable_offload": self.uvr_enable_offload - }, + "whisper": self.whisper.to_dict(), + "vad": self.vad.to_dict(), + "diarization": self.diarization.to_dict(), + "bgm_separation": self.bgm_separation.to_dict() } return data - def as_list(self) -> list: - """ - Converts the data class attributes into a list - - Returns - ---------- - A list of Whisper parameters - """ - return [getattr(self, f.name) for f in fields(self)] + # def as_list(self) -> list: + # """ + # Converts the data class attributes into a list + # + # Returns + # ---------- + # A list of Whisper parameters + # """ + # return [getattr(self, f.name) for f in fields(self)] From db355d3824025bc0eb48719b8a56b2937fa2eaac Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Sun, 27 Oct 2024 23:55:13 +0900 Subject: [PATCH 06/41] Add `as_list()` to use in gradio function --- modules/whisper/data_classes.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index fd3d6210..79363d86 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -274,7 +274,9 @@ def validate_lang(cls, v): def to_gradio_inputs(cls, defaults: Optional[Dict] = None, only_advanced: Optional[bool] = True, - whisper_type: Optional[WhisperImpl] = None): + whisper_type: Optional[WhisperImpl] = None, + available_compute_types: Optional[List] = None, + compute_type: Optional[str] = None): defaults = {} if defaults is None else defaults whisper_type = WhisperImpl.FASTER_WHISPER if whisper_type is None else whisper_type @@ -318,8 +320,8 @@ def to_gradio_inputs(cls, ), gr.Dropdown( label="Compute Type", - choices=["float16", "int8", "int16"], - value=defaults.get("compute_type", cls.compute_type), + choices=["float16", "int8", "int16"] if available_compute_types is None else available_compute_types, + value=defaults.get("compute_type", compute_type), info="Computation type for transcription" ), gr.Number( @@ -483,12 +485,9 @@ def to_dict(self) -> Dict: } return data - # def as_list(self) -> list: - # """ - # Converts the data class attributes into a list - # - # Returns - # ---------- - # A list of Whisper parameters - # """ - # return [getattr(self, f.name) for f in fields(self)] + def as_list(self) -> List: + whisper_list = [value for key, value in self.whisper.to_dict().items()] + vad_list = [value for key, value in self.vad.to_dict().items()] + diarization_list = [value for key, value in self.vad.to_dict().items()] + bgm_sep_list = [value for key, value in self.bgm_separation.to_dict().items()] + return whisper_list + vad_list + diarization_list + bgm_sep_list From 3be2b51e9c73bda0d487894091ac4dd3c425ba50 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 00:39:51 +0900 Subject: [PATCH 07/41] Add `from_list()` to use in gradio function --- modules/whisper/data_classes.py | 64 +++++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index 79363d86..88aece04 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field, field_validator from gradio_i18n import Translate, gettext as _ from enum import Enum +from copy import deepcopy import yaml from modules.utils.constants import AUTOMATIC_DETECTION @@ -15,7 +16,20 @@ class WhisperImpl(Enum): INSANELY_FAST_WHISPER = "insanely_fast_whisper" -class VadParams(BaseModel): +class BaseParams(BaseModel): + def to_dict(self) -> Dict: + return self.model_dump() + + def to_list(self) -> List: + return list(self.model_dump().values()) + + @classmethod + def from_list(cls, data_list: List) -> 'BaseParams': + field_names = list(cls.model_fields.keys()) + return cls(**dict(zip(field_names, data_list))) + + +class VadParams(BaseParams): """Voice Activity Detection parameters""" vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts") threshold: float = Field( @@ -45,9 +59,6 @@ class VadParams(BaseModel): description="Padding added to each side of speech chunks" ) - def to_dict(self) -> Dict: - return self.model_dump() - @classmethod def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]: defaults = defaults or {} @@ -74,8 +85,7 @@ def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components ] - -class DiarizationParams(BaseModel): +class DiarizationParams(BaseParams): """Speaker diarization parameters""" is_diarize: bool = Field(default=False, description="Enable speaker diarization") hf_token: str = Field( @@ -83,9 +93,6 @@ class DiarizationParams(BaseModel): description="Hugging Face token for downloading diarization models" ) - def to_dict(self) -> Dict: - return self.model_dump() - @classmethod def to_gradio_inputs(cls, defaults: Optional[Dict] = None, @@ -112,7 +119,7 @@ def to_gradio_inputs(cls, ] -class BGMSeparationParams(BaseModel): +class BGMSeparationParams(BaseParams): """Background music separation parameters""" is_separate_bgm: bool = Field(default=False, description="Enable background music separation") model_size: str = Field( @@ -133,9 +140,6 @@ class BGMSeparationParams(BaseModel): description="Offload UVR model after transcription" ) - def to_dict(self) -> Dict: - return self.model_dump() - @classmethod def to_gradio_input(cls, defaults: Optional[Dict] = None, @@ -181,7 +185,7 @@ def to_gradio_input(cls, ] -class WhisperParams(BaseModel): +class WhisperParams(BaseParams): """Whisper parameters""" model_size: str = Field(default="large-v2", description="Whisper model size") lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe") @@ -262,9 +266,6 @@ class WhisperParams(BaseModel): description="Number of segments for language detection" ) - def to_dict(self): - return self.model_dump() - @field_validator('lang') def validate_lang(cls, v): from modules.utils.constants import AUTOMATIC_DETECTION @@ -485,9 +486,36 @@ def to_dict(self) -> Dict: } return data - def as_list(self) -> List: + def to_list(self) -> List: + """ + Convert data class to the list because I have to pass the parameters as a list in the gradio. + Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471 + See more about Gradio pre-processing: https://www.gradio.app/docs/components + """ whisper_list = [value for key, value in self.whisper.to_dict().items()] vad_list = [value for key, value in self.vad.to_dict().items()] diarization_list = [value for key, value in self.vad.to_dict().items()] bgm_sep_list = [value for key, value in self.bgm_separation.to_dict().items()] return whisper_list + vad_list + diarization_list + bgm_sep_list + + @staticmethod + def from_list(pipeline_list: List) -> 'TranscriptionPipelineParams': + """Convert list to the data class again to use it in a function.""" + data_list = deepcopy(pipeline_list) + whisper_list, data_list = data_list[0:len(WhisperParams.__annotations__)] + data_list = data_list[len(WhisperParams.__annotations__):] + + vad_list = data_list[0:len(VadParams.__annotations__)] + data_list = data_list[len(VadParams.__annotations__):] + + diarization_list = data_list[0:len(DiarizationParams.__annotations__)] + data_list = data_list[len(DiarizationParams.__annotations__)] + + bgm_sep_list = data_list[0:len(BGMSeparationParams.__annotations__)] + + return TranscriptionPipelineParams( + whisper=WhisperParams.from_list(whisper_list), + vad=VadParams.from_list(vad_list), + diarization=DiarizationParams.from_list(diarization_list), + bgm_separation=BGMSeparationParams.from_list(bgm_sep_list) + ) From 25f5dd698625e5977689d0724bc8b286bb4d8347 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 01:02:51 +0900 Subject: [PATCH 08/41] Remove meaningless line --- modules/whisper/data_classes.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index 88aece04..cf7424f0 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -61,7 +61,6 @@ class VadParams(BaseParams): @classmethod def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]: - defaults = defaults or {} return [ gr.Checkbox(label=_("Enable Silero VAD Filter"), value=defaults.get("vad_filter", cls.vad_filter), interactive=True, @@ -98,7 +97,6 @@ def to_gradio_inputs(cls, defaults: Optional[Dict] = None, available_devices: Optional[List] = None, device: Optional[str] = None) -> List[gr.components.base.FormComponent]: - defaults = defaults or {} return [ gr.Checkbox( label=_("Enable Diarization"), @@ -146,7 +144,6 @@ def to_gradio_input(cls, available_devices: Optional[List] = None, device: Optional[str] = None, available_models: Optional[List] = None) -> List[gr.components.base.FormComponent]: - defaults = defaults or {} return [ gr.Checkbox( label=_("Enable Background Music Remover Filter"), @@ -278,7 +275,6 @@ def to_gradio_inputs(cls, whisper_type: Optional[WhisperImpl] = None, available_compute_types: Optional[List] = None, compute_type: Optional[str] = None): - defaults = {} if defaults is None else defaults whisper_type = WhisperImpl.FASTER_WHISPER if whisper_type is None else whisper_type inputs = [] From f137136e8f8cf5a953ef2a58b691e1827013620d Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 01:21:42 +0900 Subject: [PATCH 09/41] Add missing parameter --- modules/whisper/data_classes.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index cf7424f0..0944f0cc 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -87,6 +87,7 @@ def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components class DiarizationParams(BaseParams): """Speaker diarization parameters""" is_diarize: bool = Field(default=False, description="Enable speaker diarization") + device: str = Field(default="cuda", description="Device to run Diarization model.") hf_token: str = Field( default="", description="Hugging Face token for downloading diarization models" @@ -95,8 +96,7 @@ class DiarizationParams(BaseParams): @classmethod def to_gradio_inputs(cls, defaults: Optional[Dict] = None, - available_devices: Optional[List] = None, - device: Optional[str] = None) -> List[gr.components.base.FormComponent]: + available_devices: Optional[List] = None) -> List[gr.components.base.FormComponent]: return [ gr.Checkbox( label=_("Enable Diarization"), @@ -111,7 +111,7 @@ def to_gradio_inputs(cls, gr.Dropdown( label=_("Device"), choices=["cpu", "cuda"] if available_devices is None else available_devices, - value="cuda" if device is None else device, + value=defaults.get("device", cls.device), info=_("Device to run diarization model") ) ] @@ -124,6 +124,7 @@ class BGMSeparationParams(BaseParams): default="UVR-MDX-NET-Inst_HQ_4", description="UVR model size" ) + device: str = Field(default="cuda", description="Device to run UVR model.") segment_size: int = Field( default=256, gt=0, @@ -142,7 +143,6 @@ class BGMSeparationParams(BaseParams): def to_gradio_input(cls, defaults: Optional[Dict] = None, available_devices: Optional[List] = None, - device: Optional[str] = None, available_models: Optional[List] = None) -> List[gr.components.base.FormComponent]: return [ gr.Checkbox( @@ -154,7 +154,7 @@ def to_gradio_input(cls, gr.Dropdown( label=_("Device"), choices=["cpu", "cuda"] if available_devices is None else available_devices, - value="cuda" if device is None else device, + value=defaults.get("device", cls.device), info=_("Device to run UVR model") ), gr.Dropdown( From 37be7739d94e19874115058ee62c2ed368821e87 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 02:49:03 +0900 Subject: [PATCH 10/41] Fix to_list error --- modules/whisper/data_classes.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index 0944f0cc..3e94f8a7 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -488,24 +488,25 @@ def to_list(self) -> List: Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471 See more about Gradio pre-processing: https://www.gradio.app/docs/components """ - whisper_list = [value for key, value in self.whisper.to_dict().items()] - vad_list = [value for key, value in self.vad.to_dict().items()] - diarization_list = [value for key, value in self.vad.to_dict().items()] - bgm_sep_list = [value for key, value in self.bgm_separation.to_dict().items()] + whisper_list = self.whisper.to_list() + vad_list = self.vad.to_list() + diarization_list = self.diarization.to_list() + bgm_sep_list = self.bgm_separation.to_list() return whisper_list + vad_list + diarization_list + bgm_sep_list @staticmethod def from_list(pipeline_list: List) -> 'TranscriptionPipelineParams': """Convert list to the data class again to use it in a function.""" data_list = deepcopy(pipeline_list) - whisper_list, data_list = data_list[0:len(WhisperParams.__annotations__)] + + whisper_list = data_list[0:len(WhisperParams.__annotations__)] data_list = data_list[len(WhisperParams.__annotations__):] vad_list = data_list[0:len(VadParams.__annotations__)] data_list = data_list[len(VadParams.__annotations__):] diarization_list = data_list[0:len(DiarizationParams.__annotations__)] - data_list = data_list[len(DiarizationParams.__annotations__)] + data_list = data_list[len(DiarizationParams.__annotations__):] bgm_sep_list = data_list[0:len(BGMSeparationParams.__annotations__)] From 501c4045549d2bf9b74711a6d88001161b46d86a Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 02:50:22 +0900 Subject: [PATCH 11/41] Update model usage --- modules/whisper/faster_whisper_inference.py | 2 +- .../insanely_fast_whisper_inference.py | 2 +- modules/whisper/whisper_Inference.py | 2 +- modules/whisper/whisper_base.py | 73 ++++++++++--------- tests/test_transcription.py | 24 ++++-- 5 files changed, 57 insertions(+), 46 deletions(-) diff --git a/modules/whisper/faster_whisper_inference.py b/modules/whisper/faster_whisper_inference.py index 94eb18f5..524cd247 100644 --- a/modules/whisper/faster_whisper_inference.py +++ b/modules/whisper/faster_whisper_inference.py @@ -62,7 +62,7 @@ def transcribe(self, """ start_time = time.time() - params = TranscriptionPipelineGradioComponents.as_value(*whisper_params) + params = WhisperParams.from_list(list(whisper_params)) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) diff --git a/modules/whisper/insanely_fast_whisper_inference.py b/modules/whisper/insanely_fast_whisper_inference.py index c7476d10..57483f44 100644 --- a/modules/whisper/insanely_fast_whisper_inference.py +++ b/modules/whisper/insanely_fast_whisper_inference.py @@ -61,7 +61,7 @@ def transcribe(self, elapsed time for transcription """ start_time = time.time() - params = TranscriptionPipelineGradioComponents.as_value(*whisper_params) + params = WhisperParams.from_list(list(whisper_params)) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) diff --git a/modules/whisper/whisper_Inference.py b/modules/whisper/whisper_Inference.py index 338e04b7..8f567b49 100644 --- a/modules/whisper/whisper_Inference.py +++ b/modules/whisper/whisper_Inference.py @@ -51,7 +51,7 @@ def transcribe(self, elapsed time for transcription """ start_time = time.time() - params = TranscriptionPipelineGradioComponents.as_value(*whisper_params) + params = WhisperParams.from_list(list(whisper_params)) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) diff --git a/modules/whisper/whisper_base.py b/modules/whisper/whisper_base.py index e891ec73..08ed0453 100644 --- a/modules/whisper/whisper_base.py +++ b/modules/whisper/whisper_base.py @@ -74,7 +74,7 @@ def run(self, audio: Union[str, BinaryIO, np.ndarray], progress: gr.Progress = gr.Progress(), add_timestamp: bool = True, - *whisper_params, + *pipeline_params, ) -> Tuple[List[dict], float]: """ Run transcription with conditional pre-processing and post-processing. @@ -89,8 +89,8 @@ def run(self, Indicator to show progress directly in gradio. add_timestamp: bool Whether to add a timestamp at the end of the filename. - *whisper_params: tuple - Parameters related with whisper. This will be dealt with "WhisperParameters" data class + *pipeline_params: tuple + Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class Returns ---------- @@ -99,28 +99,29 @@ def run(self, elapsed_time: float elapsed time for running """ - params = TranscriptionPipelineGradioComponents.as_value(*whisper_params) + params = TranscriptionPipelineParams.from_list(list(pipeline_params)) + bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization self.cache_parameters( - whisper_params=params, + params=params, add_timestamp=add_timestamp ) - if params.lang is None: + if whisper_params.lang is None: pass - elif params.lang == AUTOMATIC_DETECTION: - params.lang = None + elif whisper_params.lang == AUTOMATIC_DETECTION: + whisper_params.lang = None else: language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()} - params.lang = language_code_dict[params.lang] + whisper_params.lang = language_code_dict[params.lang] - if params.is_bgm_separate: + if bgm_params.is_separate_bgm: music, audio, _ = self.music_separator.separate( audio=audio, - model_name=params.uvr_model_size, - device=params.uvr_device, - segment_size=params.uvr_segment_size, - save_file=params.uvr_save_file, + model_name=bgm_params.model_size, + device=bgm_params.device, + segment_size=bgm_params.segment_size, + save_file=bgm_params.save_file, progress=progress ) @@ -132,20 +133,20 @@ def run(self, origin_sample_rate = self.music_separator.audio_info.sample_rate audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate) - if params.uvr_enable_offload: + if bgm_params.enable_offload: self.music_separator.offload() - if params.vad_filter: + if vad_params.vad_filter: # Explicit value set for float('inf') from gr.Number() - if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999: - params.max_speech_duration_s = float('inf') + if vad_params.max_speech_duration_s is None or vad_params.max_speech_duration_s >= 9999: + vad_params.max_speech_duration_s = float('inf') vad_options = VadOptions( - threshold=params.threshold, - min_speech_duration_ms=params.min_speech_duration_ms, - max_speech_duration_s=params.max_speech_duration_s, - min_silence_duration_ms=params.min_silence_duration_ms, - speech_pad_ms=params.speech_pad_ms + threshold=vad_params.threshold, + min_speech_duration_ms=vad_params.min_speech_duration_ms, + max_speech_duration_s=vad_params.max_speech_duration_s, + min_silence_duration_ms=vad_params.min_silence_duration_ms, + speech_pad_ms=vad_params.speech_pad_ms ) audio, speech_chunks = self.vad.run( @@ -157,20 +158,21 @@ def run(self, result, elapsed_time = self.transcribe( audio, progress, - *astuple(params) + *whisper_params.to_list() ) - if params.vad_filter: + if vad_params.vad_filter: result = self.vad.restore_speech_timestamps( segments=result, - speech_chunks=speech_chunks, + speech_chunks=vad_params.speech_chunks, ) - if params.is_diarize: + if diarization_params.is_diarize: result, elapsed_time_diarization = self.diarizer.run( audio=audio, - use_auth_token=params.hf_token, + use_auth_token=diarization_params.hf_token, transcribed_result=result, + device=diarization_params.device ) elapsed_time += elapsed_time_diarization return result, elapsed_time @@ -181,7 +183,7 @@ def transcribe_file(self, file_format: str = "SRT", add_timestamp: bool = True, progress=gr.Progress(), - *whisper_params, + *params, ) -> list: """ Write subtitle file from Files @@ -199,8 +201,8 @@ def transcribe_file(self, Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename. progress: gr.Progress Indicator to show progress directly in gradio. - *whisper_params: tuple - Parameters related with whisper. This will be dealt with "WhisperParameters" data class + *params: tuple + Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class Returns ---------- @@ -223,7 +225,7 @@ def transcribe_file(self, file, progress, add_timestamp, - *whisper_params, + *params, ) file_name, file_ext = os.path.splitext(os.path.basename(file)) @@ -514,13 +516,14 @@ def remove_input_files(file_paths: List[str]): @staticmethod def cache_parameters( - whisper_params: TranscriptionPipelineParams, + params: TranscriptionPipelineParams, add_timestamp: bool ): """cache parameters to the yaml file""" cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) - cached_whisper_param = whisper_params.to_yaml() - cached_yaml = {**cached_params, **cached_whisper_param} + param_to_cache = params.to_dict() + + cached_yaml = {**cached_params, **param_to_cache} cached_yaml["whisper"]["add_timestamp"] = add_timestamp save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH) diff --git a/tests/test_transcription.py b/tests/test_transcription.py index dc85c2df..f9dc3cc2 100644 --- a/tests/test_transcription.py +++ b/tests/test_transcription.py @@ -1,5 +1,5 @@ from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.data_classes import TranscriptionPipelineParams +from modules.whisper.data_classes import * from modules.utils.paths import WEBUI_DIR from test_config import * @@ -38,13 +38,21 @@ def test_transcribe( ) hparams = TranscriptionPipelineParams( - model_size=TEST_WHISPER_MODEL, - vad_filter=vad_filter, - is_bgm_separate=bgm_separation, - compute_type=whisper_inferencer.current_compute_type, - uvr_enable_offload=True, - is_diarize=diarization, - ).as_list() + whisper=WhisperParams( + model_size=TEST_WHISPER_MODEL, + compute_type=whisper_inferencer.current_compute_type + ), + vad=VadParams( + vad_filter=vad_filter + ), + bgm_separation=BGMSeparationParams( + is_separate_bgm=bgm_separation, + enable_offload=True + ), + diarization=DiarizationParams( + is_diarize=diarization + ), + ).to_list() subtitle_str, file_path = whisper_inferencer.transcribe_file( [audio_path], From 393a9c32f3366aa8a79e3065f89ac3218f709033 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 17:04:35 +0900 Subject: [PATCH 12/41] Fix class method attribute access --- modules/whisper/data_classes.py | 138 ++++++++++++++++++-------------- 1 file changed, 78 insertions(+), 60 deletions(-) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index 3e94f8a7..67c37939 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -62,25 +62,37 @@ class VadParams(BaseParams): @classmethod def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]: return [ - gr.Checkbox(label=_("Enable Silero VAD Filter"), value=defaults.get("vad_filter", cls.vad_filter), - interactive=True, - info=_("Enable this to transcribe only detected voice")), - gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", - value=defaults.get("threshold", cls.threshold), - info="Lower it to be more sensitive to small sounds."), - gr.Number(label="Minimum Speech Duration (ms)", precision=0, - value=defaults.get("min_speech_duration_ms", cls.min_speech_duration_ms), - info="Final speech chunks shorter than this time are thrown out"), - gr.Number(label="Maximum Speech Duration (s)", - value=defaults.get("max_speech_duration_s", cls.max_speech_duration_s), - info="Maximum duration of speech chunks in \"seconds\"."), - gr.Number(label="Minimum Silence Duration (ms)", precision=0, - value=defaults.get("min_silence_duration_ms", cls.min_silence_duration_ms), - info="In the end of each speech chunk wait for this time" - " before separating it"), - gr.Number(label="Speech Padding (ms)", precision=0, - value=defaults.get("speech_pad_ms", cls.speech_pad_ms), - info="Final speech chunks are padded by this time each side") + gr.Checkbox( + label=_("Enable Silero VAD Filter"), + value=defaults.get("vad_filter", cls.__fields__["vad_filter"].default), + interactive=True, + info=_("Enable this to transcribe only detected voice") + ), + gr.Slider( + minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", + value=defaults.get("threshold", cls.__fields__["threshold"].default), + info="Lower it to be more sensitive to small sounds." + ), + gr.Number( + label="Minimum Speech Duration (ms)", precision=0, + value=defaults.get("min_speech_duration_ms", cls.__fields__["min_speech_duration_ms"].default), + info="Final speech chunks shorter than this time are thrown out" + ), + gr.Number( + label="Maximum Speech Duration (s)", + value=defaults.get("max_speech_duration_s", cls.__fields__["max_speech_duration_s"].default), + info="Maximum duration of speech chunks in \"seconds\"." + ), + gr.Number( + label="Minimum Silence Duration (ms)", precision=0, + value=defaults.get("min_silence_duration_ms", cls.__fields__["min_silence_duration_ms"].default), + info="In the end of each speech chunk wait for this time before separating it" + ), + gr.Number( + label="Speech Padding (ms)", precision=0, + value=defaults.get("speech_pad_ms", cls.__fields__["speech_pad_ms"].default), + info="Final speech chunks are padded by this time each side" + ) ] @@ -100,18 +112,18 @@ def to_gradio_inputs(cls, return [ gr.Checkbox( label=_("Enable Diarization"), - value=defaults.get("is_diarize", cls.is_diarize), + value=defaults.get("is_diarize", cls.__fields__["is_diarize"].default), info=_("Enable speaker diarization") ), gr.Textbox( label=_("HuggingFace Token"), - value=defaults.get("hf_token", cls.hf_token), + value=defaults.get("hf_token", cls.__fields__["hf_token"].default), info=_("This is only needed the first time you download the model") ), gr.Dropdown( label=_("Device"), choices=["cpu", "cuda"] if available_devices is None else available_devices, - value=defaults.get("device", cls.device), + value=defaults.get("device", cls.__fields__["device"].default), info=_("Device to run diarization model") ) ] @@ -147,36 +159,37 @@ def to_gradio_input(cls, return [ gr.Checkbox( label=_("Enable Background Music Remover Filter"), - value=defaults.get("is_separate_bgm", cls.is_separate_bgm), + value=defaults.get("is_separate_bgm", cls.__fields__["is_separate_bgm"].default), interactive=True, info=_("Enabling this will remove background music") ), gr.Dropdown( label=_("Device"), choices=["cpu", "cuda"] if available_devices is None else available_devices, - value=defaults.get("device", cls.device), + value=defaults.get("device", cls.__fields__["device"].default), info=_("Device to run UVR model") ), gr.Dropdown( label=_("Model"), - choices=["UVR-MDX-NET-Inst_HQ_4", "UVR-MDX-NET-Inst_3"] if available_models is None else available_models, - value=defaults.get("model_size", cls.model_size), + choices=["UVR-MDX-NET-Inst_HQ_4", + "UVR-MDX-NET-Inst_3"] if available_models is None else available_models, + value=defaults.get("model_size", cls.__fields__["model_size"].default), info=_("UVR model size") ), gr.Number( label="Segment Size", - value=defaults.get("segment_size", cls.segment_size), + value=defaults.get("segment_size", cls.__fields__["segment_size"].default), precision=0, info="Segment size for UVR model" ), gr.Checkbox( label=_("Save separated files to output"), - value=defaults.get("save_file", cls.save_file), + value=defaults.get("save_file", cls.__fields__["save_file"].default), info=_("Whether to save separated audio files") ), gr.Checkbox( label=_("Offload sub model after removing background music"), - value=defaults.get("enable_offload", cls.enable_offload), + value=defaults.get("enable_offload", cls.__fields__["enable_offload"].default), info=_("Offload UVR model after transcription") ) ] @@ -283,17 +296,17 @@ def to_gradio_inputs(cls, gr.Dropdown( label="Model Size", choices=["small", "medium", "large-v2"], - value=defaults.get("model_size", cls.model_size), + value=defaults.get("model_size", cls.__fields__["model_size"].default), info="Whisper model size" ), gr.Textbox( label="Language", - value=defaults.get("lang", cls.lang), + value=defaults.get("lang", cls.__fields__["lang"].default), info="Source language of the file to transcribe" ), gr.Checkbox( label="Translate to English", - value=defaults.get("is_translate", cls.is_translate), + value=defaults.get("is_translate", cls.__fields__["is_translate"].default), info="Translate speech to English end-to-end" ), ] @@ -301,18 +314,18 @@ def to_gradio_inputs(cls, inputs += [ gr.Number( label="Beam Size", - value=defaults.get("beam_size", cls.beam_size), + value=defaults.get("beam_size", cls.__fields__["beam_size"].default), precision=0, info="Beam size for decoding" ), gr.Number( label="Log Probability Threshold", - value=defaults.get("log_prob_threshold", cls.log_prob_threshold), + value=defaults.get("log_prob_threshold", cls.__fields__["log_prob_threshold"].default), info="Threshold for average log probability of sampled tokens" ), gr.Number( label="No Speech Threshold", - value=defaults.get("no_speech_threshold", cls.no_speech_threshold), + value=defaults.get("no_speech_threshold", cls.__fields__["no_speech_threshold"].default), info="Threshold for detecting silence" ), gr.Dropdown( @@ -323,23 +336,24 @@ def to_gradio_inputs(cls, ), gr.Number( label="Best Of", - value=defaults.get("best_of", cls.best_of), + value=defaults.get("best_of", cls.__fields__["best_of"].default), precision=0, info="Number of candidates when sampling" ), gr.Number( label="Patience", - value=defaults.get("patience", cls.patience), + value=defaults.get("patience", cls.__fields__["patience"].default), info="Beam search patience factor" ), gr.Checkbox( label="Condition On Previous Text", - value=defaults.get("condition_on_previous_text", cls.condition_on_previous_text), + value=defaults.get("condition_on_previous_text", cls.__fields__["condition_on_previous_text"].default), info="Use previous output as prompt for next window" ), gr.Slider( label="Prompt Reset On Temperature", - value=defaults.get("prompt_reset_on_temperature", cls.prompt_reset_on_temperature), + value=defaults.get("prompt_reset_on_temperature", + cls.__fields__["prompt_reset_on_temperature"].default), minimum=0, maximum=1, step=0.01, @@ -347,12 +361,12 @@ def to_gradio_inputs(cls, ), gr.Textbox( label="Initial Prompt", - value=defaults.get("initial_prompt", cls.initial_prompt), + value=defaults.get("initial_prompt", cls.__fields__["initial_prompt"].default), info="Initial prompt for first window" ), gr.Slider( label="Temperature", - value=defaults.get("temperature", cls.temperature), + value=defaults.get("temperature", cls.__fields__["temperature"].default), minimum=0.0, step=0.01, maximum=1.0, @@ -360,7 +374,8 @@ def to_gradio_inputs(cls, ), gr.Number( label="Compression Ratio Threshold", - value=defaults.get("compression_ratio_threshold", cls.compression_ratio_threshold), + value=defaults.get("compression_ratio_threshold", + cls.__fields__["compression_ratio_threshold"].default), info="Threshold for gzip compression ratio" ) ] @@ -368,86 +383,89 @@ def to_gradio_inputs(cls, inputs += [ gr.Number( label="Length Penalty", - value=defaults.get("length_penalty", cls.length_penalty), + value=defaults.get("length_penalty", cls.__fields__["length_penalty"].default), info="Exponential length penalty", - visible=whisper_type=="faster_whisper" + visible=whisper_type == "faster_whisper" ), gr.Number( label="Repetition Penalty", - value=defaults.get("repetition_penalty", cls.repetition_penalty), + value=defaults.get("repetition_penalty", cls.__fields__["repetition_penalty"].default), info="Penalty for repeated tokens" ), gr.Number( label="No Repeat N-gram Size", - value=defaults.get("no_repeat_ngram_size", cls.no_repeat_ngram_size), + value=defaults.get("no_repeat_ngram_size", cls.__fields__["no_repeat_ngram_size"].default), precision=0, info="Size of n-grams to prevent repetition" ), gr.Textbox( label="Prefix", - value=defaults.get("prefix", cls.prefix), + value=defaults.get("prefix", cls.__fields__["prefix"].default), info="Prefix text for first window" ), gr.Checkbox( label="Suppress Blank", - value=defaults.get("suppress_blank", cls.suppress_blank), + value=defaults.get("suppress_blank", cls.__fields__["suppress_blank"].default), info="Suppress blank outputs at start of sampling" ), gr.Textbox( label="Suppress Tokens", - value=defaults.get("suppress_tokens", cls.suppress_tokens), + value=defaults.get("suppress_tokens", cls.__fields__["suppress_tokens"].default), info="Token IDs to suppress" ), gr.Number( label="Max Initial Timestamp", - value=defaults.get("max_initial_timestamp", cls.max_initial_timestamp), + value=defaults.get("max_initial_timestamp", cls.__fields__["max_initial_timestamp"].default), info="Maximum initial timestamp" ), gr.Checkbox( label="Word Timestamps", - value=defaults.get("word_timestamps", cls.word_timestamps), + value=defaults.get("word_timestamps", cls.__fields__["word_timestamps"].default), info="Extract word-level timestamps" ), gr.Textbox( label="Prepend Punctuations", - value=defaults.get("prepend_punctuations", cls.prepend_punctuations), + value=defaults.get("prepend_punctuations", cls.__fields__["prepend_punctuations"].default), info="Punctuations to merge with next word" ), gr.Textbox( label="Append Punctuations", - value=defaults.get("append_punctuations", cls.append_punctuations), + value=defaults.get("append_punctuations", cls.__fields__["append_punctuations"].default), info="Punctuations to merge with previous word" ), gr.Number( label="Max New Tokens", - value=defaults.get("max_new_tokens", cls.max_new_tokens), + value=defaults.get("max_new_tokens", cls.__fields__["max_new_tokens"].default), precision=0, info="Maximum number of new tokens per chunk" ), gr.Number( label="Chunk Length (s)", - value=defaults.get("chunk_length", cls.chunk_length), + value=defaults.get("chunk_length", cls.__fields__["chunk_length"].default), precision=0, info="Length of audio segments in seconds" ), gr.Number( label="Hallucination Silence Threshold (sec)", - value=defaults.get("hallucination_silence_threshold", cls.hallucination_silence_threshold), + value=defaults.get("hallucination_silence_threshold", + cls.__fields__["hallucination_silence_threshold"].default), info="Threshold for skipping silent periods in hallucination detection" ), gr.Textbox( label="Hotwords", - value=defaults.get("hotwords", cls.hotwords), + value=defaults.get("hotwords", cls.__fields__["hotwords"].default), info="Hotwords/hint phrases for the model" ), gr.Number( label="Language Detection Threshold", - value=defaults.get("language_detection_threshold", cls.language_detection_threshold), + value=defaults.get("language_detection_threshold", + cls.__fields__["language_detection_threshold"].default), info="Threshold for language detection probability" ), gr.Number( label="Language Detection Segments", - value=defaults.get("language_detection_segments", cls.language_detection_segments), + value=defaults.get("language_detection_segments", + cls.__fields__["language_detection_segments"].default), precision=0, info="Number of segments for language detection" ) @@ -457,7 +475,7 @@ def to_gradio_inputs(cls, inputs += [ gr.Number( label="Batch Size", - value=defaults.get("batch_size", cls.batch_size), + value=defaults.get("batch_size", cls.__fields__["batch_size"].default), precision=0, info="Batch size for processing", visible=whisper_type == "insanely_fast_whisper" From e667af97dd159bc79d154af03bf96c6139c9376f Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 17:10:45 +0900 Subject: [PATCH 13/41] Update comment --- modules/whisper/whisper_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/whisper/whisper_base.py b/modules/whisper/whisper_base.py index 08ed0453..466d327d 100644 --- a/modules/whisper/whisper_base.py +++ b/modules/whisper/whisper_base.py @@ -80,6 +80,7 @@ def run(self, Run transcription with conditional pre-processing and post-processing. The VAD will be performed to remove noise from the audio input in pre-processing, if enabled. The diarization will be performed in post-processing, if enabled. + Due to the integration with gradio, the parameters have to be specified with a `*` wildcard. Parameters ---------- From 806824baa5c8e0c922c929c9fc26a9c4c68558a2 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 17:11:02 +0900 Subject: [PATCH 14/41] Refactor to gradio functions --- app.py | 155 ++++----------------------------------------------------- 1 file changed, 10 insertions(+), 145 deletions(-) diff --git a/app.py b/app.py index 90f2e4ba..72026de8 100644 --- a/app.py +++ b/app.py @@ -66,158 +66,23 @@ def create_whisper_parameters(self): interactive=True) with gr.Accordion(_("Advanced Parameters"), open=False): - nb_beam_size = gr.Number(label="Beam Size", value=whisper_params["beam_size"], precision=0, - interactive=True, - info="Beam size to use for decoding.") - nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", - value=whisper_params["log_prob_threshold"], interactive=True, - info="If the average log probability over sampled tokens is below this value, treat as failed.") - nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=whisper_params["no_speech_threshold"], - interactive=True, - info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.") - dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, - value=self.whisper_inf.current_compute_type, interactive=True, - allow_custom_value=True, - info="Select the type of computation to perform.") - nb_best_of = gr.Number(label="Best Of", value=whisper_params["best_of"], interactive=True, - info="Number of candidates when sampling with non-zero temperature.") - nb_patience = gr.Number(label="Patience", value=whisper_params["patience"], interactive=True, - info="Beam search patience factor.") - cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", - value=whisper_params["condition_on_previous_text"], - interactive=True, - info="Condition on previous text during decoding.") - sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature", - value=whisper_params["prompt_reset_on_temperature"], - minimum=0, maximum=1, step=0.01, interactive=True, - info="Resets prompt if temperature is above this value." - " Arg has effect only if 'Condition On Previous Text' is True.") - tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True, - info="Initial prompt to use for decoding.") - sd_temperature = gr.Slider(label="Temperature", value=whisper_params["temperature"], minimum=0.0, - step=0.01, maximum=1.0, interactive=True, - info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.") - nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", - value=whisper_params["compression_ratio_threshold"], - interactive=True, - info="If the gzip compression ratio is above this value, treat as failed.") - nb_chunk_length = gr.Number(label="Chunk Length (s)", value=lambda: whisper_params["chunk_length"], - precision=0, - info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.") - with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)): - nb_length_penalty = gr.Number(label="Length Penalty", value=whisper_params["length_penalty"], - info="Exponential length penalty constant.") - nb_repetition_penalty = gr.Number(label="Repetition Penalty", - value=whisper_params["repetition_penalty"], - info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).") - nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size", - value=whisper_params["no_repeat_ngram_size"], - precision=0, - info="Prevent repetitions of n-grams with this size (set 0 to disable).") - tb_prefix = gr.Textbox(label="Prefix", value=lambda: whisper_params["prefix"], - info="Optional text to provide as a prefix for the first window.") - cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=whisper_params["suppress_blank"], - info="Suppress blank outputs at the beginning of the sampling.") - tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value=whisper_params["suppress_tokens"], - info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.") - nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp", - value=whisper_params["max_initial_timestamp"], - info="The initial timestamp cannot be later than this.") - cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=whisper_params["word_timestamps"], - info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.") - tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations", - value=whisper_params["prepend_punctuations"], - info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.") - tb_append_punctuations = gr.Textbox(label="Append Punctuations", - value=whisper_params["append_punctuations"], - info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.") - nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: whisper_params["max_new_tokens"], - precision=0, - info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.") - nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)", - value=lambda: whisper_params[ - "hallucination_silence_threshold"], - info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.") - tb_hotwords = gr.Textbox(label="Hotwords", value=lambda: whisper_params["hotwords"], - info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.") - nb_language_detection_threshold = gr.Number(label="Language Detection Threshold", - value=lambda: whisper_params[ - "language_detection_threshold"], - info="If the maximum probability of the language tokens is higher than this value, the language is detected.") - nb_language_detection_segments = gr.Number(label="Language Detection Segments", - value=lambda: whisper_params["language_detection_segments"], - precision=0, - info="Number of segments to consider for the language detection.") - with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)): - nb_batch_size = gr.Number(label="Batch Size", value=whisper_params["batch_size"], precision=0) + whisper_inputs = WhisperParams.to_gradio_inputs(defaults=whisper_params, only_advanced=True) with gr.Accordion(_("Background Music Remover Filter"), open=False): - cb_bgm_separation = gr.Checkbox(label=_("Enable Background Music Remover Filter"), - value=uvr_params["is_separate_bgm"], - interactive=True, - info=_("Enabling this will remove background music")) - dd_uvr_device = gr.Dropdown(label=_("Device"), value=self.whisper_inf.music_separator.device, - choices=self.whisper_inf.music_separator.available_devices) - dd_uvr_model_size = gr.Dropdown(label=_("Model"), value=uvr_params["model_size"], - choices=self.whisper_inf.music_separator.available_models) - nb_uvr_segment_size = gr.Number(label="Segment Size", value=uvr_params["segment_size"], precision=0) - cb_uvr_save_file = gr.Checkbox(label=_("Save separated files to output"), value=uvr_params["save_file"]) - cb_uvr_enable_offload = gr.Checkbox(label=_("Offload sub model after removing background music"), - value=uvr_params["enable_offload"]) + uvr_inputs = BGMSeparationParams.to_gradio_input(defaults=uvr_params) with gr.Accordion(_("Voice Detection Filter"), open=False): - cb_vad_filter = gr.Checkbox(label=_("Enable Silero VAD Filter"), value=vad_params["vad_filter"], - interactive=True, - info=_("Enable this to transcribe only detected voice")) - sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", - value=vad_params["threshold"], - info="Lower it to be more sensitive to small sounds.") - nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, - value=vad_params["min_speech_duration_ms"], - info="Final speech chunks shorter than this time are thrown out") - nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", - value=vad_params["max_speech_duration_s"], - info="Maximum duration of speech chunks in \"seconds\".") - nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, - value=vad_params["min_silence_duration_ms"], - info="In the end of each speech chunk wait for this time" - " before separating it") - nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=vad_params["speech_pad_ms"], - info="Final speech chunks are padded by this time each side") + vad_inputs = VadParams.to_gradio_inputs(defaults=vad_params) with gr.Accordion(_("Diarization"), open=False): - cb_diarize = gr.Checkbox(label=_("Enable Diarization"), value=diarization_params["is_diarize"]) - tb_hf_token = gr.Text(label=_("HuggingFace Token"), value=diarization_params["hf_token"], - info=_("This is only needed the first time you download the model")) - dd_diarization_device = gr.Dropdown(label=_("Device"), - choices=self.whisper_inf.diarizer.get_available_device(), - value=self.whisper_inf.diarizer.get_device()) + diarization_inputs = DiarizationParams.to_gradio_inputs(defaults=diarization_params) dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate]) + inputs = [dd_model, dd_lang, cb_translate] + whisper_inputs + vad_inputs + diarization_inputs + uvr_inputs + return ( - TranscriptionPipelineGradioComponents( - model_size=dd_model, lang=dd_lang, is_translate=cb_translate, beam_size=nb_beam_size, - log_prob_threshold=nb_log_prob_threshold, no_speech_threshold=nb_no_speech_threshold, - compute_type=dd_compute_type, best_of=nb_best_of, patience=nb_patience, - condition_on_previous_text=cb_condition_on_previous_text, initial_prompt=tb_initial_prompt, - temperature=sd_temperature, compression_ratio_threshold=nb_compression_ratio_threshold, - vad_filter=cb_vad_filter, threshold=sd_threshold, min_speech_duration_ms=nb_min_speech_duration_ms, - max_speech_duration_s=nb_max_speech_duration_s, min_silence_duration_ms=nb_min_silence_duration_ms, - speech_pad_ms=nb_speech_pad_ms, chunk_length=nb_chunk_length, batch_size=nb_batch_size, - is_diarize=cb_diarize, hf_token=tb_hf_token, diarization_device=dd_diarization_device, - length_penalty=nb_length_penalty, repetition_penalty=nb_repetition_penalty, - no_repeat_ngram_size=nb_no_repeat_ngram_size, prefix=tb_prefix, suppress_blank=cb_suppress_blank, - suppress_tokens=tb_suppress_tokens, max_initial_timestamp=nb_max_initial_timestamp, - word_timestamps=cb_word_timestamps, prepend_punctuations=tb_prepend_punctuations, - append_punctuations=tb_append_punctuations, max_new_tokens=nb_max_new_tokens, - hallucination_silence_threshold=nb_hallucination_silence_threshold, hotwords=tb_hotwords, - language_detection_threshold=nb_language_detection_threshold, - language_detection_segments=nb_language_detection_segments, - prompt_reset_on_temperature=sld_prompt_reset_on_temperature, is_bgm_separate=cb_bgm_separation, - uvr_device=dd_uvr_device, uvr_model_size=dd_uvr_model_size, uvr_segment_size=nb_uvr_segment_size, - uvr_save_file=cb_uvr_save_file, uvr_enable_offload=cb_uvr_enable_offload - ), + inputs, dd_file_format, cb_timestamp ) @@ -254,7 +119,7 @@ def launch(self): params = [input_file, tb_input_folder, dd_file_format, cb_timestamp] btn_run.click(fn=self.whisper_inf.transcribe_file, - inputs=params + whisper_params.as_list(), + inputs=params + whisper_params, outputs=[tb_indicator, files_subtitles]) btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None) @@ -280,7 +145,7 @@ def launch(self): params = [tb_youtubelink, dd_file_format, cb_timestamp] btn_run.click(fn=self.whisper_inf.transcribe_youtube, - inputs=params + whisper_params.as_list(), + inputs=params + whisper_params, outputs=[tb_indicator, files_subtitles]) tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink], outputs=[img_thumbnail, tb_title, tb_description]) @@ -302,7 +167,7 @@ def launch(self): params = [mic_input, dd_file_format, cb_timestamp] btn_run.click(fn=self.whisper_inf.transcribe_mic, - inputs=params + whisper_params.as_list(), + inputs=params + whisper_params, outputs=[tb_indicator, files_subtitles]) btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None) From 4df5628dad8836896715ee542853f6ff2c033d74 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 19:02:41 +0900 Subject: [PATCH 15/41] Rename function and variable --- app.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/app.py b/app.py index 72026de8..ac5010c1 100644 --- a/app.py +++ b/app.py @@ -44,7 +44,7 @@ def __init__(self, args): print(f"Use \"{self.args.whisper_type}\" implementation\n" f"Device \"{self.whisper_inf.device}\" is detected") - def create_whisper_parameters(self): + def create_pipeline_inputs(self): whisper_params = self.default_params["whisper"] vad_params = self.default_params["vad"] diarization_params = self.default_params["diarization"] @@ -108,7 +108,7 @@ def launch(self): visible=self.args.colab, value="") - whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters() + pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs() with gr.Row(): btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary") @@ -119,7 +119,7 @@ def launch(self): params = [input_file, tb_input_folder, dd_file_format, cb_timestamp] btn_run.click(fn=self.whisper_inf.transcribe_file, - inputs=params + whisper_params, + inputs=params + pipeline_params, outputs=[tb_indicator, files_subtitles]) btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None) @@ -133,7 +133,7 @@ def launch(self): tb_title = gr.Label(label=_("Youtube Title")) tb_description = gr.Textbox(label=_("Youtube Description"), max_lines=15) - whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters() + pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs() with gr.Row(): btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary") @@ -145,7 +145,7 @@ def launch(self): params = [tb_youtubelink, dd_file_format, cb_timestamp] btn_run.click(fn=self.whisper_inf.transcribe_youtube, - inputs=params + whisper_params, + inputs=params + pipeline_params, outputs=[tb_indicator, files_subtitles]) tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink], outputs=[img_thumbnail, tb_title, tb_description]) @@ -155,7 +155,7 @@ def launch(self): with gr.Row(): mic_input = gr.Microphone(label=_("Record with Mic"), type="filepath", interactive=True) - whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters() + pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs() with gr.Row(): btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary") @@ -167,7 +167,7 @@ def launch(self): params = [mic_input, dd_file_format, cb_timestamp] btn_run.click(fn=self.whisper_inf.transcribe_mic, - inputs=params + whisper_params, + inputs=params + pipeline_params, outputs=[tb_indicator, files_subtitles]) btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None) From e1367680451ba3babaa087a0cc04ec9f6d17dc85 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 19:07:14 +0900 Subject: [PATCH 16/41] Remove meaningless info --- modules/whisper/data_classes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index 67c37939..8572e805 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -113,7 +113,6 @@ def to_gradio_inputs(cls, gr.Checkbox( label=_("Enable Diarization"), value=defaults.get("is_diarize", cls.__fields__["is_diarize"].default), - info=_("Enable speaker diarization") ), gr.Textbox( label=_("HuggingFace Token"), From 9ea886232c87be7bd90f158191d10304814e79df Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 19:10:21 +0900 Subject: [PATCH 17/41] Remove meaningless info --- modules/whisper/data_classes.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index 8572e805..52431b5a 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -123,7 +123,6 @@ def to_gradio_inputs(cls, label=_("Device"), choices=["cpu", "cuda"] if available_devices is None else available_devices, value=defaults.get("device", cls.__fields__["device"].default), - info=_("Device to run diarization model") ) ] @@ -166,14 +165,12 @@ def to_gradio_input(cls, label=_("Device"), choices=["cpu", "cuda"] if available_devices is None else available_devices, value=defaults.get("device", cls.__fields__["device"].default), - info=_("Device to run UVR model") ), gr.Dropdown( label=_("Model"), choices=["UVR-MDX-NET-Inst_HQ_4", "UVR-MDX-NET-Inst_3"] if available_models is None else available_models, value=defaults.get("model_size", cls.__fields__["model_size"].default), - info=_("UVR model size") ), gr.Number( label="Segment Size", @@ -184,12 +181,10 @@ def to_gradio_input(cls, gr.Checkbox( label=_("Save separated files to output"), value=defaults.get("save_file", cls.__fields__["save_file"].default), - info=_("Whether to save separated audio files") ), gr.Checkbox( label=_("Offload sub model after removing background music"), value=defaults.get("enable_offload", cls.__fields__["enable_offload"].default), - info=_("Offload UVR model after transcription") ) ] From c93f2d37c63ec8238dc7cadb53b1ea3669ff780a Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 19:31:38 +0900 Subject: [PATCH 18/41] Receive device as param --- modules/whisper/data_classes.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index 52431b5a..042c9911 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -108,7 +108,8 @@ class DiarizationParams(BaseParams): @classmethod def to_gradio_inputs(cls, defaults: Optional[Dict] = None, - available_devices: Optional[List] = None) -> List[gr.components.base.FormComponent]: + available_devices: Optional[List] = None, + device: Optional[str] = None) -> List[gr.components.base.FormComponent]: return [ gr.Checkbox( label=_("Enable Diarization"), @@ -122,7 +123,7 @@ def to_gradio_inputs(cls, gr.Dropdown( label=_("Device"), choices=["cpu", "cuda"] if available_devices is None else available_devices, - value=defaults.get("device", cls.__fields__["device"].default), + value=defaults.get("device", device), ) ] @@ -153,6 +154,7 @@ class BGMSeparationParams(BaseParams): def to_gradio_input(cls, defaults: Optional[Dict] = None, available_devices: Optional[List] = None, + device: Optional[str] = None, available_models: Optional[List] = None) -> List[gr.components.base.FormComponent]: return [ gr.Checkbox( @@ -164,7 +166,7 @@ def to_gradio_input(cls, gr.Dropdown( label=_("Device"), choices=["cpu", "cuda"] if available_devices is None else available_devices, - value=defaults.get("device", cls.__fields__["device"].default), + value=defaults.get("device", device), ), gr.Dropdown( label=_("Model"), From e9e1347e5afd80518ef6f6bc61161ce254b2a571 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 19:31:45 +0900 Subject: [PATCH 19/41] Use enum --- modules/whisper/whisper_factory.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/modules/whisper/whisper_factory.py b/modules/whisper/whisper_factory.py index 6bda8c58..4abc8f99 100644 --- a/modules/whisper/whisper_factory.py +++ b/modules/whisper/whisper_factory.py @@ -7,6 +7,7 @@ from modules.whisper.whisper_Inference import WhisperInference from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference from modules.whisper.whisper_base import WhisperBase +from modules.whisper.data_classes import * class WhisperFactory: @@ -51,30 +52,21 @@ def create_whisper_inference( # Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144 os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' - whisper_type = whisper_type.lower().strip() - - faster_whisper_typos = ["faster_whisper", "faster-whisper", "fasterwhisper"] - whisper_typos = ["whisper"] - insanely_fast_whisper_typos = [ - "insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper", - "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper" - ] - - if whisper_type in faster_whisper_typos: + if whisper_type == WhisperImpl.FASTER_WHISPER: return FasterWhisperInference( model_dir=faster_whisper_model_dir, output_dir=output_dir, diarization_model_dir=diarization_model_dir, uvr_model_dir=uvr_model_dir ) - elif whisper_type in whisper_typos: + elif whisper_type in WhisperImpl.WHISPER: return WhisperInference( model_dir=whisper_model_dir, output_dir=output_dir, diarization_model_dir=diarization_model_dir, uvr_model_dir=uvr_model_dir ) - elif whisper_type in insanely_fast_whisper_typos: + elif whisper_type in WhisperImpl.INSANELY_FAST_WHISPER: return InsanelyFastWhisperInference( model_dir=insanely_fast_whisper_model_dir, output_dir=output_dir, From 0da25b667a3c20a509a523231f56d904ea3e8493 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 19:31:59 +0900 Subject: [PATCH 20/41] Pass device as param --- app.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/app.py b/app.py index ac5010c1..b1dd9d03 100644 --- a/app.py +++ b/app.py @@ -66,16 +66,24 @@ def create_pipeline_inputs(self): interactive=True) with gr.Accordion(_("Advanced Parameters"), open=False): - whisper_inputs = WhisperParams.to_gradio_inputs(defaults=whisper_params, only_advanced=True) + whisper_inputs = WhisperParams.to_gradio_inputs(defaults=whisper_params, only_advanced=True, + whisper_type=self.args.whisper_type, + available_compute_types=self.whisper_inf.available_compute_types, + compute_type=self.whisper_inf.current_compute_type) with gr.Accordion(_("Background Music Remover Filter"), open=False): - uvr_inputs = BGMSeparationParams.to_gradio_input(defaults=uvr_params) + uvr_inputs = BGMSeparationParams.to_gradio_input(defaults=uvr_params, + available_models=self.whisper_inf.music_separator.available_models, + available_devices=self.whisper_inf.music_separator.available_devices, + device=self.whisper_inf.music_separator.device) with gr.Accordion(_("Voice Detection Filter"), open=False): vad_inputs = VadParams.to_gradio_inputs(defaults=vad_params) with gr.Accordion(_("Diarization"), open=False): - diarization_inputs = DiarizationParams.to_gradio_inputs(defaults=diarization_params) + diarization_inputs = DiarizationParams.to_gradio_inputs(defaults=diarization_params, + available_devices=self.whisper_inf.diarizer.available_device, + device=self.whisper_inf.diarizer.device) dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate]) @@ -312,8 +320,8 @@ def on_change_models(model_size: str): parser = argparse.ArgumentParser() -parser.add_argument('--whisper_type', type=str, default="faster-whisper", - choices=["whisper", "faster-whisper", "insanely-fast-whisper"], +parser.add_argument('--whisper_type', type=str, default=WhisperImpl.FASTER_WHISPER, + choices=[item.value for item in WhisperImpl], help='A type of the whisper implementation (Github repo name)') parser.add_argument('--share', type=str2bool, default=False, nargs='?', const=True, help='Gradio share value') parser.add_argument('--server_name', type=str, default=None, help='Gradio server host') From 21bbf6db0309489d381b4f6de73473bb7fa9f250 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 21:24:37 +0900 Subject: [PATCH 21/41] Update visibility by whisper implementation --- modules/whisper/data_classes.py | 212 +++++++++++++++++--------------- 1 file changed, 111 insertions(+), 101 deletions(-) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index 042c9911..7748bffb 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -1,7 +1,7 @@ import gradio as gr import torch from typing import Optional, Dict, List -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, ConfigDict from gradio_i18n import Translate, gettext as _ from enum import Enum from copy import deepcopy @@ -17,6 +17,8 @@ class WhisperImpl(Enum): class BaseParams(BaseModel): + model_config = ConfigDict(protected_namespaces=()) + def to_dict(self) -> Dict: return self.model_dump() @@ -231,7 +233,6 @@ class WhisperParams(BaseParams): gt=0, description="Threshold for gzip compression ratio" ) - batch_size: int = Field(default=24, gt=0, description="Batch size for processing") length_penalty: float = Field(default=1.0, gt=0, description="Exponential length penalty") repetition_penalty: float = Field(default=1.0, gt=0, description="Penalty for repeated tokens") no_repeat_ngram_size: int = Field(default=0, ge=0, description="Size of n-grams to prevent repetition") @@ -271,6 +272,7 @@ class WhisperParams(BaseParams): gt=0, description="Number of segments for language detection" ) + batch_size: int = Field(default=24, gt=0, description="Batch size for processing") @field_validator('lang') def validate_lang(cls, v): @@ -375,108 +377,116 @@ def to_gradio_inputs(cls, info="Threshold for gzip compression ratio" ) ] + + faster_whisper_inputs = [ + gr.Number( + label="Length Penalty", + value=defaults.get("length_penalty", cls.__fields__["length_penalty"].default), + info="Exponential length penalty", + ), + gr.Number( + label="Repetition Penalty", + value=defaults.get("repetition_penalty", cls.__fields__["repetition_penalty"].default), + info="Penalty for repeated tokens" + ), + gr.Number( + label="No Repeat N-gram Size", + value=defaults.get("no_repeat_ngram_size", cls.__fields__["no_repeat_ngram_size"].default), + precision=0, + info="Size of n-grams to prevent repetition" + ), + gr.Textbox( + label="Prefix", + value=defaults.get("prefix", cls.__fields__["prefix"].default), + info="Prefix text for first window" + ), + gr.Checkbox( + label="Suppress Blank", + value=defaults.get("suppress_blank", cls.__fields__["suppress_blank"].default), + info="Suppress blank outputs at start of sampling" + ), + gr.Textbox( + label="Suppress Tokens", + value=defaults.get("suppress_tokens", cls.__fields__["suppress_tokens"].default), + info="Token IDs to suppress" + ), + gr.Number( + label="Max Initial Timestamp", + value=defaults.get("max_initial_timestamp", cls.__fields__["max_initial_timestamp"].default), + info="Maximum initial timestamp" + ), + gr.Checkbox( + label="Word Timestamps", + value=defaults.get("word_timestamps", cls.__fields__["word_timestamps"].default), + info="Extract word-level timestamps" + ), + gr.Textbox( + label="Prepend Punctuations", + value=defaults.get("prepend_punctuations", cls.__fields__["prepend_punctuations"].default), + info="Punctuations to merge with next word" + ), + gr.Textbox( + label="Append Punctuations", + value=defaults.get("append_punctuations", cls.__fields__["append_punctuations"].default), + info="Punctuations to merge with previous word" + ), + gr.Number( + label="Max New Tokens", + value=defaults.get("max_new_tokens", cls.__fields__["max_new_tokens"].default), + precision=0, + info="Maximum number of new tokens per chunk" + ), + gr.Number( + label="Chunk Length (s)", + value=defaults.get("chunk_length", cls.__fields__["chunk_length"].default), + precision=0, + info="Length of audio segments in seconds" + ), + gr.Number( + label="Hallucination Silence Threshold (sec)", + value=defaults.get("hallucination_silence_threshold", + cls.__fields__["hallucination_silence_threshold"].default), + info="Threshold for skipping silent periods in hallucination detection" + ), + gr.Textbox( + label="Hotwords", + value=defaults.get("hotwords", cls.__fields__["hotwords"].default), + info="Hotwords/hint phrases for the model" + ), + gr.Number( + label="Language Detection Threshold", + value=defaults.get("language_detection_threshold", + cls.__fields__["language_detection_threshold"].default), + info="Threshold for language detection probability" + ), + gr.Number( + label="Language Detection Segments", + value=defaults.get("language_detection_segments", + cls.__fields__["language_detection_segments"].default), + precision=0, + info="Number of segments for language detection" + ) + ] + + insanely_fast_whisper_inputs = [ + gr.Number( + label="Batch Size", + value=defaults.get("batch_size", cls.__fields__["batch_size"].default), + precision=0, + info="Batch size for processing" + ) + ] + if whisper_type == WhisperImpl.FASTER_WHISPER: - inputs += [ - gr.Number( - label="Length Penalty", - value=defaults.get("length_penalty", cls.__fields__["length_penalty"].default), - info="Exponential length penalty", - visible=whisper_type == "faster_whisper" - ), - gr.Number( - label="Repetition Penalty", - value=defaults.get("repetition_penalty", cls.__fields__["repetition_penalty"].default), - info="Penalty for repeated tokens" - ), - gr.Number( - label="No Repeat N-gram Size", - value=defaults.get("no_repeat_ngram_size", cls.__fields__["no_repeat_ngram_size"].default), - precision=0, - info="Size of n-grams to prevent repetition" - ), - gr.Textbox( - label="Prefix", - value=defaults.get("prefix", cls.__fields__["prefix"].default), - info="Prefix text for first window" - ), - gr.Checkbox( - label="Suppress Blank", - value=defaults.get("suppress_blank", cls.__fields__["suppress_blank"].default), - info="Suppress blank outputs at start of sampling" - ), - gr.Textbox( - label="Suppress Tokens", - value=defaults.get("suppress_tokens", cls.__fields__["suppress_tokens"].default), - info="Token IDs to suppress" - ), - gr.Number( - label="Max Initial Timestamp", - value=defaults.get("max_initial_timestamp", cls.__fields__["max_initial_timestamp"].default), - info="Maximum initial timestamp" - ), - gr.Checkbox( - label="Word Timestamps", - value=defaults.get("word_timestamps", cls.__fields__["word_timestamps"].default), - info="Extract word-level timestamps" - ), - gr.Textbox( - label="Prepend Punctuations", - value=defaults.get("prepend_punctuations", cls.__fields__["prepend_punctuations"].default), - info="Punctuations to merge with next word" - ), - gr.Textbox( - label="Append Punctuations", - value=defaults.get("append_punctuations", cls.__fields__["append_punctuations"].default), - info="Punctuations to merge with previous word" - ), - gr.Number( - label="Max New Tokens", - value=defaults.get("max_new_tokens", cls.__fields__["max_new_tokens"].default), - precision=0, - info="Maximum number of new tokens per chunk" - ), - gr.Number( - label="Chunk Length (s)", - value=defaults.get("chunk_length", cls.__fields__["chunk_length"].default), - precision=0, - info="Length of audio segments in seconds" - ), - gr.Number( - label="Hallucination Silence Threshold (sec)", - value=defaults.get("hallucination_silence_threshold", - cls.__fields__["hallucination_silence_threshold"].default), - info="Threshold for skipping silent periods in hallucination detection" - ), - gr.Textbox( - label="Hotwords", - value=defaults.get("hotwords", cls.__fields__["hotwords"].default), - info="Hotwords/hint phrases for the model" - ), - gr.Number( - label="Language Detection Threshold", - value=defaults.get("language_detection_threshold", - cls.__fields__["language_detection_threshold"].default), - info="Threshold for language detection probability" - ), - gr.Number( - label="Language Detection Segments", - value=defaults.get("language_detection_segments", - cls.__fields__["language_detection_segments"].default), - precision=0, - info="Number of segments for language detection" - ) - ] + for input_component in faster_whisper_inputs: + input_component.visible = True if whisper_type == WhisperImpl.INSANELY_FAST_WHISPER: - inputs += [ - gr.Number( - label="Batch Size", - value=defaults.get("batch_size", cls.__fields__["batch_size"].default), - precision=0, - info="Batch size for processing", - visible=whisper_type == "insanely_fast_whisper" - ) - ] + for input_component in insanely_fast_whisper_inputs: + input_component.visible = True + + inputs += faster_whisper_inputs + insanely_fast_whisper_inputs + return inputs From 6b0fe26b64ef7fff484428d3b88769e8d8bab997 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 21:25:10 +0900 Subject: [PATCH 22/41] Rename variable --- app.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app.py b/app.py index b1dd9d03..ab692590 100644 --- a/app.py +++ b/app.py @@ -87,10 +87,10 @@ def create_pipeline_inputs(self): dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate]) - inputs = [dd_model, dd_lang, cb_translate] + whisper_inputs + vad_inputs + diarization_inputs + uvr_inputs + pipeline_inputs = [dd_model, dd_lang, cb_translate] + whisper_inputs + vad_inputs + diarization_inputs + uvr_inputs return ( - inputs, + pipeline_inputs, dd_file_format, cb_timestamp ) From e730b1b97ded7b670ac714d70d49f165d35b25fc Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 22:03:51 +0900 Subject: [PATCH 23/41] Fix order --- modules/whisper/data_classes.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index 7748bffb..36c69ee7 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -117,16 +117,16 @@ def to_gradio_inputs(cls, label=_("Enable Diarization"), value=defaults.get("is_diarize", cls.__fields__["is_diarize"].default), ), + gr.Dropdown( + label=_("Device"), + choices=["cpu", "cuda"] if available_devices is None else available_devices, + value=defaults.get("device", device), + ), gr.Textbox( label=_("HuggingFace Token"), value=defaults.get("hf_token", cls.__fields__["hf_token"].default), info=_("This is only needed the first time you download the model") ), - gr.Dropdown( - label=_("Device"), - choices=["cpu", "cuda"] if available_devices is None else available_devices, - value=defaults.get("device", device), - ) ] @@ -165,17 +165,17 @@ def to_gradio_input(cls, interactive=True, info=_("Enabling this will remove background music") ), - gr.Dropdown( - label=_("Device"), - choices=["cpu", "cuda"] if available_devices is None else available_devices, - value=defaults.get("device", device), - ), gr.Dropdown( label=_("Model"), choices=["UVR-MDX-NET-Inst_HQ_4", "UVR-MDX-NET-Inst_3"] if available_models is None else available_models, value=defaults.get("model_size", cls.__fields__["model_size"].default), ), + gr.Dropdown( + label=_("Device"), + choices=["cpu", "cuda"] if available_devices is None else available_devices, + value=defaults.get("device", device), + ), gr.Number( label="Segment Size", value=defaults.get("segment_size", cls.__fields__["segment_size"].default), From c7ebe8c056e25a4e5b22577075e1d766e5360813 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 22:03:58 +0900 Subject: [PATCH 24/41] Post cache --- modules/whisper/whisper_base.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/modules/whisper/whisper_base.py b/modules/whisper/whisper_base.py index 466d327d..0cf3c9dc 100644 --- a/modules/whisper/whisper_base.py +++ b/modules/whisper/whisper_base.py @@ -103,11 +103,6 @@ def run(self, params = TranscriptionPipelineParams.from_list(list(pipeline_params)) bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization - self.cache_parameters( - params=params, - add_timestamp=add_timestamp - ) - if whisper_params.lang is None: pass elif whisper_params.lang == AUTOMATIC_DETECTION: @@ -176,6 +171,11 @@ def run(self, device=diarization_params.device ) elapsed_time += elapsed_time_diarization + + self.cache_parameters( + params=params, + add_timestamp=add_timestamp + ) return result, elapsed_time def transcribe_file(self, @@ -521,12 +521,18 @@ def cache_parameters( add_timestamp: bool ): """cache parameters to the yaml file""" + cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) param_to_cache = params.to_dict() + print(param_to_cache) + cached_yaml = {**cached_params, **param_to_cache} cached_yaml["whisper"]["add_timestamp"] = add_timestamp + if cached_yaml["whisper"].get("lang", None) is None: + cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION + save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH) @staticmethod From 250b9b4b2d16076102bf7fc57f6a85db6b1efd40 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 22:09:53 +0900 Subject: [PATCH 25/41] Fix component type --- modules/whisper/data_classes.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index 36c69ee7..b7db887f 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -284,6 +284,8 @@ def to_gradio_inputs(cls, defaults: Optional[Dict] = None, only_advanced: Optional[bool] = True, whisper_type: Optional[WhisperImpl] = None, + available_models: Optional[List] = None, + available_langs: Optional[List] = None, available_compute_types: Optional[List] = None, compute_type: Optional[str] = None): whisper_type = WhisperImpl.FASTER_WHISPER if whisper_type is None else whisper_type @@ -292,20 +294,18 @@ def to_gradio_inputs(cls, if not only_advanced: inputs += [ gr.Dropdown( - label="Model Size", - choices=["small", "medium", "large-v2"], + label=_("Model"), + choices=available_models, value=defaults.get("model_size", cls.__fields__["model_size"].default), - info="Whisper model size" ), - gr.Textbox( - label="Language", + gr.Dropdown( + label=_("Language"), + choices=available_langs, value=defaults.get("lang", cls.__fields__["lang"].default), - info="Source language of the file to transcribe" ), gr.Checkbox( - label="Translate to English", + label=_("Translate to English?"), value=defaults.get("is_translate", cls.__fields__["is_translate"].default), - info="Translate speech to English end-to-end" ), ] From cee12df07ba4dc3b7f8f0fc296907fab587ec5b1 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 23:09:54 +0900 Subject: [PATCH 26/41] Handle gradio None values --- modules/whisper/data_classes.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index b7db887f..23d697ce 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -1,6 +1,6 @@ import gradio as gr import torch -from typing import Optional, Dict, List +from typing import Optional, Dict, List, Union from pydantic import BaseModel, Field, field_validator, ConfigDict from gradio_i18n import Translate, gettext as _ from enum import Enum @@ -241,7 +241,7 @@ class WhisperParams(BaseParams): default=True, description="Suppress blank outputs at start of sampling" ) - suppress_tokens: Optional[str] = Field(default="[-1]", description="Token IDs to suppress") + suppress_tokens: Optional[Union[List, str]] = Field(default=[-1], description="Token IDs to suppress") max_initial_timestamp: float = Field( default=0.0, ge=0.0, @@ -279,6 +279,20 @@ def validate_lang(cls, v): from modules.utils.constants import AUTOMATIC_DETECTION return None if v == AUTOMATIC_DETECTION.unwrap() else v + @field_validator('suppress_tokens') + def validate_supress_tokens(cls, v): + import ast + try: + if isinstance(v, str): + suppress_tokens = ast.literal_eval(v) + if not isinstance(suppress_tokens, list): + raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]") + return suppress_tokens + if isinstance(v, list): + return v + except Exception as e: + raise ValueError(f"Invalid Suppress Tokens. The value must be type of List[int]: {e}") + @classmethod def to_gradio_inputs(cls, defaults: Optional[Dict] = None, @@ -301,7 +315,7 @@ def to_gradio_inputs(cls, gr.Dropdown( label=_("Language"), choices=available_langs, - value=defaults.get("lang", cls.__fields__["lang"].default), + value=defaults.get("lang", AUTOMATIC_DETECTION), ), gr.Checkbox( label=_("Translate to English?"), @@ -407,7 +421,7 @@ def to_gradio_inputs(cls, ), gr.Textbox( label="Suppress Tokens", - value=defaults.get("suppress_tokens", cls.__fields__["suppress_tokens"].default), + value=defaults.get("suppress_tokens", "[-1]"), info="Token IDs to suppress" ), gr.Number( From e862b084328e18a1856e9dd27c453d2aa233b1ea Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 23:25:39 +0900 Subject: [PATCH 27/41] Handle gradio none values --- modules/whisper/faster_whisper_inference.py | 10 ---- modules/whisper/whisper_base.py | 56 +++++++++++++++------ 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/modules/whisper/faster_whisper_inference.py b/modules/whisper/faster_whisper_inference.py index 524cd247..43d23ffe 100644 --- a/modules/whisper/faster_whisper_inference.py +++ b/modules/whisper/faster_whisper_inference.py @@ -67,16 +67,6 @@ def transcribe(self, if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) - # None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723 - if not params.initial_prompt: - params.initial_prompt = None - if not params.prefix: - params.prefix = None - if not params.hotwords: - params.hotwords = None - - params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens) - segments, info = self.model.transcribe( audio=audio, language=params.lang, diff --git a/modules/whisper/whisper_base.py b/modules/whisper/whisper_base.py index 0cf3c9dc..eb86e845 100644 --- a/modules/whisper/whisper_base.py +++ b/modules/whisper/whisper_base.py @@ -1,5 +1,6 @@ import os import torch +import ast import whisper import ctranslate2 import gradio as gr @@ -14,7 +15,7 @@ from modules.uvr.music_separator import MusicSeparator from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH, UVR_MODELS_DIR) -from modules.utils.constants import AUTOMATIC_DETECTION +from modules.utils.constants import AUTOMATIC_DETECTION, GRADIO_NONE_VALUES from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename from modules.utils.youtube_manager import get_ytdata, get_ytaudio from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml @@ -101,16 +102,9 @@ def run(self, elapsed time for running """ params = TranscriptionPipelineParams.from_list(list(pipeline_params)) + params = self.handle_gradio_values(params) bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization - if whisper_params.lang is None: - pass - elif whisper_params.lang == AUTOMATIC_DETECTION: - whisper_params.lang = None - else: - language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()} - whisper_params.lang = language_code_dict[params.lang] - if bgm_params.is_separate_bgm: music, audio, _ = self.music_separator.separate( audio=audio, @@ -515,25 +509,57 @@ def remove_input_files(file_paths: List[str]): if file_path and os.path.exists(file_path): os.remove(file_path) + @staticmethod + def handle_gradio_values(params: TranscriptionPipelineParams): + """ + Handle gradio specific values that can't be displayed as None in the UI. + Related issue : https://github.com/gradio-app/gradio/issues/8723 + """ + if params.whisper.lang is None: + pass + elif params.whisper.lang == AUTOMATIC_DETECTION: + params.whisper.lang = None + else: + language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()} + params.whisper.lang = language_code_dict[params.lang] + + if not params.whisper.initial_prompt: + params.whisper.initial_prompt = None + if not params.whisper.prefix: + params.whisper.prefix = None + if not params.whisper.hotwords: + params.whisper.hotwords = None + if params.whisper.max_new_tokens == 0: + params.whisper.max_new_tokens = None + if params.whisper.hallucination_silence_threshold == 0: + params.whisper.hallucination_silence_threshold = None + if params.whisper.language_detection_threshold == 0: + params.whisper.language_detection_threshold = None + if params.whisper.max_speech_duration_s >= 9999: + params.whisper.max_speech_duration_s = float('inf') + return params + @staticmethod def cache_parameters( params: TranscriptionPipelineParams, add_timestamp: bool ): - """cache parameters to the yaml file""" - + """Cache parameters to the yaml file""" cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) param_to_cache = params.to_dict() - print(param_to_cache) - cached_yaml = {**cached_params, **param_to_cache} cached_yaml["whisper"]["add_timestamp"] = add_timestamp + supress_token = cached_yaml["whisper"].get("suppress_tokens", None) + if supress_token and isinstance(supress_token, list): + cached_yaml["whisper"]["suppress_tokens"] = str(supress_token) + if cached_yaml["whisper"].get("lang", None) is None: - cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION + cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap() - save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH) + if cached_yaml is not None and cached_yaml: + save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH) @staticmethod def resample_audio(audio: Union[str, np.ndarray], From 988860fd34f375e479ab5390fbd90c1ee0025037 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 23:29:12 +0900 Subject: [PATCH 28/41] Fix factory function --- modules/whisper/whisper_factory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/whisper/whisper_factory.py b/modules/whisper/whisper_factory.py index 4abc8f99..74c30f09 100644 --- a/modules/whisper/whisper_factory.py +++ b/modules/whisper/whisper_factory.py @@ -59,14 +59,14 @@ def create_whisper_inference( diarization_model_dir=diarization_model_dir, uvr_model_dir=uvr_model_dir ) - elif whisper_type in WhisperImpl.WHISPER: + elif whisper_type == WhisperImpl.WHISPER: return WhisperInference( model_dir=whisper_model_dir, output_dir=output_dir, diarization_model_dir=diarization_model_dir, uvr_model_dir=uvr_model_dir ) - elif whisper_type in WhisperImpl.INSANELY_FAST_WHISPER: + elif whisper_type == WhisperImpl.INSANELY_FAST_WHISPER: return InsanelyFastWhisperInference( model_dir=insanely_fast_whisper_model_dir, output_dir=output_dir, From 6db9d8dd548e82b65053746b967c1cae62a940d4 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 23:34:46 +0900 Subject: [PATCH 29/41] Use enum for string --- app.py | 2 +- modules/whisper/whisper_factory.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/app.py b/app.py index ab692590..0d4c56f8 100644 --- a/app.py +++ b/app.py @@ -320,7 +320,7 @@ def on_change_models(model_size: str): parser = argparse.ArgumentParser() -parser.add_argument('--whisper_type', type=str, default=WhisperImpl.FASTER_WHISPER, +parser.add_argument('--whisper_type', type=str, default=WhisperImpl.FASTER_WHISPER.value, choices=[item.value for item in WhisperImpl], help='A type of the whisper implementation (Github repo name)') parser.add_argument('--share', type=str2bool, default=False, nargs='?', const=True, help='Gradio share value') diff --git a/modules/whisper/whisper_factory.py b/modules/whisper/whisper_factory.py index 74c30f09..96f167d1 100644 --- a/modules/whisper/whisper_factory.py +++ b/modules/whisper/whisper_factory.py @@ -52,21 +52,23 @@ def create_whisper_inference( # Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144 os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' - if whisper_type == WhisperImpl.FASTER_WHISPER: + whisper_type = whisper_type.strip().lower() + + if whisper_type == WhisperImpl.FASTER_WHISPER.value: return FasterWhisperInference( model_dir=faster_whisper_model_dir, output_dir=output_dir, diarization_model_dir=diarization_model_dir, uvr_model_dir=uvr_model_dir ) - elif whisper_type == WhisperImpl.WHISPER: + elif whisper_type == WhisperImpl.WHISPER.value: return WhisperInference( model_dir=whisper_model_dir, output_dir=output_dir, diarization_model_dir=diarization_model_dir, uvr_model_dir=uvr_model_dir ) - elif whisper_type == WhisperImpl.INSANELY_FAST_WHISPER: + elif whisper_type == WhisperImpl.INSANELY_FAST_WHISPER.value: return InsanelyFastWhisperInference( model_dir=insanely_fast_whisper_model_dir, output_dir=output_dir, From aa98cf2adb5d6bd57763306459d3bcdf55563526 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 23:35:00 +0900 Subject: [PATCH 30/41] Fix param validation --- modules/whisper/whisper_base.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/modules/whisper/whisper_base.py b/modules/whisper/whisper_base.py index eb86e845..c0dd685f 100644 --- a/modules/whisper/whisper_base.py +++ b/modules/whisper/whisper_base.py @@ -15,7 +15,7 @@ from modules.uvr.music_separator import MusicSeparator from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH, UVR_MODELS_DIR) -from modules.utils.constants import AUTOMATIC_DETECTION, GRADIO_NONE_VALUES +from modules.utils.constants import AUTOMATIC_DETECTION from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename from modules.utils.youtube_manager import get_ytdata, get_ytaudio from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml @@ -127,10 +127,6 @@ def run(self, self.music_separator.offload() if vad_params.vad_filter: - # Explicit value set for float('inf') from gr.Number() - if vad_params.max_speech_duration_s is None or vad_params.max_speech_duration_s >= 9999: - vad_params.max_speech_duration_s = float('inf') - vad_options = VadOptions( threshold=vad_params.threshold, min_speech_duration_ms=vad_params.min_speech_duration_ms, @@ -535,8 +531,8 @@ def handle_gradio_values(params: TranscriptionPipelineParams): params.whisper.hallucination_silence_threshold = None if params.whisper.language_detection_threshold == 0: params.whisper.language_detection_threshold = None - if params.whisper.max_speech_duration_s >= 9999: - params.whisper.max_speech_duration_s = float('inf') + if params.vad.max_speech_duration_s >= 9999: + params.vad.max_speech_duration_s = float('inf') return params @staticmethod From a85d7d2620e0f2231deaf19edaa3cdaae69ae68f Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 23:35:39 +0900 Subject: [PATCH 31/41] Rename function --- modules/whisper/whisper_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/whisper/whisper_base.py b/modules/whisper/whisper_base.py index c0dd685f..1977678b 100644 --- a/modules/whisper/whisper_base.py +++ b/modules/whisper/whisper_base.py @@ -102,7 +102,7 @@ def run(self, elapsed time for running """ params = TranscriptionPipelineParams.from_list(list(pipeline_params)) - params = self.handle_gradio_values(params) + params = self.validate_gradio_values(params) bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization if bgm_params.is_separate_bgm: @@ -506,9 +506,9 @@ def remove_input_files(file_paths: List[str]): os.remove(file_path) @staticmethod - def handle_gradio_values(params: TranscriptionPipelineParams): + def validate_gradio_values(params: TranscriptionPipelineParams): """ - Handle gradio specific values that can't be displayed as None in the UI. + Validate gradio specific values that can't be displayed as None in the UI. Related issue : https://github.com/gradio-app/gradio/issues/8723 """ if params.whisper.lang is None: From a3de4546becf1fdfba30b91a95efc66117ad36ea Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 23:46:51 +0900 Subject: [PATCH 32/41] Rename class & file --- .../{whisper_base.py => base_transcription_pipeline.py} | 4 ++-- modules/whisper/faster_whisper_inference.py | 4 ++-- modules/whisper/insanely_fast_whisper_inference.py | 4 ++-- modules/whisper/whisper_Inference.py | 4 ++-- modules/whisper/whisper_factory.py | 6 +++--- 5 files changed, 11 insertions(+), 11 deletions(-) rename modules/whisper/{whisper_base.py => base_transcription_pipeline.py} (99%) diff --git a/modules/whisper/whisper_base.py b/modules/whisper/base_transcription_pipeline.py similarity index 99% rename from modules/whisper/whisper_base.py rename to modules/whisper/base_transcription_pipeline.py index 1977678b..689d6274 100644 --- a/modules/whisper/whisper_base.py +++ b/modules/whisper/base_transcription_pipeline.py @@ -24,7 +24,7 @@ from modules.vad.silero_vad import SileroVAD -class WhisperBase(ABC): +class BaseTranscriptionPipeline(ABC): def __init__(self, model_dir: str = WHISPER_MODELS_DIR, diarization_model_dir: str = DIARIZATION_MODELS_DIR, @@ -464,7 +464,7 @@ def get_device(): if torch.cuda.is_available(): return "cuda" elif torch.backends.mps.is_available(): - if not WhisperBase.is_sparse_api_supported(): + if not BaseTranscriptionPipeline.is_sparse_api_supported(): # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886 return "cpu" return "mps" diff --git a/modules/whisper/faster_whisper_inference.py b/modules/whisper/faster_whisper_inference.py index 43d23ffe..b4a2b4ce 100644 --- a/modules/whisper/faster_whisper_inference.py +++ b/modules/whisper/faster_whisper_inference.py @@ -13,10 +13,10 @@ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR) from modules.whisper.data_classes import * -from modules.whisper.whisper_base import WhisperBase +from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline -class FasterWhisperInference(WhisperBase): +class FasterWhisperInference(BaseTranscriptionPipeline): def __init__(self, model_dir: str = FASTER_WHISPER_MODELS_DIR, diarization_model_dir: str = DIARIZATION_MODELS_DIR, diff --git a/modules/whisper/insanely_fast_whisper_inference.py b/modules/whisper/insanely_fast_whisper_inference.py index 57483f44..bca9a628 100644 --- a/modules/whisper/insanely_fast_whisper_inference.py +++ b/modules/whisper/insanely_fast_whisper_inference.py @@ -13,10 +13,10 @@ from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR) from modules.whisper.data_classes import * -from modules.whisper.whisper_base import WhisperBase +from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline -class InsanelyFastWhisperInference(WhisperBase): +class InsanelyFastWhisperInference(BaseTranscriptionPipeline): def __init__(self, model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR, diarization_model_dir: str = DIARIZATION_MODELS_DIR, diff --git a/modules/whisper/whisper_Inference.py b/modules/whisper/whisper_Inference.py index 8f567b49..825bbe3e 100644 --- a/modules/whisper/whisper_Inference.py +++ b/modules/whisper/whisper_Inference.py @@ -8,11 +8,11 @@ from argparse import Namespace from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR) -from modules.whisper.whisper_base import WhisperBase +from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline from modules.whisper.data_classes import * -class WhisperInference(WhisperBase): +class WhisperInference(BaseTranscriptionPipeline): def __init__(self, model_dir: str = WHISPER_MODELS_DIR, diarization_model_dir: str = DIARIZATION_MODELS_DIR, diff --git a/modules/whisper/whisper_factory.py b/modules/whisper/whisper_factory.py index 96f167d1..b5ae33a7 100644 --- a/modules/whisper/whisper_factory.py +++ b/modules/whisper/whisper_factory.py @@ -6,7 +6,7 @@ from modules.whisper.faster_whisper_inference import FasterWhisperInference from modules.whisper.whisper_Inference import WhisperInference from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference -from modules.whisper.whisper_base import WhisperBase +from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline from modules.whisper.data_classes import * @@ -20,7 +20,7 @@ def create_whisper_inference( diarization_model_dir: str = DIARIZATION_MODELS_DIR, uvr_model_dir: str = UVR_MODELS_DIR, output_dir: str = OUTPUT_DIR, - ) -> "WhisperBase": + ) -> "BaseTranscriptionPipeline": """ Create a whisper inference class based on the provided whisper_type. @@ -46,7 +46,7 @@ def create_whisper_inference( Returns ------- - WhisperBase + BaseTranscriptionPipeline An instance of the appropriate whisper inference class based on the whisper_type. """ # Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144 From ce2f852032b75eb05d5c6bca87c21e2f3ba50269 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 28 Oct 2024 23:46:59 +0900 Subject: [PATCH 33/41] Clean import --- app.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/app.py b/app.py index 0d4c56f8..b18cbb25 100644 --- a/app.py +++ b/app.py @@ -7,11 +7,8 @@ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, WHISPER_MODELS_DIR, INSANELY_FAST_WHISPER_MODELS_DIR, NLLB_MODELS_DIR, DEFAULT_PARAMETERS_CONFIG_PATH, UVR_MODELS_DIR, I18N_YAML_PATH) -from modules.utils.constants import AUTOMATIC_DETECTION from modules.utils.files_manager import load_yaml from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.faster_whisper_inference import FasterWhisperInference -from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference from modules.translation.nllb_inference import NLLBInference from modules.ui.htmls import * from modules.utils.cli_manager import str2bool @@ -290,7 +287,6 @@ def launch(self): # Launch the app with optional gradio settings args = self.args - self.app.queue( api_open=args.api_open ).launch( From 19e342ad3b6a96481fa6370a477f2bdc194d39ab Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Tue, 29 Oct 2024 00:02:57 +0900 Subject: [PATCH 34/41] Add gradio value validation --- modules/whisper/base_transcription_pipeline.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/whisper/base_transcription_pipeline.py b/modules/whisper/base_transcription_pipeline.py index 689d6274..2e9d0da4 100644 --- a/modules/whisper/base_transcription_pipeline.py +++ b/modules/whisper/base_transcription_pipeline.py @@ -554,6 +554,9 @@ def cache_parameters( if cached_yaml["whisper"].get("lang", None) is None: cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap() + if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'): + cached_yaml["vad"]["max_speech_duration_s"] = 9999 + if cached_yaml is not None and cached_yaml: save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH) From 2a2f7c60fa88506f5378a3d792bfe0386e97a86f Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Tue, 29 Oct 2024 00:15:41 +0900 Subject: [PATCH 35/41] Use constant for gradio none validation values --- modules/utils/constants.py | 3 +++ modules/whisper/base_transcription_pipeline.py | 18 +++++++++--------- modules/whisper/data_classes.py | 14 +++++++------- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/modules/utils/constants.py b/modules/utils/constants.py index e9309bc3..49b45c8c 100644 --- a/modules/utils/constants.py +++ b/modules/utils/constants.py @@ -1,3 +1,6 @@ from gradio_i18n import Translate, gettext as _ AUTOMATIC_DETECTION = _("Automatic Detection") +GRADIO_NONE_STR = "" +GRADIO_NONE_NUMBER_MAX = 9999 +GRADIO_NONE_NUMBER_MIN = 0 diff --git a/modules/whisper/base_transcription_pipeline.py b/modules/whisper/base_transcription_pipeline.py index 2e9d0da4..b9a3ae05 100644 --- a/modules/whisper/base_transcription_pipeline.py +++ b/modules/whisper/base_transcription_pipeline.py @@ -15,7 +15,7 @@ from modules.uvr.music_separator import MusicSeparator from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH, UVR_MODELS_DIR) -from modules.utils.constants import AUTOMATIC_DETECTION +from modules.utils.constants import * from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename from modules.utils.youtube_manager import get_ytdata, get_ytaudio from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml @@ -519,19 +519,19 @@ def validate_gradio_values(params: TranscriptionPipelineParams): language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()} params.whisper.lang = language_code_dict[params.lang] - if not params.whisper.initial_prompt: + if params.whisper.initial_prompt == GRADIO_NONE_STR: params.whisper.initial_prompt = None - if not params.whisper.prefix: + if params.whisper.prefix == GRADIO_NONE_STR: params.whisper.prefix = None - if not params.whisper.hotwords: + if params.whisper.hotwords == GRADIO_NONE_STR: params.whisper.hotwords = None - if params.whisper.max_new_tokens == 0: + if params.whisper.max_new_tokens == GRADIO_NONE_NUMBER_MIN: params.whisper.max_new_tokens = None - if params.whisper.hallucination_silence_threshold == 0: + if params.whisper.hallucination_silence_threshold == GRADIO_NONE_NUMBER_MIN: params.whisper.hallucination_silence_threshold = None - if params.whisper.language_detection_threshold == 0: + if params.whisper.language_detection_threshold == GRADIO_NONE_NUMBER_MIN: params.whisper.language_detection_threshold = None - if params.vad.max_speech_duration_s >= 9999: + if params.vad.max_speech_duration_s == GRADIO_NONE_NUMBER_MAX: params.vad.max_speech_duration_s = float('inf') return params @@ -555,7 +555,7 @@ def cache_parameters( cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap() if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'): - cached_yaml["vad"]["max_speech_duration_s"] = 9999 + cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX if cached_yaml is not None and cached_yaml: save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index 23d697ce..f23a5793 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -7,7 +7,7 @@ from copy import deepcopy import yaml -from modules.utils.constants import AUTOMATIC_DETECTION +from modules.utils.constants import * class WhisperImpl(Enum): @@ -82,7 +82,7 @@ def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components ), gr.Number( label="Maximum Speech Duration (s)", - value=defaults.get("max_speech_duration_s", cls.__fields__["max_speech_duration_s"].default), + value=defaults.get("max_speech_duration_s", GRADIO_NONE_NUMBER_MAX), info="Maximum duration of speech chunks in \"seconds\"." ), gr.Number( @@ -373,7 +373,7 @@ def to_gradio_inputs(cls, ), gr.Textbox( label="Initial Prompt", - value=defaults.get("initial_prompt", cls.__fields__["initial_prompt"].default), + value=defaults.get("initial_prompt", GRADIO_NONE_STR), info="Initial prompt for first window" ), gr.Slider( @@ -411,7 +411,7 @@ def to_gradio_inputs(cls, ), gr.Textbox( label="Prefix", - value=defaults.get("prefix", cls.__fields__["prefix"].default), + value=defaults.get("prefix", GRADIO_NONE_STR), info="Prefix text for first window" ), gr.Checkbox( @@ -446,7 +446,7 @@ def to_gradio_inputs(cls, ), gr.Number( label="Max New Tokens", - value=defaults.get("max_new_tokens", cls.__fields__["max_new_tokens"].default), + value=defaults.get("max_new_tokens", GRADIO_NONE_NUMBER_MIN), precision=0, info="Maximum number of new tokens per chunk" ), @@ -459,7 +459,7 @@ def to_gradio_inputs(cls, gr.Number( label="Hallucination Silence Threshold (sec)", value=defaults.get("hallucination_silence_threshold", - cls.__fields__["hallucination_silence_threshold"].default), + GRADIO_NONE_NUMBER_MIN), info="Threshold for skipping silent periods in hallucination detection" ), gr.Textbox( @@ -470,7 +470,7 @@ def to_gradio_inputs(cls, gr.Number( label="Language Detection Threshold", value=defaults.get("language_detection_threshold", - cls.__fields__["language_detection_threshold"].default), + GRADIO_NONE_NUMBER_MIN), info="Threshold for language detection probability" ), gr.Number( From 4c4ecfe230f77033be1d3c0a937c0c1175c7c621 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Tue, 29 Oct 2024 00:33:07 +0900 Subject: [PATCH 36/41] Update to use enum --- tests/test_bgm_separation.py | 14 +++++++------- tests/test_diarization.py | 8 ++++---- tests/test_transcription.py | 6 +++--- tests/test_vad.py | 8 ++++---- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/test_bgm_separation.py b/tests/test_bgm_separation.py index 504be03f..95b77a0b 100644 --- a/tests/test_bgm_separation.py +++ b/tests/test_bgm_separation.py @@ -1,6 +1,6 @@ from modules.utils.paths import * from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.data_classes import TranscriptionPipelineParams +from modules.whisper.data_classes import * from test_config import * from test_transcription import download_file, test_transcribe @@ -17,9 +17,9 @@ @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", [ - ("whisper", False, True, False), - ("faster-whisper", False, True, False), - ("insanely_fast_whisper", False, True, False) + (WhisperImpl.WHISPER.value, False, True, False), + (WhisperImpl.FASTER_WHISPER.value, False, True, False), + (WhisperImpl.INSANELY_FAST_WHISPER.value, False, True, False) ] ) def test_bgm_separation_pipeline( @@ -38,9 +38,9 @@ def test_bgm_separation_pipeline( @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", [ - ("whisper", True, True, False), - ("faster-whisper", True, True, False), - ("insanely_fast_whisper", True, True, False) + (WhisperImpl.WHISPER.value, True, True, False), + (WhisperImpl.FASTER_WHISPER.value, True, True, False), + (WhisperImpl.INSANELY_FAST_WHISPER.value, True, True, False) ] ) def test_bgm_separation_with_vad_pipeline( diff --git a/tests/test_diarization.py b/tests/test_diarization.py index daf41475..f18a2633 100644 --- a/tests/test_diarization.py +++ b/tests/test_diarization.py @@ -1,6 +1,6 @@ from modules.utils.paths import * from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.data_classes import TranscriptionPipelineParams +from modules.whisper.data_classes import * from test_config import * from test_transcription import download_file, test_transcribe @@ -16,9 +16,9 @@ @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", [ - ("whisper", False, False, True), - ("faster-whisper", False, False, True), - ("insanely_fast_whisper", False, False, True) + (WhisperImpl.WHISPER.value, False, False, True), + (WhisperImpl.FASTER_WHISPER.value, False, False, True), + (WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, True) ] ) def test_diarization_pipeline( diff --git a/tests/test_transcription.py b/tests/test_transcription.py index f9dc3cc2..3353782b 100644 --- a/tests/test_transcription.py +++ b/tests/test_transcription.py @@ -12,9 +12,9 @@ @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", [ - ("whisper", False, False, False), - ("faster-whisper", False, False, False), - ("insanely_fast_whisper", False, False, False) + (WhisperImpl.WHISPER.value, False, False, False), + (WhisperImpl.FASTER_WHISPER.value, False, False, False), + (WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, False) ] ) def test_transcribe( diff --git a/tests/test_vad.py b/tests/test_vad.py index fd7fd6c9..cb3dc054 100644 --- a/tests/test_vad.py +++ b/tests/test_vad.py @@ -1,6 +1,6 @@ from modules.utils.paths import * from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.data_classes import TranscriptionPipelineParams +from modules.whisper.data_classes import * from test_config import * from test_transcription import download_file, test_transcribe @@ -12,9 +12,9 @@ @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", [ - ("whisper", True, False, False), - ("faster-whisper", True, False, False), - ("insanely_fast_whisper", True, False, False) + (WhisperImpl.WHISPER.value, True, False, False), + (WhisperImpl.FASTER_WHISPER.value, True, False, False), + (WhisperImpl.INSANELY_FAST_WHISPER.value, True, False, False) ] ) def test_vad_pipeline( From 54a7493b979ffa62a8aace1878f6d85dd1a53b8f Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Tue, 29 Oct 2024 01:05:02 +0900 Subject: [PATCH 37/41] Update model --- tests/test_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_config.py b/tests/test_config.py index 0f60aa58..0020eca1 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,7 +6,7 @@ TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav" TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav") TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer" -TEST_WHISPER_MODEL = "tiny" +TEST_WHISPER_MODEL = "tiny.en" TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4" TEST_NLLB_MODEL = "facebook/nllb-200-distilled-600M" TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt") From 0184fe2ff8e8c3b7fd651839e8492d0e520e7c74 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Tue, 29 Oct 2024 15:49:24 +0900 Subject: [PATCH 38/41] Fix gradio input visibility by implementation type --- modules/whisper/data_classes.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index f23a5793..878aaf17 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -297,12 +297,12 @@ def validate_supress_tokens(cls, v): def to_gradio_inputs(cls, defaults: Optional[Dict] = None, only_advanced: Optional[bool] = True, - whisper_type: Optional[WhisperImpl] = None, + whisper_type: Optional[str] = None, available_models: Optional[List] = None, available_langs: Optional[List] = None, available_compute_types: Optional[List] = None, compute_type: Optional[str] = None): - whisper_type = WhisperImpl.FASTER_WHISPER if whisper_type is None else whisper_type + whisper_type = WhisperImpl.FASTER_WHISPER.value if whisper_type is None else whisper_type.strip().lower() inputs = [] if not only_advanced: @@ -491,13 +491,13 @@ def to_gradio_inputs(cls, ) ] - if whisper_type == WhisperImpl.FASTER_WHISPER: + if whisper_type != WhisperImpl.FASTER_WHISPER.value: for input_component in faster_whisper_inputs: - input_component.visible = True + input_component.visible = False - if whisper_type == WhisperImpl.INSANELY_FAST_WHISPER: + if whisper_type != WhisperImpl.INSANELY_FAST_WHISPER.value: for input_component in insanely_fast_whisper_inputs: - input_component.visible = True + input_component.visible = False inputs += faster_whisper_inputs + insanely_fast_whisper_inputs From 78d8e1888268ea53ed541658571965ea85993dcb Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Tue, 29 Oct 2024 19:57:21 +0900 Subject: [PATCH 39/41] Add Segment model --- modules/whisper/data_classes.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index 878aaf17..247a62e3 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -16,6 +16,15 @@ class WhisperImpl(Enum): INSANELY_FAST_WHISPER = "insanely_fast_whisper" +class Segment(BaseModel): + text: Optional[str] = Field(default=None, + description="Transcription text of the segment") + start: Optional[float] = Field(default=None, + description="Start time of the segment") + end: Optional[float] = Field(default=None, + description="End time of the segment") + + class BaseParams(BaseModel): model_config = ConfigDict(protected_namespaces=()) From ddbe0b6bab42b6e2cb9bfc9bde27582d2e08a518 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Tue, 29 Oct 2024 19:58:00 +0900 Subject: [PATCH 40/41] Apply Segment model to the pipeline --- modules/diarize/diarize_pipeline.py | 1 + modules/diarize/diarizer.py | 24 +++++++---- modules/utils/subtitle_manager.py | 11 +++++ modules/vad/silero_vad.py | 11 ++--- modules/whisper/faster_whisper_inference.py | 16 ++++---- .../insanely_fast_whisper_inference.py | 20 ++++++--- modules/whisper/whisper_Inference.py | 41 +++++++++++-------- 7 files changed, 80 insertions(+), 44 deletions(-) diff --git a/modules/diarize/diarize_pipeline.py b/modules/diarize/diarize_pipeline.py index b4109e84..360f88b7 100644 --- a/modules/diarize/diarize_pipeline.py +++ b/modules/diarize/diarize_pipeline.py @@ -44,6 +44,7 @@ def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speaker def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False): transcript_segments = transcript_result["segments"] for seg in transcript_segments: + seg = seg.dict() # assign speaker to segment (if any) diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start']) diff --git a/modules/diarize/diarizer.py b/modules/diarize/diarizer.py index 2dd3f94a..4c0727ab 100644 --- a/modules/diarize/diarizer.py +++ b/modules/diarize/diarizer.py @@ -1,6 +1,6 @@ import os import torch -from typing import List, Union, BinaryIO, Optional +from typing import List, Union, BinaryIO, Optional, Tuple import numpy as np import time import logging @@ -8,6 +8,7 @@ from modules.utils.paths import DIARIZATION_MODELS_DIR from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers from modules.diarize.audio_loader import load_audio +from modules.whisper.data_classes import * class Diarizer: @@ -23,10 +24,10 @@ def __init__(self, def run(self, audio: Union[str, BinaryIO, np.ndarray], - transcribed_result: List[dict], + transcribed_result: List[Segment], use_auth_token: str, device: Optional[str] = None - ): + ) -> Tuple[List[Segment], float]: """ Diarize transcribed result as a post-processing @@ -34,7 +35,7 @@ def run(self, ---------- audio: Union[str, BinaryIO, np.ndarray] Audio input. This can be file path or binary type. - transcribed_result: List[dict] + transcribed_result: List[Segment] transcribed result through whisper. use_auth_token: str Huggingface token with READ permission. This is only needed the first time you download the model. @@ -44,8 +45,8 @@ def run(self, Returns ---------- - segments_result: List[dict] - list of dicts that includes start, end timestamps and transcribed text + segments_result: List[Segment] + list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for running """ @@ -68,14 +69,21 @@ def run(self, {"segments": transcribed_result} ) + segments_result = [] for segment in diarized_result["segments"]: + segment = segment.dict() speaker = "None" if "speaker" in segment: speaker = segment["speaker"] - segment["text"] = speaker + "|" + segment["text"].strip() + diarized_text = speaker + "|" + segment["text"].strip() + segments_result.append(Segment( + start=segment["start"], + end=segment["end"], + text=diarized_text + )) elapsed_time = time.time() - start_time - return diarized_result["segments"], elapsed_time + return segments_result, elapsed_time def update_pipe(self, use_auth_token: str, diff --git a/modules/utils/subtitle_manager.py b/modules/utils/subtitle_manager.py index b260cb4a..87568a39 100644 --- a/modules/utils/subtitle_manager.py +++ b/modules/utils/subtitle_manager.py @@ -1,5 +1,7 @@ import re +from modules.whisper.data_classes import Segment + def timeformat_srt(time): hours = time // 3600 @@ -23,6 +25,9 @@ def write_file(subtitle, output_file): def get_srt(segments): + if segments and isinstance(segments[0], Segment): + segments = [seg.dict() for seg in segments] + output = "" for i, segment in enumerate(segments): output += f"{i + 1}\n" @@ -34,6 +39,9 @@ def get_srt(segments): def get_vtt(segments): + if segments and isinstance(segments[0], Segment): + segments = [seg.dict() for seg in segments] + output = "WEBVTT\n\n" for i, segment in enumerate(segments): output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n" @@ -44,6 +52,9 @@ def get_vtt(segments): def get_txt(segments): + if segments and isinstance(segments[0], Segment): + segments = [seg.dict() for seg in segments] + output = "" for i, segment in enumerate(segments): if segment['text'].startswith(' '): diff --git a/modules/vad/silero_vad.py b/modules/vad/silero_vad.py index bb5c9192..d44c26fc 100644 --- a/modules/vad/silero_vad.py +++ b/modules/vad/silero_vad.py @@ -5,7 +5,8 @@ from typing import BinaryIO, Union, List, Optional, Tuple import warnings import faster_whisper -from faster_whisper.transcribe import SpeechTimestampsMap, Segment +from modules.whisper.data_classes import * +from faster_whisper.transcribe import SpeechTimestampsMap import gradio as gr @@ -247,18 +248,18 @@ def format_timestamp( def restore_speech_timestamps( self, - segments: List[dict], + segments: List[Segment], speech_chunks: List[dict], sampling_rate: Optional[int] = None, - ) -> List[dict]: + ) -> List[Segment]: if sampling_rate is None: sampling_rate = self.sampling_rate ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate) for segment in segments: - segment["start"] = ts_map.get_original_time(segment["start"]) - segment["end"] = ts_map.get_original_time(segment["end"]) + segment.start = ts_map.get_original_time(segment.start) + segment.start = ts_map.get_original_time(segment.start) return segments diff --git a/modules/whisper/faster_whisper_inference.py b/modules/whisper/faster_whisper_inference.py index b4a2b4ce..5dad6edb 100644 --- a/modules/whisper/faster_whisper_inference.py +++ b/modules/whisper/faster_whisper_inference.py @@ -40,7 +40,7 @@ def transcribe(self, audio: Union[str, BinaryIO, np.ndarray], progress: gr.Progress = gr.Progress(), *whisper_params, - ) -> Tuple[List[dict], float]: + ) -> Tuple[List[Segment], float]: """ transcribe method for faster-whisper. @@ -55,8 +55,8 @@ def transcribe(self, Returns ---------- - segments_result: List[dict] - list of dicts that includes start, end timestamps and transcribed text + segments_result: List[Segment] + list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for transcription """ @@ -102,11 +102,11 @@ def transcribe(self, segments_result = [] for segment in segments: progress(segment.start / info.duration, desc="Transcribing..") - segments_result.append({ - "start": segment.start, - "end": segment.end, - "text": segment.text - }) + segments_result.append(Segment( + start=segment.start, + end=segment.end, + text=segment.text + )) elapsed_time = time.time() - start_time return segments_result, elapsed_time diff --git a/modules/whisper/insanely_fast_whisper_inference.py b/modules/whisper/insanely_fast_whisper_inference.py index bca9a628..c1f7cb2c 100644 --- a/modules/whisper/insanely_fast_whisper_inference.py +++ b/modules/whisper/insanely_fast_whisper_inference.py @@ -40,7 +40,7 @@ def transcribe(self, audio: Union[str, np.ndarray, torch.Tensor], progress: gr.Progress = gr.Progress(), *whisper_params, - ) -> Tuple[List[dict], float]: + ) -> Tuple[List[Segment], float]: """ transcribe method for faster-whisper. @@ -55,8 +55,8 @@ def transcribe(self, Returns ---------- - segments_result: List[dict] - list of dicts that includes start, end timestamps and transcribed text + segments_result: List[Segment] + list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for transcription """ @@ -95,9 +95,17 @@ def transcribe(self, generate_kwargs=kwargs ) - segments_result = self.format_result( - transcribed_result=segments, - ) + segments_result = [] + for item in segments["chunks"]: + start, end = item["timestamp"][0], item["timestamp"][1] + if end is None: + end = start + segments_result.append(Segment( + text=item["text"], + start=start, + end=end + )) + elapsed_time = time.time() - start_time return segments_result, elapsed_time diff --git a/modules/whisper/whisper_Inference.py b/modules/whisper/whisper_Inference.py index 825bbe3e..ccd4bbb4 100644 --- a/modules/whisper/whisper_Inference.py +++ b/modules/whisper/whisper_Inference.py @@ -30,7 +30,7 @@ def transcribe(self, audio: Union[str, np.ndarray, torch.Tensor], progress: gr.Progress = gr.Progress(), *whisper_params, - ) -> Tuple[List[dict], float]: + ) -> Tuple[List[Segment], float]: """ transcribe method for faster-whisper. @@ -45,8 +45,8 @@ def transcribe(self, Returns ---------- - segments_result: List[dict] - list of dicts that includes start, end timestamps and transcribed text + segments_result: List[Segment] + list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for transcription """ @@ -59,21 +59,28 @@ def transcribe(self, def progress_callback(progress_value): progress(progress_value, desc="Transcribing..") - segments_result = self.model.transcribe(audio=audio, - language=params.lang, - verbose=False, - beam_size=params.beam_size, - logprob_threshold=params.log_prob_threshold, - no_speech_threshold=params.no_speech_threshold, - task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe", - fp16=True if params.compute_type == "float16" else False, - best_of=params.best_of, - patience=params.patience, - temperature=params.temperature, - compression_ratio_threshold=params.compression_ratio_threshold, - progress_callback=progress_callback,)["segments"] - elapsed_time = time.time() - start_time + result = self.model.transcribe(audio=audio, + language=params.lang, + verbose=False, + beam_size=params.beam_size, + logprob_threshold=params.log_prob_threshold, + no_speech_threshold=params.no_speech_threshold, + task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe", + fp16=True if params.compute_type == "float16" else False, + best_of=params.best_of, + patience=params.patience, + temperature=params.temperature, + compression_ratio_threshold=params.compression_ratio_threshold, + progress_callback=progress_callback,)["segments"] + segments_result = [] + for segment in result: + segments_result.append(Segment( + start=segment["start"], + end=segment["end"], + text=segment["text"] + )) + elapsed_time = time.time() - start_time return segments_result, elapsed_time def update_model(self, From eec0c16128570239db2896a077f4fd7423907144 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Tue, 29 Oct 2024 01:21:30 +0900 Subject: [PATCH 41/41] Fix VAD syntax & add vad handling case --- modules/whisper/base_transcription_pipeline.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/modules/whisper/base_transcription_pipeline.py b/modules/whisper/base_transcription_pipeline.py index b9a3ae05..808a47b4 100644 --- a/modules/whisper/base_transcription_pipeline.py +++ b/modules/whisper/base_transcription_pipeline.py @@ -135,12 +135,17 @@ def run(self, speech_pad_ms=vad_params.speech_pad_ms ) - audio, speech_chunks = self.vad.run( + vad_processed, speech_chunks = self.vad.run( audio=audio, vad_parameters=vad_options, progress=progress ) + if vad_processed.size > 0: + audio = vad_processed + else: + vad_params.vad_filter = False + result, elapsed_time = self.transcribe( audio, progress, @@ -150,7 +155,7 @@ def run(self, if vad_params.vad_filter: result = self.vad.restore_speech_timestamps( segments=result, - speech_chunks=vad_params.speech_chunks, + speech_chunks=speech_chunks, ) if diarization_params.is_diarize: