diff --git a/aana/models/pydantic/asr_output.py b/aana/models/pydantic/asr_output.py index 600476d6..b5ca6b1c 100644 --- a/aana/models/pydantic/asr_output.py +++ b/aana/models/pydantic/asr_output.py @@ -13,6 +13,10 @@ class Timestamp(BaseModel): """ Pydantic schema for Timestamp. + + Attributes: + start (float): Start time + end (float): End time """ start: float = Field(ge=0.0, description="Start time") @@ -27,6 +31,11 @@ class Config: class AsrWord(BaseModel): """ Pydantic schema for Word from ASR model. + + Attributes: + word (str): The word text + timestamp (Timestamp): Timestamp of the word + alignment_confidence (float): Alignment confidence of the word """ word: str = Field(description="The word text") @@ -55,6 +64,13 @@ class Config: class AsrSegment(BaseModel): """ Pydantic schema for Segment from ASR model. + + Attributes: + text (str): The text of the segment (transcript/translation) + timestamp (Timestamp): Timestamp of the segment + confidence (float): Confidence of the segment + no_speech_confidence (float): Chance of being a silence segment + words (Optional[List[AsrWord]]): List of words in the segment """ text: str = Field(description="The text of the segment (transcript/translation)") @@ -64,7 +80,7 @@ class AsrSegment(BaseModel): ge=0.0, le=1.0, description="Chance of being a silence segment" ) words: Optional[List[AsrWord]] = Field( - description="List of words in the segment", default=None + description="List of words in the segment", default_factory=list ) @classmethod @@ -77,7 +93,7 @@ def from_whisper(cls, whisper_segment: WhisperSegment) -> "AsrSegment": if whisper_segment.words: words = [AsrWord.from_whisper(word) for word in whisper_segment.words] else: - words = None + words = [] return cls( text=whisper_segment.text, diff --git a/aana/tests/test_asr_output.py b/aana/tests/test_asr_output.py new file mode 100644 index 00000000..593bf88d --- /dev/null +++ b/aana/tests/test_asr_output.py @@ -0,0 +1,83 @@ +import numpy as np +import pytest +from faster_whisper.transcribe import ( + Segment as WhisperSegment, + Word as WhisperWord, + TranscriptionInfo as WhisperTranscriptionInfo, +) +from aana.models.pydantic.asr_output import ( + AsrSegment, + AsrTranscriptionInfo, + AsrWord, + Timestamp, +) + + +def test_asr_segment_from_whisper(): + """ + Test function for the AsrSegment class's from_whisper method. + """ + whisper_segment = WhisperSegment( + id=0, + seek=0, + tokens=[], + temperature=0.0, + compression_ratio=0.0, + start=0.0, + end=1.0, + avg_logprob=-0.5, + no_speech_prob=0.1, + words=[], + text="hello world", + ) + + asr_segment = AsrSegment.from_whisper(whisper_segment) + + assert asr_segment.text == "hello world" + assert asr_segment.timestamp == Timestamp( + start=whisper_segment.start, end=whisper_segment.end + ) + assert asr_segment.confidence == np.exp(whisper_segment.avg_logprob) + assert asr_segment.no_speech_confidence == whisper_segment.no_speech_prob + assert asr_segment.words == [] + + word = WhisperWord( + word="hello", + start=0.0, + end=0.5, + probability=0.5, + ) + whisper_segment = WhisperSegment( + id=0, + seek=0, + tokens=[], + temperature=0.0, + compression_ratio=0.0, + start=0.0, + end=1.0, + avg_logprob=-0.5, + no_speech_prob=0.1, + words=[word], + text="hello world", + ) + + asr_segment = AsrSegment.from_whisper(whisper_segment) + assert asr_segment.words == [AsrWord.from_whisper(word)] + + +def test_asr_word_from_whisper(): + """ + Test function for the AsrWord class's from_whisper method. + """ + word = WhisperWord( + word="hello", + start=0.0, + end=0.5, + probability=0.5, + ) + + asr_word = AsrWord.from_whisper(word) + + assert asr_word.word == "hello" + assert asr_word.timestamp == Timestamp(start=word.start, end=word.end) + assert asr_word.alignment_confidence == word.probability diff --git a/aana/tests/test_whisper_params.py b/aana/tests/test_whisper_params.py new file mode 100644 index 00000000..84f1c8c6 --- /dev/null +++ b/aana/tests/test_whisper_params.py @@ -0,0 +1,70 @@ +import pytest +from aana.models.pydantic.whisper_params import WhisperParams + + +def test_whisper_params_default(): + """ + Test the default values of WhisperParams object. + + Keeping the default parameters of a function or object is important + in case other code relies on them. + + If you need to change the default parameters, think twice before doing so. + """ + params = WhisperParams() + + assert params.language is None + assert params.beam_size == 5 + assert params.best_of == 5 + assert params.temperature == (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) + assert params.word_timestamps is False + assert params.vad_filter is False + + +@pytest.mark.parametrize( + "language, beam_size, best_of, temperature, word_timestamps, vad_filter", + [ + ("en", 5, 5, 0.5, True, True), + ("fr", 3, 3, 0.2, False, False), + (None, 1, 1, [0.8, 0.9], True, False), + ], +) +def test_whisper_params( + language, beam_size, best_of, temperature, word_timestamps, vad_filter +): + """ + Test function for the WhisperParams class with valid parameters. + """ + params = WhisperParams( + language=language, + beam_size=beam_size, + best_of=best_of, + temperature=temperature, + word_timestamps=word_timestamps, + vad_filter=vad_filter, + ) + + assert params.language == language + assert params.beam_size == beam_size + assert params.best_of == best_of + assert params.temperature == temperature + assert params.word_timestamps == word_timestamps + assert params.vad_filter == vad_filter + + +@pytest.mark.parametrize( + "temperature", + [ + [-1.0, 0.5, 1.5], + [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 2.0], + "invalid_temperature", + 2, + ], +) +def test_whisper_params_invalid_temperature(temperature): + """ + Test function to check if ValueError is raised + when invalid temperature is passed to WhisperParams constructor. + """ + with pytest.raises(ValueError): + WhisperParams(temperature=temperature)