-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_yolo.py
48 lines (39 loc) · 1.7 KB
/
train_yolo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import chainer
from chainer.training import extensions
from tiny_model import YOLOv2, YOLOv2Predictor
from image_dataset import DatasetYOLO
n_classes_fcn = 7
n_classes_yolo = 2
n_boxes = 5
gpu = 0
epoch = 100
batchsize = 3
out_path = 'result/yolo-x5/'
initial_weight_file = 'result/fcn-x4/final.npz'
weight_decay = 1e-5
test = False
snapshot_interval = 100
yolov2 = YOLOv2(n_classes_fcn=n_classes_fcn, n_classes_yolo=n_classes_yolo, n_boxes=n_boxes)
model = YOLOv2Predictor(yolov2, FCN=False)
if initial_weight_file:
chainer.serializers.load_npz(initial_weight_file, model)
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
optimizer.use_cleargrads()
#optimizer.add_hook(chainer.optimizer.WeightDecay(weight_decay))
train_data = DatasetYOLO()
train_iter = chainer.iterators.SerialIterator(train_data, batchsize)
if test:
test_data = DatasetYOLO()
test_iter = chainer.iterators.SerialIterator(test_data, batchsize, repeat=False, shuffle=False)
updater = chainer.training.StandardUpdater(train_iter, optimizer, device=gpu)
trainer = chainer.training.Trainer(updater, (epoch, 'epoch'), out=out_path)
if test:
trainer.extend(extensions.Evaluator(test_iter, model, device=gpu))
trainer.extend(extensions.LogReport(), trigger=(1, 'epoch'))
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/x_loss', 'main/y_loss', 'main/w_loss', 'main/h_loss', 'main/c_loss', 'main/p_loss', 'elapsed_time']), trigger=(1, 'epoch'))
#trainer.extend(extensions.snapshot_object(model, 'model_snapshot_yolo_{.updater.epoch}'), trigger=(snapshot_interval, 'epoch'))
trainer.extend(extensions.ProgressBar(update_interval=10))
trainer.run()
chainer.serializers.save_npz(out_path + 'final.npz', model)