-
Notifications
You must be signed in to change notification settings - Fork 52
/
asr_evaluation.py
32 lines (30 loc) · 1.21 KB
/
asr_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import os
import logging
import deepspeech
import jiwer
import soundfile as sf
import numpy as np
from unidecode import unidecode
import librosa
import tqdm
def evaluate(testset, audio_directory):
model = deepspeech.Model('deepspeech-0.7.0-models.pbmm')
model.enableExternalScorer('deepspeech-0.7.0-models.scorer')
predictions = []
targets = []
for i, datapoint in enumerate(tqdm.tqdm(testset, 'Evaluate outputs', disable=None)):
audio, rate = sf.read(os.path.join(audio_directory,f'example_output_{i}.wav'))
if rate != 16000:
audio = librosa.resample(audio, orig_sr=rate, target_sr=16000)
assert model.sampleRate() == 16000, 'wrong sample rate'
audio_int16 = (audio*(2**15)).astype(np.int16)
text = model.stt(audio_int16)
predictions.append(text)
target_text = unidecode(datapoint['text'])
targets.append(target_text)
transformation = jiwer.Compose([jiwer.RemovePunctuation(), jiwer.ToLowerCase()])
targets = transformation(targets)
predictions = transformation(predictions)
logging.info(f'targets: {targets}')
logging.info(f'predictions: {predictions}')
logging.info(f'wer: {jiwer.wer(targets, predictions)}')