diff --git a/app.py b/app.py index c22e52d..175522b 100644 --- a/app.py +++ b/app.py @@ -53,7 +53,7 @@ def create_pipeline_inputs(self): dd_lang = gr.Dropdown(choices=self.whisper_inf.available_langs + [AUTOMATIC_DETECTION], value=AUTOMATIC_DETECTION if whisper_params["lang"] == AUTOMATIC_DETECTION.unwrap() else whisper_params["lang"], label=_("Language")) - dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value="SRT", label=_("File Format")) + dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value=whisper_params["file_format"], label=_("File Format")) with gr.Row(): cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label=_("Translate to English?"), interactive=True) diff --git a/configs/default_parameters.yaml b/configs/default_parameters.yaml index 3a12d14..e317d25 100644 --- a/configs/default_parameters.yaml +++ b/configs/default_parameters.yaml @@ -1,5 +1,6 @@ whisper: model_size: "large-v2" + file_format: "SRT" lang: "Automatic Detection" is_translate: false beam_size: 5 diff --git a/modules/whisper/base_transcription_pipeline.py b/modules/whisper/base_transcription_pipeline.py index 4882abd..2791dc6 100644 --- a/modules/whisper/base_transcription_pipeline.py +++ b/modules/whisper/base_transcription_pipeline.py @@ -71,6 +71,7 @@ def update_model(self, def run(self, audio: Union[str, BinaryIO, np.ndarray], progress: gr.Progress = gr.Progress(), + file_format: str = "SRT", add_timestamp: bool = True, *pipeline_params, ) -> Tuple[List[Segment], float]: @@ -86,6 +87,8 @@ def run(self, Audio input. This can be file path or binary type. progress: gr.Progress Indicator to show progress directly in gradio. + file_format: str + Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"] add_timestamp: bool Whether to add a timestamp at the end of the filename. *pipeline_params: tuple @@ -168,6 +171,7 @@ def run(self, self.cache_parameters( params=params, + file_format=file_format, add_timestamp=add_timestamp ) return result, elapsed_time @@ -224,6 +228,7 @@ def transcribe_file(self, transcribed_segments, time_for_task = self.run( file, progress, + file_format, add_timestamp, *pipeline_params, ) @@ -298,6 +303,7 @@ def transcribe_mic(self, transcribed_segments, time_for_task = self.run( mic_audio, progress, + file_format, add_timestamp, *pipeline_params, ) @@ -364,6 +370,7 @@ def transcribe_youtube(self, transcribed_segments, time_for_task = self.run( audio, progress, + file_format, add_timestamp, *pipeline_params, ) @@ -513,7 +520,8 @@ def validate_gradio_values(params: TranscriptionPipelineParams): @staticmethod def cache_parameters( params: TranscriptionPipelineParams, - add_timestamp: bool + file_format: str = "SRT", + add_timestamp: bool = True ): """Cache parameters to the yaml file""" cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) @@ -521,6 +529,7 @@ def cache_parameters( cached_yaml = {**cached_params, **param_to_cache} cached_yaml["whisper"]["add_timestamp"] = add_timestamp + cached_yaml["whisper"]["file_format"] = file_format supress_token = cached_yaml["whisper"].get("suppress_tokens", None) if supress_token and isinstance(supress_token, list):