-
Notifications
You must be signed in to change notification settings - Fork 4
/
language_task.py
executable file
·374 lines (307 loc) · 13.3 KB
/
language_task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
import argparse
import time
import math
import torch
import torch.nn as nn
import torch.optim as optim
from utils import select_network,select_optimizer
import numpy as np
import os
import pickle
from datetime import datetime
class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = []
def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
return self.word2idx[word]
def __len__(self):
return len(self.idx2word)
class Corpus(object):
def __init__(self, path):
self.dictionary = Dictionary()
self.train = self.tokenize(os.path.join(path, 'train.txt'))
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
self.test = self.tokenize(os.path.join(path, 'test.txt'))
def tokenize(self, path):
"""Tokenizes a text file."""
assert os.path.exists(path)
# Add words to the dictionary
with open(path, 'r') as f:
tokens = 0
for line in f:
words = line.split() + ['<eos>']
tokens += len(words)
for word in words:
self.dictionary.add_word(word)
# Tokenize file content
with open(path, 'r') as f:
ids = torch.LongTensor(tokens)
token = 0
for line in f:
words = line.split() + ['<eos>']
for word in words:
ids[token] = self.dictionary.word2idx[word]
token += 1
return ids
class RNNModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, rnn, ntoken, ninp, nhid, tie_weights=False):
super(RNNModel, self).__init__()
self.encoder = nn.Embedding(ntoken, ninp)
self.rnn = rnn
self.decoder = nn.Linear(nhid, ntoken)
print(rnn)
# self.params = rnn.params + [self.encoder.weight,self.decoder.weight,self.decoder.bias]
# Optionally tie weights as in:
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
# https://arxiv.org/abs/1608.05859
# and
# "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
# https://arxiv.org/abs/1611.01462
if tie_weights:
if nhid != ninp:
raise ValueError('When using the tied flag, nhid must be equal to emsize')
self.decoder.weight = self.encoder.weight
self.init_weights()
self.nhid = nhid
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, input,hidden):
emb = self.encoder(input)
hs= []
for i in range(emb.shape[0]):
hidden = self.rnn(emb[i], hidden)
hs.append(hidden)
output = torch.stack(hs)
decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden
parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model')
parser.add_argument('--net-type', type=str, default='nnRNN',
help='rnn net type')
parser.add_argument('--emsize', type=int, default=200,
help='size of word embeddings')
parser.add_argument('--nhid', type=int, default=1024,
help='number of hidden units per layer')
parser.add_argument('--epochs', type=int, default=100,
help='upper epoch limit')
parser.add_argument('--bptt', type=int, default=150,
help='sequence length')
parser.add_argument('--cuda', action='store_true', default=False, help='use cuda')
parser.add_argument('--tied', action='store_true', default=False, help='For tie weights')
parser.add_argument('--random-seed', type=int, default=400,
help='random seed')
parser.add_argument('--batch', type=int, default=128, metavar='N',
help='batch size')
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
help='report interval')
parser.add_argument('--save', type=str, default='model.pt',
help='path to save the final model')
parser.add_argument('--lr', type=float, default=0.0008)
parser.add_argument('--lr_orth', type=float, default=8e-5)
parser.add_argument('--rinit', type=str, default="cayley",
choices=['random', 'cayley', 'henaff', 'xavier'],
help='recurrent weight matrix initialization')
parser.add_argument('--iinit', type=str, default="kaiming",
choices=['xavier', 'kaiming'],
help='input weight matrix initialization' )
parser.add_argument('--nonlin', type=str, default='modrelu',
choices=['none','modrelu', 'tanh', 'relu', 'sigmoid'],
help='non linearity none, relu, tanh, sigmoid')
parser.add_argument('--alam', type=float, default=1, help='decay for gamma values nnRNN')
parser.add_argument('--Tdecay', type=float,
default=0.0001, help='weight decay on upper T')
parser.add_argument('--optimizer', type=str, default='RMSprop', choices=['RMSprop', 'Adam'])
parser.add_argument('--alpha', type=float, default=0.9)
parser.add_argument('--betas', type=float, default=(0.0, 0.9), nargs='+')
args = parser.parse_args()
# Set the random seed manually for reproducibility.
np.random.seed(args.random_seed)
torch.manual_seed(args.random_seed)
if torch.cuda.is_available():
if not args.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
else:
torch.cuda.manual_seed(args.random_seed)
###############################################################################
# Load data
###############################################################################
corpus = Corpus('./data/pennchar/')
def batchify(data, bsz):
# Work out how cleanly we can divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * bsz)
# Evenly divide the data across the bsz batches.
data = data.view(bsz, -1).t().contiguous()
if args.cuda:
data = data.cuda()
return data
eval_batch_size = 10
train_data = batchify(corpus.train, args.batch)
val_data = batchify(corpus.valid, eval_batch_size)
test_data = batchify(corpus.test, eval_batch_size)
###############################################################################
# Build the model
###############################################################################
ntokens = len(corpus.dictionary)
NET_TYPE = args.net_type
inp_size = args.emsize
hid_size = args.nhid
alam = args.alam
CUDA = args.cuda
nonlin = args.nonlin
rnn = select_network(args, inp_size)
model = RNNModel(rnn, ntokens, inp_size, hid_size, args.tied)
if args.cuda:
model.cuda()
print('Language Task')
print(NET_TYPE)
print(args)
criterion = nn.CrossEntropyLoss()
###############################################################################
# Training code
###############################################################################
def get_batch(source, i, evaluation=False):
seq_len = min(args.bptt, len(source) - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len].view(-1)
return data, target
def evaluate(data_source):
# Turn on evaluation mode which disables dropout.
model.eval()
total_loss = 0
ntokens = len(corpus.dictionary)
hidden = None
correct = 0
processed = 0
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i, evaluation=True)
if i == 0 and NET_TYPE == 'LSTM':
model.rnn.init_states(data.shape[1])
output, hidden = model(data, hidden)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).item()
correct += torch.eq(torch.argmax(output_flat,dim=1),targets).sum().item()
processed += targets.shape[0]
hidden = hidden.detach()
if NET_TYPE == 'LSTM':
model.rnn.ct = model.rnn.ct.detach()
return total_loss / len(data_source), correct/processed
def train(optimizer, orthog_optimizer):
# Turn on training mode which enables dropout.
model.train()
total_loss = 0
start_time = time.time()
ntokens = len(corpus.dictionary)
hidden = None
losses = []
bpcs = []
for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)):
data, targets = get_batch(train_data, i)
# Starting each batch, we detach the hidden state from how it was previously produced.
# If we didn't, the model would try backpropagating all the way to start of the dataset.
if hidden is not None:
hidden = hidden.detach()
if NET_TYPE == 'LSTM':
model.rnn.ct = model.rnn.ct.detach()
elif NET_TYPE == 'nnRNN':
model.rnn.calc_rec()
if i == 0 and NET_TYPE == 'LSTM':
model.rnn.init_states(data.shape[1])
model.zero_grad()
output, hidden = model(data, hidden)
loss_act = criterion(output.view(-1, ntokens), targets)
loss = loss_act
if NET_TYPE == 'nnRNN' and alam > 0:
alpha_loss = model.rnn.alpha_loss(alam)
loss += alpha_loss
loss.backward()
if orthog_optimizer:
model.rnn.orthogonal_step(orthog_optimizer)
optimizer.step()
total_loss += loss_act.item()
if batch % args.log_interval == 0 and batch > 0:
cur_loss = total_loss / args.log_interval
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format(
epoch, batch, len(train_data) // args.bptt, lr,
elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss),cur_loss / math.log(2)))
losses.append(cur_loss)
bpcs.append(cur_loss/math.log(2))
total_loss = 0
start_time = time.time()
return np.mean(losses)
# Loop over epochs.
lr = args.lr
decay = args.Tdecay
best_val_loss = None
optimizer, orthog_optimizer = select_optimizer(model, args)
scheduler = optim.lr_scheduler.StepLR(optimizer,1,gamma=0.5)
if orthog_optimizer:
orthog_scheduler = optim.lr_scheduler.StepLR(orthog_optimizer,1,gamma=0.5)
# At any point you can hit Ctrl + C to break out of training early.
try:
exp_time = "{0:%Y-%m-%d}_{0:%H-%M-%S}".format(datetime.now())
SAVEDIR = os.path.join('./saves',
'sMNIST',
NET_TYPE,
str(args.random_seed),
exp_time)
if not os.path.exists(SAVEDIR):
os.makedirs(SAVEDIR)
with open(SAVEDIR + 'hparams.txt','w') as fp:
for key,val in args.__dict__.items():
fp.write(('{}: {}'.format(key,val)))
tr_losses = []
v_losses = []
v_accs = []
for epoch in range(1, args.epochs+1):
epoch_start_time = time.time()
loss = train(optimizer, orthog_optimizer)
tr_losses.append(loss)
val_loss, val_acc = evaluate(val_data)
v_losses.append(val_loss)
v_accs.append(val_acc)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
'valid ppl {:8.2f} | valid bpc {:8.3f} | valid accuracy {:5.2f} '.format(epoch, (time.time() - epoch_start_time),
val_loss, math.exp(val_loss),val_loss / math.log(2), val_acc))
print('-' * 89)
# Save the model if the validation loss is the best we've seen so far.
if not best_val_loss or val_loss < best_val_loss:
with open(SAVEDIR + args.save, 'wb') as f:
torch.save(model,f)
best_val_loss = val_loss
else:
# Anneal the learning rate if no improvement has been seen in the validation dataset.
scheduler.step()
if orthog_optimizer:
orthog_scheduler.step()
with open(SAVEDIR + '{}_Train_Losses'.format(NET_TYPE), 'wb') as fp:
pickle.dump(tr_losses,fp)
with open(SAVEDIR + '{}_Val_Losses'.format(NET_TYPE),'wb') as fp:
pickle.dump(v_losses,fp)
with open(SAVEDIR + '{}_Val_Accs'.format(NET_TYPE),'wb') as fp:
pickle.dump(v_accs,fp)
except KeyboardInterrupt:
print('-' * 89)
print('Exiting from training early')
# Load the best saved model.
with open(SAVEDIR + args.save, 'rb') as f:
model = torch.load(f)
# Run on test data.
test_loss, test_accuracy = evaluate(test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}, test bpc {:8.2f} | test acc {:8.2f}'.format(
test_loss, math.exp(test_loss), test_loss / math.log(2), test_accuracy))
print('=' * 89)
with open(SAVEDIR + 'testdat.txt', 'w') as fp:
fp.write('Test loss: {} Test Accuracy: {}'.format(test_loss, test_accuracy))