You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
for i in range(20000):
batch=mnist.train.next_batch(50)
if i%100==0:
train_accuracy=accuracy.eval(feed_dict={x:batch[0],y_:batch[1],keep_prob:1})
print('step %d,train accuracy %g'%(i,train_accuracy))
train_step.run([y_conv,y2_conv],feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})
print('test accuracy %g'%accuracy.eval(feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0}))
The text was updated successfully, but these errors were encountered:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
def weight_varible(shape):
initial =tf.truncated_normal(shape,stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial=tf.constant(0.1,shape=shape)
return tf.Variable(initial)
def conv2d(x,W):
return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
sess=tf.InteractiveSession()
################ input,for 2 models i use the same input ################3
x=tf.placeholder(tf.float32,[None,784])
x_image=tf.reshape(x,[-1,28,28,1])
######################### the first model ########################################
W_conv1=weight_varible([5,5,1,32])
b_conv1=weight_varible([32])
h_conv1=tf.nn.relu(conv2d(x_image,W_conv1)+b_conv1)
h_pool1=max_pool_2x2(h_conv1)
W_conv2=weight_varible([5,5,32,64])
b_conv2=weight_varible([64])
h_conv2=tf.nn.relu(conv2d(h_pool1,W_conv2)+b_conv2)
h_pool2=max_pool_2x2(h_conv2)
W_fc1=weight_varible([7764,1024])
b_fc1=weight_varible([1024])
h_pool2_flat=tf.reshape(h_pool2,[-1,7764])
h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1)+b_fc1)
keep_prob=tf.placeholder(tf.float32)
h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)
W_fc2=weight_varible([1024,10])
b_fc2=weight_varible([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2)+b_fc2)
y_=tf.placeholder(tf.float32,[None,10])
cross_entropy=tf.reduce_sum(y_*tf.log(y_conv))
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction=tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
sess.run(tf.initialize_all_variables())
################## the second parallel model,i use the same input as the first ###############################
W2_conv1=weight_varible([5,5,1,32])
b2_conv1=weight_varible([32])
x=tf.placeholder(tf.float32,[None,784])
x_image=tf.reshape(x,[-1,28,28,1])
h2_conv1=tf.nn.relu(conv2d(x_image,W2_conv1)+b2_conv1)
h2_pool1=max_pool_2x2(h2_conv1)
W2_conv2=weight_varible([5,5,32,64])
b2_conv2=weight_varible([64])
h2_conv2=tf.nn.relu(conv2d(h2_pool1,W2_conv2)+b2_conv2)
h2_pool2=max_pool_2x2(h2_conv2)
W2_fc1=weight_varible([7764,1024])
b2_fc1=weight_varible([1024])
h2_pool2_flat=tf.reshape(h2_pool2,[-1,7764])
h2_fc1=tf.nn.relu(tf.matmul(h2_pool2_flat,W2_fc1)+b2_fc1)
keep_prob=tf.placeholder(tf.float32)
h2_fc1_drop=tf.nn.dropout(h2_fc1,keep_prob)
W2_fc2=weight_varible([1024,10])
b2_fc2=weight_varible([10])
y2_=tf.placeholder(tf.float32,[None,10])
y2_conv=tf.nn.softmax(tf.matmul(h2_fc1_drop,W2_fc2)+b2_fc2)
cross_entropy=tf.reduce_sum(y2_*tf.log(y2_conv))
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction=tf.equal(tf.argmax(y2_conv,1),tf.argmax(y2_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
sess.run(tf.initialize_all_variables())
##########################################################
for i in range(20000):
batch=mnist.train.next_batch(50)
if i%100==0:
train_accuracy=accuracy.eval(feed_dict={x:batch[0],y_:batch[1],keep_prob:1})
print('step %d,train accuracy %g'%(i,train_accuracy))
train_step.run([y_conv,y2_conv],feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})
print('test accuracy %g'%accuracy.eval(feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0}))
The text was updated successfully, but these errors were encountered: