Skip to content

Commit

Permalink
Merge pull request #375 from jhj0517/fix/improve-test
Browse files Browse the repository at this point in the history
Improve test based on WER
  • Loading branch information
jhj0517 authored Nov 2, 2024
2 parents f12a40c + c7bfcf2 commit 50380bc
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y git ffmpeg

- name: Install dependencies
run: pip install -r requirements.txt pytest
run: pip install -r requirements.txt pytest jiwer

- name: Run test
run: python -m pytest -rs tests
12 changes: 6 additions & 6 deletions modules/whisper/base_transcription_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def transcribe_file(self,
add_timestamp: bool = True,
progress=gr.Progress(),
*pipeline_params,
) -> list:
) -> Tuple[str, List]:
"""
Write subtitle file from Files
Expand Down Expand Up @@ -250,7 +250,7 @@ def transcribe_file(self,
result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
result_file_path = [info['path'] for info in files_info.values()]

return [result_str, result_file_path]
return result_str, result_file_path

except Exception as e:
print(f"Error transcribing file: {e}")
Expand All @@ -264,7 +264,7 @@ def transcribe_mic(self,
add_timestamp: bool = True,
progress=gr.Progress(),
*pipeline_params,
) -> list:
) -> Tuple[str, str]:
"""
Write subtitle file from microphone
Expand Down Expand Up @@ -314,7 +314,7 @@ def transcribe_mic(self,
)

result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
return [result_str, file_path]
return result_str, file_path
except Exception as e:
print(f"Error transcribing mic: {e}")
raise
Expand All @@ -327,7 +327,7 @@ def transcribe_youtube(self,
add_timestamp: bool = True,
progress=gr.Progress(),
*pipeline_params,
) -> list:
) -> Tuple[str, str]:
"""
Write subtitle file from Youtube
Expand Down Expand Up @@ -385,7 +385,7 @@ def transcribe_youtube(self,
if os.path.exists(audio):
os.remove(audio)

return [result_str, file_path]
return result_str, file_path

except Exception as e:
print(f"Error transcribing youtube: {e}")
Expand Down
12 changes: 8 additions & 4 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import functools
import jiwer
import os
import torch

from modules.utils.paths import *
from modules.utils.youtube_manager import *

import os
import torch

TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
TEST_ANSWER = "And so my fellow Americans ask not what your country can do for you ask what you can do for your country"
TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer"
TEST_WHISPER_MODEL = "tiny.en"
TEST_WHISPER_MODEL = "tiny"
TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4"
TEST_NLLB_MODEL = "facebook/nllb-200-distilled-600M"
TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
Expand All @@ -34,3 +35,6 @@ def is_pytube_detected_bot(url: str = TEST_YOUTUBE_URL):
print(f"Pytube has detected as a bot: {e}")
return True


def calculate_wer(answer, prediction):
return jiwer.wer(answer, prediction)
22 changes: 13 additions & 9 deletions tests/test_transcription.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from modules.whisper.whisper_factory import WhisperFactory
from modules.whisper.data_classes import *
from modules.utils.subtitle_manager import read_file
from modules.utils.paths import WEBUI_DIR
from test_config import *

Expand Down Expand Up @@ -28,6 +29,10 @@ def test_transcribe(
if not os.path.exists(audio_path):
download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)

answer = TEST_ANSWER
if diarization:
answer = "SPEAKER_00|"+TEST_ANSWER

whisper_inferencer = WhisperFactory.create_whisper_inference(
whisper_type=whisper_type,
)
Expand All @@ -54,38 +59,37 @@ def test_transcribe(
),
).to_list()

subtitle_str, file_path = whisper_inferencer.transcribe_file(
subtitle_str, file_paths = whisper_inferencer.transcribe_file(
[audio_path],
None,
"SRT",
False,
gr.Progress(),
*hparams,
)

assert isinstance(subtitle_str, str) and subtitle_str
assert isinstance(file_path[0], str) and file_path
subtitle = read_file(file_paths[0]).split("\n")
assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1

if not is_pytube_detected_bot():
whisper_inferencer.transcribe_youtube(
subtitle_str, file_path = whisper_inferencer.transcribe_youtube(
TEST_YOUTUBE_URL,
"SRT",
False,
gr.Progress(),
*hparams,
)
assert isinstance(subtitle_str, str) and subtitle_str
assert isinstance(file_path[0], str) and file_path
assert os.path.exists(file_path)

whisper_inferencer.transcribe_mic(
subtitle_str, file_path = whisper_inferencer.transcribe_mic(
audio_path,
"SRT",
False,
gr.Progress(),
*hparams,
)
assert isinstance(subtitle_str, str) and subtitle_str
assert isinstance(file_path[0], str) and file_path
subtitle = read_file(file_path).split("\n")
assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1


def download_file(url, save_dir):
Expand Down

0 comments on commit 50380bc

Please sign in to comment.