forked from tensorpack/benchmarks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tensorpack.cifar10.py
executable file
·65 lines (54 loc) · 2.47 KB
/
tensorpack.cifar10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: tensorpack.cifar10.py
import tensorflow as tf
from tensorpack import *
class Model(ModelDesc):
def inputs(self):
return [tf.TensorSpec([None, 32, 32, 3], tf.float32, 'input'),
tf.TensorSpec([None], tf.int32, 'label')]
def build_graph(self, image, label):
image = tf.transpose(image, [0, 3, 1, 2])
image = image / 255.0
with argscope(Conv2D, activation=tf.nn.relu, kernel_size=3, padding='VALID'), \
argscope([Conv2D, MaxPooling], data_format='NCHW'):
logits = (LinearWrap(image)
.Conv2D('conv0', 32, padding='SAME')
.Conv2D('conv1', 32)
.MaxPooling('pool0', 2)
.Dropout(rate=0.25)
.Conv2D('conv2', 64, padding='SAME')
.Conv2D('conv3', 64)
.MaxPooling('pool1', 2)
.Dropout(rate=0.25)
.FullyConnected('fc1', 512, activation=tf.nn.relu)
.Dropout(rate=0.5)
.FullyConnected('linear', 10, activation=tf.identity)())
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cost')
wrong = tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, 1)), tf.float32, name='wrong')
tf.reduce_mean(wrong, name='train_error')
# no weight decay
return cost
def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=1e-4, trainable=False)
return tf.train.RMSPropOptimizer(lr, epsilon=1e-8)
def get_data(train_or_test):
isTrain = train_or_test == 'train'
ds = dataset.Cifar10(train_or_test)
ds = BatchData(ds, 32, remainder=not isTrain)
return ds
if __name__ == '__main__':
dataset_train = get_data('train')
dataset_test = get_data('test')
config = TrainConfig(
model=Model(),
data=QueueInput(dataset_train,
queue=tf.FIFOQueue(300, [tf.float32, tf.int32])),
# callbacks=[InferenceRunner(dataset_test, ClassificationError('wrong'))], # skip validation
callbacks=[],
# keras monitor these two live data during training. do it here (no overhead actually)
extra_callbacks=[ProgressBar(['cost', 'train_error']), MergeAllSummaries()],
max_epoch=200,
)
launch_train_with_config(config, SimpleTrainer())