-
Notifications
You must be signed in to change notification settings - Fork 38
/
train_image_classification.py
798 lines (648 loc) · 30 KB
/
train_image_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
'''
Copyright (C) 2010-2021 Alibaba Group Holding Limited.
'''
import os, sys, copy, time, logging
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
try:
import horovod.torch as hvd
except ImportError:
print('fail to import hvd.')
try:
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
from apex import amp, optimizers
from apex.multi_tensor_apply import multi_tensor_applier
except ImportError:
print('fail to import apex.')
import ModelLoader, DataLoader
import global_utils
def save_checkpoint(checkpoint_filename, state_dict):
save_dir = os.path.dirname(checkpoint_filename)
base_filename = os.path.basename(checkpoint_filename)
backup_filename = os.path.join(save_dir, base_filename + '.backup')
global_utils.mkdir(save_dir)
if os.path.isfile(checkpoint_filename):
if os.path.isfile(backup_filename):
os.remove(backup_filename)
os.rename(checkpoint_filename, backup_filename)
torch.save(state_dict, checkpoint_filename)
if os.path.isfile(backup_filename):
os.remove(backup_filename)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':4g', disp_avg=True):
self.name = name
self.fmt = fmt
self.disp_avg = disp_avg
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name}{val' + self.fmt + '}'
fmtstr = fmtstr.format(name=self.name, val=self.val)
if self.disp_avg:
fmtstr += '({avg' + self.fmt + '})'
fmtstr = fmtstr.format(avg=self.avg)
return fmtstr.format(**self.__dict__)
def accuracy(output, target, topk=(1, )):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def split_weights(net):
"""split network weights into to categlories,
one are weights in conv layer and linear layer,
others are other learnable paramters(conv bias,
bn weights, bn bias, linear bias)
Args:
net: network architecture
Returns:
a dictionary of params splite into to categlories
"""
decay = []
no_decay = []
for m in net.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
decay.append(m.weight)
if m.bias is not None:
no_decay.append(m.bias)
else:
if hasattr(m, 'weight'):
no_decay.append(m.weight)
if hasattr(m, 'bias'):
no_decay.append(m.bias)
assert len(list(net.parameters())) == len(decay) + len(no_decay)
return [dict(params=decay), dict(params=no_decay, weight_decay=0)]
def network_weight_MSRAPrelu_init(net: nn.Module):
# the gain of xavier_normal_ is computed from gain=magnitude * sqrt(3) where magnitude is 2/(1+0.25**2). [mxnet implementation]
for m in net.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data, gain=3.26033)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 3.26033 * np.sqrt(2 / (m.weight.shape[0] + m.weight.shape[1])))
if hasattr(m, 'bias') and m.bias is not None:
nn.init.zeros_(m.bias)
else:
pass
return net
def network_weight_xavier_init(net: nn.Module):
for m in net.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data, gain=nn.init.calculate_gain('relu'))
if hasattr(m, 'bias') and m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 3.26033 * np.sqrt(2 / (m.weight.shape[0] + m.weight.shape[1])))
if hasattr(m, 'bias') and m.bias is not None:
nn.init.zeros_(m.bias)
else:
pass
return net
def network_weight_stupid_init(net: nn.Module):
with torch.no_grad():
for m in net.modules():
if isinstance(m, nn.Conv2d):
device = m.weight.device
in_channels, out_channels, k1, k2 = m.weight.shape
m.weight[:] = torch.randn(m.weight.shape, device=device) / np.sqrt(k1 * k2 * in_channels)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
device = m.weight.device
in_channels, out_channels = m.weight.shape
m.weight[:] = torch.randn(m.weight.shape, device=device) / np.sqrt(in_channels)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.zeros_(m.bias)
else:
continue
return net
def network_weight_zero_init(net: nn.Module):
with torch.no_grad():
for m in net.modules():
if isinstance(m, nn.Conv2d):
device = m.weight.device
in_channels, out_channels, k1, k2 = m.weight.shape
m.weight[:] = torch.randn(m.weight.shape, device=device) / np.sqrt(k1 * k2 * in_channels) * 1e-3
if hasattr(m, 'bias') and m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
device = m.weight.device
in_channels, out_channels = m.weight.shape
m.weight[:] = torch.randn(m.weight.shape, device=device) / np.sqrt(in_channels) * 1e-3
if hasattr(m, 'bias') and m.bias is not None:
nn.init.zeros_(m.bias)
else:
continue
return net
def network_weight_01_init(net: nn.Module):
with torch.no_grad():
for m in net.modules():
if isinstance(m, nn.Conv2d):
device = m.weight.device
in_channels, out_channels, k1, k2 = m.weight.shape
m.weight[:] = torch.randn(m.weight.shape, device=device) / np.sqrt(k1 * k2 * in_channels) * 0.1
if hasattr(m, 'bias') and m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
device = m.weight.device
in_channels, out_channels = m.weight.shape
m.weight[:] = torch.randn(m.weight.shape, device=device) / np.sqrt(in_channels) * 0.1
if hasattr(m, 'bias') and m.bias is not None:
nn.init.zeros_(m.bias)
else:
continue
return net
def mixup(input, target, alpha=0.2):
gamma = np.random.beta(alpha, alpha)
# target is onehot format!
perm = torch.randperm(input.size(0))
perm_input = input[perm]
perm_target = target[perm]
return input.mul_(gamma).add_(1 - gamma, perm_input), target.mul_(gamma).add_(1 - gamma, perm_target)
def one_hot(y, num_classes, smoothing_eps=None):
if smoothing_eps is None:
one_hot_y = F.one_hot(y, num_classes).float()
return one_hot_y
else:
one_hot_y = F.one_hot(y, num_classes).float()
v1 = 1 - smoothing_eps + smoothing_eps / float(num_classes)
v0 = smoothing_eps / float(num_classes)
new_y = one_hot_y * (v1 - v0) + v0
return new_y
def cross_entropy(logit, target):
# target must be one-hot format!!
prob_logit = F.log_softmax(logit, dim=1)
loss = -(target * prob_logit).sum(dim=1).mean()
return loss
def config_dist_env_and_opt(opt):
opt = copy.copy(opt)
# set world_size, gpu, global rank
if opt.dist_mode == 'cpu':
opt.gpu = None
opt.world_size = 1
opt.rank = 0
elif opt.dist_mode == 'single':
if opt.gpu is None:
opt.gpu = 0
opt.world_size = 1
opt.rank = 0
torch.cuda.set_device(opt.gpu)
elif opt.dist_mode == 'auto':
opt.AutoGPU = global_utils.AutoGPU()
opt.gpu = opt.AutoGPU.gpu
opt.world_size = 1
opt.rank = 0
torch.cuda.set_device(opt.gpu)
elif opt.dist_mode == 'mpi':
from mpi4py import MPI
mpi_comm = MPI.COMM_WORLD
mpi_rank = mpi_comm.Get_rank()
mpi_size = mpi_comm.Get_size()
opt.world_size = mpi_size
opt.rank = mpi_rank
opt.AutoGPU = global_utils.AutoGPU()
opt.gpu = opt.AutoGPU.gpu
torch.cuda.set_device(opt.gpu)
elif opt.dist_mode == 'horovod':
hvd.init()
# Horovod: pin GPU to local rank.
opt.gpu = hvd.local_rank()
torch.cuda.set_device(opt.gpu)
opt.world_size = hvd.size()
opt.rank = hvd.rank()
else:
raise ValueError('unknown dist_mode={}'.format(opt.dist_mode))
if not opt.dist_mode == 'cpu':
torch.backends.cudnn.benchmark = True
# adjust batch_size and learning rate
if opt.batch_size is None:
opt.batch_size = opt.batch_size_per_gpu * opt.world_size
if opt.lr is None:
opt.lr = opt.lr_per_256 * opt.batch_size / 256.0
if opt.target_lr is None:
opt.target_lr = opt.target_lr_per_256 * opt.batch_size / 256.0
return opt
def init_model(model, opt, argv):
if hasattr(opt, 'weight_init') and opt.weight_init == 'xavier':
network_weight_xavier_init(model)
elif hasattr(opt, 'weight_init') and opt.weight_init == 'MSRAPrelu':
network_weight_MSRAPrelu_init(model)
elif hasattr(opt, 'weight_init') and opt.weight_init == 'stupid':
network_weight_stupid_init(model)
elif hasattr(opt, 'weight_init') and opt.weight_init == 'zero':
network_weight_zero_init(model)
elif hasattr(opt, 'weight_init') and opt.weight_init == '01':
network_weight_01_init(model)
elif hasattr(opt, 'weight_init') and opt.weight_init == 'custom':
assert hasattr(model, 'init_parameters')
model.init_parameters()
elif hasattr(opt, 'weight_init') and opt.weight_init == 'None':
logging.info('Warning!!! model loaded without initialization !')
else:
raise ValueError('Unknown weight_init')
if hasattr(opt, 'bn_momentum') and opt.bn_momentum is not None:
for layer in model.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.momentum = opt.bn_momentum
if hasattr(opt, 'bn_eps') and opt.bn_eps is not None:
for layer in model.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.eps = opt.bn_eps
return model
def get_optimizer(model, opt):
params = split_weights(model)
if opt.optimizer == 'sgd':
optimizer = torch.optim.SGD(params,
opt.lr,
momentum=opt.momentum,
weight_decay=opt.weight_decay,
nesterov=opt.nesterov)
elif opt.optimizer == 'adadelta':
optimizer = torch.optim.Adadelta(params,
opt.lr,
opt.adadelta_rho,
opt.adadelta_eps,
weight_decay=opt.weight_decay)
elif opt.optimizer == 'adam':
optimizer = torch.optim.Adam(params, opt.lr, weight_decay=opt.weight_decay)
elif opt.optimizer == 'rmsprop':
optimizer = torch.optim.RMSprop(params, opt.lr, alpha=0.9, momentum=opt.momentum, weight_decay=opt.weight_decay)
else:
raise ValueError('Unknown optimizer: ' + opt.optimizer)
return optimizer
def load_model(model, load_parameters_from, strict_load=False, map_location='cpu'):
logging.info('loading params from ' + load_parameters_from)
checkpoint = torch.load(load_parameters_from, map_location=map_location)
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
model.load_state_dict(state_dict, strict=strict_load)
return model
def resume_checkpoint(model, optimizer, checkpoint_filename, opt, map_location='cpu'):
logging.info('resuming from ' + checkpoint_filename)
checkpoint = torch.load(checkpoint_filename, map_location=map_location)
assert 'state_dict' in checkpoint
state_dict = checkpoint['state_dict']
model.load_state_dict(state_dict, strict=True)
optimizer.load_state_dict(checkpoint['optimizer'])
opt.start_epoch = checkpoint['epoch'] + 1
training_status_info = checkpoint['training_status_info']
return model, optimizer, training_status_info, opt
def config_model_optimizer_hvd_and_apex(model, optimizer, opt):
if opt.dist_mode == 'horovod' and not opt.independent_training:
# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if opt.fp16_allreduce else hvd.Compression.none
optimizer = hvd.DistributedOptimizer(optimizer,
named_parameters=model.named_parameters(),
compression=compression,
backward_passes_per_step=opt.batches_per_allreduce)
# Horovod: broadcast parameters & optimizer state.
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
if opt.apex:
if opt.apex_loss_scale != 'dynamic':
apex_loss_scale = float(opt.apex_loss_scale)
else:
apex_loss_scale = 'dynamic'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt.apex_opt_level, loss_scale=apex_loss_scale)
return model, optimizer
def train_one_epoch(train_loader, model, criterion, optimizer, epoch, opt, num_train_samples, no_acc_eval=False):
info = {}
losses = AverageMeter('Loss ', ':6.4g')
top1 = AverageMeter('Acc@1 ', ':6.2f')
top5 = AverageMeter('Acc@5 ', ':6.2f')
# switch to train mode
model.train()
lr_scheduler = global_utils.LearningRateScheduler(mode=opt.lr_mode,
lr=opt.lr,
num_training_instances=num_train_samples,
target_lr=opt.target_lr,
stop_epoch=opt.epochs,
warmup_epoch=opt.warmup,
stage_list=opt.lr_stage_list,
stage_decay=opt.lr_stage_decay)
lr_scheduler.update_lr(batch_size=epoch * num_train_samples)
optimizer.zero_grad()
for i, (input, target) in enumerate(train_loader):
if not opt.independent_training:
lr_scheduler.update_lr(batch_size=input.shape[0] * opt.world_size)
else:
lr_scheduler.update_lr(batch_size=input.shape[0])
pass # end if
current_lr = lr_scheduler.get_lr()
for param_group in optimizer.param_groups:
param_group['lr'] = current_lr
bool_label_smoothing = False
bool_mixup = False
if not opt.dist_mode == 'cpu':
input = input.cuda(opt.gpu, non_blocking=True)
target = target.cuda(opt.gpu, non_blocking=True)
transformed_target = target
with torch.no_grad():
if hasattr(opt, 'label_smoothing') and opt.label_smoothing:
bool_label_smoothing = True
if hasattr(opt, 'mixup') and opt.mixup:
bool_mixup = True
if bool_label_smoothing and not bool_mixup:
transformed_target = one_hot(target, num_classes=opt.num_classes, smoothing_eps=0.1)
if not bool_label_smoothing and bool_mixup:
transformed_target = one_hot(target, num_classes=opt.num_classes)
input, transformed_target = mixup(input, transformed_target)
if bool_label_smoothing and bool_mixup:
transformed_target = one_hot(target, num_classes=opt.num_classes, smoothing_eps=0.1)
input, transformed_target = mixup(input, transformed_target)
pass # end with
# compute output
output = model(input)
loss = criterion(output, transformed_target)
# measure accuracy and record loss
input_size = int(input.size(0))
if not no_acc_eval:
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(float(acc1[0]), input_size)
top5.update(float(acc5[0]), input_size)
else:
acc1 = [0]
acc5 = [0]
losses.update(float(loss), input_size)
if opt.apex:
if opt.dist_mode == 'horovod':
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if opt.grad_clip is not None:
torch.nn.utils.clip_grad_value_(model.parameters(), opt.grad_clip)
optimizer.synchronize()
with optimizer.skip_synchronize():
optimizer.step()
else:
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if opt.grad_clip is not None:
torch.nn.utils.clip_grad_value_(model.parameters(), opt.grad_clip)
optimizer.step()
else:
optimizer.zero_grad()
loss.backward()
if opt.grad_clip is not None:
torch.nn.utils.clip_grad_value_(model.parameters(), opt.grad_clip)
optimizer.step()
if i % opt.print_freq == 0 and opt.rank == 0:
logging.info('Train epoch={}, i={}, loss={:4g}, acc1={:4g}%, acc5={:4g}%, lr={:4g}'.format(
epoch, i, float(loss), float(acc1[0]), float(acc5[0]), current_lr
))
pass # end if
pass # end for i
# if distributed, sync
if opt.dist_mode == 'horovod' and (not opt.independent_training):
sync_tensor = torch.tensor([losses.sum, losses.count], dtype=torch.float32)
hvd.allreduce(sync_tensor, name='sync_tensor_topk_acc')
losses_avg = (sync_tensor[0] / sync_tensor[1]).item()
else:
losses_avg = losses.avg
info['losses'] = losses_avg
return info
def validate(val_loader, model, criterion, opt, epoch='N/A'):
losses = AverageMeter('Loss ', ':6.4g')
top1 = AverageMeter('Acc@1 ', ':6.2f')
top5 = AverageMeter('Acc@5 ', ':6.2f')
# switch to evaluate mode
model.eval()
with torch.no_grad():
for i, (input, target) in enumerate(val_loader):
transformed_target = target
if (hasattr(opt, 'label_smoothing') and opt.label_smoothing) or (hasattr(opt, 'mixup') and opt.mixup):
transformed_target = one_hot(transformed_target, num_classes=opt.num_classes, smoothing_eps=None)
if not opt.dist_mode == 'cpu':
input = input.cuda(opt.gpu, non_blocking=True)
target = target.cuda(opt.gpu, non_blocking=True)
transformed_target = transformed_target.cuda(opt.gpu, non_blocking=True)
# compute output
output = model(input)
if criterion is not None:
loss = criterion(output, transformed_target)
else:
loss = torch.tensor([0])
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
input_size = int(input.size(0))
losses.update(float(loss), input_size)
top1.update(float(acc1[0]), input_size)
top5.update(float(acc5[0]), input_size)
if i % opt.print_freq == 0 and opt.rank == 0:
logging.info('Eval epoch={}, i={}, loss={:4g}, acc1={:4g}%, acc5={:4g}%'.format(
epoch, i, float(loss), float(acc1[0]), float(acc5[0])))
pass # end for
pass # end with
top1_acc_avg = top1.avg
top5_acc_avg = top5.avg
total_val_count = top1.count
# if distributed, sync
if opt.dist_mode == 'horovod' and (not opt.independent_training):
sync_tensor = torch.tensor([top1.sum, top1.count, top5.sum, top5.count], dtype=torch.float32)
sync_tensor = hvd.allreduce(sync_tensor, average=False, name='sync_tensor_topk_acc')
top1_acc_avg = (sync_tensor[0] / sync_tensor[1]).item()
top5_acc_avg = (sync_tensor[2] / sync_tensor[3]).item()
total_val_count = sync_tensor[1].item()
else:
pass
logging.info(' * Validate Acc@1 {:.3f} Acc@5 {:.3f}, n_val={}'.format(top1_acc_avg, top5_acc_avg, total_val_count))
return {'top1_acc': top1_acc_avg, 'top5_acc': top5_acc_avg}
def train_all_epochs(opt, model, optimizer, train_sampler, train_loader, criterion, val_loader, num_train_samples=None,
no_acc_eval=False, save_all_ranks=False, training_status_info=None, save_params=True):
timer_start = time.time()
if training_status_info is None:
training_status_info = {}
training_status_info['best_acc1'] = 0
training_status_info['best_acc5'] = 0
training_status_info['best_acc1_at_epoch'] = 0
training_status_info['best_acc5_at_epoch'] = 0
training_status_info['training_elasped_time'] = 0
training_status_info['validation_elasped_time'] = 0
if num_train_samples is None:
num_train_samples = len(train_loader)
for epoch in range(opt.start_epoch, opt.epochs):
logging.info('--- Start training epoch {}'.format(epoch))
if train_sampler is not None:
train_sampler.set_epoch(epoch)
# train for one epoch
training_timer_start = time.time()
train_one_epoch_info = train_one_epoch(train_loader, model, criterion, optimizer, epoch, opt, num_train_samples,
no_acc_eval=no_acc_eval)
training_status_info['training_elasped_time'] += time.time() - training_timer_start
# evaluate on validation set
if val_loader is not None:
validation_timer_start = time.time()
validate_info = validate(val_loader, model, criterion, opt, epoch=epoch)
training_status_info['validation_elasped_time'] += time.time() - validation_timer_start
acc1 = validate_info['top1_acc']
acc5 = validate_info['top5_acc']
else:
acc1 = 0
acc5 = 0
# remember best acc@1 and save checkpoint
is_best_acc1 = acc1 > training_status_info['best_acc1']
is_best_acc5 = acc5 > training_status_info['best_acc5']
training_status_info['best_acc1'] = max(acc1, training_status_info['best_acc1'])
training_status_info['best_acc5'] = max(acc5, training_status_info['best_acc5'])
if is_best_acc1:
training_status_info['best_acc1_at_epoch'] = epoch
if is_best_acc5:
training_status_info['best_acc5_at_epoch'] = epoch
elasped_hour = (time.time() - timer_start) / 3600
remaining_hour = (time.time() - timer_start) / float(epoch - opt.start_epoch + 1) * (opt.epochs - epoch) / 3600
logging.info(
'--- Epoch={}, Elasped hour={:8.4g}, Remaining hour={:8.4g}, Training Speed={:4g},'
' best_acc1={:4g}, best_acc1_at_epoch={}, best_acc5={}, best_acc5_at_epoch={}'.format(
epoch, elasped_hour, remaining_hour,
num_train_samples * (epoch + 1) / float(training_status_info['training_elasped_time'] + 1e-8),
training_status_info['best_acc1'], training_status_info['best_acc1_at_epoch'],
training_status_info['best_acc5'], training_status_info['best_acc5_at_epoch']
))
# ----- save latest epoch -----#
if save_params and (opt.rank == 0 or save_all_ranks) and \
((epoch + 1) % opt.save_freq == 0 or epoch + 1 == opt.epochs):
checkpoint_filename = os.path.join(opt.save_dir, 'latest-params_rank{}.pth'.format(opt.rank))
save_checkpoint(checkpoint_filename, {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'top1_acc': acc1,
'top5_acc': acc5,
'training_status_info': training_status_info
})
# ----- save best parameters -----#
if save_params and is_best_acc1 and (opt.rank == 0 or save_all_ranks):
checkpoint_filename = os.path.join(opt.save_dir, 'best-params_rank{}.pth'.format(opt.rank))
save_checkpoint(checkpoint_filename, {
'epoch': epoch,
'state_dict': model.state_dict(),
'top1_acc': acc1,
'top5_acc': acc5,
'training_status_info': training_status_info
})
pass # end for epoch in range(opt.start_epoch, opt.epochs):
return training_status_info
def main(opt, argv):
assert opt.save_dir is not None
job_done_fn = os.path.join(opt.save_dir, 'train_image_classification.done')
if os.path.isfile(job_done_fn):
print('skip ' + job_done_fn)
return
opt = config_dist_env_and_opt(opt)
# create log
if opt.rank == 0:
log_filename = os.path.join(opt.save_dir, 'train_image_classification.log')
global_utils.create_logging(log_filename=log_filename)
else:
global_utils.create_logging(log_filename=None, level=logging.ERROR)
logging.info('argv=\n' + str(argv))
logging.info('opt=\n' + str(opt))
logging.info('-----')
# load dataset
data_loader_info = DataLoader.get_data(opt, argv)
train_loader = data_loader_info['train_loader']
val_loader = data_loader_info['val_loader']
train_sampler = data_loader_info['train_sampler']
num_train_samples = DataLoader.params_dict[opt.dataset]['num_train_samples']
# create model
model = ModelLoader.get_model(opt, argv)
model = init_model(model, opt, argv)
logging.info('loading model:')
logging.info(str(model))
if opt.load_parameters_from:
model = load_model(model, opt.load_parameters_from, opt.strict_load, map_location='cpu')
if opt.fp16:
model = model.half()
# set device
if opt.gpu is not None:
torch.cuda.set_device(opt.gpu)
model.cuda(opt.gpu)
logging.info('rank={}, using GPU {}'.format(opt.rank, opt.gpu))
# define loss function (criterion)
if (hasattr(opt, 'label_smoothing') and opt.label_smoothing) or (hasattr(opt, 'mixup') and opt.mixup):
criterion = cross_entropy
else:
criterion = nn.CrossEntropyLoss()
if not opt.dist_mode == 'cpu':
criterion = criterion.cuda(opt.gpu)
# get optimizer
optimizer = get_optimizer(model, opt)
logging.info('optimizer is :')
logging.info(str(optimizer))
# hvd and apex
model, optimizer = config_model_optimizer_hvd_and_apex(model, optimizer, opt)
training_status_info = {}
training_status_info['best_acc1'] = 0
training_status_info['best_acc5'] = 0
training_status_info['best_acc1_at_epoch'] = 0
training_status_info['best_acc5_at_epoch'] = 0
training_status_info['training_elasped_time'] = 0
training_status_info['validation_elasped_time'] = 0
map_location = 'cpu'
if opt.gpu is not None:
map_location = 'cuda:{}'.format(opt.gpu)
if opt.auto_resume and opt.resume is None:
latest_pth_fn = os.path.join(opt.save_dir, 'latest-params_rank0.pth')
if os.path.isfile(latest_pth_fn):
logging.info(('auto-resume from ' + latest_pth_fn))
model, optimizer, training_status_info, opt = resume_checkpoint(model, optimizer, latest_pth_fn, opt,
map_location=map_location)
if opt.resume:
assert not opt.auto_resume
logging.info(('resume from ' + opt.resume))
model, optimizer, training_status_info, opt = resume_checkpoint(model, optimizer, opt.resume, opt,
map_location=map_location)
if not opt.evaluate_only:
training_status_info = train_all_epochs(opt, model, optimizer, train_sampler, train_loader, criterion, val_loader,
num_train_samples=num_train_samples,
no_acc_eval=False, save_all_ranks=False,
training_status_info=training_status_info)
else:
validate(val_loader, model, criterion, opt)
# mark job done
global_utils.save_pyobj(job_done_fn, training_status_info)
# # don't forget to release auto-assigned gpu, but this is done via AutoGPU class automatically
# if opt.dist_mode == 'auto':
# global_utils.release_gpu(opt.gpu)
if __name__ == "__main__":
opt = global_utils.parse_cmd_options(sys.argv)
main(opt, sys.argv)