-
Notifications
You must be signed in to change notification settings - Fork 24
/
demo.py
90 lines (78 loc) · 3.02 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
'''
Created on 29/08/2018
@author: wblgers
'''
from __future__ import print_function
import warnings
import os
from scikits.talkbox.features import mfcc
from scipy.io import wavfile
from hmmlearn import hmm
import numpy as np
warnings.filterwarnings('ignore')
def extract_mfcc(full_audio_path):
sample_rate, wave = wavfile.read(full_audio_path)
mfcc_features = mfcc(wave, nwin=int(sample_rate * 0.03), fs=sample_rate, nceps=12)[0]
return mfcc_features
def buildDataSet(dir):
# Filter out the wav audio files under the dir
fileList = [f for f in os.listdir(dir) if os.path.splitext(f)[1] == '.wav']
dataset = {}
for fileName in fileList:
tmp = fileName.split('.')[0]
label = tmp.split('_')[1]
feature = extract_mfcc(dir+fileName)
if label not in dataset.keys():
dataset[label] = []
dataset[label].append(feature)
else:
exist_feature = dataset[label]
exist_feature.append(feature)
dataset[label] = exist_feature
return dataset
def train_GMMHMM(dataset):
GMMHMM_Models = {}
states_num = 5
GMM_mix_num = 3
tmp_p = 1.0/(states_num-2)
transmatPrior = np.array([[tmp_p, tmp_p, tmp_p, 0 ,0], \
[0, tmp_p, tmp_p, tmp_p , 0], \
[0, 0, tmp_p, tmp_p,tmp_p], \
[0, 0, 0, 0.5, 0.5], \
[0, 0, 0, 0, 1]],dtype=np.float)
startprobPrior = np.array([0.5, 0.5, 0, 0, 0],dtype=np.float)
for label in dataset.keys():
model = hmm.GMMHMM(n_components=states_num, n_mix=GMM_mix_num, \
transmat_prior=transmatPrior, startprob_prior=startprobPrior, \
covariance_type='diag', n_iter=10)
trainData = dataset[label]
length = np.zeros([len(trainData), ], dtype=np.int)
for m in range(len(trainData)):
length[m] = trainData[m].shape[0]
trainData = np.vstack(trainData)
model.fit(trainData, lengths=length) # get optimal parameters
GMMHMM_Models[label] = model
return GMMHMM_Models
def main():
trainDir = './train_audio/'
trainDataSet = buildDataSet(trainDir)
print("Finish prepare the training data")
hmmModels = train_GMMHMM(trainDataSet)
print("Finish training of the GMM_HMM models for digits 0-9")
testDir = './test_audio/'
testDataSet = buildDataSet(testDir)
score_cnt = 0
for label in testDataSet.keys():
feature = testDataSet[label]
scoreList = {}
for model_label in hmmModels.keys():
model = hmmModels[model_label]
score = model.score(feature[0])
scoreList[model_label] = score
predict = max(scoreList, key=scoreList.get)
print("Test on true label ", label, ": predict result label is ", predict)
if predict == label:
score_cnt+=1
print("Final recognition rate is %.2f"%(100.0*score_cnt/len(testDataSet.keys())), "%")
if __name__ == '__main__':
main()