-
Notifications
You must be signed in to change notification settings - Fork 1
/
Training.py
361 lines (219 loc) · 12.1 KB
/
Training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
# coding: utf-8
# In[ ]:
""" This file is for training the Cycle-GAN network developed to transform a given RGB image to the corresponding depth image
The code has been explained in verbose for better understanding of the reader/user.
"""
"""
DataLoader part of the code has been taken from Pytorch website: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
"""
# In[2]:
import itertools
import torch
from torch.utils.data import DataLoader
from torch.autograd import Variable
from PIL import Image
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import os
import numpy as np
# In[3]:
# The fuctions and classes being imported below are customized to our needs. Their implementations can be found in their corresponding files
# Generator and Discriminators are the classes developed
# LambdaLR is the function developed for using decaying learning rate for the optimizer
# save_samples saves a sample while training for every given number of epochs to check the performance of the network
# weight_init_normal initializes the weights of the filter with numbers generated by normal distribution. Range is given in the file
# scale normalizes the given image to the range 1 to -1. 0 to 1 would also work and you can change the range by while calling the function.
# Look at function definition for more information
# createArchitecture class creates the cycleGAN architecture using the generator and discriminator classes developed
# getDataLoader takes in the path of the dataset and creates an iterable dataset object. Transforms applied on the dataset
# are hardcoded in the getDataLoader class. check the class for more information
# In[4]:
from model import Generator
from model import Discriminator
from helper import LambdaLR,real_mse_loss,fake_mse_loss,cycle_consistency_loss
from helper import save_samples,checkpoint,weights_init_normal,scale
from getDataLoader import getDataLoader
from createArchitecture import createArchitecture
# In[5]:
# Since we are using RGB images as our input, input number of channels are three
# We have taken a batch_size of 16. Make sure the batch_size is always a multiple of 16 to utilise all the streaming multiprocessors in the GPU
# Learning rate assumed is 0.001 because it has shown good performance
# decay_epoch is the epoch at which the learning rate starts to decay so that when the error moves towards
# the global minima, it doesn't overshoot
# image size is 128 as it would train faster on our GPU. Size can be changed based on the availability of the GPU
# Number of epochs is high because its a generative model and requires intense training to give good output.
# Make sure the network doesn't overtrain
# beta1 and beta2 are the values used by Adam optimizer.
# In[6]:
input_numChannels = 3
batch_size = 16
num_epochs =5000
learningRate = 0.001
decay_epoch = int(round(0.6*num_epochs))
image_size = 128
beta1 = 0.55
beta2 = 0.999
# In[7]:
# rgb and depth dataset objects are created below. The arguments in the getDataLoader are folder names of the training
# dataset. For more information, check the folder structure described in readme
# .load_data() loads the data from the folder, applies the transforms and returns a training and a testing dataset iterable
# In[8]:
rgb =getDataLoader('rgb')
depth =getDataLoader('depth')
dataloader_A, test_dataloader_A = rgb.load_data()
dataloader_B, test_dataloader_B = depth.load_data()
# In[9]:
# createArchitecture creates an instance for cycleGAN architecture as mentioned earlier. .create_model() function creates the cycleGAN model
# and moves it GPU if available
# In[10]:
mod = createArchitecture()
gen_A2B,gen_B2A,disc_A,disc_B = mod.create_model()
# In[12]:
print(disc_A)
# In[ ]:
# In[ ]:
# network parameters lists are created because they are easy to operate on
# Adam optimizer has been used, as suggested in the GAN paper, and an optimizer for generators and discriminator_A and discriminator_B are created.
# Separate optimizers are not created for the generator because the parameters are combined in the list. Separate can be created as well. But this cannot
# be the case for discriminators because they have identify two separate o/p
# A learning scheduler has been created for the generators and discrimnators to offer learning rate decay when the
# training approaches global minima
# In[6]:
generator_params = list(gen_A2B.parameters())+ list(gen_B2A.parameters())
# In[7]:
optimizer_gen = optim.Adam(generator_params, learningRate, [beta1, beta2])
optimizer_disc_A = torch.optim.Adam(disc_A.parameters(),learningRate, [beta1,beta2])
optimizer_disc_B = torch.optim.Adam(disc_B.parameters(), learningRate, [beta1,beta2])
# In[8]:
learningRate_scheduler_gen = torch.optim.lr_scheduler.LambdaLR(optimizer_gen,lr_lambda = LambdaLR(num_epochs,decay_epoch).step)
learningRate_scheduler_disc_A = torch.optim.lr_scheduler.LambdaLR(optimizer_disc_A, lr_lambda = LambdaLR(num_epochs,decay_epoch).step)
learningRate_scheduler_disc_B = torch.optim.lr_scheduler.LambdaLR(optimizer_disc_B, lr_lambda = LambdaLR(num_epochs,decay_epoch).step)
# In[ ]:
# Training the Cycle-GAN proceeds in the following way:
# We first train the Discriminator in the following way-
# 1) Check the output of discriminator when real image is fed
# 2) Using the output, get the MSE loss to quantify the error that discriminator produces for real image
# 3) Now get a fake image from the generator [from B-->A]
# 4) Check how the discriminator judges the fake image
# 5) Get the MSE loss produced for the fake image
# 6) Sum both the losses in step (2) and step (5)
# 7) Repeat the same method for the second discriminator as well
# Training the Generator-
# 1) A sample from domain-A (RGB images in our case) is given to the generator to convert it from Domain A-->B
# 2) The generated output is then fed to the discriminator to check the level of fakeness and the MSE loss is received
# 3) The generated output is the fed to the generator (B-->A) to check if the correct mappings are learnt
# 4) The output of step(3) and the input of step (1) are checked for the consistency because the output of step(3) should be
# input of step(1) (use cycle_consistency_loss)
# 5) steps (3) and (4) are repeated for the second generator
# 6) Both the losses are added and propogated backward
# In[9]:
def training_loop(dataloader_RGB, dataloader_Depth, test_dataloader_RGB, test_dataloader_Depth, n_epochs=1000):
print_every = 10
losses = []
test_iter_RGB = iter(test_dataloader_RGB)
test_iter_Depth = iter(test_dataloader_Depth)
fixed_RGB = test_iter_RGB.next()[0]
fixed_Depth = test_iter_Depth.next()[0]
fixed_RGB = scale(fixed_RGB)
fixed_Depth = scale(fixed_Depth)
iter_RGB = iter(dataloader_RGB)
iter_Depth = iter(dataloader_Depth)
batches_per_epoch = min(len(iter_RGB), len(iter_Depth))
for epoch in range(1, n_epochs+1):
if epoch % batches_per_epoch == 0:
iter_RGB = iter(dataloader_RGB)
iter_Depth = iter(dataloader_Depth)
images_RGB, _ = iter_RGB.next()
images_RGB = scale(images_RGB)
images_Depth, _ = iter_Depth.next()
images_Depth = scale(images_Depth)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
images_RGB = images_RGB.to(device)
images_Depth = images_Depth.to(device)
#-------------------------------------Discriminator Training-----------------------------------------------------#
# optimizers are initialized to zero so that they don't carry forward the errors form previous epochs. Basically to
# to avoid any accumulation
optimizer_disc_A.zero_grad()
# Discriminator-A is initially checked for its performance with real image
out_RGB = disc_A(images_RGB)
disc_A_real_loss = real_mse_loss(out_RGB)
# A fake image is generated by the generator to check the Discriminator's performance with fake image
fake_RGB = gen_B2A(images_Depth)
out_RGB = disc_A(fake_RGB)
disc_A_fake_loss = fake_mse_loss(out_RGB)
# Losses produced while the above operations are added which is the total loss of the Discriminator-A. It is backpropogated
disc_A_loss = disc_A_real_loss + disc_A_fake_loss
disc_A_loss.backward()
optimizer_disc_A.step()
# Same operations as mentioned above are now performed on Discriminator-B
optimizer_disc_B.zero_grad()
out_Depth = disc_B(images_Depth)
disc_B_real_loss = real_mse_loss(out_Depth)
fake_Depth = gen_A2B(images_RGB)
out_Depth = disc_B(fake_Depth)
disc_B_fake_loss = fake_mse_loss(out_Depth)
disc_B_loss = disc_B_real_loss + disc_B_fake_loss
disc_B_loss.backward()
optimizer_disc_B.step()
#------------------------------------------Generator Training---------------------------------------------------#
optimizer_gen.zero_grad()
# An image is produced from the generator (in this case generator-B2A is being trained) is produced to check how well it can
fake_RGB = gen_B2A(images_Depth)
# The generated image is then fed to discrminator to check the fakeness and the MSE loss is received
out_RGB = disc_A(fake_RGB)
gen_B2A_loss = real_mse_loss(out_RGB)
# To check the consistency, the generated image is fed to other generator(A2B in this case) and reconstructed image is received
reconstructed_Depth = gen_A2B(fake_RGB)
# Reconstruction loss is received by comparing two images as mentioned in step(4) of training generators
reconstructed_Depth_loss = cycle_consistency_loss(images_Depth, reconstructed_Depth, lambda_weight=10)
# Same operations are done, but this time to train gen_A2B
fake_Depth = gen_A2B(images_RGB)
out_Depth = disc_B(fake_Depth)
gen_A2B_loss = real_mse_loss(out_Depth)
reconstructed_RGB = gen_B2A(fake_Depth)
reconstructed_RGB_loss = cycle_consistency_loss(images_RGB, reconstructed_RGB, lambda_weight=10)
# Losses from both the generators are received and propogated backwards
g_total_loss = gen_B2A_loss + gen_A2B_loss + reconstructed_Depth_loss + reconstructed_RGB_loss
g_total_loss.backward()
optimizer_gen.step()
# Weight decay as mentioned earlier
learningRate_scheduler_gen.step()
learningRate_scheduler_disc_A.step()
learningRate_scheduler_disc_B.step()
if epoch % print_every == 0:
losses.append((disc_A_loss.item(), disc_B_loss.item(), g_total_loss.item()))
print('Epoch [{:5d}/{:5d}] | d_X_loss: {:6.4f} | d_Y_loss: {:6.4f} | g_total_loss: {:6.4f}'.format(
epoch, n_epochs, disc_A_loss.item(), disc_B_loss.item(), g_total_loss.item()))
# A training sample is saved every given time to check the performance of our network
sample_every=100
if epoch % sample_every == 0:
gen_B2A.eval()
gen_A2B.eval()
save_samples(epoch, fixed_Depth, fixed_RGB, gen_B2A, gen_A2B, batch_size=16)
gen_B2A.train()
gen_A2B.train()
# Models are saved after every given number of epochs
checkpoint_every=1000
# Save the model parameters
if epoch % checkpoint_every == 0:
checkpoint(epoch, gen_A2B, gen_B2A, disc_A, disc_B)
return losses
# In[10]:
losses = training_loop(dataloader_A, dataloader_B, test_dataloader_A, test_dataloader_B, n_epochs=num_epochs)
# In[11]:
# gen_A2B,gen_B2A,disc_A,disc_B
# In[12]:
# A plot to track the training losses
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(12,8))
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator, RGB', alpha=0.5)
plt.plot(losses.T[1], label='Discriminator, Depth', alpha=0.5)
plt.plot(losses.T[2], label='Generators', alpha=0.5)
plt.title("Training Losses")
plt.legend()
plt.savefig('Training_Loss.png')
# In[13]:
plt.show()
# In[ ]: