From cfab18dbaccea9056d006202ed92abc1f47ef19c Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Tue, 29 Oct 2024 19:58:00 +0900 Subject: [PATCH] Apply Segment model to the pipeline --- modules/diarize/diarize_pipeline.py | 1 + modules/diarize/diarizer.py | 24 +++++++---- modules/utils/subtitle_manager.py | 11 +++++ modules/vad/silero_vad.py | 11 ++--- modules/whisper/faster_whisper_inference.py | 16 ++++---- .../insanely_fast_whisper_inference.py | 20 ++++++--- modules/whisper/whisper_Inference.py | 41 +++++++++++-------- 7 files changed, 80 insertions(+), 44 deletions(-) diff --git a/modules/diarize/diarize_pipeline.py b/modules/diarize/diarize_pipeline.py index b4109e84..360f88b7 100644 --- a/modules/diarize/diarize_pipeline.py +++ b/modules/diarize/diarize_pipeline.py @@ -44,6 +44,7 @@ def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speaker def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False): transcript_segments = transcript_result["segments"] for seg in transcript_segments: + seg = seg.dict() # assign speaker to segment (if any) diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start']) diff --git a/modules/diarize/diarizer.py b/modules/diarize/diarizer.py index 2dd3f94a..4c0727ab 100644 --- a/modules/diarize/diarizer.py +++ b/modules/diarize/diarizer.py @@ -1,6 +1,6 @@ import os import torch -from typing import List, Union, BinaryIO, Optional +from typing import List, Union, BinaryIO, Optional, Tuple import numpy as np import time import logging @@ -8,6 +8,7 @@ from modules.utils.paths import DIARIZATION_MODELS_DIR from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers from modules.diarize.audio_loader import load_audio +from modules.whisper.data_classes import * class Diarizer: @@ -23,10 +24,10 @@ def __init__(self, def run(self, audio: Union[str, BinaryIO, np.ndarray], - transcribed_result: List[dict], + transcribed_result: List[Segment], use_auth_token: str, device: Optional[str] = None - ): + ) -> Tuple[List[Segment], float]: """ Diarize transcribed result as a post-processing @@ -34,7 +35,7 @@ def run(self, ---------- audio: Union[str, BinaryIO, np.ndarray] Audio input. This can be file path or binary type. - transcribed_result: List[dict] + transcribed_result: List[Segment] transcribed result through whisper. use_auth_token: str Huggingface token with READ permission. This is only needed the first time you download the model. @@ -44,8 +45,8 @@ def run(self, Returns ---------- - segments_result: List[dict] - list of dicts that includes start, end timestamps and transcribed text + segments_result: List[Segment] + list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for running """ @@ -68,14 +69,21 @@ def run(self, {"segments": transcribed_result} ) + segments_result = [] for segment in diarized_result["segments"]: + segment = segment.dict() speaker = "None" if "speaker" in segment: speaker = segment["speaker"] - segment["text"] = speaker + "|" + segment["text"].strip() + diarized_text = speaker + "|" + segment["text"].strip() + segments_result.append(Segment( + start=segment["start"], + end=segment["end"], + text=diarized_text + )) elapsed_time = time.time() - start_time - return diarized_result["segments"], elapsed_time + return segments_result, elapsed_time def update_pipe(self, use_auth_token: str, diff --git a/modules/utils/subtitle_manager.py b/modules/utils/subtitle_manager.py index 4b484254..44751a82 100644 --- a/modules/utils/subtitle_manager.py +++ b/modules/utils/subtitle_manager.py @@ -1,5 +1,7 @@ import re +from modules.whisper.data_classes import Segment + def timeformat_srt(time): hours = time // 3600 @@ -23,6 +25,9 @@ def write_file(subtitle, output_file): def get_srt(segments): + if segments and isinstance(segments[0], Segment): + segments = [seg.dict() for seg in segments] + output = "" for i, segment in enumerate(segments): output += f"{i + 1}\n" @@ -34,6 +39,9 @@ def get_srt(segments): def get_vtt(segments): + if segments and isinstance(segments[0], Segment): + segments = [seg.dict() for seg in segments] + output = "WebVTT\n\n" for i, segment in enumerate(segments): output += f"{i + 1}\n" @@ -45,6 +53,9 @@ def get_vtt(segments): def get_txt(segments): + if segments and isinstance(segments[0], Segment): + segments = [seg.dict() for seg in segments] + output = "" for i, segment in enumerate(segments): if segment['text'].startswith(' '): diff --git a/modules/vad/silero_vad.py b/modules/vad/silero_vad.py index bb5c9192..d44c26fc 100644 --- a/modules/vad/silero_vad.py +++ b/modules/vad/silero_vad.py @@ -5,7 +5,8 @@ from typing import BinaryIO, Union, List, Optional, Tuple import warnings import faster_whisper -from faster_whisper.transcribe import SpeechTimestampsMap, Segment +from modules.whisper.data_classes import * +from faster_whisper.transcribe import SpeechTimestampsMap import gradio as gr @@ -247,18 +248,18 @@ def format_timestamp( def restore_speech_timestamps( self, - segments: List[dict], + segments: List[Segment], speech_chunks: List[dict], sampling_rate: Optional[int] = None, - ) -> List[dict]: + ) -> List[Segment]: if sampling_rate is None: sampling_rate = self.sampling_rate ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate) for segment in segments: - segment["start"] = ts_map.get_original_time(segment["start"]) - segment["end"] = ts_map.get_original_time(segment["end"]) + segment.start = ts_map.get_original_time(segment.start) + segment.start = ts_map.get_original_time(segment.start) return segments diff --git a/modules/whisper/faster_whisper_inference.py b/modules/whisper/faster_whisper_inference.py index b4a2b4ce..5dad6edb 100644 --- a/modules/whisper/faster_whisper_inference.py +++ b/modules/whisper/faster_whisper_inference.py @@ -40,7 +40,7 @@ def transcribe(self, audio: Union[str, BinaryIO, np.ndarray], progress: gr.Progress = gr.Progress(), *whisper_params, - ) -> Tuple[List[dict], float]: + ) -> Tuple[List[Segment], float]: """ transcribe method for faster-whisper. @@ -55,8 +55,8 @@ def transcribe(self, Returns ---------- - segments_result: List[dict] - list of dicts that includes start, end timestamps and transcribed text + segments_result: List[Segment] + list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for transcription """ @@ -102,11 +102,11 @@ def transcribe(self, segments_result = [] for segment in segments: progress(segment.start / info.duration, desc="Transcribing..") - segments_result.append({ - "start": segment.start, - "end": segment.end, - "text": segment.text - }) + segments_result.append(Segment( + start=segment.start, + end=segment.end, + text=segment.text + )) elapsed_time = time.time() - start_time return segments_result, elapsed_time diff --git a/modules/whisper/insanely_fast_whisper_inference.py b/modules/whisper/insanely_fast_whisper_inference.py index bca9a628..c1f7cb2c 100644 --- a/modules/whisper/insanely_fast_whisper_inference.py +++ b/modules/whisper/insanely_fast_whisper_inference.py @@ -40,7 +40,7 @@ def transcribe(self, audio: Union[str, np.ndarray, torch.Tensor], progress: gr.Progress = gr.Progress(), *whisper_params, - ) -> Tuple[List[dict], float]: + ) -> Tuple[List[Segment], float]: """ transcribe method for faster-whisper. @@ -55,8 +55,8 @@ def transcribe(self, Returns ---------- - segments_result: List[dict] - list of dicts that includes start, end timestamps and transcribed text + segments_result: List[Segment] + list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for transcription """ @@ -95,9 +95,17 @@ def transcribe(self, generate_kwargs=kwargs ) - segments_result = self.format_result( - transcribed_result=segments, - ) + segments_result = [] + for item in segments["chunks"]: + start, end = item["timestamp"][0], item["timestamp"][1] + if end is None: + end = start + segments_result.append(Segment( + text=item["text"], + start=start, + end=end + )) + elapsed_time = time.time() - start_time return segments_result, elapsed_time diff --git a/modules/whisper/whisper_Inference.py b/modules/whisper/whisper_Inference.py index 825bbe3e..ccd4bbb4 100644 --- a/modules/whisper/whisper_Inference.py +++ b/modules/whisper/whisper_Inference.py @@ -30,7 +30,7 @@ def transcribe(self, audio: Union[str, np.ndarray, torch.Tensor], progress: gr.Progress = gr.Progress(), *whisper_params, - ) -> Tuple[List[dict], float]: + ) -> Tuple[List[Segment], float]: """ transcribe method for faster-whisper. @@ -45,8 +45,8 @@ def transcribe(self, Returns ---------- - segments_result: List[dict] - list of dicts that includes start, end timestamps and transcribed text + segments_result: List[Segment] + list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for transcription """ @@ -59,21 +59,28 @@ def transcribe(self, def progress_callback(progress_value): progress(progress_value, desc="Transcribing..") - segments_result = self.model.transcribe(audio=audio, - language=params.lang, - verbose=False, - beam_size=params.beam_size, - logprob_threshold=params.log_prob_threshold, - no_speech_threshold=params.no_speech_threshold, - task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe", - fp16=True if params.compute_type == "float16" else False, - best_of=params.best_of, - patience=params.patience, - temperature=params.temperature, - compression_ratio_threshold=params.compression_ratio_threshold, - progress_callback=progress_callback,)["segments"] - elapsed_time = time.time() - start_time + result = self.model.transcribe(audio=audio, + language=params.lang, + verbose=False, + beam_size=params.beam_size, + logprob_threshold=params.log_prob_threshold, + no_speech_threshold=params.no_speech_threshold, + task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe", + fp16=True if params.compute_type == "float16" else False, + best_of=params.best_of, + patience=params.patience, + temperature=params.temperature, + compression_ratio_threshold=params.compression_ratio_threshold, + progress_callback=progress_callback,)["segments"] + segments_result = [] + for segment in result: + segments_result.append(Segment( + start=segment["start"], + end=segment["end"], + text=segment["text"] + )) + elapsed_time = time.time() - start_time return segments_result, elapsed_time def update_model(self,