From c5897c1b79f6d57a72cbf4c92c87a5bd85ad3e5d Mon Sep 17 00:00:00 2001 From: Jayesh Parmar <60539217+jayeshp19@users.noreply.github.com> Date: Mon, 9 Dec 2024 21:27:21 +0530 Subject: [PATCH] Improve test names for better readability (#1198) --- tests/test_llm.py | 6 ++--- tests/test_stt.py | 67 ++++++++++++++++++++++++++++------------------- tests/test_tts.py | 50 ++++++++++++++++++++++------------- 3 files changed, 75 insertions(+), 48 deletions(-) diff --git a/tests/test_llm.py b/tests/test_llm.py index f5ebe5f00..2744ba404 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -79,7 +79,7 @@ def test_hashable_typeinfo(): LLMS: list[Callable[[], llm.LLM]] = [ - lambda: openai.LLM(), + pytest.param(lambda: openai.LLM(), id="openai"), # lambda: openai.beta.AssistantLLM( # assistant_opts=openai.beta.AssistantOptions( # create_options=openai.beta.AssistantCreateOptions( @@ -89,8 +89,8 @@ def test_hashable_typeinfo(): # ) # ) # ), - lambda: anthropic.LLM(), - lambda: openai.LLM.with_vertex(), + pytest.param(lambda: anthropic.LLM(), id="anthropic"), + pytest.param(lambda: openai.LLM.with_vertex(), id="openai.with_vertex"), ] diff --git a/tests/test_stt.py b/tests/test_stt.py index 9f9377340..f9e52b3d8 100644 --- a/tests/test_stt.py +++ b/tests/test_stt.py @@ -4,34 +4,37 @@ import asyncio import time -from itertools import product +from typing import Callable import pytest from livekit import agents +from livekit.agents import stt from livekit.plugins import assemblyai, azure, deepgram, fal, google, openai, silero from .utils import make_test_speech, wer SAMPLE_RATES = [24000, 44100] # test multiple input sample rates WER_THRESHOLD = 0.2 -RECOGNIZE_STT = [ - lambda: deepgram.STT(), - lambda: google.STT(), - lambda: google.STT( - languages=["en-AU"], - model="chirp_2", - spoken_punctuation=False, - location="us-central1", +RECOGNIZE_STT: list[Callable[[], stt.STT]] = [ + pytest.param(lambda: deepgram.STT(), id="deepgram"), + pytest.param(lambda: google.STT(), id="google"), + pytest.param( + lambda: google.STT( + languages=["en-AU"], + model="chirp_2", + spoken_punctuation=False, + location="us-central1", + ), + id="google.chirp_2", ), - lambda: openai.STT(), - lambda: fal.WizperSTT(), + pytest.param(lambda: openai.STT(), id="openai"), + pytest.param(lambda: fal.WizperSTT(), id="fal"), ] @pytest.mark.usefixtures("job_process") -@pytest.mark.parametrize( - "stt_factory, sample_rate", product(RECOGNIZE_STT, SAMPLE_RATES) -) +@pytest.mark.parametrize("stt_factory", RECOGNIZE_STT) +@pytest.mark.parametrize("sample_rate", SAMPLE_RATES) async def test_recognize(stt_factory, sample_rate): async with stt_factory() as stt: frames, transcript = make_test_speech(sample_rate=sample_rate) @@ -47,24 +50,34 @@ async def test_recognize(stt_factory, sample_rate): STREAM_VAD = silero.VAD.load(min_silence_duration=0.75) -STREAM_STT = [ - lambda: assemblyai.STT(), - lambda: deepgram.STT(), - lambda: google.STT(), - lambda: agents.stt.StreamAdapter(stt=openai.STT(), vad=STREAM_VAD), - lambda: agents.stt.StreamAdapter(stt=openai.STT.with_groq(), vad=STREAM_VAD), - lambda: google.STT( - languages=["en-AU"], - model="chirp_2", - spoken_punctuation=False, - location="us-central1", +STREAM_STT: list[Callable[[], stt.STT]] = [ + pytest.param(lambda: assemblyai.STT(), id="assemblyai"), + pytest.param(lambda: deepgram.STT(), id="deepgram"), + pytest.param(lambda: google.STT(), id="google"), + pytest.param( + lambda: agents.stt.StreamAdapter(stt=openai.STT(), vad=STREAM_VAD), + id="openai.stream", ), - lambda: azure.STT(), + pytest.param( + lambda: agents.stt.StreamAdapter(stt=openai.STT.with_groq(), vad=STREAM_VAD), + id="openai.with_groq.stream", + ), + pytest.param( + lambda: google.STT( + languages=["en-AU"], + model="chirp_2", + spoken_punctuation=False, + location="us-central1", + ), + id="google.chirp_2", + ), + pytest.param(lambda: azure.STT(), id="azure"), ] @pytest.mark.usefixtures("job_process") -@pytest.mark.parametrize("stt_factory, sample_rate", product(STREAM_STT, SAMPLE_RATES)) +@pytest.mark.parametrize("stt_factory", STREAM_STT) +@pytest.mark.parametrize("sample_rate", SAMPLE_RATES) async def test_stream(stt_factory, sample_rate): stt = stt_factory() diff --git a/tests/test_tts.py b/tests/test_tts.py index 2461d0c99..cf43b5605 100644 --- a/tests/test_tts.py +++ b/tests/test_tts.py @@ -4,10 +4,11 @@ """ import dataclasses +from typing import Callable import pytest from livekit import agents -from livekit.agents import APIConnectionError, tokenize +from livekit.agents import APIConnectionError, tokenize, tts from livekit.agents.utils import AudioBuffer, merge_frames from livekit.plugins import azure, cartesia, elevenlabs, google, openai @@ -33,13 +34,15 @@ async def _assert_valid_synthesized_audio( ), "num channels should be the same" -SYNTHESIZE_TTS = [ - lambda: elevenlabs.TTS(), - lambda: elevenlabs.TTS(encoding="pcm_44100"), - lambda: openai.TTS(), - lambda: google.TTS(), - lambda: azure.TTS(), - lambda: cartesia.TTS(), +SYNTHESIZE_TTS: list[Callable[[], tts.TTS]] = [ + pytest.param(lambda: elevenlabs.TTS(), id="elevenlabs"), + pytest.param( + lambda: elevenlabs.TTS(encoding="pcm_44100"), id="elevenlabs.pcm_44100" + ), + pytest.param(lambda: openai.TTS(), id="openai"), + pytest.param(lambda: google.TTS(), id="google"), + pytest.param(lambda: azure.TTS(), id="azure"), + pytest.param(lambda: cartesia.TTS(), id="cartesia"), ] @@ -60,18 +63,29 @@ async def test_synthesize(tts_factory): STREAM_SENT_TOKENIZER = tokenize.basic.SentenceTokenizer(min_sentence_len=20) -STREAM_TTS = [ - lambda: elevenlabs.TTS(), - lambda: elevenlabs.TTS(encoding="pcm_44100"), - lambda: cartesia.TTS(), - lambda: agents.tts.StreamAdapter( - tts=openai.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER +STREAM_TTS: list[Callable[[], tts.TTS]] = [ + pytest.param(lambda: elevenlabs.TTS(), id="elevenlabs"), + pytest.param( + lambda: elevenlabs.TTS(encoding="pcm_44100"), id="elevenlabs.pcm_44100" + ), + pytest.param(lambda: cartesia.TTS(), id="cartesia"), + pytest.param( + lambda: agents.tts.StreamAdapter( + tts=openai.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER + ), + id="openai.stream", ), - lambda: agents.tts.StreamAdapter( - tts=google.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER + pytest.param( + lambda: agents.tts.StreamAdapter( + tts=google.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER + ), + id="google.stream", ), - lambda: agents.tts.StreamAdapter( - tts=azure.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER + pytest.param( + lambda: agents.tts.StreamAdapter( + tts=azure.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER + ), + id="azure.stream", ), ]