Skip to content

Commit

Permalink
Apply Segment model to the pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
jhj0517 committed Oct 29, 2024
1 parent 95073dd commit cfab18d
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 44 deletions.
1 change: 1 addition & 0 deletions modules/diarize/diarize_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speaker
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
transcript_segments = transcript_result["segments"]
for seg in transcript_segments:
seg = seg.dict()
# assign speaker to segment (if any)
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
seg['start'])
Expand Down
24 changes: 16 additions & 8 deletions modules/diarize/diarizer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
import torch
from typing import List, Union, BinaryIO, Optional
from typing import List, Union, BinaryIO, Optional, Tuple
import numpy as np
import time
import logging

from modules.utils.paths import DIARIZATION_MODELS_DIR
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
from modules.diarize.audio_loader import load_audio
from modules.whisper.data_classes import *


class Diarizer:
Expand All @@ -23,18 +24,18 @@ def __init__(self,

def run(self,
audio: Union[str, BinaryIO, np.ndarray],
transcribed_result: List[dict],
transcribed_result: List[Segment],
use_auth_token: str,
device: Optional[str] = None
):
) -> Tuple[List[Segment], float]:
"""
Diarize transcribed result as a post-processing
Parameters
----------
audio: Union[str, BinaryIO, np.ndarray]
Audio input. This can be file path or binary type.
transcribed_result: List[dict]
transcribed_result: List[Segment]
transcribed result through whisper.
use_auth_token: str
Huggingface token with READ permission. This is only needed the first time you download the model.
Expand All @@ -44,8 +45,8 @@ def run(self,
Returns
----------
segments_result: List[dict]
list of dicts that includes start, end timestamps and transcribed text
segments_result: List[Segment]
list of Segment that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for running
"""
Expand All @@ -68,14 +69,21 @@ def run(self,
{"segments": transcribed_result}
)

segments_result = []
for segment in diarized_result["segments"]:
segment = segment.dict()
speaker = "None"
if "speaker" in segment:
speaker = segment["speaker"]
segment["text"] = speaker + "|" + segment["text"].strip()
diarized_text = speaker + "|" + segment["text"].strip()
segments_result.append(Segment(
start=segment["start"],
end=segment["end"],
text=diarized_text
))

elapsed_time = time.time() - start_time
return diarized_result["segments"], elapsed_time
return segments_result, elapsed_time

def update_pipe(self,
use_auth_token: str,
Expand Down
11 changes: 11 additions & 0 deletions modules/utils/subtitle_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re

from modules.whisper.data_classes import Segment


def timeformat_srt(time):
hours = time // 3600
Expand All @@ -23,6 +25,9 @@ def write_file(subtitle, output_file):


def get_srt(segments):
if segments and isinstance(segments[0], Segment):
segments = [seg.dict() for seg in segments]

output = ""
for i, segment in enumerate(segments):
output += f"{i + 1}\n"
Expand All @@ -34,6 +39,9 @@ def get_srt(segments):


def get_vtt(segments):
if segments and isinstance(segments[0], Segment):
segments = [seg.dict() for seg in segments]

output = "WebVTT\n\n"
for i, segment in enumerate(segments):
output += f"{i + 1}\n"
Expand All @@ -45,6 +53,9 @@ def get_vtt(segments):


def get_txt(segments):
if segments and isinstance(segments[0], Segment):
segments = [seg.dict() for seg in segments]

output = ""
for i, segment in enumerate(segments):
if segment['text'].startswith(' '):
Expand Down
11 changes: 6 additions & 5 deletions modules/vad/silero_vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from typing import BinaryIO, Union, List, Optional, Tuple
import warnings
import faster_whisper
from faster_whisper.transcribe import SpeechTimestampsMap, Segment
from modules.whisper.data_classes import *
from faster_whisper.transcribe import SpeechTimestampsMap
import gradio as gr


Expand Down Expand Up @@ -247,18 +248,18 @@ def format_timestamp(

def restore_speech_timestamps(
self,
segments: List[dict],
segments: List[Segment],
speech_chunks: List[dict],
sampling_rate: Optional[int] = None,
) -> List[dict]:
) -> List[Segment]:
if sampling_rate is None:
sampling_rate = self.sampling_rate

ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)

