This repository has been archived by the owner on Jun 16, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 230
/
policy.py
208 lines (177 loc) · 9.64 KB
/
policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
'''
Neural network architecture.
The input to the policy network is a 19 x 19 x 48 image stack consisting of
48 feature planes. The first hidden layer zero pads the input into a 23 x 23
image, then convolves k filters of kernel size 5 x 5 with stride 1 with the
input image and applies a rectifier nonlinearity. Each of the subsequent
hidden layers 2 to 12 zero pads the respective previous hidden layer into a
21 x 21 image, then convolves k filters of kernel size 3 x 3 with stride 1,
again followed by a rectifier nonlinearity. The final layer convolves 1 filter
of kernel size 1 x 1 with stride 1, with a different bias for each position,
and applies a softmax function. The match version of AlphaGo used k = 192
filters; Fig. 2b and Extended Data Table 3 additionally show the results
of training with k = 128, 256 and 384 filters.
The input to the value network is also a 19 x 19 x 48 image stack, with an
additional binary feature plane describing the current colour to play.
Hidden layers 2 to 11 are identical to the policy network, hidden layer 12
is an additional convolution layer, hidden layer 13 convolves 1 filter of
kernel size 1 x 1 with stride 1, and hidden layer 14 is a fully connected
linear layer with 256 rectifier units. The output layer is a fully connected
linear layer with a single tanh unit.
'''
import math
import os
import sys
import tensorflow as tf
import features
import go
import utils
EPSILON = 1e-35
class PolicyNetwork(object):
def __init__(self, features=features.DEFAULT_FEATURES, k=32, num_int_conv_layers=3, use_cpu=False):
self.num_input_planes = sum(f.planes for f in features)
self.features = features
self.k = k
self.num_int_conv_layers = num_int_conv_layers
self.test_summary_writer = None
self.training_summary_writer = None
self.test_stats = StatisticsCollector()
self.training_stats = StatisticsCollector()
self.session = tf.Session()
if use_cpu:
with tf.device("/cpu:0"):
self.set_up_network()
else:
self.set_up_network()
def set_up_network(self):
# a global_step variable allows epoch counts to persist through multiple training sessions
global_step = tf.Variable(0, name="global_step", trainable=False)
x = tf.placeholder(tf.float32, [None, go.N, go.N, self.num_input_planes])
y = tf.placeholder(tf.float32, shape=[None, go.N ** 2])
#convenience functions for initializing weights and biases
def _weight_variable(shape, name):
# If shape is [5, 5, 20, 32], then each of the 32 output planes
# has 5 * 5 * 20 inputs.
number_inputs_added = utils.product(shape[:-1])
stddev = 1 / math.sqrt(number_inputs_added)
# http://neuralnetworksanddeeplearning.com/chap3.html#weight_initialization
return tf.Variable(tf.truncated_normal(shape, stddev=stddev), name=name)
def _conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding="SAME")
# initial conv layer is 5x5
W_conv_init = _weight_variable([5, 5, self.num_input_planes, self.k], name="W_conv_init")
h_conv_init = tf.nn.relu(_conv2d(x, W_conv_init), name="h_conv_init")
# followed by a series of 3x3 conv layers
W_conv_intermediate = []
h_conv_intermediate = []
_current_h_conv = h_conv_init
for i in range(self.num_int_conv_layers):
with tf.name_scope("layer"+str(i)):
W_conv_intermediate.append(_weight_variable([3, 3, self.k, self.k], name="W_conv"))
h_conv_intermediate.append(tf.nn.relu(_conv2d(_current_h_conv, W_conv_intermediate[-1]), name="h_conv"))
_current_h_conv = h_conv_intermediate[-1]
W_conv_final = _weight_variable([1, 1, self.k, 1], name="W_conv_final")
b_conv_final = tf.Variable(tf.constant(0, shape=[go.N ** 2], dtype=tf.float32), name="b_conv_final")
h_conv_final = _conv2d(h_conv_intermediate[-1], W_conv_final)
logits = tf.reshape(h_conv_final, [-1, go.N ** 2]) + b_conv_final
output = tf.nn.softmax(logits)
log_likelihood_cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y))
train_step = tf.train.AdamOptimizer(1e-4).minimize(log_likelihood_cost, global_step=global_step)
was_correct = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(was_correct, tf.float32))
weight_summaries = tf.summary.merge([
tf.summary.histogram(weight_var.name, weight_var)
for weight_var in [W_conv_init] + W_conv_intermediate + [W_conv_final, b_conv_final]],
name="weight_summaries"
)
activation_summaries = tf.summary.merge([
tf.summary.histogram(act_var.name, act_var)
for act_var in [h_conv_init] + h_conv_intermediate + [h_conv_final]],
name="activation_summaries"
)
saver = tf.train.Saver()
# save everything to self.
for name, thing in locals().items():
if not name.startswith('_'):
setattr(self, name, thing)
def initialize_logging(self, tensorboard_logdir):
self.test_summary_writer = tf.summary.FileWriter(os.path.join(tensorboard_logdir, "test"), self.session.graph)
self.training_summary_writer = tf.summary.FileWriter(os.path.join(tensorboard_logdir, "training"), self.session.graph)
def initialize_variables(self, save_file=None):
self.session.run(tf.global_variables_initializer())
if save_file is not None:
self.saver.restore(self.session, save_file)
def get_global_step(self):
return self.session.run(self.global_step)
def save_variables(self, save_file):
if save_file is not None:
print("Saving checkpoint to %s" % save_file, file=sys.stderr)
self.saver.save(self.session, save_file)
def train(self, training_data, batch_size=32):
num_minibatches = training_data.data_size // batch_size
for i in range(num_minibatches):
batch_x, batch_y = training_data.get_batch(batch_size)
_, accuracy, cost = self.session.run(
[self.train_step, self.accuracy, self.log_likelihood_cost],
feed_dict={self.x: batch_x, self.y: batch_y})
self.training_stats.report(accuracy, cost)
avg_accuracy, avg_cost, accuracy_summaries = self.training_stats.collect()
global_step = self.get_global_step()
print("Step %d training data accuracy: %g; cost: %g" % (global_step, avg_accuracy, avg_cost))
if self.training_summary_writer is not None:
activation_summaries = self.session.run(
self.activation_summaries,
feed_dict={self.x: batch_x, self.y: batch_y})
self.training_summary_writer.add_summary(activation_summaries, global_step)
self.training_summary_writer.add_summary(accuracy_summaries, global_step)
def run(self, position):
'Return a sorted list of (probability, move) tuples'
processed_position = features.extract_features(position, features=self.features)
probabilities = self.session.run(self.output, feed_dict={self.x: processed_position[None, :]})[0]
return probabilities.reshape([go.N, go.N])
def check_accuracy(self, test_data, batch_size=128):
num_minibatches = test_data.data_size // batch_size
weight_summaries = self.session.run(self.weight_summaries)
for i in range(num_minibatches):
batch_x, batch_y = test_data.get_batch(batch_size)
accuracy, cost = self.session.run(
[self.accuracy, self.log_likelihood_cost],
feed_dict={self.x: batch_x, self.y: batch_y})
self.test_stats.report(accuracy, cost)
avg_accuracy, avg_cost, accuracy_summaries = self.test_stats.collect()
global_step = self.get_global_step()
print("Step %s test data accuracy: %g; cost: %g" % (global_step, avg_accuracy, avg_cost))
if self.test_summary_writer is not None:
self.test_summary_writer.add_summary(weight_summaries, global_step)
self.test_summary_writer.add_summary(accuracy_summaries, global_step)
class StatisticsCollector(object):
'''
Accuracy and cost cannot be calculated with the full test dataset
in one pass, so they must be computed in batches. Unfortunately,
the built-in TF summary nodes cannot be told to aggregate multiple
executions. Therefore, we aggregate the accuracy/cost ourselves at
the python level, and then shove it through the accuracy/cost summary
nodes to generate the appropriate summary protobufs for writing.
'''
graph = tf.Graph()
with tf.device("/cpu:0"), graph.as_default():
accuracy = tf.placeholder(tf.float32, [])
cost = tf.placeholder(tf.float32, [])
accuracy_summary = tf.summary.scalar("accuracy", accuracy)
cost_summary = tf.summary.scalar("log_likelihood_cost", cost)
accuracy_summaries = tf.summary.merge([accuracy_summary, cost_summary], name="accuracy_summaries")
session = tf.Session(graph=graph)
def __init__(self):
self.accuracies = []
self.costs = []
def report(self, accuracy, cost):
self.accuracies.append(accuracy)
self.costs.append(cost)
def collect(self):
avg_acc = sum(self.accuracies) / len(self.accuracies)
avg_cost = sum(self.costs) / len(self.costs)
self.accuracies = []
self.costs = []
summary = self.session.run(self.accuracy_summaries,
feed_dict={self.accuracy:avg_acc, self.cost: avg_cost})
return avg_acc, avg_cost, summary