-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
91 lines (75 loc) · 2.76 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import config
from torch import nn
from torch import optim
from utils import load_checkpoint, save_checkpoint, plot_examples
from loss import VGGLoss
from torch.utils.data import DataLoader
from model import Generator, Discriminator
from tqdm import tqdm
from dataset import MyImageFolder
torch.backends.cudnn.benchmark = True
def train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss):
loop = tqdm(loader, leave=True)
for idx, (low_res, high_res) in enumerate(loop):
high_res = high_res.to(config.DEVICE)
low_res = low_res.to(config.DEVICE)
# Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
fake = gen(low_res)
disc_real = disc(high_res)
disc_fake = disc(fake.detach())
disc_loss_real = bce(
disc_real, torch.ones_like(
disc_real) - 0.1 * torch.rand_like(disc_real)
)
disc_loss_fake = bce(disc_fake, torch.zeros_like(disc_fake))
loss_disc = disc_loss_fake + disc_loss_real
opt_disc.zero_grad()
loss_disc.backward()
opt_disc.step()
# Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
disc_fake = disc(fake)
#l2_loss = mse(fake, high_res)
adversarial_loss = 1e-3 * bce(disc_fake, torch.ones_like(disc_fake))
loss_for_vgg = 0.006 * vgg_loss(fake, high_res)
gen_loss = loss_for_vgg + adversarial_loss
opt_gen.zero_grad()
gen_loss.backward()
opt_gen.step()
if idx % 200 == 0:
plot_examples("test_images/", gen)
def main():
dataset = MyImageFolder(root_dir="new_data/")
loader = DataLoader(
dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
pin_memory=True,
num_workers=config.NUM_WORKERS,
)
gen = Generator(in_channels=3).to(config.DEVICE)
disc = Discriminator(in_channels=3).to(config.DEVICE)
opt_gen = optim.Adam(
gen.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.9))
opt_disc = optim.Adam(
disc.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.9))
mse = nn.MSELoss()
bce = nn.BCEWithLogitsLoss()
vgg_loss = VGGLoss()
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT_GEN,
gen,
opt_gen,
config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
)
for epoch in range(config.NUM_EPOCHS):
train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss)
if config.SAVE_MODEL:
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)
if __name__ == "__main__":
main()