-
Notifications
You must be signed in to change notification settings - Fork 8
/
main.py
106 lines (90 loc) · 4.57 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import numpy as np
import scipy.misc
import os
import sys
from model import ECGAN
from utils import pp
import tensorflow as tf
flags = tf.app.flags
# flags.DEFINE_integer("momentum_decay_steps", 100,
# "change after 100 iterations of inner loop of G")
# flags.DEFINE_float("momentum_decay_rate", 1.17, "factor of change in momentum")
flags.DEFINE_integer("epoch_pretrain", 1000, "Epoch to train [25]")
flags.DEFINE_integer("epoch_policy", 2000000, "Epochs for policy gradient")
flags.DEFINE_float("learning_rate_D", 0.0002,
"Learning rate of for adam [0.0002]")
flags.DEFINE_float("learning_rate_G", 0.0002,
"Learning rate of for adam [0.0002]")
flags.DEFINE_float("beta1D", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_float("beta1G", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_integer("decay_step", 5000000, "Decay step of learning rate in epochs")
flags.DEFINE_float("decay_rate", 0.8, "Decay rate of learning rate")
flags.DEFINE_float("eps", 1e-5, "Epsilon")
flags.DEFINE_float("var", 1e-5, "Variance")
flags.DEFINE_float("gpu_frac", 0.35, "Gpu fraction")
flags.DEFINE_integer("no_of_samples", 50,
"no of samples for each noise vector Z during policy gradient")
flags.DEFINE_boolean("teacher_forcing", False,
"True if teacher forcing is enabled")
flags.DEFINE_boolean("label_to_disc", True,
"True if labels are passed to the discriminator")
flags.DEFINE_boolean("conditional", True,
"True if want to train conditional GAN")
flags.DEFINE_integer("pre_train_iters", 2000,
"Number of iterations to pre-train D")
flags.DEFINE_integer("num_keypoints", 68,
"Number of keypoints extracted in the face")
flags.DEFINE_float("lam", 0.1,
"lam for impainting")
dataset = "celebA"
comment ="model_weights"
flags.DEFINE_float(
"margin", 0.3, "Threshold to judge stopping of D and G nets training")
flags.DEFINE_boolean("margin_restriction", True,
"whether to use margin restriction to stop D or G nets")
flags.DEFINE_boolean("policy_train", True,
"Whether to use PolicyGan training procedure")
flags.DEFINE_string("dataset", dataset,
"The name of dataset [celebA, mnist, lsun]")
if dataset == 'celebA':
flags.DEFINE_string("data_dir", "data/",
"Directory name containing the dataset [data]")
else:
flags.DEFINE_string("data_dir", "data/" + dataset,
"Directory name containing the dataset [data]")
flags.DEFINE_string("checkpoint_dir", "checkpoint/" + dataset + "/" + comment,
"Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "samples/" + dataset,
"Directory name to save the image samples [samples]")
flags.DEFINE_string("log_dir", "logs/" + dataset + "/" + comment,
"Directory name to save the logs [logs]")
flags.DEFINE_boolean("load_chkpt", False, "True for loading saved checkpoint")
flags.DEFINE_boolean("inc_score", False, "True for computing inception score")
flags.DEFINE_boolean("gauss_noise", False, "True for adding noise to disc input")
flags.DEFINE_boolean("flip_label", False, "True for flipping the labels")
flags.DEFINE_boolean("error_conceal", False, "True for flipping the labels")
flags.DEFINE_boolean("siamese_net", False, "True for flipping the labels")
flags.DEFINE_boolean("use_tfrecords", True, "True for running error concealment part")
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
flags.DEFINE_integer("z_dim", 100, "Dimension of latent vector.")
flags.DEFINE_integer("sampleInterval", 500, "Dimension of latent vector.")
flags.DEFINE_integer("saveInterval", 2500, "Dimension of latent vector.")
flags.DEFINE_integer("c_dim", 3, "Number of channels in input image")
flags.DEFINE_boolean("is_grayscale", False, "True for grayscale image")
flags.DEFINE_integer("output_size", 64, "True for grayscale image")
FLAGS = flags.FLAGS
def main(_):
pp.pprint(flags.FLAGS.__flags)
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
gpu_options = tf.GPUOptions(
per_process_gpu_memory_fraction=FLAGS.gpu_frac)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config = tf.ConfigProto(gpu_options=gpu_options)) as sess:
dcgan = ECGAN(sess)
dcgan.train()
if __name__ == '__main__':
tf.app.run()