Skip to content

Commit

Permalink
Handle gradio none values
Browse files Browse the repository at this point in the history
  • Loading branch information
jhj0517 committed Oct 28, 2024
1 parent e58ee71 commit d48aaa6
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 25 deletions.
10 changes: 0 additions & 10 deletions modules/whisper/faster_whisper_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,6 @@ def transcribe(self,
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
self.update_model(params.model_size, params.compute_type, progress)

# None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723
if not params.initial_prompt:
params.initial_prompt = None
if not params.prefix:
params.prefix = None
if not params.hotwords:
params.hotwords = None

params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)

segments, info = self.model.transcribe(
audio=audio,
language=params.lang,
Expand Down
56 changes: 41 additions & 15 deletions modules/whisper/whisper_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import torch
import ast
import whisper
import ctranslate2
import gradio as gr
Expand All @@ -14,7 +15,7 @@
from modules.uvr.music_separator import MusicSeparator
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
UVR_MODELS_DIR)
from modules.utils.constants import AUTOMATIC_DETECTION
from modules.utils.constants import AUTOMATIC_DETECTION, GRADIO_NONE_VALUES
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
Expand Down Expand Up @@ -101,16 +102,9 @@ def run(self,
elapsed time for running
"""
params = TranscriptionPipelineParams.from_list(list(pipeline_params))
params = self.handle_gradio_values(params)
bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization

if whisper_params.lang is None:
pass
elif whisper_params.lang == AUTOMATIC_DETECTION:
whisper_params.lang = None
else:
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
whisper_params.lang = language_code_dict[params.lang]

if bgm_params.is_separate_bgm:
music, audio, _ = self.music_separator.separate(
audio=audio,
Expand Down Expand Up @@ -515,25 +509,57 @@ def remove_input_files(file_paths: List[str]):
if file_path and os.path.exists(file_path):
os.remove(file_path)

@staticmethod
def handle_gradio_values(params: TranscriptionPipelineParams):
"""
Handle gradio specific values that can't be displayed as None in the UI.
Related issue : https://github.com/gradio-app/gradio/issues/8723
"""
if params.whisper.lang is None:
pass
elif params.whisper.lang == AUTOMATIC_DETECTION:
params.whisper.lang = None
else:
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
params.whisper.lang = language_code_dict[params.lang]

if not params.whisper.initial_prompt:
params.whisper.initial_prompt = None
if not params.whisper.prefix:
params.whisper.prefix = None
if not params.whisper.hotwords:
params.whisper.hotwords = None
if params.whisper.max_new_tokens == 0:
params.whisper.max_new_tokens = None
if params.whisper.hallucination_silence_threshold == 0:
params.whisper.hallucination_silence_threshold = None
if params.whisper.language_detection_threshold == 0:
params.whisper.language_detection_threshold = None
if params.whisper.max_speech_duration_s >= 9999:
params.whisper.max_speech_duration_s = float('inf')
return params

@staticmethod
def cache_parameters(
params: TranscriptionPipelineParams,
add_timestamp: bool
):
"""cache parameters to the yaml file"""

"""Cache parameters to the yaml file"""
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
param_to_cache = params.to_dict()

print(param_to_cache)

cached_yaml = {**cached_params, **param_to_cache}
cached_yaml["whisper"]["add_timestamp"] = add_timestamp

supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
if supress_token and isinstance(supress_token, list):
cached_yaml["whisper"]["suppress_tokens"] = str(supress_token)

if cached_yaml["whisper"].get("lang", None) is None:
cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION
cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap()

save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
if cached_yaml is not None and cached_yaml:
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)

@staticmethod
def resample_audio(audio: Union[str, np.ndarray],
Expand Down

0 comments on commit d48aaa6

Please sign in to comment.