-
Notifications
You must be signed in to change notification settings - Fork 10
/
train.lua
57 lines (50 loc) · 1.48 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
require 'nn';
local train = {}
function train.accuracy(Xv,Yv,net,batch)
net:evaluate()
local batch = batch or 64
local Nv = Xv:size(1)
local lloss = 0
for i =1,Nv,batch do
local j = math.min(i+batch-1,Nv)
local Xb = Xv[{{i,j}}]:cuda()
local Yb = Yv[{{i,j}}]:cuda()
local out = net:forward(Xb) -- N*k*C
local tmp,YYb = out:max(2)
lloss = lloss + YYb:eq(Yb):sum()
end
return (100*lloss/(Nv))
end
function train.sgd(net,ct,Xt,Yt,Xv,Yv,K,sgd_config,batch)
local x,dx = net:getParameters()
require 'optim'
local batch = batch or 64
local Nt = Xt:size(1)
print('parameters size ..')
print(#x)
for k=1,K do
local lloss = 0
net:training()
for i =1,Nt,batch do
dx:zero()
local j = math.min(i+batch-1,Nt)
local Xb = Xt[{{i,j}}]:cuda()
local Yb = Yt[{{i,j}}]:cuda()
local out = net:forward(Xb)
local loss = ct:forward(out,Yb)
local dout = ct:backward(out,Yb)
net:backward(Xb,dout)
dx:div(j-i+1)
function feval()
return loss,dx
end
local ltmp,tmp = optim.sgd(feval,x,sgd_config)
--print(loss)
lloss = lloss + loss
end
print('loss..'..lloss)
print('valid .. '.. train.accuracy(Xv,Yv,net,batch))
print('train .. '.. train.accuracy(Xt,Yt,net,batch))
end
end
return train