-
Notifications
You must be signed in to change notification settings - Fork 30
/
trainer.py
213 lines (175 loc) · 9.35 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
"""
Author: Mohamed K. Eid ([email protected])
Description: trainer class for training a new generative model
"""
import logging
import os
import time
import urllib
import zipfile
import numpy as np
import tensorflow as tf
import custom_vgg16 as vgg16
import helpers
class Trainer:
def __init__(self, session, net, train_path, train_dims, print_training_status=True, print_every_n=100):
self.current_path = os.path.dirname(os.path.realpath(__file__))
self.net = net
self.paths = {
'out_dir': self.current_path + '/../output/',
'style_file': train_path,
'trained_generators_dir': self.current_path + '/../lib/generators/',
'training_dir': self.current_path + '/../lib/images/train2014/',
'training_url': 'http://msvocds.blob.core.windows.net/coco2014/train2014.zip'
}
self.session = session
self.train_height = train_dims['height']
self.train_width = train_dims['width']
self.print_training_status = print_training_status
self.train_n = print_every_n
def train(self, epochs, learning_rate, content_layer, content_weight, style_layers, style_weight, tv_weight, retrain=False):
# Check if there is training data available and initialize generator network
self.__check_for_examples()
# Initialize and process images and placeholders to be used for our descriptors
art, art_shape = helpers.load_img_to(self.paths['style_file'], height=self.train_height, width=self.train_width)
art_shape = [1] + art_shape
art = art.reshape(art_shape).astype(np.float32)
# Generator Network ops
variable_placeholder = tf.placeholder(dtype=tf.float32, shape=art_shape)
self.net.build(variable_placeholder)
variable_img = self.net.output
# VGG Network ops
with tf.name_scope('vgg_style'):
style_model = vgg16.Vgg16()
style_model.build(art, shape=art_shape[1:])
with tf.name_scope('vgg_content'):
content_placeholder = tf.placeholder(dtype=tf.float32, shape=art_shape)
content_model = vgg16.Vgg16()
content_model.build(content_placeholder, shape=art_shape[1:])
with tf.name_scope('vgg_variable'):
variable_model = vgg16.Vgg16()
variable_model.build(variable_img, shape=art_shape[1:])
# Continue from a pretrained model
if retrain:
print("tryue")
name = os.path.basename(self.paths['style_file']).replace('.jpg', '')
saver = tf.train.Saver()
saver.restore(self.session, "%s/%s/%s" % (self.paths['trained_generators_dir'], name, name))
# Loss ops
with tf.name_scope('loss'):
if content_weight is 0:
content_loss = tf.constant(0.)
else:
content_loss = helpers.get_content_loss(variable_model, content_model, content_layer) * content_weight
if style_weight is 0:
style_loss = tf.constant(0.)
else:
style_loss = helpers.get_style_loss(variable_model, style_model, style_layers) * style_weight
if tv_weight is 0:
tv_loss = tf.constant(0.)
else:
tv_loss = helpers.get_total_variation(variable_img, art_shape) * tv_weight
total_loss = content_loss + style_loss + tv_loss
# Optimization ops
with tf.name_scope('optimization'):
optimizer = tf.train.AdamOptimizer(learning_rate)
trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="generator")
grads = optimizer.compute_gradients(total_loss, trainable_vars)
update_weights = optimizer.apply_gradients(grads)
# Populate the training data
logging.info("Initializing session and loading training images..")
example = self.__next_example(height=art_shape[1], width=art_shape[2])
self.session.run(tf.local_variables_initializer())
self.session.run(tf.global_variables_initializer())
# Initialize threads and begin training
logging.info("Begining training..")
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
start_time = time.time()
for i in range(epochs):
self.net.is_training = True
# Get next training image from batch and reshape it to include a batch size of 1
training_img = self.session.run(example) / 255.
training_img = training_img.reshape([1] + list(training_img.shape)).astype(np.float32)
# Initialize new feed dict for the training iteration and invoke the update op
feed_dict = {variable_placeholder: training_img, content_placeholder: training_img}
_, loss = self.session.run([update_weights, total_loss], feed_dict=feed_dict)
if self.print_training_status and i % self.train_n == 0:
logging.info("Epoch %06d | Loss %.06f" % (i, loss))
in_path = self.current_path + '/../lib/images/content/nyc.jpg'
input_img, input_shape = helpers.load_img_to(in_path, height=self.train_height, width=self.train_width)
input_img = input_img.reshape([1] + input_shape).astype(np.float32)
path_out = self.current_path + '/../output/' + str(start_time) + '.jpg'
img = self.session.run(variable_img, feed_dict={variable_placeholder: input_img})
helpers.render(img, path_out=path_out)
# Alert that training has been completed and print the run time
elapsed = time.time() - start_time
logging.info("Training complete. The session took %.2f seconds to complete." % elapsed)
coord.request_stop()
coord.join(threads)
self.__save_model(trainable_vars)
# Checks for training data to see if it's missing or not. Asks to download if missing.
def __check_for_examples(self):
# Asks on stdout to download MSCOCO data. Downloads if response is 'y'
def ask_to_download():
logging.info("You've requested to train a new model. However, you've yet to download the training data.")
answer = 0
while answer is not 'y' and answer is not 'N':
answer = input("Would you like to download the 13 GB file? [y/N] ").replace(" ", "")
# Download weights if yes, else exit the program
if answer == 'y':
logging.info("Downloading from %s. Please be patient..." % self.paths['training_url'])
zip_save_path = self.current_path + '/../lib/images/train2014.zip'
urllib.request.urlretrieve(self.paths['training_url'], zip_save_path)
ask_to_unzip(zip_save_path)
elif answer == 'N':
self.__exit()
# Asks on stdout to unzip a given zip file path. Unizips if response is 'y'
def ask_to_unzip(path):
answer = 0
while answer is not 'y' and answer is not 'N':
answer = input("The application requires the file to be unzipped. Unzip? [y/N] ").replace(" ", "")
if answer == 'y':
if not os.path.isdir(self.paths['training_dir']):
os.makedirs(self.paths['training_dir'])
logging.info("Unzipping file..")
zip_ref = zipfile.ZipFile(path, 'r')
zip_ref.extractall(self.current_path + '/../lib/')
zip_ref.close()
os.remove(path)
else:
self.__exit(0, message="Please unzip the program manually to run the program. Exiting..")
# Ask to unzip training data if a previous attempt was made
zip_path = os.path.abspath(self.current_path + '/../lib/images/train2014.zip')
if os.path.isfile(zip_path):
ask_to_unzip(zip_path)
# Ask to download training data if the training dir does not exist or does not contain the needed files
if not os.path.isdir(self.paths['training_dir']):
ask_to_download()
else:
training_files = os.listdir(self.paths['training_dir'])
num_training_files = len(training_files)
if num_training_files <= 1:
ask_to_download()
# Returns a new training example
def __next_example(self, height, width):
filenames = tf.train.match_filenames_once(self.paths['training_dir'] + '*.jpg')
filename_queue = tf.train.string_input_producer(filenames)
reader = tf.WholeFileReader()
_, files = reader.read(filename_queue)
training_img = tf.image.decode_jpeg(files, channels=3)
training_img = tf.image.resize_images(training_img, [height, width])
return training_img
# Saves the weights with the name of the references style so that the net may stylize future images
def __save_model(self, variables):
logging.info("Proceeding to save weights..")
name = os.path.basename(self.paths['style_file']).replace('.jpg', '')
gen_dir = self.paths['trained_generators_dir'] + name + '/'
if not os.path.isdir(gen_dir):
os.makedirs(gen_dir)
saver = tf.train.Saver(variables)
saver.save(self.session, gen_dir + name)
def __exit(self, rc=0, message="Exiting the program.."):
logging.info(message)
self.session.close()
exit(rc)