diff --git a/tests/test_vad.py b/tests/test_vad.py index cb3dc05..eb0e35f 100644 --- a/tests/test_vad.py +++ b/tests/test_vad.py @@ -1,13 +1,13 @@ -from modules.utils.paths import * -from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.data_classes import * -from test_config import * -from test_transcription import download_file, test_transcribe - import gradio as gr import pytest import os +from modules.whisper.data_classes import * +from modules.vad.silero_vad import SileroVAD +from test_config import * +from test_transcription import download_file, test_transcribe +from faster_whisper.vad import VadOptions, get_speech_timestamps + @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", @@ -24,3 +24,35 @@ def test_vad_pipeline( diarization: bool, ): test_transcribe(whisper_type, vad_filter, bgm_separation, diarization) + + +@pytest.mark.parametrize( + "threshold,min_speech_duration_ms,min_silence_duration_ms", + [ + (0.5, 250, 2000), + ] +) +def test_vad( + threshold: float, + min_speech_duration_ms: int, + min_silence_duration_ms: int +): + audio_path_dir = os.path.join(WEBUI_DIR, "tests") + audio_path = os.path.join(audio_path_dir, "jfk.wav") + + if not os.path.exists(audio_path): + download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir) + + vad_model = SileroVAD() + vad_model.update_model() + + audio, speech_chunks = vad_model.run( + audio=audio_path, + vad_parameters=VadOptions( + threshold=threshold, + min_silence_duration_ms=min_silence_duration_ms, + min_speech_duration_ms=min_speech_duration_ms + ) + ) + + assert speech_chunks