for segment in segments:
segment["start"] = ts_map.get_original_time(segment["start"])
segment["end"] = ts_map.get_original_time(segment["end"])
segment.start = ts_map.get_original_time(segment.start)
segment.start = ts_map.get_original_time(segment.start)

return segments

16 changes: 8 additions & 8 deletions modules/whisper/faster_whisper_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def transcribe(self,
audio: Union[str, BinaryIO, np.ndarray],
progress: gr.Progress = gr.Progress(),
*whisper_params,
) -> Tuple[List[dict], float]:
) -> Tuple[List[Segment], float]:
"""
transcribe method for faster-whisper.
Expand All @@ -55,8 +55,8 @@ def transcribe(self,
Returns
----------
segments_result: List[dict]
list of dicts that includes start, end timestamps and transcribed text
segments_result: List[Segment]
list of Segment that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for transcription
"""
Expand Down Expand Up @@ -102,11 +102,11 @@ def transcribe(self,
segments_result = []
for segment in segments:
progress(segment.start / info.duration, desc="Transcribing..")
segments_result.append({
"start": segment.start,
"end": segment.end,
"text": segment.text
})
segments_result.append(Segment(
start=segment.start,
end=segment.end,
text=segment.text
))

elapsed_time = time.time() - start_time
return segments_result, elapsed_time
Expand Down
20 changes: 14 additions & 6 deletions modules/whisper/insanely_fast_whisper_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def transcribe(self,
audio: Union[str, np.ndarray, torch.Tensor],
progress: gr.Progress = gr.Progress(),
*whisper_params,
) -> Tuple[List[dict], float]:
) -> Tuple[List[Segment], float]:
"""
transcribe method for faster-whisper.
Expand All @@ -55,8 +55,8 @@ def transcribe(self,
Returns
----------
segments_result: List[dict]
list of dicts that includes start, end timestamps and transcribed text
segments_result: List[Segment]
list of Segment that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for transcription
"""
Expand Down Expand Up @@ -95,9 +95,17 @@ def transcribe(self,
generate_kwargs=kwargs
)

segments_result = self.format_result(
transcribed_result=segments,
)
segments_result = []
for item in segments["chunks"]:
start, end = item["timestamp"][0], item["timestamp"][1]
if end is None:
end = start
segments_result.append(Segment(
text=item["text"],
start=start,
end=end
))

elapsed_time = time.time() - start_time
return segments_result, elapsed_time

Expand Down
41 changes: 24 additions & 17 deletions modules/whisper/whisper_Inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def transcribe(self,
audio: Union[str, np.ndarray, torch.Tensor],
progress: gr.Progress = gr.Progress(),
*whisper_params,
) -> Tuple[List[dict], float]:
) -> Tuple[List[Segment], float]:
"""
transcribe method for faster-whisper.
Expand All @@ -45,8 +45,8 @@ def transcribe(self,
Returns
----------
segments_result: List[dict]
list of dicts that includes start, end timestamps and transcribed text
segments_result: List[Segment]
list of Segment that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for transcription
"""
Expand All @@ -59,21 +59,28 @@ def transcribe(self,
def progress_callback(progress_value):
progress(progress_value, desc="Transcribing..")

segments_result = self.model.transcribe(audio=audio,
language=params.lang,
verbose=False,
beam_size=params.beam_size,
logprob_threshold=params.log_prob_threshold,
no_speech_threshold=params.no_speech_threshold,
task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
fp16=True if params.compute_type == "float16" else False,
best_of=params.best_of,
patience=params.patience,
temperature=params.temperature,
compression_ratio_threshold=params.compression_ratio_threshold,
progress_callback=progress_callback,)["segments"]
elapsed_time = time.time() - start_time
result = self.model.transcribe(audio=audio,
language=params.lang,
verbose=False,
beam_size=params.beam_size,
logprob_threshold=params.log_prob_threshold,
no_speech_threshold=params.no_speech_threshold,
task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
fp16=True if params.compute_type == "float16" else False,
best_of=params.best_of,
patience=params.patience,
temperature=params.temperature,
compression_ratio_threshold=params.compression_ratio_threshold,
progress_callback=progress_callback,)["segments"]
segments_result = []
for segment in result:
segments_result.append(Segment(
start=segment["start"],
end=segment["end"],
text=segment["text"]
))

elapsed_time = time.time() - start_time
return segments_result, elapsed_time

def update_model(self,
Expand Down

0 comments on commit cfab18d

Please sign in to comment.