Skip to content

Commit

Permalink
Merge pull request #380 from linuxlurak/master-mod
Browse files Browse the repository at this point in the history
Include loading of default value of file_format from config file.
  • Loading branch information
jhj0517 authored Nov 4, 2024
2 parents 02f3c63 + e284444 commit e9891d2
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
2 changes: 1 addition & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def create_pipeline_inputs(self):
dd_lang = gr.Dropdown(choices=self.whisper_inf.available_langs + [AUTOMATIC_DETECTION],
value=AUTOMATIC_DETECTION if whisper_params["lang"] == AUTOMATIC_DETECTION.unwrap()
else whisper_params["lang"], label=_("Language"))
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value="SRT", label=_("File Format"))
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value=whisper_params["file_format"], label=_("File Format"))
with gr.Row():
cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label=_("Translate to English?"),
interactive=True)
Expand Down
1 change: 1 addition & 0 deletions configs/default_parameters.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
whisper:
model_size: "large-v2"
file_format: "SRT"
lang: "Automatic Detection"
is_translate: false
beam_size: 5
Expand Down
11 changes: 10 additions & 1 deletion modules/whisper/base_transcription_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def update_model(self,
def run(self,
audio: Union[str, BinaryIO, np.ndarray],
progress: gr.Progress = gr.Progress(),
file_format: str = "SRT",
add_timestamp: bool = True,
*pipeline_params,
) -> Tuple[List[Segment], float]:
Expand All @@ -86,6 +87,8 @@ def run(self,
Audio input. This can be file path or binary type.
progress: gr.Progress
Indicator to show progress directly in gradio.
file_format: str
Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"]
add_timestamp: bool
Whether to add a timestamp at the end of the filename.
*pipeline_params: tuple
Expand Down Expand Up @@ -168,6 +171,7 @@ def run(self,

self.cache_parameters(
params=params,
file_format=file_format,
add_timestamp=add_timestamp
)
return result, elapsed_time
Expand Down Expand Up @@ -224,6 +228,7 @@ def transcribe_file(self,
transcribed_segments, time_for_task = self.run(
file,
progress,
file_format,
add_timestamp,
*pipeline_params,
)
Expand Down Expand Up @@ -298,6 +303,7 @@ def transcribe_mic(self,
transcribed_segments, time_for_task = self.run(
mic_audio,
progress,
file_format,
add_timestamp,
*pipeline_params,
)
Expand Down Expand Up @@ -364,6 +370,7 @@ def transcribe_youtube(self,
transcribed_segments, time_for_task = self.run(
audio,
progress,
file_format,
add_timestamp,
*pipeline_params,
)
Expand Down Expand Up @@ -513,14 +520,16 @@ def validate_gradio_values(params: TranscriptionPipelineParams):
@staticmethod
def cache_parameters(
params: TranscriptionPipelineParams,
add_timestamp: bool
file_format: str = "SRT",
add_timestamp: bool = True
):
"""Cache parameters to the yaml file"""
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
param_to_cache = params.to_dict()

cached_yaml = {**cached_params, **param_to_cache}
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
cached_yaml["whisper"]["file_format"] = file_format

supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
if supress_token and isinstance(supress_token, list):
Expand Down

0 comments on commit e9891d2

Please sign in to comment.