-
Notifications
You must be signed in to change notification settings - Fork 3
/
train_rnn.py
96 lines (85 loc) · 4.12 KB
/
train_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#encoding:utf-8
import os
import numpy as np
import tensorflow as tf
import gc
from tensorflow.python.keras.preprocessing import sequence
from pytfeager.io.io import load_imdb
from pytfeager.model.rnn import RNN
from pytfeager.callbacks.model_callbacks import WriteSummary
from pytfeager.callbacks.model_callbacks import EarlyStopping
from pytfeager.callbacks.model_callbacks import ModelCheckpoint
from pytfeager.callbacks.model_callbacks import ProgressBar
from pytfeager.config import rnn_config as config
tf.enable_eager_execution()
tf.set_random_seed(2018)
np.random.seed(2018)
def main(_):
log_path = os.path.join(config.PATH, config.LOG_PATH)
checkpoint_dir = os.path.join(config.PATH, config.CHECKPOINT_PATH)
# 加载数据集
(x_train, y_train), (x_test, y_test) = load_imdb(path=os.path.join(config.PATH,config.DATA_PATH),
maxlen=FLAGS.maxlen,
num_words=FLAGS.num_words)
train_sequence = sequence.pad_sequences(x_train, maxlen=FLAGS.maxlen, dtype='int32',
padding='post', truncating='post', value=0)
test_sequence = sequence.pad_sequences(x_test, maxlen=FLAGS.maxlen, dtype='int32',
padding='post', truncating='post', value=0)
del x_train
del x_test
gc.collect()
buffer_size = len(train_sequence)
n_batch = buffer_size // FLAGS.batch_size
resiud = buffer_size % FLAGS.batch_size
train_dataset = tf.data.Dataset.from_tensor_slices((train_sequence, y_train)).shuffle(10000)
train_dataset = train_dataset.batch(FLAGS.batch_size, drop_remainder=False)
val = [tf.convert_to_tensor(test_sequence), tf.convert_to_tensor(y_test)]
# 优化器
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
# 初始化模型
model = RNN(units=FLAGS.lstm_units,
output_dim=FLAGS.output_dim,
input_length=FLAGS.maxlen,
input_dim=FLAGS.num_words,
trainable=True)
save_kwargs = {
'model': model,
'optimizer': optimizer,
'global_step': tf.train.get_or_create_global_step()
}
# 定义callbacks
write_summary = WriteSummary(log_path=log_path)
model_checkpoint = ModelCheckpoint(checkpoint_dir=checkpoint_dir, mode='min', monitor='val_loss',
save_best_only=False, epoch_per=3,
**save_kwargs)
early_stop = EarlyStopping(mode='min',patience=FLAGS.early_stopping_rounds)
progressbar = ProgressBar(data_size=buffer_size, n_batch=n_batch, batch_size=FLAGS.batch_size, resiud=resiud,
eval_name='acc', loss_name='loss')
# 开始训练模型
model.fit(trainDataset=train_dataset,
valDataset=val,
epochs=FLAGS.epochs,
optimizer=optimizer,
ModelCheckPoint=model_checkpoint,
progressbar=progressbar,
write_summary=write_summary,
early_stopping=early_stop)
if __name__ =='__main__':
tf.app.flags.DEFINE_float('learning_rate', 0.01,
'learning rate value.')
tf.app.flags.DEFINE_integer('batch_size', 128,
'Number of training pairs per iteration.')
tf.app.flags.DEFINE_integer('epochs',20,
'number of training')
tf.app.flags.DEFINE_integer('early_stopping_rounds',10,
'number of early stopping')
tf.app.flags.DEFINE_integer('lstm_units',128,
'cells number of lstm')
tf.app.flags.DEFINE_integer('maxlen',400,
'max length of sequence')
tf.app.flags.DEFINE_integer('output_dim',64,
'output dim of embedding')
tf.app.flags.DEFINE_integer('num_words',10000,
'word number')
FLAGS = tf.app.flags.FLAGS
tf.app.run()