forked from nl8590687/ASRT_SpeechRecognition
-
Notifications
You must be signed in to change notification settings - Fork 0
/
speech_model.py
262 lines (220 loc) · 11.3 KB
/
speech_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
# !/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2016-2099 Ailemon.net
#
# This file is part of ASRT Speech Recognition Tool.
#
# ASRT is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# ASRT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ASRT. If not, see <https://www.gnu.org/licenses/>.
# ============================================================================
"""
@author: nl8590687
声学模型基础功能模板定义
"""
import os
import time
import random
import numpy as np
from utils.ops import get_edit_distance, read_wav_data
from utils.config import load_config_file, DEFAULT_CONFIG_FILENAME, load_pinyin_dict
from utils.thread import threadsafe_generator
class ModelSpeech:
"""
语音模型类
参数:
speech_model: 声学模型类型 (BaseModel类) 实例对象
speech_features: 声学特征类型(SpeechFeatureMeta类)实例对象
"""
def __init__(self, speech_model, speech_features, max_label_length=64):
self.data_loader = None
self.speech_model = speech_model
self.trained_model, self.base_model = speech_model.get_model()
self.speech_features = speech_features
self.max_label_length = max_label_length
@threadsafe_generator
def _data_generator(self, batch_size, data_loader):
"""
数据生成器函数,用于Keras的generator_fit训练
batch_size: 一次产生的数据量
"""
labels = np.zeros((batch_size, 1), dtype=np.float64)
data_count = data_loader.get_data_count()
index = 0
while True:
X = np.zeros((batch_size,) + self.speech_model.input_shape, dtype=np.float64)
y = np.zeros((batch_size, self.max_label_length), dtype=np.int16)
input_length = []
label_length = []
for i in range(batch_size):
wavdata, sample_rate, data_labels = data_loader.get_data(index)
data_input = self.speech_features.run(wavdata, sample_rate)
data_input = data_input.reshape(data_input.shape[0], data_input.shape[1], 1)
# 必须加上模pool_size得到的值,否则会出现inf问题,然后提示No valid path found.
# 但是直接加又可能会出现sequence_length <= xxx 的问题,因此不能让其超过时间序列长度的最大值,比如200
pool_size = self.speech_model.input_shape[0] // self.speech_model.output_shape[0]
inlen = min(data_input.shape[0] // pool_size + data_input.shape[0] % pool_size,
self.speech_model.output_shape[0])
input_length.append(inlen)
X[i, 0:len(data_input)] = data_input
y[i, 0:len(data_labels)] = data_labels
label_length.append([len(data_labels)])
index = (index + 1) % data_count
label_length = np.matrix(label_length)
input_length = np.array([input_length]).T
yield [X, y, input_length, label_length], labels
def train_model(self, optimizer, data_loader, epochs=1, save_step=1, batch_size=16, last_epoch=0, call_back=None):
"""
训练模型
参数:
optimizer:tensorflow.keras.optimizers 优化器实例对象
data_loader:数据加载器类型 (SpeechData) 实例对象
epochs: 迭代轮数
save_step: 每多少epoch保存一次模型
batch_size: mini batch大小
last_epoch: 上一次epoch的编号,可用于断点处继续训练时,epoch编号不冲突
call_back: keras call back函数
"""
save_filename = os.path.join('save_models', self.speech_model.get_model_name(),
self.speech_model.get_model_name())
self.trained_model.compile(loss=self.speech_model.get_loss_function(), optimizer=optimizer)
print('[ASRT] Compiles Model Successfully.')
yielddatas = self._data_generator(batch_size, data_loader)
data_count = data_loader.get_data_count() # 获取数据的数量
# 计算每一个epoch迭代的次数
num_iterate = data_count // batch_size
iter_start = last_epoch
iter_end = last_epoch + epochs
for epoch in range(iter_start, iter_end): # 迭代轮数
try:
epoch += 1
print('[ASRT Training] train epoch %d/%d .' % (epoch, iter_end))
data_loader.shuffle()
self.trained_model.fit_generator(yielddatas, num_iterate, callbacks=call_back)
except StopIteration:
print('[error] generator error. please check data format.')
break
if epoch % save_step == 0:
if not os.path.exists('save_models'): # 判断保存模型的目录是否存在
os.makedirs('save_models') # 如果不存在,就新建一个,避免之后保存模型的时候炸掉
if not os.path.exists(os.path.join('save_models', self.speech_model.get_model_name())): # 判断保存模型的目录是否存在
os.makedirs(
os.path.join('save_models', self.speech_model.get_model_name())) # 如果不存在,就新建一个,避免之后保存模型的时候炸掉
self.save_model(save_filename + '_epoch' + str(epoch))
print('[ASRT Info] Model training complete. ')
def load_model(self, filename):
"""
加载模型参数
"""
self.speech_model.load_weights(filename)
def save_model(self, filename):
"""
保存模型参数
"""
self.speech_model.save_weights(filename)
def evaluate_model(self, data_loader, data_count=-1, out_report=False, show_ratio=True, show_per_step=100):
"""
评估检验模型的识别效果
"""
data_nums = data_loader.get_data_count()
if data_count <= 0 or data_count > data_nums: # 当data_count为小于等于0或者大于测试数据量的值时,则使用全部数据来测试
data_count = data_nums
try:
ran_num = random.randint(0, data_nums - 1) # 获取一个随机数
words_num = 0
word_error_num = 0
nowtime = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
if out_report:
txt_obj = open('Test_Report_' + data_loader.dataset_type + '_' + nowtime + '.txt', 'w',
encoding='UTF-8') # 打开文件并读入
txt_obj.truncate((data_count + 1) * 300) # 预先分配一定数量的磁盘空间,避免后期在硬盘中文件存储位置频繁移动,以防写入速度越来越慢
txt_obj.seek(0) # 从文件首开始
txt = ''
i = 0
while i < data_count:
wavdata, fs, data_labels = data_loader.get_data((ran_num + i) % data_nums) # 从随机数开始连续向后取一定数量数据
data_input = self.speech_features.run(wavdata, fs)
data_input = data_input.reshape(data_input.shape[0], data_input.shape[1], 1)
# 数据格式出错处理 开始
# 当输入的wav文件长度过长时自动跳过该文件,转而使用下一个wav文件来运行
if data_input.shape[0] > self.speech_model.input_shape[0]:
print('*[Error]', 'wave data lenghth of num', (ran_num + i) % data_nums, 'is too long.',
'this data\'s length is', data_input.shape[0],
'expect <=', self.speech_model.input_shape[0],
'\n A Exception raise when test Speech Model.')
i += 1
continue
# 数据格式出错处理 结束
pre = self.predict(data_input)
words_n = data_labels.shape[0] # 获取每个句子的字数
words_num += words_n # 把句子的总字数加上
edit_distance = get_edit_distance(data_labels, pre) # 获取编辑距离
if edit_distance <= words_n: # 当编辑距离小于等于句子字数时
word_error_num += edit_distance # 使用编辑距离作为错误字数
else: # 否则肯定是增加了一堆乱七八糟的奇奇怪怪的字
word_error_num += words_n # 就直接加句子本来的总字数就好了
if i % show_per_step == 0 and show_ratio:
print('[ASRT Info] Testing: ', i, '/', data_count)
txt = ''
if out_report:
txt += str(i) + '\n'
txt += 'True:\t' + str(data_labels) + '\n'
txt += 'Pred:\t' + str(pre) + '\n'
txt += '\n'
txt_obj.write(txt)
i += 1
# print('*[测试结果] 语音识别 ' + str_dataset + ' 集语音单字错误率:', word_error_num / words_num * 100, '%')
print('*[ASRT Test Result] Speech Recognition ' + data_loader.dataset_type + ' set word error ratio: ',
word_error_num / words_num * 100, '%')
if out_report:
txt = '*[ASRT Test Result] Speech Recognition ' + data_loader.dataset_type + ' set word error ratio: ' + str(
word_error_num / words_num * 100) + ' %'
txt_obj.write(txt)
txt_obj.truncate() # 去除文件末尾剩余未使用的空白存储字节
txt_obj.close()
except StopIteration:
print('[ASRT Error] Model testing raise a error. Please check data format.')
def predict(self, data_input):
"""
预测结果
返回语音识别后的forward结果
"""
return self.speech_model.forward(data_input)
def recognize_speech(self, wavsignal, fs):
"""
最终做语音识别用的函数,识别一个wav序列的语音
"""
# 获取输入特征
data_input = self.speech_features.run(wavsignal, fs)
data_input = np.array(data_input, dtype=np.float64)
# print(data_input,data_input.shape)
data_input = data_input.reshape(data_input.shape[0], data_input.shape[1], 1)
r1 = self.predict(data_input)
# 获取拼音列表
list_symbol_dic, _ = load_pinyin_dict(load_config_file(DEFAULT_CONFIG_FILENAME)['dict_filename'])
r_str = []
for i in r1:
r_str.append(list_symbol_dic[i])
return r_str
def recognize_speech_from_file(self, filename):
"""
最终做语音识别用的函数,识别指定文件名的语音
"""
wavsignal, sample_rate, _, _ = read_wav_data(filename)
r = self.recognize_speech(wavsignal, sample_rate)
return r
@property
def model(self):
"""
返回tf.keras model
"""
return self.trained_model