Skip to content

Commit

Permalink
Improve test names for better readability (#1198)
Browse files Browse the repository at this point in the history
  • Loading branch information
jayeshp19 authored Dec 9, 2024
1 parent 40c1879 commit c5897c1
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 48 deletions.
6 changes: 3 additions & 3 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_hashable_typeinfo():


LLMS: list[Callable[[], llm.LLM]] = [
lambda: openai.LLM(),
pytest.param(lambda: openai.LLM(), id="openai"),
# lambda: openai.beta.AssistantLLM(
# assistant_opts=openai.beta.AssistantOptions(
# create_options=openai.beta.AssistantCreateOptions(
Expand All @@ -89,8 +89,8 @@ def test_hashable_typeinfo():
# )
# )
# ),
lambda: anthropic.LLM(),
lambda: openai.LLM.with_vertex(),
pytest.param(lambda: anthropic.LLM(), id="anthropic"),
pytest.param(lambda: openai.LLM.with_vertex(), id="openai.with_vertex"),
]


Expand Down
67 changes: 40 additions & 27 deletions tests/test_stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,37 @@

import asyncio
import time
from itertools import product
from typing import Callable

import pytest
from livekit import agents
from livekit.agents import stt
from livekit.plugins import assemblyai, azure, deepgram, fal, google, openai, silero

from .utils import make_test_speech, wer

SAMPLE_RATES = [24000, 44100] # test multiple input sample rates
WER_THRESHOLD = 0.2
RECOGNIZE_STT = [
lambda: deepgram.STT(),
lambda: google.STT(),
lambda: google.STT(
languages=["en-AU"],
model="chirp_2",
spoken_punctuation=False,
location="us-central1",
RECOGNIZE_STT: list[Callable[[], stt.STT]] = [
pytest.param(lambda: deepgram.STT(), id="deepgram"),
pytest.param(lambda: google.STT(), id="google"),
pytest.param(
lambda: google.STT(
languages=["en-AU"],
model="chirp_2",
spoken_punctuation=False,
location="us-central1",
),
id="google.chirp_2",
),
lambda: openai.STT(),
lambda: fal.WizperSTT(),
pytest.param(lambda: openai.STT(), id="openai"),
pytest.param(lambda: fal.WizperSTT(), id="fal"),
]


@pytest.mark.usefixtures("job_process")
@pytest.mark.parametrize(
"stt_factory, sample_rate", product(RECOGNIZE_STT, SAMPLE_RATES)
)
@pytest.mark.parametrize("stt_factory", RECOGNIZE_STT)
@pytest.mark.parametrize("sample_rate", SAMPLE_RATES)
async def test_recognize(stt_factory, sample_rate):
async with stt_factory() as stt:
frames, transcript = make_test_speech(sample_rate=sample_rate)
Expand All @@ -47,24 +50,34 @@ async def test_recognize(stt_factory, sample_rate):


STREAM_VAD = silero.VAD.load(min_silence_duration=0.75)
STREAM_STT = [
lambda: assemblyai.STT(),
lambda: deepgram.STT(),
lambda: google.STT(),
lambda: agents.stt.StreamAdapter(stt=openai.STT(), vad=STREAM_VAD),
lambda: agents.stt.StreamAdapter(stt=openai.STT.with_groq(), vad=STREAM_VAD),
lambda: google.STT(
languages=["en-AU"],
model="chirp_2",
spoken_punctuation=False,
location="us-central1",
STREAM_STT: list[Callable[[], stt.STT]] = [
pytest.param(lambda: assemblyai.STT(), id="assemblyai"),
pytest.param(lambda: deepgram.STT(), id="deepgram"),
pytest.param(lambda: google.STT(), id="google"),
pytest.param(
lambda: agents.stt.StreamAdapter(stt=openai.STT(), vad=STREAM_VAD),
id="openai.stream",
),
lambda: azure.STT(),
pytest.param(
lambda: agents.stt.StreamAdapter(stt=openai.STT.with_groq(), vad=STREAM_VAD),
id="openai.with_groq.stream",
),
pytest.param(
lambda: google.STT(
languages=["en-AU"],
model="chirp_2",
spoken_punctuation=False,
location="us-central1",
),
id="google.chirp_2",
),
pytest.param(lambda: azure.STT(), id="azure"),
]


@pytest.mark.usefixtures("job_process")
@pytest.mark.parametrize("stt_factory, sample_rate", product(STREAM_STT, SAMPLE_RATES))
@pytest.mark.parametrize("stt_factory", STREAM_STT)
@pytest.mark.parametrize("sample_rate", SAMPLE_RATES)
async def test_stream(stt_factory, sample_rate):
stt = stt_factory()

Expand Down
50 changes: 32 additions & 18 deletions tests/test_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
"""

import dataclasses
from typing import Callable

import pytest
from livekit import agents
from livekit.agents import APIConnectionError, tokenize
from livekit.agents import APIConnectionError, tokenize, tts
from livekit.agents.utils import AudioBuffer, merge_frames
from livekit.plugins import azure, cartesia, elevenlabs, google, openai

Expand All @@ -33,13 +34,15 @@ async def _assert_valid_synthesized_audio(
), "num channels should be the same"


SYNTHESIZE_TTS = [
lambda: elevenlabs.TTS(),
lambda: elevenlabs.TTS(encoding="pcm_44100"),
lambda: openai.TTS(),
lambda: google.TTS(),
lambda: azure.TTS(),
lambda: cartesia.TTS(),
SYNTHESIZE_TTS: list[Callable[[], tts.TTS]] = [
pytest.param(lambda: elevenlabs.TTS(), id="elevenlabs"),
pytest.param(
lambda: elevenlabs.TTS(encoding="pcm_44100"), id="elevenlabs.pcm_44100"
),
pytest.param(lambda: openai.TTS(), id="openai"),
pytest.param(lambda: google.TTS(), id="google"),
pytest.param(lambda: azure.TTS(), id="azure"),
pytest.param(lambda: cartesia.TTS(), id="cartesia"),
]


Expand All @@ -60,18 +63,29 @@ async def test_synthesize(tts_factory):


STREAM_SENT_TOKENIZER = tokenize.basic.SentenceTokenizer(min_sentence_len=20)
STREAM_TTS = [
lambda: elevenlabs.TTS(),
lambda: elevenlabs.TTS(encoding="pcm_44100"),
lambda: cartesia.TTS(),
lambda: agents.tts.StreamAdapter(
tts=openai.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER
STREAM_TTS: list[Callable[[], tts.TTS]] = [
pytest.param(lambda: elevenlabs.TTS(), id="elevenlabs"),
pytest.param(
lambda: elevenlabs.TTS(encoding="pcm_44100"), id="elevenlabs.pcm_44100"
),
pytest.param(lambda: cartesia.TTS(), id="cartesia"),
pytest.param(
lambda: agents.tts.StreamAdapter(
tts=openai.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER
),
id="openai.stream",
),
lambda: agents.tts.StreamAdapter(
tts=google.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER
pytest.param(
lambda: agents.tts.StreamAdapter(
tts=google.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER
),
id="google.stream",
),
lambda: agents.tts.StreamAdapter(
tts=azure.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER
pytest.param(
lambda: agents.tts.StreamAdapter(
tts=azure.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER
),
id="azure.stream",
),
]

Expand Down

0 comments on commit c5897c1

Please sign in to comment.