-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
51 lines (42 loc) · 1.99 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# from hourglass import *
import tensorflow as tf
import numpy as np
from DataManager.manager import get_batch
import model
def train_model(batch_size, iterations, load=True):
# Create model
step_size = tf.placeholder(tf.float32, name="stepsize")
input = tf.placeholder(tf.float32, name="input", shape=(None, 200, 200, 3))
labels = tf.placeholder(tf.float32, name="labels", shape=(None, 200, 200, 200))
hourglass_model = model.get_model(input, name='hourglass')
cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=hourglass_model, labels=labels)
loss = tf.reduce_mean(cross_entropy, name= "cross_entropy_loss")
adam_step = tf.train.AdamOptimizer(step_size, name="optimizer").minimize(loss)
saver = tf.train.Saver()
#Overfit to only this batch for now
images, voxels = get_batch(batch_size)
with tf.Session() as sess:
if load:
input, labels, hourglass_model, loss, step_size, adam_step = model.load_model(sess=sess)
print("Successfully loaded saved file")
else:
sess.run(tf.global_variables_initializer())
print("Intialized model with random variables")
for i in range(iterations):
# images, voxels = get_batch(batch_size)
feed_dict = {input: images, labels: voxels, step_size: 10**(-iterations//1000+1)}
try:
train_step = adam_step
err, _ = sess.run([loss, train_step], feed_dict=feed_dict)
if i % 10 == 0:
print("Loss: %i, %f " % (i, err))
except ValueError as e:
print("Random error optimizing, don't know what's wrong. Just skipping this epoch.\n%s" % str(e))
continue
#save our sess every 100 iterations
if (i % 20 == 0):
saver.save(sess, './models/chkpt')
return hourglass_model
if __name__ == "__main__":
# model_path = "hourglass_util/"
train_model(batch_size=10, iterations=5000, load=False)