forked from xiefan-guo/ctsdg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
134 lines (101 loc) · 5.62 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import torch
import numpy as np
from tqdm import tqdm
from tensorboardX import SummaryWriter
from utils.distributed import get_rank, reduce_loss_dict
from utils.misc import requires_grad, sample_data
from criteria.loss import generator_loss_func, discriminator_loss_func
def train(opts, image_data_loader, generator, discriminator, extractor, generator_optim, discriminator_optim, is_cuda):
image_data_loader = sample_data(image_data_loader)
pbar = range(opts.train_iter)
if get_rank() == 0:
pbar = tqdm(pbar, initial=opts.start_iter, dynamic_ncols=True, smoothing=0.01)
if opts.distributed:
generator_module, discriminator_module = generator.module, discriminator.module
else:
generator_module, discriminator_module = generator, discriminator
writer = SummaryWriter(opts.log_dir)
for index in pbar:
i = index + opts.start_iter
if i > opts.train_iter:
print('Done...')
break
ground_truth, mask, edge, gray_image = next(image_data_loader)
if is_cuda:
ground_truth, mask, edge, gray_image = ground_truth.cuda(), mask.cuda(), edge.cuda(), gray_image.cuda()
input_image, input_edge, input_gray_image = ground_truth * mask, edge * mask, gray_image * mask
# ---------
# Generator
# ---------
requires_grad(generator, True)
requires_grad(discriminator, False)
output, projected_image, projected_edge = generator(input_image, torch.cat((input_edge, input_gray_image), dim=1), mask)
comp = ground_truth * mask + output * (1 - mask)
output_pred, output_edge = discriminator(output, gray_image, edge, is_real=False)
vgg_comp, vgg_output, vgg_ground_truth = extractor(comp), extractor(output), extractor(ground_truth)
generator_loss_dict = generator_loss_func(
mask, output, ground_truth, edge, output_pred,
vgg_comp, vgg_output, vgg_ground_truth,
projected_image, projected_edge,
output_edge
)
generator_loss = generator_loss_dict['loss_hole'] * opts.HOLE_LOSS + \
generator_loss_dict['loss_valid'] * opts.VALID_LOSS + \
generator_loss_dict['loss_perceptual'] * opts.PERCEPTUAL_LOSS + \
generator_loss_dict['loss_style'] * opts.STYLE_LOSS + \
generator_loss_dict['loss_adversarial'] * opts.ADVERSARIAL_LOSS + \
generator_loss_dict['loss_intermediate'] * opts.INTERMEDIATE_LOSS
generator_loss_dict['loss_joint'] = generator_loss
generator_optim.zero_grad()
generator_loss.backward()
generator_optim.step()
# -------------
# Discriminator
# -------------
requires_grad(generator, False)
requires_grad(discriminator, True)
real_pred, real_pred_edge = discriminator(ground_truth, gray_image, edge, is_real=True)
fake_pred, fake_pred_edge = discriminator(output.detach(), gray_image, edge, is_real=False)
discriminator_loss_dict = discriminator_loss_func(real_pred, fake_pred, real_pred_edge, fake_pred_edge, edge)
discriminator_loss = discriminator_loss_dict['loss_adversarial']
discriminator_loss_dict['loss_joint'] = discriminator_loss
discriminator_optim.zero_grad()
discriminator_loss.backward()
discriminator_optim.step()
# ---
# log
# ---
generator_loss_dict_reduced, discriminator_loss_dict_reduced = reduce_loss_dict(generator_loss_dict), reduce_loss_dict(discriminator_loss_dict)
pbar_g_loss_hole = generator_loss_dict_reduced['loss_hole'].mean().item()
pbar_g_loss_valid = generator_loss_dict_reduced['loss_valid'].mean().item()
pbar_g_loss_perceptual = generator_loss_dict_reduced['loss_perceptual'].mean().item()
pbar_g_loss_style = generator_loss_dict_reduced['loss_style'].mean().item()
pbar_g_loss_adversarial = generator_loss_dict_reduced['loss_adversarial'].mean().item()
pbar_g_loss_intermediate = generator_loss_dict_reduced['loss_intermediate'].mean().item()
pbar_g_loss_joint = generator_loss_dict_reduced['loss_joint'].mean().item()
pbar_d_loss_adversarial = discriminator_loss_dict_reduced['loss_adversarial'].mean().item()
pbar_d_loss_joint = discriminator_loss_dict_reduced['loss_joint'].mean().item()
if get_rank() == 0:
pbar.set_description((
f'g_loss_joint: {pbar_g_loss_joint:.4f} '
f'd_loss_joint: {pbar_d_loss_joint:.4f}'
))
writer.add_scalar('g_loss_hole', pbar_g_loss_hole, i)
writer.add_scalar('g_loss_valid', pbar_g_loss_valid, i)
writer.add_scalar('g_loss_perceptual', pbar_g_loss_perceptual, i)
writer.add_scalar('g_loss_style', pbar_g_loss_style, i)
writer.add_scalar('g_loss_adversarial', pbar_g_loss_adversarial, i)
writer.add_scalar('g_loss_intermediate', pbar_g_loss_intermediate, i)
writer.add_scalar('g_loss_joint', pbar_g_loss_joint, i)
writer.add_scalar('d_loss_adversarial', pbar_d_loss_adversarial, i)
writer.add_scalar('d_loss_joint', pbar_d_loss_joint, i)
if i % opts.save_interval == 0:
torch.save(
{
'n_iter': i,
'generator': generator_module.state_dict(),
'discriminator': discriminator_module.state_dict()
},
f"{opts.save_dir}/{str(i).zfill(6)}.pt",
)