-
Notifications
You must be signed in to change notification settings - Fork 0
/
AN4CTCTrain.lua
49 lines (41 loc) · 1.28 KB
/
AN4CTCTrain.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
--[[Trains the CTC model using the AN4 audio database.]]
local Network = require 'Network'
--Training parameters
torch.setdefaulttensortype('torch.FloatTensor')
seed = 10
torch.manualSeed(seed)
cutorch.manualSeedAll(seed)
local epochs = 70
local networkParams = {
loadModel = false,
saveModel = true,
modelName = 'DeepSpeechModel',
backend = 'cudnn',
nGPU = 1, -- Number of GPUs, set -1 to use CPU
trainingSetLMDBPath = './prepare_an4/train/',-- online loading path data.
validationSetLMDBPath = './prepare_an4/test/',
logsTrainPath = './logs/TrainingLoss/',
logsValidationPath = './logs/ValidationScores/',
modelTrainingPath = './models/',
fileName = 'CTCNetwork.t7',
dictionaryPath = './dictionary',
batchSize = 20,
validationBatchSize = 2,
validationIterations = 1,
saveModelIterations = 50
}
--Parameters for the stochastic gradient descent (using the optim library).
local sgdParams = {
learningRate = 1e-3,
learningRateDecay = 1e-9,
weightDecay = 0,
momentum = 0.9,
dampening = 0,
nesterov = true
}
--Create and train the network based on the parameters and training data.
Network:init(networkParams)
Network:trainNetwork(epochs, sgdParams)
--Creates the loss plot.
Network:createLossGraph()
print("finished")