-
Notifications
You must be signed in to change notification settings - Fork 0
/
WEREvaluator.lua
129 lines (112 loc) · 4.55 KB
/
WEREvaluator.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
require 'Loader'
require 'Util'
require 'Mapper'
require 'torch'
require 'xlua'
require 'cutorch'
local threads = require 'threads'
local Evaluator = require 'Evaluator'
local WEREvaluator = torch.class('WEREvaluator')
local _loader
function WEREvaluator:__init(_path, mapper, testBatchSize, nbOfTestIterations, logsPath)
_loader = Loader(_path)
self.testBatchSize = testBatchSize
self.nbOfTestIterations = nbOfTestIterations
self.indexer = indexer(_path, testBatchSize)
self.pool = threads.Threads(1, function() require 'Loader' end)
self.mapper = mapper
self.logsPath = logsPath
self.suffix = '_' .. os.date('%Y%m%d_%H%M%S')
end
function WEREvaluator:predictCTC(src, nGPU)
local gpu_number = nGPU or 1
return src:view(-1, self.testBatchSize / gpu_number, src:size(2)):transpose(1,2)
end
function WEREvaluator:getWER(gpu, model, calSizeOfSequences, verbose, epoch)
--[[
load test_iter*batch_size data point from test set; compute average WER
input:
verbose:if true then print WER and predicted strings for each data to log
--]]
local cumWER = 0
local inputs = torch.Tensor()
if (gpu) then
inputs = inputs:cuda()
end
local spect_buf, label_buf, sizes_buf
-- get first batch
local inds = self.indexer:nxt_inds()
self.pool:addjob(function()
return _loader:nxt_batch(inds, false)
end,
function(spect, label, sizes)
spect_buf = spect
label_buf = label
sizes_buf = sizes
end)
if verbose then
local f = assert(io.open(self.logsPath .. 'WER_Test' .. self.suffix .. '.log', 'a'), "Could not create validation test logs, does the folder "
.. self.logsPath .. " exist?")
f:write('======================== BEGIN WER TEST EPOCH: ' .. epoch .. ' =========================\n')
f:close()
end
local werPredictions = {} -- stores the predictions to order for log.
-- ======================= for every test iteration ==========================
for i = 1, self.nbOfTestIterations do
-- get buf and fetch next one
self.pool:synchronize()
local inputsCPU, targets, sizes_array = spect_buf, label_buf, sizes_buf
inds = self.indexer:nxt_inds()
self.pool:addjob(function()
return _loader:nxt_batch(inds, true)
end,
function(spect, label, sizes)
spect_buf = spect
label_buf = label
sizes_buf = sizes
end)
sizes_array = calSizeOfSequences(sizes_array)
inputs:resize(inputsCPU:size()):copy(inputsCPU)
local predictions = model:forward(inputs)
if type(predictions) == 'table' then
local temp = self:predictCTC(predictions[1], #predictions)
for k = 2, #predictions do
temp = torch.cat(temp, self:predictCTC(predictions[k], #predictions), 1)
end
predictions = temp
else
predictions = self:predictCTC(predictions)
end
-- =============== for every data point in this batch ==================
for j = 1, self.testBatchSize do
local prediction_single = predictions[j]
local predict_tokens = Evaluator.predict2tokens(prediction_single, self.mapper)
local WER = Evaluator.sequenceErrorRate(targets[j], predict_tokens)
cumWER = cumWER + WER
table.insert(werPredictions, { wer = WER * 100, target = self:tokens2text(targets[j]), prediction = self:tokens2text(predict_tokens) })
end
end
local function comp(a, b) return a.wer < b.wer end
table.sort(werPredictions, comp)
if verbose then
for index, werPrediction in ipairs(werPredictions) do
local f = assert(io.open(self.logsPath .. 'WER_Test' .. self.suffix .. '.log', 'a'))
f:write(string.format("WER = %.2f%% | Text = \"%s\" | Predict = \"%s\"\n",
werPrediction.wer, werPrediction.target, werPrediction.prediction))
f:close()
end
end
local averageWER = cumWER / (self.nbOfTestIterations * self.testBatchSize)
local f = assert(io.open(self.logsPath .. 'WER_Test' .. self.suffix .. '.log', 'a'))
f:write(string.format("Average WER = %.2f%%", averageWER * 100))
f:close()
self.pool:synchronize() -- end the last loading
return averageWER
end
function WEREvaluator:tokens2text(tokens)
local text = ""
for i, t in ipairs(tokens) do
text = text .. self.mapper.token2alphabet[tokens[i]]
end
return text
end