-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tests for whisper params and ASR outputs
- Loading branch information
Aleksandr Movchan
committed
Nov 9, 2023
1 parent
39f1893
commit 2cfa0fd
Showing
3 changed files
with
171 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import numpy as np | ||
import pytest | ||
from faster_whisper.transcribe import ( | ||
Segment as WhisperSegment, | ||
Word as WhisperWord, | ||
TranscriptionInfo as WhisperTranscriptionInfo, | ||
) | ||
from aana.models.pydantic.asr_output import ( | ||
AsrSegment, | ||
AsrTranscriptionInfo, | ||
AsrWord, | ||
Timestamp, | ||
) | ||
|
||
|
||
def test_asr_segment_from_whisper(): | ||
""" | ||
Test function for the AsrSegment class's from_whisper method. | ||
""" | ||
whisper_segment = WhisperSegment( | ||
id=0, | ||
seek=0, | ||
tokens=[], | ||
temperature=0.0, | ||
compression_ratio=0.0, | ||
start=0.0, | ||
end=1.0, | ||
avg_logprob=-0.5, | ||
no_speech_prob=0.1, | ||
words=[], | ||
text="hello world", | ||
) | ||
|
||
asr_segment = AsrSegment.from_whisper(whisper_segment) | ||
|
||
assert asr_segment.text == "hello world" | ||
assert asr_segment.timestamp == Timestamp( | ||
start=whisper_segment.start, end=whisper_segment.end | ||
) | ||
assert asr_segment.confidence == np.exp(whisper_segment.avg_logprob) | ||
assert asr_segment.no_speech_confidence == whisper_segment.no_speech_prob | ||
assert asr_segment.words == [] | ||
|
||
word = WhisperWord( | ||
word="hello", | ||
start=0.0, | ||
end=0.5, | ||
probability=0.5, | ||
) | ||
whisper_segment = WhisperSegment( | ||
id=0, | ||
seek=0, | ||
tokens=[], | ||
temperature=0.0, | ||
compression_ratio=0.0, | ||
start=0.0, | ||
end=1.0, | ||
avg_logprob=-0.5, | ||
no_speech_prob=0.1, | ||
words=[word], | ||
text="hello world", | ||
) | ||
|
||
asr_segment = AsrSegment.from_whisper(whisper_segment) | ||
assert asr_segment.words == [AsrWord.from_whisper(word)] | ||
|
||
|
||
def test_asr_word_from_whisper(): | ||
""" | ||
Test function for the AsrWord class's from_whisper method. | ||
""" | ||
word = WhisperWord( | ||
word="hello", | ||
start=0.0, | ||
end=0.5, | ||
probability=0.5, | ||
) | ||
|
||
asr_word = AsrWord.from_whisper(word) | ||
|
||
assert asr_word.word == "hello" | ||
assert asr_word.timestamp == Timestamp(start=word.start, end=word.end) | ||
assert asr_word.alignment_confidence == word.probability |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import pytest | ||
from aana.models.pydantic.whisper_params import WhisperParams | ||
|
||
|
||
def test_whisper_params_default(): | ||
""" | ||
Test the default values of WhisperParams object. | ||
Keeping the default parameters of a function or object is important | ||
in case other code relies on them. | ||
If you need to change the default parameters, think twice before doing so. | ||
""" | ||
params = WhisperParams() | ||
|
||
assert params.language is None | ||
assert params.beam_size == 5 | ||
assert params.best_of == 5 | ||
assert params.temperature == (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) | ||
assert params.word_timestamps is False | ||
assert params.vad_filter is False | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"language, beam_size, best_of, temperature, word_timestamps, vad_filter", | ||
[ | ||
("en", 5, 5, 0.5, True, True), | ||
("fr", 3, 3, 0.2, False, False), | ||
(None, 1, 1, [0.8, 0.9], True, False), | ||
], | ||
) | ||
def test_whisper_params( | ||
language, beam_size, best_of, temperature, word_timestamps, vad_filter | ||
): | ||
""" | ||
Test function for the WhisperParams class with valid parameters. | ||
""" | ||
params = WhisperParams( | ||
language=language, | ||
beam_size=beam_size, | ||
best_of=best_of, | ||
temperature=temperature, | ||
word_timestamps=word_timestamps, | ||
vad_filter=vad_filter, | ||
) | ||
|
||
assert params.language == language | ||
assert params.beam_size == beam_size | ||
assert params.best_of == best_of | ||
assert params.temperature == temperature | ||
assert params.word_timestamps == word_timestamps | ||
assert params.vad_filter == vad_filter | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"temperature", | ||
[ | ||
[-1.0, 0.5, 1.5], | ||
[0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 2.0], | ||
"invalid_temperature", | ||
2, | ||
], | ||
) | ||
def test_whisper_params_invalid_temperature(temperature): | ||
""" | ||
Test function to check if ValueError is raised | ||
when invalid temperature is passed to WhisperParams constructor. | ||
""" | ||
with pytest.raises(ValueError): | ||
WhisperParams(temperature=temperature) |