Skip to content

Commit

Permalink
Tests for whisper params and ASR outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Movchan committed Nov 9, 2023
1 parent 39f1893 commit 2cfa0fd
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 2 deletions.
20 changes: 18 additions & 2 deletions aana/models/pydantic/asr_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
class Timestamp(BaseModel):
"""
Pydantic schema for Timestamp.
Attributes:
start (float): Start time
end (float): End time
"""

start: float = Field(ge=0.0, description="Start time")
Expand All @@ -27,6 +31,11 @@ class Config:
class AsrWord(BaseModel):
"""
Pydantic schema for Word from ASR model.
Attributes:
word (str): The word text
timestamp (Timestamp): Timestamp of the word
alignment_confidence (float): Alignment confidence of the word
"""

word: str = Field(description="The word text")
Expand Down Expand Up @@ -55,6 +64,13 @@ class Config:
class AsrSegment(BaseModel):
"""
Pydantic schema for Segment from ASR model.
Attributes:
text (str): The text of the segment (transcript/translation)
timestamp (Timestamp): Timestamp of the segment
confidence (float): Confidence of the segment
no_speech_confidence (float): Chance of being a silence segment
words (Optional[List[AsrWord]]): List of words in the segment
"""

text: str = Field(description="The text of the segment (transcript/translation)")
Expand All @@ -64,7 +80,7 @@ class AsrSegment(BaseModel):
ge=0.0, le=1.0, description="Chance of being a silence segment"
)
words: Optional[List[AsrWord]] = Field(
description="List of words in the segment", default=None
description="List of words in the segment", default_factory=list
)

@classmethod
Expand All @@ -77,7 +93,7 @@ def from_whisper(cls, whisper_segment: WhisperSegment) -> "AsrSegment":
if whisper_segment.words:
words = [AsrWord.from_whisper(word) for word in whisper_segment.words]
else:
words = None
words = []

return cls(
text=whisper_segment.text,
Expand Down
83 changes: 83 additions & 0 deletions aana/tests/test_asr_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import numpy as np
import pytest
from faster_whisper.transcribe import (
Segment as WhisperSegment,
Word as WhisperWord,
TranscriptionInfo as WhisperTranscriptionInfo,
)
from aana.models.pydantic.asr_output import (
AsrSegment,
AsrTranscriptionInfo,
AsrWord,
Timestamp,
)


def test_asr_segment_from_whisper():
"""
Test function for the AsrSegment class's from_whisper method.
"""
whisper_segment = WhisperSegment(
id=0,
seek=0,
tokens=[],
temperature=0.0,
compression_ratio=0.0,
start=0.0,
end=1.0,
avg_logprob=-0.5,
no_speech_prob=0.1,
words=[],
text="hello world",
)

asr_segment = AsrSegment.from_whisper(whisper_segment)

assert asr_segment.text == "hello world"
assert asr_segment.timestamp == Timestamp(
start=whisper_segment.start, end=whisper_segment.end
)
assert asr_segment.confidence == np.exp(whisper_segment.avg_logprob)
assert asr_segment.no_speech_confidence == whisper_segment.no_speech_prob
assert asr_segment.words == []

word = WhisperWord(
word="hello",
start=0.0,
end=0.5,
probability=0.5,
)
whisper_segment = WhisperSegment(
id=0,
seek=0,
tokens=[],
temperature=0.0,
compression_ratio=0.0,
start=0.0,
end=1.0,
avg_logprob=-0.5,
no_speech_prob=0.1,
words=[word],
text="hello world",
)

asr_segment = AsrSegment.from_whisper(whisper_segment)
assert asr_segment.words == [AsrWord.from_whisper(word)]


def test_asr_word_from_whisper():
"""
Test function for the AsrWord class's from_whisper method.
"""
word = WhisperWord(
word="hello",
start=0.0,
end=0.5,
probability=0.5,
)

asr_word = AsrWord.from_whisper(word)

assert asr_word.word == "hello"
assert asr_word.timestamp == Timestamp(start=word.start, end=word.end)
assert asr_word.alignment_confidence == word.probability
70 changes: 70 additions & 0 deletions aana/tests/test_whisper_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
from aana.models.pydantic.whisper_params import WhisperParams


def test_whisper_params_default():
"""
Test the default values of WhisperParams object.
Keeping the default parameters of a function or object is important
in case other code relies on them.
If you need to change the default parameters, think twice before doing so.
"""
params = WhisperParams()

assert params.language is None
assert params.beam_size == 5
assert params.best_of == 5
assert params.temperature == (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
assert params.word_timestamps is False
assert params.vad_filter is False


@pytest.mark.parametrize(
"language, beam_size, best_of, temperature, word_timestamps, vad_filter",
[
("en", 5, 5, 0.5, True, True),
("fr", 3, 3, 0.2, False, False),
(None, 1, 1, [0.8, 0.9], True, False),
],
)
def test_whisper_params(
language, beam_size, best_of, temperature, word_timestamps, vad_filter
):
"""
Test function for the WhisperParams class with valid parameters.
"""
params = WhisperParams(
language=language,
beam_size=beam_size,
best_of=best_of,
temperature=temperature,
word_timestamps=word_timestamps,
vad_filter=vad_filter,
)

assert params.language == language
assert params.beam_size == beam_size
assert params.best_of == best_of
assert params.temperature == temperature
assert params.word_timestamps == word_timestamps
assert params.vad_filter == vad_filter


@pytest.mark.parametrize(
"temperature",
[
[-1.0, 0.5, 1.5],
[0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 2.0],
"invalid_temperature",
2,
],
)
def test_whisper_params_invalid_temperature(temperature):
"""
Test function to check if ValueError is raised
when invalid temperature is passed to WhisperParams constructor.
"""
with pytest.raises(ValueError):
WhisperParams(temperature=temperature)

0 comments on commit 2cfa0fd

Please sign in to comment.