diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 076f0fc6..571558c3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,7 +37,7 @@ jobs: run: sudo apt-get update && sudo apt-get install -y git ffmpeg - name: Install dependencies - run: pip install -r requirements.txt pytest + run: pip install -r requirements.txt pytest jiwer - name: Run test run: python -m pytest -rs tests \ No newline at end of file diff --git a/modules/whisper/base_transcription_pipeline.py b/modules/whisper/base_transcription_pipeline.py index 35605f54..4882abd8 100644 --- a/modules/whisper/base_transcription_pipeline.py +++ b/modules/whisper/base_transcription_pipeline.py @@ -179,7 +179,7 @@ def transcribe_file(self, add_timestamp: bool = True, progress=gr.Progress(), *pipeline_params, - ) -> list: + ) -> Tuple[str, List]: """ Write subtitle file from Files @@ -250,7 +250,7 @@ def transcribe_file(self, result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}" result_file_path = [info['path'] for info in files_info.values()] - return [result_str, result_file_path] + return result_str, result_file_path except Exception as e: print(f"Error transcribing file: {e}") @@ -264,7 +264,7 @@ def transcribe_mic(self, add_timestamp: bool = True, progress=gr.Progress(), *pipeline_params, - ) -> list: + ) -> Tuple[str, str]: """ Write subtitle file from microphone @@ -314,7 +314,7 @@ def transcribe_mic(self, ) result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}" - return [result_str, file_path] + return result_str, file_path except Exception as e: print(f"Error transcribing mic: {e}") raise @@ -327,7 +327,7 @@ def transcribe_youtube(self, add_timestamp: bool = True, progress=gr.Progress(), *pipeline_params, - ) -> list: + ) -> Tuple[str, str]: """ Write subtitle file from Youtube @@ -385,7 +385,7 @@ def transcribe_youtube(self, if os.path.exists(audio): os.remove(audio) - return [result_str, file_path] + return result_str, file_path except Exception as e: print(f"Error transcribing youtube: {e}") diff --git a/tests/test_config.py b/tests/test_config.py index ba52892d..f82e4f1a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,15 +1,16 @@ import functools +import jiwer +import os +import torch from modules.utils.paths import * from modules.utils.youtube_manager import * -import os -import torch - TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav" TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav") +TEST_ANSWER = "And so my fellow Americans ask not what your country can do for you ask what you can do for your country" TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer" -TEST_WHISPER_MODEL = "tiny.en" +TEST_WHISPER_MODEL = "tiny" TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4" TEST_NLLB_MODEL = "facebook/nllb-200-distilled-600M" TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt") @@ -34,3 +35,6 @@ def is_pytube_detected_bot(url: str = TEST_YOUTUBE_URL): print(f"Pytube has detected as a bot: {e}") return True + +def calculate_wer(answer, prediction): + return jiwer.wer(answer, prediction) diff --git a/tests/test_transcription.py b/tests/test_transcription.py index e9591d03..bc5267c4 100644 --- a/tests/test_transcription.py +++ b/tests/test_transcription.py @@ -1,5 +1,6 @@ from modules.whisper.whisper_factory import WhisperFactory from modules.whisper.data_classes import * +from modules.utils.subtitle_manager import read_file from modules.utils.paths import WEBUI_DIR from test_config import * @@ -28,6 +29,10 @@ def test_transcribe( if not os.path.exists(audio_path): download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir) + answer = TEST_ANSWER + if diarization: + answer = "SPEAKER_00|"+TEST_ANSWER + whisper_inferencer = WhisperFactory.create_whisper_inference( whisper_type=whisper_type, ) @@ -54,7 +59,7 @@ def test_transcribe( ), ).to_list() - subtitle_str, file_path = whisper_inferencer.transcribe_file( + subtitle_str, file_paths = whisper_inferencer.transcribe_file( [audio_path], None, "SRT", @@ -62,12 +67,11 @@ def test_transcribe( gr.Progress(), *hparams, ) - - assert isinstance(subtitle_str, str) and subtitle_str - assert isinstance(file_path[0], str) and file_path + subtitle = read_file(file_paths[0]).split("\n") + assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1 if not is_pytube_detected_bot(): - whisper_inferencer.transcribe_youtube( + subtitle_str, file_path = whisper_inferencer.transcribe_youtube( TEST_YOUTUBE_URL, "SRT", False, @@ -75,17 +79,17 @@ def test_transcribe( *hparams, ) assert isinstance(subtitle_str, str) and subtitle_str - assert isinstance(file_path[0], str) and file_path + assert os.path.exists(file_path) - whisper_inferencer.transcribe_mic( + subtitle_str, file_path = whisper_inferencer.transcribe_mic( audio_path, "SRT", False, gr.Progress(), *hparams, ) - assert isinstance(subtitle_str, str) and subtitle_str - assert isinstance(file_path[0], str) and file_path + subtitle = read_file(file_path).split("\n") + assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1 def download_file(url, save_dir):