diff --git a/modules/whisper/faster_whisper_inference.py b/modules/whisper/faster_whisper_inference.py index 94eb18f5..524cd247 100644 --- a/modules/whisper/faster_whisper_inference.py +++ b/modules/whisper/faster_whisper_inference.py @@ -62,7 +62,7 @@ def transcribe(self, """ start_time = time.time() - params = TranscriptionPipelineGradioComponents.as_value(*whisper_params) + params = WhisperParams.from_list(list(whisper_params)) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) diff --git a/modules/whisper/insanely_fast_whisper_inference.py b/modules/whisper/insanely_fast_whisper_inference.py index c7476d10..57483f44 100644 --- a/modules/whisper/insanely_fast_whisper_inference.py +++ b/modules/whisper/insanely_fast_whisper_inference.py @@ -61,7 +61,7 @@ def transcribe(self, elapsed time for transcription """ start_time = time.time() - params = TranscriptionPipelineGradioComponents.as_value(*whisper_params) + params = WhisperParams.from_list(list(whisper_params)) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) diff --git a/modules/whisper/whisper_Inference.py b/modules/whisper/whisper_Inference.py index 338e04b7..8f567b49 100644 --- a/modules/whisper/whisper_Inference.py +++ b/modules/whisper/whisper_Inference.py @@ -51,7 +51,7 @@ def transcribe(self, elapsed time for transcription """ start_time = time.time() - params = TranscriptionPipelineGradioComponents.as_value(*whisper_params) + params = WhisperParams.from_list(list(whisper_params)) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) diff --git a/modules/whisper/whisper_base.py b/modules/whisper/whisper_base.py index e891ec73..08ed0453 100644 --- a/modules/whisper/whisper_base.py +++ b/modules/whisper/whisper_base.py @@ -74,7 +74,7 @@ def run(self, audio: Union[str, BinaryIO, np.ndarray], progress: gr.Progress = gr.Progress(), add_timestamp: bool = True, - *whisper_params, + *pipeline_params, ) -> Tuple[List[dict], float]: """ Run transcription with conditional pre-processing and post-processing. @@ -89,8 +89,8 @@ def run(self, Indicator to show progress directly in gradio. add_timestamp: bool Whether to add a timestamp at the end of the filename. - *whisper_params: tuple - Parameters related with whisper. This will be dealt with "WhisperParameters" data class + *pipeline_params: tuple + Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class Returns ---------- @@ -99,28 +99,29 @@ def run(self, elapsed_time: float elapsed time for running """ - params = TranscriptionPipelineGradioComponents.as_value(*whisper_params) + params = TranscriptionPipelineParams.from_list(list(pipeline_params)) + bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization self.cache_parameters( - whisper_params=params, + params=params, add_timestamp=add_timestamp ) - if params.lang is None: + if whisper_params.lang is None: pass - elif params.lang == AUTOMATIC_DETECTION: - params.lang = None + elif whisper_params.lang == AUTOMATIC_DETECTION: + whisper_params.lang = None else: language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()} - params.lang = language_code_dict[params.lang] + whisper_params.lang = language_code_dict[params.lang] - if params.is_bgm_separate: + if bgm_params.is_separate_bgm: music, audio, _ = self.music_separator.separate( audio=audio, - model_name=params.uvr_model_size, - device=params.uvr_device, - segment_size=params.uvr_segment_size, - save_file=params.uvr_save_file, + model_name=bgm_params.model_size, + device=bgm_params.device, + segment_size=bgm_params.segment_size, + save_file=bgm_params.save_file, progress=progress ) @@ -132,20 +133,20 @@ def run(self, origin_sample_rate = self.music_separator.audio_info.sample_rate audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate) - if params.uvr_enable_offload: + if bgm_params.enable_offload: self.music_separator.offload() - if params.vad_filter: + if vad_params.vad_filter: # Explicit value set for float('inf') from gr.Number() - if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999: - params.max_speech_duration_s = float('inf') + if vad_params.max_speech_duration_s is None or vad_params.max_speech_duration_s >= 9999: + vad_params.max_speech_duration_s = float('inf') vad_options = VadOptions( - threshold=params.threshold, - min_speech_duration_ms=params.min_speech_duration_ms, - max_speech_duration_s=params.max_speech_duration_s, - min_silence_duration_ms=params.min_silence_duration_ms, - speech_pad_ms=params.speech_pad_ms + threshold=vad_params.threshold, + min_speech_duration_ms=vad_params.min_speech_duration_ms, + max_speech_duration_s=vad_params.max_speech_duration_s, + min_silence_duration_ms=vad_params.min_silence_duration_ms, + speech_pad_ms=vad_params.speech_pad_ms ) audio, speech_chunks = self.vad.run( @@ -157,20 +158,21 @@ def run(self, result, elapsed_time = self.transcribe( audio, progress, - *astuple(params) + *whisper_params.to_list() ) - if params.vad_filter: + if vad_params.vad_filter: result = self.vad.restore_speech_timestamps( segments=result, - speech_chunks=speech_chunks, + speech_chunks=vad_params.speech_chunks, ) - if params.is_diarize: + if diarization_params.is_diarize: result, elapsed_time_diarization = self.diarizer.run( audio=audio, - use_auth_token=params.hf_token, + use_auth_token=diarization_params.hf_token, transcribed_result=result, + device=diarization_params.device ) elapsed_time += elapsed_time_diarization return result, elapsed_time @@ -181,7 +183,7 @@ def transcribe_file(self, file_format: str = "SRT", add_timestamp: bool = True, progress=gr.Progress(), - *whisper_params, + *params, ) -> list: """ Write subtitle file from Files @@ -199,8 +201,8 @@ def transcribe_file(self, Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename. progress: gr.Progress Indicator to show progress directly in gradio. - *whisper_params: tuple - Parameters related with whisper. This will be dealt with "WhisperParameters" data class + *params: tuple + Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class Returns ---------- @@ -223,7 +225,7 @@ def transcribe_file(self, file, progress, add_timestamp, - *whisper_params, + *params, ) file_name, file_ext = os.path.splitext(os.path.basename(file)) @@ -514,13 +516,14 @@ def remove_input_files(file_paths: List[str]): @staticmethod def cache_parameters( - whisper_params: TranscriptionPipelineParams, + params: TranscriptionPipelineParams, add_timestamp: bool ): """cache parameters to the yaml file""" cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) - cached_whisper_param = whisper_params.to_yaml() - cached_yaml = {**cached_params, **cached_whisper_param} + param_to_cache = params.to_dict() + + cached_yaml = {**cached_params, **param_to_cache} cached_yaml["whisper"]["add_timestamp"] = add_timestamp save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH) diff --git a/tests/test_transcription.py b/tests/test_transcription.py index dc85c2df..f9dc3cc2 100644 --- a/tests/test_transcription.py +++ b/tests/test_transcription.py @@ -1,5 +1,5 @@ from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.data_classes import TranscriptionPipelineParams +from modules.whisper.data_classes import * from modules.utils.paths import WEBUI_DIR from test_config import * @@ -38,13 +38,21 @@ def test_transcribe( ) hparams = TranscriptionPipelineParams( - model_size=TEST_WHISPER_MODEL, - vad_filter=vad_filter, - is_bgm_separate=bgm_separation, - compute_type=whisper_inferencer.current_compute_type, - uvr_enable_offload=True, - is_diarize=diarization, - ).as_list() + whisper=WhisperParams( + model_size=TEST_WHISPER_MODEL, + compute_type=whisper_inferencer.current_compute_type + ), + vad=VadParams( + vad_filter=vad_filter + ), + bgm_separation=BGMSeparationParams( + is_separate_bgm=bgm_separation, + enable_offload=True + ), + diarization=DiarizationParams( + is_diarize=diarization + ), + ).to_list() subtitle_str, file_path = whisper_inferencer.transcribe_file( [audio_path],