Skip to content

Commit

Permalink
Merge pull request #363 from jhj0517/feature/refactor-models
Browse files Browse the repository at this point in the history
Refactor data classes
  • Loading branch information
jhj0517 authored Oct 30, 2024
2 parents ffb268e + eec0c16 commit d4bc29b
Show file tree
Hide file tree
Showing 19 changed files with 821 additions and 687 deletions.
181 changes: 25 additions & 156 deletions app.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions modules/diarize/diarize_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speaker
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
transcript_segments = transcript_result["segments"]
for seg in transcript_segments:
seg = seg.dict()
# assign speaker to segment (if any)
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
seg['start'])
Expand Down
24 changes: 16 additions & 8 deletions modules/diarize/diarizer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
import torch
from typing import List, Union, BinaryIO, Optional
from typing import List, Union, BinaryIO, Optional, Tuple
import numpy as np
import time
import logging

from modules.utils.paths import DIARIZATION_MODELS_DIR
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
from modules.diarize.audio_loader import load_audio
from modules.whisper.data_classes import *


class Diarizer:
Expand All @@ -23,18 +24,18 @@ def __init__(self,

def run(self,
audio: Union[str, BinaryIO, np.ndarray],
transcribed_result: List[dict],
transcribed_result: List[Segment],
use_auth_token: str,
device: Optional[str] = None
):
) -> Tuple[List[Segment], float]:
"""
Diarize transcribed result as a post-processing
Parameters
----------
audio: Union[str, BinaryIO, np.ndarray]
Audio input. This can be file path or binary type.
transcribed_result: List[dict]
transcribed_result: List[Segment]
transcribed result through whisper.
use_auth_token: str
Huggingface token with READ permission. This is only needed the first time you download the model.
Expand All @@ -44,8 +45,8 @@ def run(self,
Returns
----------
segments_result: List[dict]
list of dicts that includes start, end timestamps and transcribed text
segments_result: List[Segment]
list of Segment that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for running
"""
Expand All @@ -68,14 +69,21 @@ def run(self,
{"segments": transcribed_result}
)

segments_result = []
for segment in diarized_result["segments"]:
segment = segment.dict()
speaker = "None"
if "speaker" in segment:
speaker = segment["speaker"]
segment["text"] = speaker + "|" + segment["text"].strip()
diarized_text = speaker + "|" + segment["text"].strip()
segments_result.append(Segment(
start=segment["start"],
end=segment["end"],
text=diarized_text
))

elapsed_time = time.time() - start_time
return diarized_result["segments"], elapsed_time
return segments_result, elapsed_time

def update_pipe(self,
use_auth_token: str,
Expand Down
2 changes: 1 addition & 1 deletion modules/translation/translation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from datetime import datetime

import modules.translation.nllb_inference as nllb
from modules.whisper.whisper_parameter import *
from modules.whisper.data_classes import *
from modules.utils.subtitle_manager import *
from modules.utils.files_manager import load_yaml, save_yaml
from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
Expand Down
3 changes: 3 additions & 0 deletions modules/utils/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from gradio_i18n import Translate, gettext as _

AUTOMATIC_DETECTION = _("Automatic Detection")
GRADIO_NONE_STR = ""
GRADIO_NONE_NUMBER_MAX = 9999
GRADIO_NONE_NUMBER_MIN = 0
11 changes: 11 additions & 0 deletions modules/utils/subtitle_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re

from modules.whisper.data_classes import Segment


def timeformat_srt(time):
hours = time // 3600
Expand All @@ -23,6 +25,9 @@ def write_file(subtitle, output_file):


def get_srt(segments):
if segments and isinstance(segments[0], Segment):
segments = [seg.dict() for seg in segments]

output = ""
for i, segment in enumerate(segments):
output += f"{i + 1}\n"
Expand All @@ -34,6 +39,9 @@ def get_srt(segments):


def get_vtt(segments):
if segments and isinstance(segments[0], Segment):
segments = [seg.dict() for seg in segments]

output = "WEBVTT\n\n"
for i, segment in enumerate(segments):
output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
Expand All @@ -44,6 +52,9 @@ def get_vtt(segments):


def get_txt(segments):
if segments and isinstance(segments[0], Segment):
segments = [seg.dict() for seg in segments]

output = ""
for i, segment in enumerate(segments):
if segment['text'].startswith(' '):
Expand Down
11 changes: 6 additions & 5 deletions modules/vad/silero_vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from typing import BinaryIO, Union, List, Optional, Tuple
import warnings
import faster_whisper
from faster_whisper.transcribe import SpeechTimestampsMap, Segment
from modules.whisper.data_classes import *
from faster_whisper.transcribe import SpeechTimestampsMap
import gradio as gr


Expand Down Expand Up @@ -247,18 +248,18 @@ def format_timestamp(

def restore_speech_timestamps(
self,
segments: List[dict],
segments: List[Segment],
speech_chunks: List[dict],
sampling_rate: Optional[int] = None,
) -> List[dict]:
) -> List[Segment]:
if sampling_rate is None:
sampling_rate = self.sampling_rate

ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)

for segment in segments:
segment["start"] = ts_map.get_original_time(segment["start"])
segment["end"] = ts_map.get_original_time(segment["end"])
segment.start = ts_map.get_original_time(segment.start)
segment.start = ts_map.get_original_time(segment.start)

return segments

Loading

0 comments on commit d4bc29b

Please sign in to comment.