diff --git a/app.py b/app.py index ea094c50..b18cbb25 100644 --- a/app.py +++ b/app.py @@ -7,17 +7,14 @@ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, WHISPER_MODELS_DIR, INSANELY_FAST_WHISPER_MODELS_DIR, NLLB_MODELS_DIR, DEFAULT_PARAMETERS_CONFIG_PATH, UVR_MODELS_DIR, I18N_YAML_PATH) -from modules.utils.constants import AUTOMATIC_DETECTION from modules.utils.files_manager import load_yaml from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.faster_whisper_inference import FasterWhisperInference -from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference from modules.translation.nllb_inference import NLLBInference from modules.ui.htmls import * from modules.utils.cli_manager import str2bool from modules.utils.youtube_manager import get_ytmetas from modules.translation.deepl_api import DeepLAPI -from modules.whisper.whisper_parameter import * +from modules.whisper.data_classes import * class App: @@ -44,7 +41,7 @@ def __init__(self, args): print(f"Use \"{self.args.whisper_type}\" implementation\n" f"Device \"{self.whisper_inf.device}\" is detected") - def create_whisper_parameters(self): + def create_pipeline_inputs(self): whisper_params = self.default_params["whisper"] vad_params = self.default_params["vad"] diarization_params = self.default_params["diarization"] @@ -66,158 +63,31 @@ def create_whisper_parameters(self): interactive=True) with gr.Accordion(_("Advanced Parameters"), open=False): - nb_beam_size = gr.Number(label="Beam Size", value=whisper_params["beam_size"], precision=0, - interactive=True, - info="Beam size to use for decoding.") - nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", - value=whisper_params["log_prob_threshold"], interactive=True, - info="If the average log probability over sampled tokens is below this value, treat as failed.") - nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=whisper_params["no_speech_threshold"], - interactive=True, - info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.") - dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, - value=self.whisper_inf.current_compute_type, interactive=True, - allow_custom_value=True, - info="Select the type of computation to perform.") - nb_best_of = gr.Number(label="Best Of", value=whisper_params["best_of"], interactive=True, - info="Number of candidates when sampling with non-zero temperature.") - nb_patience = gr.Number(label="Patience", value=whisper_params["patience"], interactive=True, - info="Beam search patience factor.") - cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", - value=whisper_params["condition_on_previous_text"], - interactive=True, - info="Condition on previous text during decoding.") - sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature", - value=whisper_params["prompt_reset_on_temperature"], - minimum=0, maximum=1, step=0.01, interactive=True, - info="Resets prompt if temperature is above this value." - " Arg has effect only if 'Condition On Previous Text' is True.") - tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True, - info="Initial prompt to use for decoding.") - sd_temperature = gr.Slider(label="Temperature", value=whisper_params["temperature"], minimum=0.0, - step=0.01, maximum=1.0, interactive=True, - info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.") - nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", - value=whisper_params["compression_ratio_threshold"], - interactive=True, - info="If the gzip compression ratio is above this value, treat as failed.") - nb_chunk_length = gr.Number(label="Chunk Length (s)", value=lambda: whisper_params["chunk_length"], - precision=0, - info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.") - with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)): - nb_length_penalty = gr.Number(label="Length Penalty", value=whisper_params["length_penalty"], - info="Exponential length penalty constant.") - nb_repetition_penalty = gr.Number(label="Repetition Penalty", - value=whisper_params["repetition_penalty"], - info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).") - nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size", - value=whisper_params["no_repeat_ngram_size"], - precision=0, - info="Prevent repetitions of n-grams with this size (set 0 to disable).") - tb_prefix = gr.Textbox(label="Prefix", value=lambda: whisper_params["prefix"], - info="Optional text to provide as a prefix for the first window.") - cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=whisper_params["suppress_blank"], - info="Suppress blank outputs at the beginning of the sampling.") - tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value=whisper_params["suppress_tokens"], - info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.") - nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp", - value=whisper_params["max_initial_timestamp"], - info="The initial timestamp cannot be later than this.") - cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=whisper_params["word_timestamps"], - info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.") - tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations", - value=whisper_params["prepend_punctuations"], - info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.") - tb_append_punctuations = gr.Textbox(label="Append Punctuations", - value=whisper_params["append_punctuations"], - info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.") - nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: whisper_params["max_new_tokens"], - precision=0, - info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.") - nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)", - value=lambda: whisper_params[ - "hallucination_silence_threshold"], - info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.") - tb_hotwords = gr.Textbox(label="Hotwords", value=lambda: whisper_params["hotwords"], - info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.") - nb_language_detection_threshold = gr.Number(label="Language Detection Threshold", - value=lambda: whisper_params[ - "language_detection_threshold"], - info="If the maximum probability of the language tokens is higher than this value, the language is detected.") - nb_language_detection_segments = gr.Number(label="Language Detection Segments", - value=lambda: whisper_params["language_detection_segments"], - precision=0, - info="Number of segments to consider for the language detection.") - with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)): - nb_batch_size = gr.Number(label="Batch Size", value=whisper_params["batch_size"], precision=0) + whisper_inputs = WhisperParams.to_gradio_inputs(defaults=whisper_params, only_advanced=True, + whisper_type=self.args.whisper_type, + available_compute_types=self.whisper_inf.available_compute_types, + compute_type=self.whisper_inf.current_compute_type) with gr.Accordion(_("Background Music Remover Filter"), open=False): - cb_bgm_separation = gr.Checkbox(label=_("Enable Background Music Remover Filter"), - value=uvr_params["is_separate_bgm"], - interactive=True, - info=_("Enabling this will remove background music")) - dd_uvr_device = gr.Dropdown(label=_("Device"), value=self.whisper_inf.music_separator.device, - choices=self.whisper_inf.music_separator.available_devices) - dd_uvr_model_size = gr.Dropdown(label=_("Model"), value=uvr_params["model_size"], - choices=self.whisper_inf.music_separator.available_models) - nb_uvr_segment_size = gr.Number(label="Segment Size", value=uvr_params["segment_size"], precision=0) - cb_uvr_save_file = gr.Checkbox(label=_("Save separated files to output"), value=uvr_params["save_file"]) - cb_uvr_enable_offload = gr.Checkbox(label=_("Offload sub model after removing background music"), - value=uvr_params["enable_offload"]) + uvr_inputs = BGMSeparationParams.to_gradio_input(defaults=uvr_params, + available_models=self.whisper_inf.music_separator.available_models, + available_devices=self.whisper_inf.music_separator.available_devices, + device=self.whisper_inf.music_separator.device) with gr.Accordion(_("Voice Detection Filter"), open=False): - cb_vad_filter = gr.Checkbox(label=_("Enable Silero VAD Filter"), value=vad_params["vad_filter"], - interactive=True, - info=_("Enable this to transcribe only detected voice")) - sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", - value=vad_params["threshold"], - info="Lower it to be more sensitive to small sounds.") - nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, - value=vad_params["min_speech_duration_ms"], - info="Final speech chunks shorter than this time are thrown out") - nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", - value=vad_params["max_speech_duration_s"], - info="Maximum duration of speech chunks in \"seconds\".") - nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, - value=vad_params["min_silence_duration_ms"], - info="In the end of each speech chunk wait for this time" - " before separating it") - nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=vad_params["speech_pad_ms"], - info="Final speech chunks are padded by this time each side") + vad_inputs = VadParams.to_gradio_inputs(defaults=vad_params) with gr.Accordion(_("Diarization"), open=False): - cb_diarize = gr.Checkbox(label=_("Enable Diarization"), value=diarization_params["is_diarize"]) - tb_hf_token = gr.Text(label=_("HuggingFace Token"), value=diarization_params["hf_token"], - info=_("This is only needed the first time you download the model")) - dd_diarization_device = gr.Dropdown(label=_("Device"), - choices=self.whisper_inf.diarizer.get_available_device(), - value=self.whisper_inf.diarizer.get_device()) + diarization_inputs = DiarizationParams.to_gradio_inputs(defaults=diarization_params, + available_devices=self.whisper_inf.diarizer.available_device, + device=self.whisper_inf.diarizer.device) dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate]) + pipeline_inputs = [dd_model, dd_lang, cb_translate] + whisper_inputs + vad_inputs + diarization_inputs + uvr_inputs + return ( - WhisperParameters( - model_size=dd_model, lang=dd_lang, is_translate=cb_translate, beam_size=nb_beam_size, - log_prob_threshold=nb_log_prob_threshold, no_speech_threshold=nb_no_speech_threshold, - compute_type=dd_compute_type, best_of=nb_best_of, patience=nb_patience, - condition_on_previous_text=cb_condition_on_previous_text, initial_prompt=tb_initial_prompt, - temperature=sd_temperature, compression_ratio_threshold=nb_compression_ratio_threshold, - vad_filter=cb_vad_filter, threshold=sd_threshold, min_speech_duration_ms=nb_min_speech_duration_ms, - max_speech_duration_s=nb_max_speech_duration_s, min_silence_duration_ms=nb_min_silence_duration_ms, - speech_pad_ms=nb_speech_pad_ms, chunk_length=nb_chunk_length, batch_size=nb_batch_size, - is_diarize=cb_diarize, hf_token=tb_hf_token, diarization_device=dd_diarization_device, - length_penalty=nb_length_penalty, repetition_penalty=nb_repetition_penalty, - no_repeat_ngram_size=nb_no_repeat_ngram_size, prefix=tb_prefix, suppress_blank=cb_suppress_blank, - suppress_tokens=tb_suppress_tokens, max_initial_timestamp=nb_max_initial_timestamp, - word_timestamps=cb_word_timestamps, prepend_punctuations=tb_prepend_punctuations, - append_punctuations=tb_append_punctuations, max_new_tokens=nb_max_new_tokens, - hallucination_silence_threshold=nb_hallucination_silence_threshold, hotwords=tb_hotwords, - language_detection_threshold=nb_language_detection_threshold, - language_detection_segments=nb_language_detection_segments, - prompt_reset_on_temperature=sld_prompt_reset_on_temperature, is_bgm_separate=cb_bgm_separation, - uvr_device=dd_uvr_device, uvr_model_size=dd_uvr_model_size, uvr_segment_size=nb_uvr_segment_size, - uvr_save_file=cb_uvr_save_file, uvr_enable_offload=cb_uvr_enable_offload - ), + pipeline_inputs, dd_file_format, cb_timestamp ) @@ -243,7 +113,7 @@ def launch(self): visible=self.args.colab, value="") - whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters() + pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs() with gr.Row(): btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary") @@ -254,7 +124,7 @@ def launch(self): params = [input_file, tb_input_folder, dd_file_format, cb_timestamp] btn_run.click(fn=self.whisper_inf.transcribe_file, - inputs=params + whisper_params.as_list(), + inputs=params + pipeline_params, outputs=[tb_indicator, files_subtitles]) btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None) @@ -268,7 +138,7 @@ def launch(self): tb_title = gr.Label(label=_("Youtube Title")) tb_description = gr.Textbox(label=_("Youtube Description"), max_lines=15) - whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters() + pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs() with gr.Row(): btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary") @@ -280,7 +150,7 @@ def launch(self): params = [tb_youtubelink, dd_file_format, cb_timestamp] btn_run.click(fn=self.whisper_inf.transcribe_youtube, - inputs=params + whisper_params.as_list(), + inputs=params + pipeline_params, outputs=[tb_indicator, files_subtitles]) tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink], outputs=[img_thumbnail, tb_title, tb_description]) @@ -290,7 +160,7 @@ def launch(self): with gr.Row(): mic_input = gr.Microphone(label=_("Record with Mic"), type="filepath", interactive=True) - whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters() + pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs() with gr.Row(): btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary") @@ -302,7 +172,7 @@ def launch(self): params = [mic_input, dd_file_format, cb_timestamp] btn_run.click(fn=self.whisper_inf.transcribe_mic, - inputs=params + whisper_params.as_list(), + inputs=params + pipeline_params, outputs=[tb_indicator, files_subtitles]) btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None) @@ -417,7 +287,6 @@ def launch(self): # Launch the app with optional gradio settings args = self.args - self.app.queue( api_open=args.api_open ).launch( @@ -447,8 +316,8 @@ def on_change_models(model_size: str): parser = argparse.ArgumentParser() -parser.add_argument('--whisper_type', type=str, default="faster-whisper", - choices=["whisper", "faster-whisper", "insanely-fast-whisper"], +parser.add_argument('--whisper_type', type=str, default=WhisperImpl.FASTER_WHISPER.value, + choices=[item.value for item in WhisperImpl], help='A type of the whisper implementation (Github repo name)') parser.add_argument('--share', type=str2bool, default=False, nargs='?', const=True, help='Gradio share value') parser.add_argument('--server_name', type=str, default=None, help='Gradio server host') diff --git a/modules/diarize/diarize_pipeline.py b/modules/diarize/diarize_pipeline.py index b4109e84..360f88b7 100644 --- a/modules/diarize/diarize_pipeline.py +++ b/modules/diarize/diarize_pipeline.py @@ -44,6 +44,7 @@ def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speaker def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False): transcript_segments = transcript_result["segments"] for seg in transcript_segments: + seg = seg.dict() # assign speaker to segment (if any) diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start']) diff --git a/modules/diarize/diarizer.py b/modules/diarize/diarizer.py index 2dd3f94a..4c0727ab 100644 --- a/modules/diarize/diarizer.py +++ b/modules/diarize/diarizer.py @@ -1,6 +1,6 @@ import os import torch -from typing import List, Union, BinaryIO, Optional +from typing import List, Union, BinaryIO, Optional, Tuple import numpy as np import time import logging @@ -8,6 +8,7 @@ from modules.utils.paths import DIARIZATION_MODELS_DIR from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers from modules.diarize.audio_loader import load_audio +from modules.whisper.data_classes import * class Diarizer: @@ -23,10 +24,10 @@ def __init__(self, def run(self, audio: Union[str, BinaryIO, np.ndarray], - transcribed_result: List[dict], + transcribed_result: List[Segment], use_auth_token: str, device: Optional[str] = None - ): + ) -> Tuple[List[Segment], float]: """ Diarize transcribed result as a post-processing @@ -34,7 +35,7 @@ def run(self, ---------- audio: Union[str, BinaryIO, np.ndarray] Audio input. This can be file path or binary type. - transcribed_result: List[dict] + transcribed_result: List[Segment] transcribed result through whisper. use_auth_token: str Huggingface token with READ permission. This is only needed the first time you download the model. @@ -44,8 +45,8 @@ def run(self, Returns ---------- - segments_result: List[dict] - list of dicts that includes start, end timestamps and transcribed text + segments_result: List[Segment] + list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for running """ @@ -68,14 +69,21 @@ def run(self, {"segments": transcribed_result} ) + segments_result = [] for segment in diarized_result["segments"]: + segment = segment.dict() speaker = "None" if "speaker" in segment: speaker = segment["speaker"] - segment["text"] = speaker + "|" + segment["text"].strip() + diarized_text = speaker + "|" + segment["text"].strip() + segments_result.append(Segment( + start=segment["start"], + end=segment["end"], + text=diarized_text + )) elapsed_time = time.time() - start_time - return diarized_result["segments"], elapsed_time + return segments_result, elapsed_time def update_pipe(self, use_auth_token: str, diff --git a/modules/translation/translation_base.py b/modules/translation/translation_base.py index abc7f44c..3dbf024e 100644 --- a/modules/translation/translation_base.py +++ b/modules/translation/translation_base.py @@ -6,7 +6,7 @@ from datetime import datetime import modules.translation.nllb_inference as nllb -from modules.whisper.whisper_parameter import * +from modules.whisper.data_classes import * from modules.utils.subtitle_manager import * from modules.utils.files_manager import load_yaml, save_yaml from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR diff --git a/modules/utils/constants.py b/modules/utils/constants.py index e9309bc3..49b45c8c 100644 --- a/modules/utils/constants.py +++ b/modules/utils/constants.py @@ -1,3 +1,6 @@ from gradio_i18n import Translate, gettext as _ AUTOMATIC_DETECTION = _("Automatic Detection") +GRADIO_NONE_STR = "" +GRADIO_NONE_NUMBER_MAX = 9999 +GRADIO_NONE_NUMBER_MIN = 0 diff --git a/modules/utils/subtitle_manager.py b/modules/utils/subtitle_manager.py index b260cb4a..87568a39 100644 --- a/modules/utils/subtitle_manager.py +++ b/modules/utils/subtitle_manager.py @@ -1,5 +1,7 @@ import re +from modules.whisper.data_classes import Segment + def timeformat_srt(time): hours = time // 3600 @@ -23,6 +25,9 @@ def write_file(subtitle, output_file): def get_srt(segments): + if segments and isinstance(segments[0], Segment): + segments = [seg.dict() for seg in segments] + output = "" for i, segment in enumerate(segments): output += f"{i + 1}\n" @@ -34,6 +39,9 @@ def get_srt(segments): def get_vtt(segments): + if segments and isinstance(segments[0], Segment): + segments = [seg.dict() for seg in segments] + output = "WEBVTT\n\n" for i, segment in enumerate(segments): output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n" @@ -44,6 +52,9 @@ def get_vtt(segments): def get_txt(segments): + if segments and isinstance(segments[0], Segment): + segments = [seg.dict() for seg in segments] + output = "" for i, segment in enumerate(segments): if segment['text'].startswith(' '): diff --git a/modules/vad/silero_vad.py b/modules/vad/silero_vad.py index bb5c9192..d44c26fc 100644 --- a/modules/vad/silero_vad.py +++ b/modules/vad/silero_vad.py @@ -5,7 +5,8 @@ from typing import BinaryIO, Union, List, Optional, Tuple import warnings import faster_whisper -from faster_whisper.transcribe import SpeechTimestampsMap, Segment +from modules.whisper.data_classes import * +from faster_whisper.transcribe import SpeechTimestampsMap import gradio as gr @@ -247,18 +248,18 @@ def format_timestamp( def restore_speech_timestamps( self, - segments: List[dict], + segments: List[Segment], speech_chunks: List[dict], sampling_rate: Optional[int] = None, - ) -> List[dict]: + ) -> List[Segment]: if sampling_rate is None: sampling_rate = self.sampling_rate ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate) for segment in segments: - segment["start"] = ts_map.get_original_time(segment["start"]) - segment["end"] = ts_map.get_original_time(segment["end"]) + segment.start = ts_map.get_original_time(segment.start) + segment.start = ts_map.get_original_time(segment.start) return segments diff --git a/modules/whisper/whisper_base.py b/modules/whisper/base_transcription_pipeline.py similarity index 80% rename from modules/whisper/whisper_base.py rename to modules/whisper/base_transcription_pipeline.py index 51c87ddf..808a47b4 100644 --- a/modules/whisper/whisper_base.py +++ b/modules/whisper/base_transcription_pipeline.py @@ -1,5 +1,6 @@ import os import torch +import ast import whisper import ctranslate2 import gradio as gr @@ -14,16 +15,16 @@ from modules.uvr.music_separator import MusicSeparator from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH, UVR_MODELS_DIR) -from modules.utils.constants import AUTOMATIC_DETECTION +from modules.utils.constants import * from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename from modules.utils.youtube_manager import get_ytdata, get_ytaudio from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml -from modules.whisper.whisper_parameter import * +from modules.whisper.data_classes import * from modules.diarize.diarizer import Diarizer from modules.vad.silero_vad import SileroVAD -class WhisperBase(ABC): +class BaseTranscriptionPipeline(ABC): def __init__(self, model_dir: str = WHISPER_MODELS_DIR, diarization_model_dir: str = DIARIZATION_MODELS_DIR, @@ -74,12 +75,13 @@ def run(self, audio: Union[str, BinaryIO, np.ndarray], progress: gr.Progress = gr.Progress(), add_timestamp: bool = True, - *whisper_params, + *pipeline_params, ) -> Tuple[List[dict], float]: """ Run transcription with conditional pre-processing and post-processing. The VAD will be performed to remove noise from the audio input in pre-processing, if enabled. The diarization will be performed in post-processing, if enabled. + Due to the integration with gradio, the parameters have to be specified with a `*` wildcard. Parameters ---------- @@ -89,8 +91,8 @@ def run(self, Indicator to show progress directly in gradio. add_timestamp: bool Whether to add a timestamp at the end of the filename. - *whisper_params: tuple - Parameters related with whisper. This will be dealt with "WhisperParameters" data class + *pipeline_params: tuple + Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class Returns ---------- @@ -99,28 +101,17 @@ def run(self, elapsed_time: float elapsed time for running """ - params = WhisperParameters.as_value(*whisper_params) - - self.cache_parameters( - whisper_params=params, - add_timestamp=add_timestamp - ) + params = TranscriptionPipelineParams.from_list(list(pipeline_params)) + params = self.validate_gradio_values(params) + bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization - if params.lang is None: - pass - elif params.lang == AUTOMATIC_DETECTION: - params.lang = None - else: - language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()} - params.lang = language_code_dict[params.lang] - - if params.is_bgm_separate: + if bgm_params.is_separate_bgm: music, audio, _ = self.music_separator.separate( audio=audio, - model_name=params.uvr_model_size, - device=params.uvr_device, - segment_size=params.uvr_segment_size, - save_file=params.uvr_save_file, + model_name=bgm_params.model_size, + device=bgm_params.device, + segment_size=bgm_params.segment_size, + save_file=bgm_params.save_file, progress=progress ) @@ -132,47 +123,54 @@ def run(self, origin_sample_rate = self.music_separator.audio_info.sample_rate audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate) - if params.uvr_enable_offload: + if bgm_params.enable_offload: self.music_separator.offload() - if params.vad_filter: - # Explicit value set for float('inf') from gr.Number() - if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999: - params.max_speech_duration_s = float('inf') - + if vad_params.vad_filter: vad_options = VadOptions( - threshold=params.threshold, - min_speech_duration_ms=params.min_speech_duration_ms, - max_speech_duration_s=params.max_speech_duration_s, - min_silence_duration_ms=params.min_silence_duration_ms, - speech_pad_ms=params.speech_pad_ms + threshold=vad_params.threshold, + min_speech_duration_ms=vad_params.min_speech_duration_ms, + max_speech_duration_s=vad_params.max_speech_duration_s, + min_silence_duration_ms=vad_params.min_silence_duration_ms, + speech_pad_ms=vad_params.speech_pad_ms ) - audio, speech_chunks = self.vad.run( + vad_processed, speech_chunks = self.vad.run( audio=audio, vad_parameters=vad_options, progress=progress ) + if vad_processed.size > 0: + audio = vad_processed + else: + vad_params.vad_filter = False + result, elapsed_time = self.transcribe( audio, progress, - *astuple(params) + *whisper_params.to_list() ) - if params.vad_filter: + if vad_params.vad_filter: result = self.vad.restore_speech_timestamps( segments=result, speech_chunks=speech_chunks, ) - if params.is_diarize: + if diarization_params.is_diarize: result, elapsed_time_diarization = self.diarizer.run( audio=audio, - use_auth_token=params.hf_token, + use_auth_token=diarization_params.hf_token, transcribed_result=result, + device=diarization_params.device ) elapsed_time += elapsed_time_diarization + + self.cache_parameters( + params=params, + add_timestamp=add_timestamp + ) return result, elapsed_time def transcribe_file(self, @@ -181,7 +179,7 @@ def transcribe_file(self, file_format: str = "SRT", add_timestamp: bool = True, progress=gr.Progress(), - *whisper_params, + *params, ) -> list: """ Write subtitle file from Files @@ -199,8 +197,8 @@ def transcribe_file(self, Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename. progress: gr.Progress Indicator to show progress directly in gradio. - *whisper_params: tuple - Parameters related with whisper. This will be dealt with "WhisperParameters" data class + *params: tuple + Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class Returns ---------- @@ -223,7 +221,7 @@ def transcribe_file(self, file, progress, add_timestamp, - *whisper_params, + *params, ) file_name, file_ext = os.path.splitext(os.path.basename(file)) @@ -471,7 +469,7 @@ def get_device(): if torch.cuda.is_available(): return "cuda" elif torch.backends.mps.is_available(): - if not WhisperBase.is_sparse_api_supported(): + if not BaseTranscriptionPipeline.is_sparse_api_supported(): # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886 return "cpu" return "mps" @@ -512,18 +510,60 @@ def remove_input_files(file_paths: List[str]): if file_path and os.path.exists(file_path): os.remove(file_path) + @staticmethod + def validate_gradio_values(params: TranscriptionPipelineParams): + """ + Validate gradio specific values that can't be displayed as None in the UI. + Related issue : https://github.com/gradio-app/gradio/issues/8723 + """ + if params.whisper.lang is None: + pass + elif params.whisper.lang == AUTOMATIC_DETECTION: + params.whisper.lang = None + else: + language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()} + params.whisper.lang = language_code_dict[params.lang] + + if params.whisper.initial_prompt == GRADIO_NONE_STR: + params.whisper.initial_prompt = None + if params.whisper.prefix == GRADIO_NONE_STR: + params.whisper.prefix = None + if params.whisper.hotwords == GRADIO_NONE_STR: + params.whisper.hotwords = None + if params.whisper.max_new_tokens == GRADIO_NONE_NUMBER_MIN: + params.whisper.max_new_tokens = None + if params.whisper.hallucination_silence_threshold == GRADIO_NONE_NUMBER_MIN: + params.whisper.hallucination_silence_threshold = None + if params.whisper.language_detection_threshold == GRADIO_NONE_NUMBER_MIN: + params.whisper.language_detection_threshold = None + if params.vad.max_speech_duration_s == GRADIO_NONE_NUMBER_MAX: + params.vad.max_speech_duration_s = float('inf') + return params + @staticmethod def cache_parameters( - whisper_params: WhisperValues, + params: TranscriptionPipelineParams, add_timestamp: bool ): - """cache parameters to the yaml file""" + """Cache parameters to the yaml file""" cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) - cached_whisper_param = whisper_params.to_yaml() - cached_yaml = {**cached_params, **cached_whisper_param} + param_to_cache = params.to_dict() + + cached_yaml = {**cached_params, **param_to_cache} cached_yaml["whisper"]["add_timestamp"] = add_timestamp - save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH) + supress_token = cached_yaml["whisper"].get("suppress_tokens", None) + if supress_token and isinstance(supress_token, list): + cached_yaml["whisper"]["suppress_tokens"] = str(supress_token) + + if cached_yaml["whisper"].get("lang", None) is None: + cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap() + + if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'): + cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX + + if cached_yaml is not None and cached_yaml: + save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH) @staticmethod def resample_audio(audio: Union[str, np.ndarray], diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py new file mode 100644 index 00000000..247a62e3 --- /dev/null +++ b/modules/whisper/data_classes.py @@ -0,0 +1,565 @@ +import gradio as gr +import torch +from typing import Optional, Dict, List, Union +from pydantic import BaseModel, Field, field_validator, ConfigDict +from gradio_i18n import Translate, gettext as _ +from enum import Enum +from copy import deepcopy +import yaml + +from modules.utils.constants import * + + +class WhisperImpl(Enum): + WHISPER = "whisper" + FASTER_WHISPER = "faster-whisper" + INSANELY_FAST_WHISPER = "insanely_fast_whisper" + + +class Segment(BaseModel): + text: Optional[str] = Field(default=None, + description="Transcription text of the segment") + start: Optional[float] = Field(default=None, + description="Start time of the segment") + end: Optional[float] = Field(default=None, + description="End time of the segment") + + +class BaseParams(BaseModel): + model_config = ConfigDict(protected_namespaces=()) + + def to_dict(self) -> Dict: + return self.model_dump() + + def to_list(self) -> List: + return list(self.model_dump().values()) + + @classmethod + def from_list(cls, data_list: List) -> 'BaseParams': + field_names = list(cls.model_fields.keys()) + return cls(**dict(zip(field_names, data_list))) + + +class VadParams(BaseParams): + """Voice Activity Detection parameters""" + vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts") + threshold: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Speech threshold for Silero VAD. Probabilities above this value are considered speech" + ) + min_speech_duration_ms: int = Field( + default=250, + ge=0, + description="Final speech chunks shorter than this are discarded" + ) + max_speech_duration_s: float = Field( + default=float("inf"), + gt=0, + description="Maximum duration of speech chunks in seconds" + ) + min_silence_duration_ms: int = Field( + default=2000, + ge=0, + description="Minimum silence duration between speech chunks" + ) + speech_pad_ms: int = Field( + default=400, + ge=0, + description="Padding added to each side of speech chunks" + ) + + @classmethod + def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]: + return [ + gr.Checkbox( + label=_("Enable Silero VAD Filter"), + value=defaults.get("vad_filter", cls.__fields__["vad_filter"].default), + interactive=True, + info=_("Enable this to transcribe only detected voice") + ), + gr.Slider( + minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", + value=defaults.get("threshold", cls.__fields__["threshold"].default), + info="Lower it to be more sensitive to small sounds." + ), + gr.Number( + label="Minimum Speech Duration (ms)", precision=0, + value=defaults.get("min_speech_duration_ms", cls.__fields__["min_speech_duration_ms"].default), + info="Final speech chunks shorter than this time are thrown out" + ), + gr.Number( + label="Maximum Speech Duration (s)", + value=defaults.get("max_speech_duration_s", GRADIO_NONE_NUMBER_MAX), + info="Maximum duration of speech chunks in \"seconds\"." + ), + gr.Number( + label="Minimum Silence Duration (ms)", precision=0, + value=defaults.get("min_silence_duration_ms", cls.__fields__["min_silence_duration_ms"].default), + info="In the end of each speech chunk wait for this time before separating it" + ), + gr.Number( + label="Speech Padding (ms)", precision=0, + value=defaults.get("speech_pad_ms", cls.__fields__["speech_pad_ms"].default), + info="Final speech chunks are padded by this time each side" + ) + ] + + +class DiarizationParams(BaseParams): + """Speaker diarization parameters""" + is_diarize: bool = Field(default=False, description="Enable speaker diarization") + device: str = Field(default="cuda", description="Device to run Diarization model.") + hf_token: str = Field( + default="", + description="Hugging Face token for downloading diarization models" + ) + + @classmethod + def to_gradio_inputs(cls, + defaults: Optional[Dict] = None, + available_devices: Optional[List] = None, + device: Optional[str] = None) -> List[gr.components.base.FormComponent]: + return [ + gr.Checkbox( + label=_("Enable Diarization"), + value=defaults.get("is_diarize", cls.__fields__["is_diarize"].default), + ), + gr.Dropdown( + label=_("Device"), + choices=["cpu", "cuda"] if available_devices is None else available_devices, + value=defaults.get("device", device), + ), + gr.Textbox( + label=_("HuggingFace Token"), + value=defaults.get("hf_token", cls.__fields__["hf_token"].default), + info=_("This is only needed the first time you download the model") + ), + ] + + +class BGMSeparationParams(BaseParams): + """Background music separation parameters""" + is_separate_bgm: bool = Field(default=False, description="Enable background music separation") + model_size: str = Field( + default="UVR-MDX-NET-Inst_HQ_4", + description="UVR model size" + ) + device: str = Field(default="cuda", description="Device to run UVR model.") + segment_size: int = Field( + default=256, + gt=0, + description="Segment size for UVR model" + ) + save_file: bool = Field( + default=False, + description="Whether to save separated audio files" + ) + enable_offload: bool = Field( + default=True, + description="Offload UVR model after transcription" + ) + + @classmethod + def to_gradio_input(cls, + defaults: Optional[Dict] = None, + available_devices: Optional[List] = None, + device: Optional[str] = None, + available_models: Optional[List] = None) -> List[gr.components.base.FormComponent]: + return [ + gr.Checkbox( + label=_("Enable Background Music Remover Filter"), + value=defaults.get("is_separate_bgm", cls.__fields__["is_separate_bgm"].default), + interactive=True, + info=_("Enabling this will remove background music") + ), + gr.Dropdown( + label=_("Model"), + choices=["UVR-MDX-NET-Inst_HQ_4", + "UVR-MDX-NET-Inst_3"] if available_models is None else available_models, + value=defaults.get("model_size", cls.__fields__["model_size"].default), + ), + gr.Dropdown( + label=_("Device"), + choices=["cpu", "cuda"] if available_devices is None else available_devices, + value=defaults.get("device", device), + ), + gr.Number( + label="Segment Size", + value=defaults.get("segment_size", cls.__fields__["segment_size"].default), + precision=0, + info="Segment size for UVR model" + ), + gr.Checkbox( + label=_("Save separated files to output"), + value=defaults.get("save_file", cls.__fields__["save_file"].default), + ), + gr.Checkbox( + label=_("Offload sub model after removing background music"), + value=defaults.get("enable_offload", cls.__fields__["enable_offload"].default), + ) + ] + + +class WhisperParams(BaseParams): + """Whisper parameters""" + model_size: str = Field(default="large-v2", description="Whisper model size") + lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe") + is_translate: bool = Field(default=False, description="Translate speech to English end-to-end") + beam_size: int = Field(default=5, ge=1, description="Beam size for decoding") + log_prob_threshold: float = Field( + default=-1.0, + description="Threshold for average log probability of sampled tokens" + ) + no_speech_threshold: float = Field( + default=0.6, + ge=0.0, + le=1.0, + description="Threshold for detecting silence" + ) + compute_type: str = Field(default="float16", description="Computation type for transcription") + best_of: int = Field(default=5, ge=1, description="Number of candidates when sampling") + patience: float = Field(default=1.0, gt=0, description="Beam search patience factor") + condition_on_previous_text: bool = Field( + default=True, + description="Use previous output as prompt for next window" + ) + prompt_reset_on_temperature: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Temperature threshold for resetting prompt" + ) + initial_prompt: Optional[str] = Field(default=None, description="Initial prompt for first window") + temperature: float = Field( + default=0.0, + ge=0.0, + description="Temperature for sampling" + ) + compression_ratio_threshold: float = Field( + default=2.4, + gt=0, + description="Threshold for gzip compression ratio" + ) + length_penalty: float = Field(default=1.0, gt=0, description="Exponential length penalty") + repetition_penalty: float = Field(default=1.0, gt=0, description="Penalty for repeated tokens") + no_repeat_ngram_size: int = Field(default=0, ge=0, description="Size of n-grams to prevent repetition") + prefix: Optional[str] = Field(default=None, description="Prefix text for first window") + suppress_blank: bool = Field( + default=True, + description="Suppress blank outputs at start of sampling" + ) + suppress_tokens: Optional[Union[List, str]] = Field(default=[-1], description="Token IDs to suppress") + max_initial_timestamp: float = Field( + default=0.0, + ge=0.0, + description="Maximum initial timestamp" + ) + word_timestamps: bool = Field(default=False, description="Extract word-level timestamps") + prepend_punctuations: Optional[str] = Field( + default="\"'“¿([{-", + description="Punctuations to merge with next word" + ) + append_punctuations: Optional[str] = Field( + default="\"'.。,,!!??::”)]}、", + description="Punctuations to merge with previous word" + ) + max_new_tokens: Optional[int] = Field(default=None, description="Maximum number of new tokens per chunk") + chunk_length: Optional[int] = Field(default=30, description="Length of audio segments in seconds") + hallucination_silence_threshold: Optional[float] = Field( + default=None, + description="Threshold for skipping silent periods in hallucination detection" + ) + hotwords: Optional[str] = Field(default=None, description="Hotwords/hint phrases for the model") + language_detection_threshold: Optional[float] = Field( + default=None, + description="Threshold for language detection probability" + ) + language_detection_segments: int = Field( + default=1, + gt=0, + description="Number of segments for language detection" + ) + batch_size: int = Field(default=24, gt=0, description="Batch size for processing") + + @field_validator('lang') + def validate_lang(cls, v): + from modules.utils.constants import AUTOMATIC_DETECTION + return None if v == AUTOMATIC_DETECTION.unwrap() else v + + @field_validator('suppress_tokens') + def validate_supress_tokens(cls, v): + import ast + try: + if isinstance(v, str): + suppress_tokens = ast.literal_eval(v) + if not isinstance(suppress_tokens, list): + raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]") + return suppress_tokens + if isinstance(v, list): + return v + except Exception as e: + raise ValueError(f"Invalid Suppress Tokens. The value must be type of List[int]: {e}") + + @classmethod + def to_gradio_inputs(cls, + defaults: Optional[Dict] = None, + only_advanced: Optional[bool] = True, + whisper_type: Optional[str] = None, + available_models: Optional[List] = None, + available_langs: Optional[List] = None, + available_compute_types: Optional[List] = None, + compute_type: Optional[str] = None): + whisper_type = WhisperImpl.FASTER_WHISPER.value if whisper_type is None else whisper_type.strip().lower() + + inputs = [] + if not only_advanced: + inputs += [ + gr.Dropdown( + label=_("Model"), + choices=available_models, + value=defaults.get("model_size", cls.__fields__["model_size"].default), + ), + gr.Dropdown( + label=_("Language"), + choices=available_langs, + value=defaults.get("lang", AUTOMATIC_DETECTION), + ), + gr.Checkbox( + label=_("Translate to English?"), + value=defaults.get("is_translate", cls.__fields__["is_translate"].default), + ), + ] + + inputs += [ + gr.Number( + label="Beam Size", + value=defaults.get("beam_size", cls.__fields__["beam_size"].default), + precision=0, + info="Beam size for decoding" + ), + gr.Number( + label="Log Probability Threshold", + value=defaults.get("log_prob_threshold", cls.__fields__["log_prob_threshold"].default), + info="Threshold for average log probability of sampled tokens" + ), + gr.Number( + label="No Speech Threshold", + value=defaults.get("no_speech_threshold", cls.__fields__["no_speech_threshold"].default), + info="Threshold for detecting silence" + ), + gr.Dropdown( + label="Compute Type", + choices=["float16", "int8", "int16"] if available_compute_types is None else available_compute_types, + value=defaults.get("compute_type", compute_type), + info="Computation type for transcription" + ), + gr.Number( + label="Best Of", + value=defaults.get("best_of", cls.__fields__["best_of"].default), + precision=0, + info="Number of candidates when sampling" + ), + gr.Number( + label="Patience", + value=defaults.get("patience", cls.__fields__["patience"].default), + info="Beam search patience factor" + ), + gr.Checkbox( + label="Condition On Previous Text", + value=defaults.get("condition_on_previous_text", cls.__fields__["condition_on_previous_text"].default), + info="Use previous output as prompt for next window" + ), + gr.Slider( + label="Prompt Reset On Temperature", + value=defaults.get("prompt_reset_on_temperature", + cls.__fields__["prompt_reset_on_temperature"].default), + minimum=0, + maximum=1, + step=0.01, + info="Temperature threshold for resetting prompt" + ), + gr.Textbox( + label="Initial Prompt", + value=defaults.get("initial_prompt", GRADIO_NONE_STR), + info="Initial prompt for first window" + ), + gr.Slider( + label="Temperature", + value=defaults.get("temperature", cls.__fields__["temperature"].default), + minimum=0.0, + step=0.01, + maximum=1.0, + info="Temperature for sampling" + ), + gr.Number( + label="Compression Ratio Threshold", + value=defaults.get("compression_ratio_threshold", + cls.__fields__["compression_ratio_threshold"].default), + info="Threshold for gzip compression ratio" + ) + ] + + faster_whisper_inputs = [ + gr.Number( + label="Length Penalty", + value=defaults.get("length_penalty", cls.__fields__["length_penalty"].default), + info="Exponential length penalty", + ), + gr.Number( + label="Repetition Penalty", + value=defaults.get("repetition_penalty", cls.__fields__["repetition_penalty"].default), + info="Penalty for repeated tokens" + ), + gr.Number( + label="No Repeat N-gram Size", + value=defaults.get("no_repeat_ngram_size", cls.__fields__["no_repeat_ngram_size"].default), + precision=0, + info="Size of n-grams to prevent repetition" + ), + gr.Textbox( + label="Prefix", + value=defaults.get("prefix", GRADIO_NONE_STR), + info="Prefix text for first window" + ), + gr.Checkbox( + label="Suppress Blank", + value=defaults.get("suppress_blank", cls.__fields__["suppress_blank"].default), + info="Suppress blank outputs at start of sampling" + ), + gr.Textbox( + label="Suppress Tokens", + value=defaults.get("suppress_tokens", "[-1]"), + info="Token IDs to suppress" + ), + gr.Number( + label="Max Initial Timestamp", + value=defaults.get("max_initial_timestamp", cls.__fields__["max_initial_timestamp"].default), + info="Maximum initial timestamp" + ), + gr.Checkbox( + label="Word Timestamps", + value=defaults.get("word_timestamps", cls.__fields__["word_timestamps"].default), + info="Extract word-level timestamps" + ), + gr.Textbox( + label="Prepend Punctuations", + value=defaults.get("prepend_punctuations", cls.__fields__["prepend_punctuations"].default), + info="Punctuations to merge with next word" + ), + gr.Textbox( + label="Append Punctuations", + value=defaults.get("append_punctuations", cls.__fields__["append_punctuations"].default), + info="Punctuations to merge with previous word" + ), + gr.Number( + label="Max New Tokens", + value=defaults.get("max_new_tokens", GRADIO_NONE_NUMBER_MIN), + precision=0, + info="Maximum number of new tokens per chunk" + ), + gr.Number( + label="Chunk Length (s)", + value=defaults.get("chunk_length", cls.__fields__["chunk_length"].default), + precision=0, + info="Length of audio segments in seconds" + ), + gr.Number( + label="Hallucination Silence Threshold (sec)", + value=defaults.get("hallucination_silence_threshold", + GRADIO_NONE_NUMBER_MIN), + info="Threshold for skipping silent periods in hallucination detection" + ), + gr.Textbox( + label="Hotwords", + value=defaults.get("hotwords", cls.__fields__["hotwords"].default), + info="Hotwords/hint phrases for the model" + ), + gr.Number( + label="Language Detection Threshold", + value=defaults.get("language_detection_threshold", + GRADIO_NONE_NUMBER_MIN), + info="Threshold for language detection probability" + ), + gr.Number( + label="Language Detection Segments", + value=defaults.get("language_detection_segments", + cls.__fields__["language_detection_segments"].default), + precision=0, + info="Number of segments for language detection" + ) + ] + + insanely_fast_whisper_inputs = [ + gr.Number( + label="Batch Size", + value=defaults.get("batch_size", cls.__fields__["batch_size"].default), + precision=0, + info="Batch size for processing" + ) + ] + + if whisper_type != WhisperImpl.FASTER_WHISPER.value: + for input_component in faster_whisper_inputs: + input_component.visible = False + + if whisper_type != WhisperImpl.INSANELY_FAST_WHISPER.value: + for input_component in insanely_fast_whisper_inputs: + input_component.visible = False + + inputs += faster_whisper_inputs + insanely_fast_whisper_inputs + + return inputs + + +class TranscriptionPipelineParams(BaseModel): + """Transcription pipeline parameters""" + whisper: WhisperParams = Field(default_factory=WhisperParams) + vad: VadParams = Field(default_factory=VadParams) + diarization: DiarizationParams = Field(default_factory=DiarizationParams) + bgm_separation: BGMSeparationParams = Field(default_factory=BGMSeparationParams) + + def to_dict(self) -> Dict: + data = { + "whisper": self.whisper.to_dict(), + "vad": self.vad.to_dict(), + "diarization": self.diarization.to_dict(), + "bgm_separation": self.bgm_separation.to_dict() + } + return data + + def to_list(self) -> List: + """ + Convert data class to the list because I have to pass the parameters as a list in the gradio. + Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471 + See more about Gradio pre-processing: https://www.gradio.app/docs/components + """ + whisper_list = self.whisper.to_list() + vad_list = self.vad.to_list() + diarization_list = self.diarization.to_list() + bgm_sep_list = self.bgm_separation.to_list() + return whisper_list + vad_list + diarization_list + bgm_sep_list + + @staticmethod + def from_list(pipeline_list: List) -> 'TranscriptionPipelineParams': + """Convert list to the data class again to use it in a function.""" + data_list = deepcopy(pipeline_list) + + whisper_list = data_list[0:len(WhisperParams.__annotations__)] + data_list = data_list[len(WhisperParams.__annotations__):] + + vad_list = data_list[0:len(VadParams.__annotations__)] + data_list = data_list[len(VadParams.__annotations__):] + + diarization_list = data_list[0:len(DiarizationParams.__annotations__)] + data_list = data_list[len(DiarizationParams.__annotations__):] + + bgm_sep_list = data_list[0:len(BGMSeparationParams.__annotations__)] + + return TranscriptionPipelineParams( + whisper=WhisperParams.from_list(whisper_list), + vad=VadParams.from_list(vad_list), + diarization=DiarizationParams.from_list(diarization_list), + bgm_separation=BGMSeparationParams.from_list(bgm_sep_list) + ) diff --git a/modules/whisper/faster_whisper_inference.py b/modules/whisper/faster_whisper_inference.py index f12fc01a..5dad6edb 100644 --- a/modules/whisper/faster_whisper_inference.py +++ b/modules/whisper/faster_whisper_inference.py @@ -12,11 +12,11 @@ from argparse import Namespace from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR) -from modules.whisper.whisper_parameter import * -from modules.whisper.whisper_base import WhisperBase +from modules.whisper.data_classes import * +from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline -class FasterWhisperInference(WhisperBase): +class FasterWhisperInference(BaseTranscriptionPipeline): def __init__(self, model_dir: str = FASTER_WHISPER_MODELS_DIR, diarization_model_dir: str = DIARIZATION_MODELS_DIR, @@ -40,7 +40,7 @@ def transcribe(self, audio: Union[str, BinaryIO, np.ndarray], progress: gr.Progress = gr.Progress(), *whisper_params, - ) -> Tuple[List[dict], float]: + ) -> Tuple[List[Segment], float]: """ transcribe method for faster-whisper. @@ -55,28 +55,18 @@ def transcribe(self, Returns ---------- - segments_result: List[dict] - list of dicts that includes start, end timestamps and transcribed text + segments_result: List[Segment] + list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for transcription """ start_time = time.time() - params = WhisperParameters.as_value(*whisper_params) + params = WhisperParams.from_list(list(whisper_params)) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) - # None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723 - if not params.initial_prompt: - params.initial_prompt = None - if not params.prefix: - params.prefix = None - if not params.hotwords: - params.hotwords = None - - params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens) - segments, info = self.model.transcribe( audio=audio, language=params.lang, @@ -112,11 +102,11 @@ def transcribe(self, segments_result = [] for segment in segments: progress(segment.start / info.duration, desc="Transcribing..") - segments_result.append({ - "start": segment.start, - "end": segment.end, - "text": segment.text - }) + segments_result.append(Segment( + start=segment.start, + end=segment.end, + text=segment.text + )) elapsed_time = time.time() - start_time return segments_result, elapsed_time diff --git a/modules/whisper/insanely_fast_whisper_inference.py b/modules/whisper/insanely_fast_whisper_inference.py index fe6f4fdb..c1f7cb2c 100644 --- a/modules/whisper/insanely_fast_whisper_inference.py +++ b/modules/whisper/insanely_fast_whisper_inference.py @@ -12,11 +12,11 @@ from argparse import Namespace from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR) -from modules.whisper.whisper_parameter import * -from modules.whisper.whisper_base import WhisperBase +from modules.whisper.data_classes import * +from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline -class InsanelyFastWhisperInference(WhisperBase): +class InsanelyFastWhisperInference(BaseTranscriptionPipeline): def __init__(self, model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR, diarization_model_dir: str = DIARIZATION_MODELS_DIR, @@ -40,7 +40,7 @@ def transcribe(self, audio: Union[str, np.ndarray, torch.Tensor], progress: gr.Progress = gr.Progress(), *whisper_params, - ) -> Tuple[List[dict], float]: + ) -> Tuple[List[Segment], float]: """ transcribe method for faster-whisper. @@ -55,13 +55,13 @@ def transcribe(self, Returns ---------- - segments_result: List[dict] - list of dicts that includes start, end timestamps and transcribed text + segments_result: List[Segment] + list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for transcription """ start_time = time.time() - params = WhisperParameters.as_value(*whisper_params) + params = WhisperParams.from_list(list(whisper_params)) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) @@ -95,9 +95,17 @@ def transcribe(self, generate_kwargs=kwargs ) - segments_result = self.format_result( - transcribed_result=segments, - ) + segments_result = [] + for item in segments["chunks"]: + start, end = item["timestamp"][0], item["timestamp"][1] + if end is None: + end = start + segments_result.append(Segment( + text=item["text"], + start=start, + end=end + )) + elapsed_time = time.time() - start_time return segments_result, elapsed_time diff --git a/modules/whisper/whisper_Inference.py b/modules/whisper/whisper_Inference.py index f87fbe5d..ccd4bbb4 100644 --- a/modules/whisper/whisper_Inference.py +++ b/modules/whisper/whisper_Inference.py @@ -8,11 +8,11 @@ from argparse import Namespace from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR) -from modules.whisper.whisper_base import WhisperBase -from modules.whisper.whisper_parameter import * +from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline +from modules.whisper.data_classes import * -class WhisperInference(WhisperBase): +class WhisperInference(BaseTranscriptionPipeline): def __init__(self, model_dir: str = WHISPER_MODELS_DIR, diarization_model_dir: str = DIARIZATION_MODELS_DIR, @@ -30,7 +30,7 @@ def transcribe(self, audio: Union[str, np.ndarray, torch.Tensor], progress: gr.Progress = gr.Progress(), *whisper_params, - ) -> Tuple[List[dict], float]: + ) -> Tuple[List[Segment], float]: """ transcribe method for faster-whisper. @@ -45,13 +45,13 @@ def transcribe(self, Returns ---------- - segments_result: List[dict] - list of dicts that includes start, end timestamps and transcribed text + segments_result: List[Segment] + list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for transcription """ start_time = time.time() - params = WhisperParameters.as_value(*whisper_params) + params = WhisperParams.from_list(list(whisper_params)) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) @@ -59,21 +59,28 @@ def transcribe(self, def progress_callback(progress_value): progress(progress_value, desc="Transcribing..") - segments_result = self.model.transcribe(audio=audio, - language=params.lang, - verbose=False, - beam_size=params.beam_size, - logprob_threshold=params.log_prob_threshold, - no_speech_threshold=params.no_speech_threshold, - task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe", - fp16=True if params.compute_type == "float16" else False, - best_of=params.best_of, - patience=params.patience, - temperature=params.temperature, - compression_ratio_threshold=params.compression_ratio_threshold, - progress_callback=progress_callback,)["segments"] - elapsed_time = time.time() - start_time + result = self.model.transcribe(audio=audio, + language=params.lang, + verbose=False, + beam_size=params.beam_size, + logprob_threshold=params.log_prob_threshold, + no_speech_threshold=params.no_speech_threshold, + task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe", + fp16=True if params.compute_type == "float16" else False, + best_of=params.best_of, + patience=params.patience, + temperature=params.temperature, + compression_ratio_threshold=params.compression_ratio_threshold, + progress_callback=progress_callback,)["segments"] + segments_result = [] + for segment in result: + segments_result.append(Segment( + start=segment["start"], + end=segment["end"], + text=segment["text"] + )) + elapsed_time = time.time() - start_time return segments_result, elapsed_time def update_model(self, diff --git a/modules/whisper/whisper_factory.py b/modules/whisper/whisper_factory.py index 6bda8c58..b5ae33a7 100644 --- a/modules/whisper/whisper_factory.py +++ b/modules/whisper/whisper_factory.py @@ -6,7 +6,8 @@ from modules.whisper.faster_whisper_inference import FasterWhisperInference from modules.whisper.whisper_Inference import WhisperInference from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference -from modules.whisper.whisper_base import WhisperBase +from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline +from modules.whisper.data_classes import * class WhisperFactory: @@ -19,7 +20,7 @@ def create_whisper_inference( diarization_model_dir: str = DIARIZATION_MODELS_DIR, uvr_model_dir: str = UVR_MODELS_DIR, output_dir: str = OUTPUT_DIR, - ) -> "WhisperBase": + ) -> "BaseTranscriptionPipeline": """ Create a whisper inference class based on the provided whisper_type. @@ -45,36 +46,29 @@ def create_whisper_inference( Returns ------- - WhisperBase + BaseTranscriptionPipeline An instance of the appropriate whisper inference class based on the whisper_type. """ # Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144 os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' - whisper_type = whisper_type.lower().strip() + whisper_type = whisper_type.strip().lower() - faster_whisper_typos = ["faster_whisper", "faster-whisper", "fasterwhisper"] - whisper_typos = ["whisper"] - insanely_fast_whisper_typos = [ - "insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper", - "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper" - ] - - if whisper_type in faster_whisper_typos: + if whisper_type == WhisperImpl.FASTER_WHISPER.value: return FasterWhisperInference( model_dir=faster_whisper_model_dir, output_dir=output_dir, diarization_model_dir=diarization_model_dir, uvr_model_dir=uvr_model_dir ) - elif whisper_type in whisper_typos: + elif whisper_type == WhisperImpl.WHISPER.value: return WhisperInference( model_dir=whisper_model_dir, output_dir=output_dir, diarization_model_dir=diarization_model_dir, uvr_model_dir=uvr_model_dir ) - elif whisper_type in insanely_fast_whisper_typos: + elif whisper_type == WhisperImpl.INSANELY_FAST_WHISPER.value: return InsanelyFastWhisperInference( model_dir=insanely_fast_whisper_model_dir, output_dir=output_dir, diff --git a/modules/whisper/whisper_parameter.py b/modules/whisper/whisper_parameter.py deleted file mode 100644 index 19115fc2..00000000 --- a/modules/whisper/whisper_parameter.py +++ /dev/null @@ -1,371 +0,0 @@ -from dataclasses import dataclass, fields -import gradio as gr -from typing import Optional, Dict -import yaml - -from modules.utils.constants import AUTOMATIC_DETECTION - - -@dataclass -class WhisperParameters: - model_size: gr.Dropdown - lang: gr.Dropdown - is_translate: gr.Checkbox - beam_size: gr.Number - log_prob_threshold: gr.Number - no_speech_threshold: gr.Number - compute_type: gr.Dropdown - best_of: gr.Number - patience: gr.Number - condition_on_previous_text: gr.Checkbox - prompt_reset_on_temperature: gr.Slider - initial_prompt: gr.Textbox - temperature: gr.Slider - compression_ratio_threshold: gr.Number - vad_filter: gr.Checkbox - threshold: gr.Slider - min_speech_duration_ms: gr.Number - max_speech_duration_s: gr.Number - min_silence_duration_ms: gr.Number - speech_pad_ms: gr.Number - batch_size: gr.Number - is_diarize: gr.Checkbox - hf_token: gr.Textbox - diarization_device: gr.Dropdown - length_penalty: gr.Number - repetition_penalty: gr.Number - no_repeat_ngram_size: gr.Number - prefix: gr.Textbox - suppress_blank: gr.Checkbox - suppress_tokens: gr.Textbox - max_initial_timestamp: gr.Number - word_timestamps: gr.Checkbox - prepend_punctuations: gr.Textbox - append_punctuations: gr.Textbox - max_new_tokens: gr.Number - chunk_length: gr.Number - hallucination_silence_threshold: gr.Number - hotwords: gr.Textbox - language_detection_threshold: gr.Number - language_detection_segments: gr.Number - is_bgm_separate: gr.Checkbox - uvr_model_size: gr.Dropdown - uvr_device: gr.Dropdown - uvr_segment_size: gr.Number - uvr_save_file: gr.Checkbox - uvr_enable_offload: gr.Checkbox - """ - A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing. - This data class is used to mitigate the key-value problem between Gradio components and function parameters. - Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471 - See more about Gradio pre-processing: https://www.gradio.app/docs/components - - Attributes - ---------- - model_size: gr.Dropdown - Whisper model size. - - lang: gr.Dropdown - Source language of the file to transcribe. - - is_translate: gr.Checkbox - Boolean value that determines whether to translate to English. - It's Whisper's feature to translate speech from another language directly into English end-to-end. - - beam_size: gr.Number - Int value that is used for decoding option. - - log_prob_threshold: gr.Number - If the average log probability over sampled tokens is below this value, treat as failed. - - no_speech_threshold: gr.Number - If the no_speech probability is higher than this value AND - the average log probability over sampled tokens is below `log_prob_threshold`, - consider the segment as silent. - - compute_type: gr.Dropdown - compute type for transcription. - see more info : https://opennmt.net/CTranslate2/quantization.html - - best_of: gr.Number - Number of candidates when sampling with non-zero temperature. - - patience: gr.Number - Beam search patience factor. - - condition_on_previous_text: gr.Checkbox - if True, the previous output of the model is provided as a prompt for the next window; - disabling may make the text inconsistent across windows, but the model becomes less prone to - getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. - - initial_prompt: gr.Textbox - Optional text to provide as a prompt for the first window. This can be used to provide, or - "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns - to make it more likely to predict those word correctly. - - temperature: gr.Slider - Temperature for sampling. It can be a tuple of temperatures, - which will be successively used upon failures according to either - `compression_ratio_threshold` or `log_prob_threshold`. - - compression_ratio_threshold: gr.Number - If the gzip compression ratio is above this value, treat as failed - - vad_filter: gr.Checkbox - Enable the voice activity detection (VAD) to filter out parts of the audio - without speech. This step is using the Silero VAD model - https://github.com/snakers4/silero-vad. - - threshold: gr.Slider - This parameter is related with Silero VAD. Speech threshold. - Silero VAD outputs speech probabilities for each audio chunk, - probabilities ABOVE this value are considered as SPEECH. It is better to tune this - parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. - - min_speech_duration_ms: gr.Number - This parameter is related with Silero VAD. Final speech chunks shorter min_speech_duration_ms are thrown out. - - max_speech_duration_s: gr.Number - This parameter is related with Silero VAD. Maximum duration of speech chunks in seconds. Chunks longer - than max_speech_duration_s will be split at the timestamp of the last silence that - lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be - split aggressively just before max_speech_duration_s. - - min_silence_duration_ms: gr.Number - This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms - before separating it - - speech_pad_ms: gr.Number - This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side - - batch_size: gr.Number - This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe - - is_diarize: gr.Checkbox - This parameter is related with whisperx. Boolean value that determines whether to diarize or not. - - hf_token: gr.Textbox - This parameter is related with whisperx. Huggingface token is needed to download diarization models. - Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements - - diarization_device: gr.Dropdown - This parameter is related with whisperx. Device to run diarization model - - length_penalty: gr.Number - This parameter is related to faster-whisper. Exponential length penalty constant. - - repetition_penalty: gr.Number - This parameter is related to faster-whisper. Penalty applied to the score of previously generated tokens - (set > 1 to penalize). - - no_repeat_ngram_size: gr.Number - This parameter is related to faster-whisper. Prevent repetitions of n-grams with this size (set 0 to disable). - - prefix: gr.Textbox - This parameter is related to faster-whisper. Optional text to provide as a prefix for the first window. - - suppress_blank: gr.Checkbox - This parameter is related to faster-whisper. Suppress blank outputs at the beginning of the sampling. - - suppress_tokens: gr.Textbox - This parameter is related to faster-whisper. List of token IDs to suppress. -1 will suppress a default set - of symbols as defined in the model config.json file. - - max_initial_timestamp: gr.Number - This parameter is related to faster-whisper. The initial timestamp cannot be later than this. - - word_timestamps: gr.Checkbox - This parameter is related to faster-whisper. Extract word-level timestamps using the cross-attention pattern - and dynamic time warping, and include the timestamps for each word in each segment. - - prepend_punctuations: gr.Textbox - This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols - with the next word. - - append_punctuations: gr.Textbox - This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols - with the previous word. - - max_new_tokens: gr.Number - This parameter is related to faster-whisper. Maximum number of new tokens to generate per-chunk. If not set, - the maximum will be set by the default max_length. - - chunk_length: gr.Number - This parameter is related to faster-whisper and insanely-fast-whisper. The length of audio segments in seconds. - If it is not None, it will overwrite the default chunk_length of the FeatureExtractor. - - hallucination_silence_threshold: gr.Number - This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold - (in seconds) when a possible hallucination is detected. - - hotwords: gr.Textbox - This parameter is related to faster-whisper. Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None. - - language_detection_threshold: gr.Number - This parameter is related to faster-whisper. If the maximum probability of the language tokens is higher than this value, the language is detected. - - language_detection_segments: gr.Number - This parameter is related to faster-whisper. Number of segments to consider for the language detection. - - is_separate_bgm: gr.Checkbox - This parameter is related to UVR. Boolean value that determines whether to separate bgm or not. - - uvr_model_size: gr.Dropdown - This parameter is related to UVR. UVR model size. - - uvr_device: gr.Dropdown - This parameter is related to UVR. Device to run UVR model. - - uvr_segment_size: gr.Number - This parameter is related to UVR. Segment size for UVR model. - - uvr_save_file: gr.Checkbox - This parameter is related to UVR. Boolean value that determines whether to save the file or not. - - uvr_enable_offload: gr.Checkbox - This parameter is related to UVR. Boolean value that determines whether to offload the UVR model or not - after each transcription. - """ - - def as_list(self) -> list: - """ - Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing. - See more about Gradio pre-processing: : https://www.gradio.app/docs/components - - Returns - ---------- - A list of Gradio components - """ - return [getattr(self, f.name) for f in fields(self)] - - @staticmethod - def as_value(*args) -> 'WhisperValues': - """ - To use Whisper parameters in function after Gradio post-processing. - See more about Gradio post-processing: : https://www.gradio.app/docs/components - - Returns - ---------- - WhisperValues - Data class that has values of parameters - """ - return WhisperValues(*args) - - -@dataclass -class WhisperValues: - model_size: str = "large-v2" - lang: Optional[str] = None - is_translate: bool = False - beam_size: int = 5 - log_prob_threshold: float = -1.0 - no_speech_threshold: float = 0.6 - compute_type: str = "float16" - best_of: int = 5 - patience: float = 1.0 - condition_on_previous_text: bool = True - prompt_reset_on_temperature: float = 0.5 - initial_prompt: Optional[str] = None - temperature: float = 0.0 - compression_ratio_threshold: float = 2.4 - vad_filter: bool = False - threshold: float = 0.5 - min_speech_duration_ms: int = 250 - max_speech_duration_s: float = float("inf") - min_silence_duration_ms: int = 2000 - speech_pad_ms: int = 400 - batch_size: int = 24 - is_diarize: bool = False - hf_token: str = "" - diarization_device: str = "cuda" - length_penalty: float = 1.0 - repetition_penalty: float = 1.0 - no_repeat_ngram_size: int = 0 - prefix: Optional[str] = None - suppress_blank: bool = True - suppress_tokens: Optional[str] = "[-1]" - max_initial_timestamp: float = 0.0 - word_timestamps: bool = False - prepend_punctuations: Optional[str] = "\"'“¿([{-" - append_punctuations: Optional[str] = "\"'.。,,!!??::”)]}、" - max_new_tokens: Optional[int] = None - chunk_length: Optional[int] = 30 - hallucination_silence_threshold: Optional[float] = None - hotwords: Optional[str] = None - language_detection_threshold: Optional[float] = None - language_detection_segments: int = 1 - is_bgm_separate: bool = False - uvr_model_size: str = "UVR-MDX-NET-Inst_HQ_4" - uvr_device: str = "cuda" - uvr_segment_size: int = 256 - uvr_save_file: bool = False - uvr_enable_offload: bool = True - """ - A data class to use Whisper parameters. - """ - - def to_yaml(self) -> Dict: - data = { - "whisper": { - "model_size": self.model_size, - "lang": AUTOMATIC_DETECTION.unwrap() if self.lang is None else self.lang, - "is_translate": self.is_translate, - "beam_size": self.beam_size, - "log_prob_threshold": self.log_prob_threshold, - "no_speech_threshold": self.no_speech_threshold, - "best_of": self.best_of, - "patience": self.patience, - "condition_on_previous_text": self.condition_on_previous_text, - "prompt_reset_on_temperature": self.prompt_reset_on_temperature, - "initial_prompt": None if not self.initial_prompt else self.initial_prompt, - "temperature": self.temperature, - "compression_ratio_threshold": self.compression_ratio_threshold, - "batch_size": self.batch_size, - "length_penalty": self.length_penalty, - "repetition_penalty": self.repetition_penalty, - "no_repeat_ngram_size": self.no_repeat_ngram_size, - "prefix": None if not self.prefix else self.prefix, - "suppress_blank": self.suppress_blank, - "suppress_tokens": self.suppress_tokens, - "max_initial_timestamp": self.max_initial_timestamp, - "word_timestamps": self.word_timestamps, - "prepend_punctuations": self.prepend_punctuations, - "append_punctuations": self.append_punctuations, - "max_new_tokens": self.max_new_tokens, - "chunk_length": self.chunk_length, - "hallucination_silence_threshold": self.hallucination_silence_threshold, - "hotwords": None if not self.hotwords else self.hotwords, - "language_detection_threshold": self.language_detection_threshold, - "language_detection_segments": self.language_detection_segments, - }, - "vad": { - "vad_filter": self.vad_filter, - "threshold": self.threshold, - "min_speech_duration_ms": self.min_speech_duration_ms, - "max_speech_duration_s": self.max_speech_duration_s, - "min_silence_duration_ms": self.min_silence_duration_ms, - "speech_pad_ms": self.speech_pad_ms, - }, - "diarization": { - "is_diarize": self.is_diarize, - "hf_token": self.hf_token - }, - "bgm_separation": { - "is_separate_bgm": self.is_bgm_separate, - "model_size": self.uvr_model_size, - "segment_size": self.uvr_segment_size, - "save_file": self.uvr_save_file, - "enable_offload": self.uvr_enable_offload - }, - } - return data - - def as_list(self) -> list: - """ - Converts the data class attributes into a list - - Returns - ---------- - A list of Whisper parameters - """ - return [getattr(self, f.name) for f in fields(self)] diff --git a/tests/test_bgm_separation.py b/tests/test_bgm_separation.py index cc4a6f80..95b77a0b 100644 --- a/tests/test_bgm_separation.py +++ b/tests/test_bgm_separation.py @@ -1,6 +1,6 @@ from modules.utils.paths import * from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.whisper_parameter import WhisperValues +from modules.whisper.data_classes import * from test_config import * from test_transcription import download_file, test_transcribe @@ -17,9 +17,9 @@ @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", [ - ("whisper", False, True, False), - ("faster-whisper", False, True, False), - ("insanely_fast_whisper", False, True, False) + (WhisperImpl.WHISPER.value, False, True, False), + (WhisperImpl.FASTER_WHISPER.value, False, True, False), + (WhisperImpl.INSANELY_FAST_WHISPER.value, False, True, False) ] ) def test_bgm_separation_pipeline( @@ -38,9 +38,9 @@ def test_bgm_separation_pipeline( @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", [ - ("whisper", True, True, False), - ("faster-whisper", True, True, False), - ("insanely_fast_whisper", True, True, False) + (WhisperImpl.WHISPER.value, True, True, False), + (WhisperImpl.FASTER_WHISPER.value, True, True, False), + (WhisperImpl.INSANELY_FAST_WHISPER.value, True, True, False) ] ) def test_bgm_separation_with_vad_pipeline( diff --git a/tests/test_config.py b/tests/test_config.py index 0f60aa58..0020eca1 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,7 +6,7 @@ TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav" TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav") TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer" -TEST_WHISPER_MODEL = "tiny" +TEST_WHISPER_MODEL = "tiny.en" TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4" TEST_NLLB_MODEL = "facebook/nllb-200-distilled-600M" TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt") diff --git a/tests/test_diarization.py b/tests/test_diarization.py index 54e7244e..f18a2633 100644 --- a/tests/test_diarization.py +++ b/tests/test_diarization.py @@ -1,6 +1,6 @@ from modules.utils.paths import * from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.whisper_parameter import WhisperValues +from modules.whisper.data_classes import * from test_config import * from test_transcription import download_file, test_transcribe @@ -16,9 +16,9 @@ @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", [ - ("whisper", False, False, True), - ("faster-whisper", False, False, True), - ("insanely_fast_whisper", False, False, True) + (WhisperImpl.WHISPER.value, False, False, True), + (WhisperImpl.FASTER_WHISPER.value, False, False, True), + (WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, True) ] ) def test_diarization_pipeline( diff --git a/tests/test_transcription.py b/tests/test_transcription.py index 4b5ab98f..3353782b 100644 --- a/tests/test_transcription.py +++ b/tests/test_transcription.py @@ -1,5 +1,5 @@ from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.whisper_parameter import WhisperValues +from modules.whisper.data_classes import * from modules.utils.paths import WEBUI_DIR from test_config import * @@ -12,9 +12,9 @@ @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", [ - ("whisper", False, False, False), - ("faster-whisper", False, False, False), - ("insanely_fast_whisper", False, False, False) + (WhisperImpl.WHISPER.value, False, False, False), + (WhisperImpl.FASTER_WHISPER.value, False, False, False), + (WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, False) ] ) def test_transcribe( @@ -37,14 +37,22 @@ def test_transcribe( f"""Diarization Device: {whisper_inferencer.diarizer.device}""" ) - hparams = WhisperValues( - model_size=TEST_WHISPER_MODEL, - vad_filter=vad_filter, - is_bgm_separate=bgm_separation, - compute_type=whisper_inferencer.current_compute_type, - uvr_enable_offload=True, - is_diarize=diarization, - ).as_list() + hparams = TranscriptionPipelineParams( + whisper=WhisperParams( + model_size=TEST_WHISPER_MODEL, + compute_type=whisper_inferencer.current_compute_type + ), + vad=VadParams( + vad_filter=vad_filter + ), + bgm_separation=BGMSeparationParams( + is_separate_bgm=bgm_separation, + enable_offload=True + ), + diarization=DiarizationParams( + is_diarize=diarization + ), + ).to_list() subtitle_str, file_path = whisper_inferencer.transcribe_file( [audio_path], diff --git a/tests/test_vad.py b/tests/test_vad.py index 124a043d..cb3dc054 100644 --- a/tests/test_vad.py +++ b/tests/test_vad.py @@ -1,6 +1,6 @@ from modules.utils.paths import * from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.whisper_parameter import WhisperValues +from modules.whisper.data_classes import * from test_config import * from test_transcription import download_file, test_transcribe @@ -12,9 +12,9 @@ @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", [ - ("whisper", True, False, False), - ("faster-whisper", True, False, False), - ("insanely_fast_whisper", True, False, False) + (WhisperImpl.WHISPER.value, True, False, False), + (WhisperImpl.FASTER_WHISPER.value, True, False, False), + (WhisperImpl.INSANELY_FAST_WHISPER.value, True, False, False) ] ) def test_vad_pipeline(