-
Notifications
You must be signed in to change notification settings - Fork 21
/
train.lua
324 lines (287 loc) · 14.1 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
require 'torch'
require 'nn'
require 'nngraph'
-- local imports
require 'pm'
local utils = require 'misc.utils'
local net_utils = require 'misc.net_utils'
require 'misc.optim_updates'
require 'misc.DataLoaderRaw'
--require 'misc.DataLoader'
-------------------------------------------------------------------------------
-- Input arguments and options
-------------------------------------------------------------------------------
cmd = torch.CmdLine()
cmd:text()
cmd:text('Generating textures & images pixels by pixels')
cmd:text()
cmd:text('Options')
-- Data input settings
cmd:option('-folder_path','imgs/','path to the preprocessed textures')
cmd:option('-image_size',256,'resize the input image to')
cmd:option('-color', 1, 'whether the input image is color image or grayscale image')
--cmd:option('-input_h5','coco/data.h5','path to the h5file containing the preprocessed dataset')
--cmd:option('-input_json','coco/data.json','path to the json file containing additional info and vocab')
cmd:option('-start_from', '', 'path to a model checkpoint to initialize model weights from. Empty = don\'t')
-- Model settings
cmd:option('-rnn_size',200,'size of the rnn in number of hidden nodes in each layer')
cmd:option('-num_layers',2,'number of layers in stacked RNN/LSTMs')
cmd:option('-num_mixtures',10,'number of gaussian mixtures to encode the output pixel')
cmd:option('-patch_size',15,'size of the neighbor patch that a pixel is conditioned on')
cmd:option('-num_neighbors',3,'number of neighbors for each pixel')
cmd:option('-border_init', 0, 'value to init the pixel on the border.')
cmd:option('-input_shift', -0.5, 'shift the input by a constant, should get better performance.')
-- Optimization: General
cmd:option('-max_iters', 20000, 'max number of iterations to run for (-1 = run forever)')
cmd:option('-batch_size',16,'what is the batch size in number of images per batch? (there will be x seq_per_img sentences)')
cmd:option('-grad_clip',0.1,'clip gradients at this value (note should be lower than usual 5 because we normalize grads by both batch and seq_length)')
cmd:option('-drop_prob_pm', 0, 'strength of dropout in the Pixel Model')
cmd:option('-mult_in', true, 'An extension of the LSTM architecture')
cmd:option('-output_back', true, 'For 4D model, feed the output of the first sweep to the next sweep')
cmd:option('-grad_norm', false, 'whether to normalize the gradients for each direction')
cmd:option('-loss_policy', 'const', 'loss decay policy for spatial patch') -- exp for exponential decay, and linear for linear decay
cmd:option('-loss_decay', 0.9, 'loss decay rate for spatial patch')
cmd:option('-noise', 0, 'input perturbation by adding noise')
-- Optimization: for the Pixel Model
cmd:option('-optim','adam','what update to use? rmsprop|sgd|sgdmom|adagrad|adam')
cmd:option('-learning_rate',1e-3,'learning rate')
cmd:option('-learning_rate_decay_start', -1, 'at what iteration to start decaying learning rate? (-1 = dont)')
cmd:option('-learning_rate_decay_every', 5000, 'every how many iterations thereafter to drop LR by half?')
cmd:option('-optim_alpha',0.9,'alpha for adagrad/rmsprop/momentum/adam')
cmd:option('-optim_beta',0.999,'beta used for adam')
cmd:option('-optim_epsilon',1e-8,'epsilon that goes into denominator for smoothing')
-- Evaluation/Checkpointing
cmd:option('-save_checkpoint_every', 1000, 'how often to save a model checkpoint?')
cmd:option('-checkpoint_path', 'models', 'folder to save checkpoints into (empty = this folder)')
cmd:option('-losses_log_every', 25, 'How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)')
-- misc
cmd:option('-backend', 'cudnn', 'nn|cudnn')
cmd:option('-id', '', 'an id identifying this run/job. used in cross-val and appended when writing progress files')
cmd:option('-seed', 123, 'random number generator seed to use')
cmd:option('-gpuid', 0, 'which gpu to use. -1 = use CPU')
cmd:text()
-------------------------------------------------------------------------------
-- Basic Torch initializations
-------------------------------------------------------------------------------
local opt = cmd:parse(arg)
torch.manualSeed(opt.seed)
torch.setdefaulttensortype('torch.FloatTensor') -- for CPU
if opt.gpuid >= 0 then
require 'cutorch'
require 'cunn'
if opt.backend == 'cudnn' then require 'cudnn' end
cutorch.manualSeed(opt.seed)
cutorch.setDevice(opt.gpuid + 1) -- note +1 because lua is 1-indexed
end
-------------------------------------------------------------------------------
-- Create the Data Loader instance
-------------------------------------------------------------------------------
local loader = DataLoaderRaw{folder_path = opt.folder_path, shift = opt.input_shift,
img_size = opt.image_size, color = opt.color}
opt.data_info = loader:getChannelScale()
-------------------------------------------------------------------------------
-- Initialize the networks
-------------------------------------------------------------------------------
local protos = {}
local iter = 0
if string.len(opt.start_from) > 0 then
-- load protos from file
print('initializing weights from ' .. opt.start_from)
local loaded_checkpoint = torch.load(opt.start_from)
protos = loaded_checkpoint.protos
local pm_modules = protos.pm:getModulesList()
for k,v in pairs(pm_modules) do net_utils.unsanitize_gradients(v) end
protos.crit = nn.PixelModelCriterion(protos.pm.pixel_size, protos.pm.num_mixtures,
{policy=loaded_checkpoint.opt.loss_policy, val=loaded_checkpoint.opt.loss_decay}) -- not in checkpoints, create manually
iter = loaded_checkpoint.iter
else
-- create protos from scratch
-- intialize pixel model
local pmOpt = {}
pmOpt.pixel_size = loader:getChannelSize()
pmOpt.rnn_size = opt.rnn_size
pmOpt.num_mixtures = opt.num_mixtures
pmOpt.num_layers = opt.num_layers
pmOpt.dropout = opt.drop_prob_pm
pmOpt.batch_size = opt.batch_size
pmOpt.recurrent_stride = opt.patch_size
pmOpt.seq_length = opt.patch_size * opt.patch_size
pmOpt.mult_in = opt.mult_in
pmOpt.num_neighbors = opt.num_neighbors
pmOpt.border_init = opt.border_init
pmOpt.output_back = opt.output_back
if opt.num_neighbors == 2 then
protos.pm = nn.PixelModel(pmOpt)
elseif opt.num_neighbors == 3 then
protos.pm = nn.PixelModel3N(pmOpt)
elseif opt.num_neighbors == 4 then
protos.pm = nn.PixelModel4N(pmOpt)
else
print('the number of neighbors should be between 2 - 4')
end
-- criterion for the pixel model
protos.crit = nn.PixelModelCriterion(pmOpt.pixel_size, pmOpt.num_mixtures,
{policy=opt.loss_policy, val=opt.loss_decay})
end
-- ship everything to GPU, maybe
if opt.gpuid >= 0 then
for k,v in pairs(protos) do v:cuda() end
end
print('Training a 2D LSTM with number of layers: ', opt.num_layers)
print('Number of pixels in the neighbor: ', opt.num_neighbors)
print('Hidden nodes in each layer: ', opt.rnn_size)
print('Number of mixtures for output gaussians: ', opt.num_mixtures)
print('The input image local patch size: ', opt.patch_size)
print('Input channel dimension: ', opt.color*2+1)
print('Input pixel shift: ', opt.input_shift)
print('Border pixel init: ', opt.border_init)
print('Training batch size: ', opt.batch_size)
-- flatten and prepare all model parameters to a single vector.
local params, grad_params = protos.pm:getParameters()
print('total number of parameters in PM: ', params:nElement())
assert(params:nElement() == grad_params:nElement())
-- construct thin module clones that share parameters with the actual
-- modules. These thin module will have no intermediates and will be used
-- for checkpointing to write significantly smaller checkpoint files
local thin_pm = protos.pm:clone()
thin_pm.core:share(protos.pm.core, 'weight', 'bias') -- TODO: we are assuming that PM has specific members! figure out clean way to get rid of, not modular.
-- sanitize all modules of gradient storage so that we dont save big checkpoints
local pm_modules = thin_pm:getModulesList()
for k,v in pairs(pm_modules) do net_utils.sanitize_gradients(v) end
-- create clones and ensure parameter sharing. we have to do this
-- all the way here at the end because calls such as :cuda() and
-- :getParameters() reshuffle memory around.
protos.pm:createClones()
collectgarbage() -- "yeah, sure why not"
-------------------------------------------------------------------------------
-- Validation evaluation
-------------------------------------------------------------------------------
local function eval_split(n)
protos.pm:evaluate()
--loader:resetIterator(split) -- rewind iteator back to first datapoint in the split
local loss_sum = 0
local i = 0
while i < n do
-- fetch a batch of data
local data = loader:getBatch{batch_size = opt.batch_size, num_neighbors = opt.num_neighbors,
patch_size = opt.patch_size, gpu = opt.gpuid, split = 'val',
border = opt.border_init, noise = opt.noise}
-- forward the model to get loss
local gmms = protos.pm:forward(data.pixels)
--print(gmms)
local loss = protos.crit:forward(gmms, data.targets)
loss_sum = loss_sum + loss
i = i + 1
if i % 10 == 0 then collectgarbage() end
end
return loss_sum/n
end
-------------------------------------------------------------------------------
-- Loss function
-------------------------------------------------------------------------------
local function lossFun()
protos.pm:training()
grad_params:zero()
-----------------------------------------------------------------------------
-- Forward pass
-----------------------------------------------------------------------------
-- get batch of data
--local timer = torch.Timer()
local data = loader:getBatch{batch_size = opt.batch_size, num_neighbors = opt.num_neighbors,
patch_size = opt.patch_size, gpu = opt.gpuid, split = 'train',
border = opt.border_init, noise = opt.noise}
-- forward the pixel model
local gmms = protos.pm:forward(data.pixels)
--print('Forward time: ' .. timer:time().real .. ' seconds')
-- forward the pixel model criterion
local loss = protos.crit:forward(gmms, data.targets)
-----------------------------------------------------------------------------
-- Backward pass
-----------------------------------------------------------------------------
-- backprop criterion
local dgmms = protos.crit:backward(gmms, data.targets)
--print('Criterion time: ' .. timer:time().real .. ' seconds')
-- backprop pixel model
local dpixels = protos.pm:backward(data.pixels, dgmms)
--print('Backward time: ' .. timer:time().real .. ' seconds')
-- normalize the gradients for different directions
if opt.grad_norm then
protos.pm:norm_grad(grad_params)
end
-- clip gradients
-- print(string.format('claming %f%% of gradients', 100*torch.mean(torch.gt(torch.abs(grad_params), opt.grad_clip))))
grad_params:clamp(-opt.grad_clip, opt.grad_clip)
-----------------------------------------------------------------------------
-- and lets get out!
local losses = { total_loss = loss }
return losses
end
-------------------------------------------------------------------------------
-- Main loop
-------------------------------------------------------------------------------
local loss0
local optim_state = {}
local loss_history = {}
local val_loss_history = {}
local best_score
while true do
iter = iter + 1
-- decay the learning rate
local learning_rate = opt.learning_rate
if iter > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0 then
local frac = (iter - opt.learning_rate_decay_start) / opt.learning_rate_decay_every
local decay_factor = math.pow(0.5, frac)
learning_rate = learning_rate * decay_factor -- set the decayed rate
end
-- eval loss/gradient
local losses = lossFun()
if iter % opt.losses_log_every == 0 then loss_history[iter] = losses.total_loss end
print(string.format('iter %d: %f. LR: %f', iter, losses.total_loss, learning_rate))
-- save checkpoint once in a while (or on final iteration)
if (iter % opt.save_checkpoint_every == 0 or iter == opt.max_iters) then
-- evaluate the validation performance
local val_loss = eval_split(10)
print('validation loss: ', val_loss)
-- val_loss_history[iter] = val_loss
local checkpoint_path = path.join(opt.checkpoint_path, 'model_id' .. opt.id .. iter)
-- write a (thin) json report
local checkpoint = {}
checkpoint.opt = opt
checkpoint.iter = iter
checkpoint.loss_history = loss_history
-- checkpoint.val_loss_history = val_loss_history
-- checkpoint.val_predictions = val_predictions -- save these too for CIDEr/METEOR/etc eval
-- include the protos (which have weights) and save to file
local save_protos = {}
save_protos.pm = thin_pm -- these are shared clones, and point to correct param storage
checkpoint.protos = save_protos
torch.save(checkpoint_path .. '.t7', checkpoint)
print('wrote checkpoint to ' .. checkpoint_path .. '.t7')
-- utils.write_json(checkpoint_path .. '.json', checkpoint)
-- print('wrote json checkpoint to ' .. checkpoint_path .. '.json')
end
-- perform a parameter update
if opt.optim == 'rmsprop' then
rmsprop(params, grad_params, learning_rate, opt.optim_alpha, opt.optim_epsilon, optim_state)
elseif opt.optim == 'adagrad' then
adagrad(params, grad_params, learning_rate, opt.optim_epsilon, optim_state)
elseif opt.optim == 'sgd' then
sgd(params, grad_params, opt.learning_rate)
elseif opt.optim == 'sgdm' then
sgdm(params, grad_params, learning_rate, opt.optim_alpha, optim_state)
elseif opt.optim == 'sgdmom' then
sgdmom(params, grad_params, learning_rate, opt.optim_alpha, optim_state)
elseif opt.optim == 'adam' then
adam(params, grad_params, learning_rate, opt.optim_alpha, opt.optim_beta, opt.optim_epsilon, optim_state)
else
error('bad option opt.optim')
end
-- stopping criterions
if iter % 10 == 0 then collectgarbage() end -- good idea to do this once in a while, i think
if loss0 == nil then loss0 = losses.total_loss end
if losses.total_loss > 2000 then
print('loss seems to be exploding, quitting.')
break
end
if opt.max_iters > 0 and iter >= opt.max_iters then break end -- stopping criterion
end