Skip to content

Commit

Permalink
Update model usage
Browse files Browse the repository at this point in the history
  • Loading branch information
jhj0517 committed Oct 27, 2024
1 parent a6b915f commit 7de504c
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 46 deletions.
2 changes: 1 addition & 1 deletion modules/whisper/faster_whisper_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def transcribe(self,
"""
start_time = time.time()

params = TranscriptionPipelineGradioComponents.as_value(*whisper_params)
params = WhisperParams.from_list(list(whisper_params))

if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
self.update_model(params.model_size, params.compute_type, progress)
Expand Down
2 changes: 1 addition & 1 deletion modules/whisper/insanely_fast_whisper_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def transcribe(self,
elapsed time for transcription
"""
start_time = time.time()
params = TranscriptionPipelineGradioComponents.as_value(*whisper_params)
params = WhisperParams.from_list(list(whisper_params))

if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
self.update_model(params.model_size, params.compute_type, progress)
Expand Down
2 changes: 1 addition & 1 deletion modules/whisper/whisper_Inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def transcribe(self,
elapsed time for transcription
"""
start_time = time.time()
params = TranscriptionPipelineGradioComponents.as_value(*whisper_params)
params = WhisperParams.from_list(list(whisper_params))

if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
self.update_model(params.model_size, params.compute_type, progress)
Expand Down
73 changes: 38 additions & 35 deletions modules/whisper/whisper_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run(self,
audio: Union[str, BinaryIO, np.ndarray],
progress: gr.Progress = gr.Progress(),
add_timestamp: bool = True,
*whisper_params,
*pipeline_params,
) -> Tuple[List[dict], float]:
"""
Run transcription with conditional pre-processing and post-processing.
Expand All @@ -89,8 +89,8 @@ def run(self,
Indicator to show progress directly in gradio.
add_timestamp: bool
Whether to add a timestamp at the end of the filename.
*whisper_params: tuple
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
*pipeline_params: tuple
Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class
Returns
----------
Expand All @@ -99,28 +99,29 @@ def run(self,
elapsed_time: float
elapsed time for running
"""
params = TranscriptionPipelineGradioComponents.as_value(*whisper_params)
params = TranscriptionPipelineParams.from_list(list(pipeline_params))
bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization

self.cache_parameters(
whisper_params=params,
params=params,
add_timestamp=add_timestamp
)

if params.lang is None:
if whisper_params.lang is None:
pass
elif params.lang == AUTOMATIC_DETECTION:
params.lang = None
elif whisper_params.lang == AUTOMATIC_DETECTION:
whisper_params.lang = None
else:
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
params.lang = language_code_dict[params.lang]
whisper_params.lang = language_code_dict[params.lang]

if params.is_bgm_separate:
if bgm_params.is_separate_bgm:
music, audio, _ = self.music_separator.separate(
audio=audio,
model_name=params.uvr_model_size,
device=params.uvr_device,
segment_size=params.uvr_segment_size,
save_file=params.uvr_save_file,
model_name=bgm_params.model_size,
device=bgm_params.device,
segment_size=bgm_params.segment_size,
save_file=bgm_params.save_file,
progress=progress
)

Expand All @@ -132,20 +133,20 @@ def run(self,
origin_sample_rate = self.music_separator.audio_info.sample_rate
audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)

if params.uvr_enable_offload:
if bgm_params.enable_offload:
self.music_separator.offload()

if params.vad_filter:
if vad_params.vad_filter:
# Explicit value set for float('inf') from gr.Number()
if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
params.max_speech_duration_s = float('inf')
if vad_params.max_speech_duration_s is None or vad_params.max_speech_duration_s >= 9999:
vad_params.max_speech_duration_s = float('inf')

vad_options = VadOptions(
threshold=params.threshold,
min_speech_duration_ms=params.min_speech_duration_ms,
max_speech_duration_s=params.max_speech_duration_s,
min_silence_duration_ms=params.min_silence_duration_ms,
speech_pad_ms=params.speech_pad_ms
threshold=vad_params.threshold,
min_speech_duration_ms=vad_params.min_speech_duration_ms,
max_speech_duration_s=vad_params.max_speech_duration_s,
min_silence_duration_ms=vad_params.min_silence_duration_ms,
speech_pad_ms=vad_params.speech_pad_ms
)

audio, speech_chunks = self.vad.run(
Expand All @@ -157,20 +158,21 @@ def run(self,
result, elapsed_time = self.transcribe(
audio,
progress,
*astuple(params)
*whisper_params.to_list()
)

if params.vad_filter:
if vad_params.vad_filter:
result = self.vad.restore_speech_timestamps(
segments=result,
speech_chunks=speech_chunks,
speech_chunks=vad_params.speech_chunks,
)

if params.is_diarize:
if diarization_params.is_diarize:
result, elapsed_time_diarization = self.diarizer.run(
audio=audio,
use_auth_token=params.hf_token,
use_auth_token=diarization_params.hf_token,
transcribed_result=result,
device=diarization_params.device
)
elapsed_time += elapsed_time_diarization
return result, elapsed_time
Expand All @@ -181,7 +183,7 @@ def transcribe_file(self,
file_format: str = "SRT",
add_timestamp: bool = True,
progress=gr.Progress(),
*whisper_params,
*params,
) -> list:
"""
Write subtitle file from Files
Expand All @@ -199,8 +201,8 @@ def transcribe_file(self,
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
progress: gr.Progress
Indicator to show progress directly in gradio.
*whisper_params: tuple
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
*params: tuple
Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class
Returns
----------
Expand All @@ -223,7 +225,7 @@ def transcribe_file(self,
file,
progress,
add_timestamp,
*whisper_params,
*params,
)

file_name, file_ext = os.path.splitext(os.path.basename(file))
Expand Down Expand Up @@ -514,13 +516,14 @@ def remove_input_files(file_paths: List[str]):

@staticmethod
def cache_parameters(
whisper_params: TranscriptionPipelineParams,
params: TranscriptionPipelineParams,
add_timestamp: bool
):
"""cache parameters to the yaml file"""
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
cached_whisper_param = whisper_params.to_yaml()
cached_yaml = {**cached_params, **cached_whisper_param}
param_to_cache = params.to_dict()

cached_yaml = {**cached_params, **param_to_cache}
cached_yaml["whisper"]["add_timestamp"] = add_timestamp

save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
Expand Down
24 changes: 16 additions & 8 deletions tests/test_transcription.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from modules.whisper.whisper_factory import WhisperFactory
from modules.whisper.data_classes import TranscriptionPipelineParams
from modules.whisper.data_classes import *
from modules.utils.paths import WEBUI_DIR
from test_config import *

Expand Down Expand Up @@ -38,13 +38,21 @@ def test_transcribe(
)

hparams = TranscriptionPipelineParams(
model_size=TEST_WHISPER_MODEL,
vad_filter=vad_filter,
is_bgm_separate=bgm_separation,
compute_type=whisper_inferencer.current_compute_type,
uvr_enable_offload=True,
is_diarize=diarization,
).as_list()
whisper=WhisperParams(
model_size=TEST_WHISPER_MODEL,
compute_type=whisper_inferencer.current_compute_type
),
vad=VadParams(
vad_filter=vad_filter
),
bgm_separation=BGMSeparationParams(
is_separate_bgm=bgm_separation,
enable_offload=True
),
diarization=DiarizationParams(
is_diarize=diarization
),
).to_list()

subtitle_str, file_path = whisper_inferencer.transcribe_file(
[audio_path],
Expand Down

0 comments on commit 7de504c

Please sign in to comment.