forked from shirgur/PointerNet
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Train.py
107 lines (81 loc) · 3.42 KB
/
Train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python3
"""
Pytorch implementation of Pointer Network.
http://arxiv.org/pdf/1506.03134v1.pdf.
"""
import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import numpy as np
import argparse
from tqdm import tqdm
from PointerNet import PointerNet
from Data_Generator import TSPDataset
parser = argparse.ArgumentParser(description="Pytorch implementation of Pointer-Net")
# Data
parser.add_argument('--train_size', default=1000000, type=int, help='Training data size')
parser.add_argument('--val_size', default=10000, type=int, help='Validation data size')
parser.add_argument('--test_size', default=10000, type=int, help='Test data size')
parser.add_argument('--batch_size', default=256, type=int, help='Batch size')
# Train
parser.add_argument('--nof_epoch', default=50000, type=int, help='Number of epochs')
parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
# GPU
parser.add_argument('--gpu', default=True, action='store_true', help='Enable gpu')
# TSP
parser.add_argument('--nof_points', type=int, default=5, help='Number of points in TSP')
# Network
parser.add_argument('--embedding_size', type=int, default=128, help='Embedding size')
parser.add_argument('--hiddens', type=int, default=512, help='Number of hidden units')
parser.add_argument('--nof_lstms', type=int, default=2, help='Number of LSTM layers')
parser.add_argument('--dropout', type=float, default=0., help='Dropout value')
parser.add_argument('--bidir', default=True, action='store_true', help='Bidirectional')
params = parser.parse_args()
if params.gpu and torch.cuda.is_available():
USE_CUDA = True
print('Using GPU, %i devices.' % torch.cuda.device_count())
else:
USE_CUDA = False
model = PointerNet(params.embedding_size,
params.hiddens,
params.nof_lstms,
params.dropout,
params.bidir)
dataset = TSPDataset(params.train_size,
params.nof_points)
dataloader = DataLoader(dataset,
batch_size=params.batch_size,
shuffle=True,
num_workers=4)
if USE_CUDA:
model.cuda()
net = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
cudnn.benchmark = True
CCE = torch.nn.CrossEntropyLoss()
model_optim = optim.Adam(filter(lambda p: p.requires_grad,
model.parameters()),
lr=params.lr)
losses = []
for epoch in range(params.nof_epoch):
batch_loss = []
iterator = tqdm(dataloader, unit='Batch')
for i_batch, sample_batched in enumerate(iterator):
iterator.set_description('Batch %i/%i' % (epoch+1, params.nof_epoch))
train_batch = Variable(sample_batched['Points'])
target_batch = Variable(sample_batched['Solution'])
if USE_CUDA:
train_batch = train_batch.cuda()
target_batch = target_batch.cuda()
o, p = model(train_batch)
o = o.contiguous().view(-1, o.size()[-1])
target_batch = target_batch.view(-1)
loss = CCE(o, target_batch)
losses.append(loss.item())
batch_loss.append(loss.item())
model_optim.zero_grad()
loss.backward()
model_optim.step()
iterator.set_postfix(loss='{}'.format(loss.item()))
iterator.set_postfix(loss=np.average(batch_loss))