-
Notifications
You must be signed in to change notification settings - Fork 277
/
train_vqvae.py
executable file
·152 lines (113 loc) · 4.2 KB
/
train_vqvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import argparse
import sys
import os
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
from tqdm import tqdm
from vqvae import VQVAE
from scheduler import CycleScheduler
import distributed as dist
def train(epoch, loader, model, optimizer, scheduler, device):
if dist.is_primary():
loader = tqdm(loader)
criterion = nn.MSELoss()
latent_loss_weight = 0.25
sample_size = 25
mse_sum = 0
mse_n = 0
for i, (img, label) in enumerate(loader):
model.zero_grad()
img = img.to(device)
out, latent_loss = model(img)
recon_loss = criterion(out, img)
latent_loss = latent_loss.mean()
loss = recon_loss + latent_loss_weight * latent_loss
loss.backward()
if scheduler is not None:
scheduler.step()
optimizer.step()
part_mse_sum = recon_loss.item() * img.shape[0]
part_mse_n = img.shape[0]
comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
comm = dist.all_gather(comm)
for part in comm:
mse_sum += part["mse_sum"]
mse_n += part["mse_n"]
if dist.is_primary():
lr = optimizer.param_groups[0]["lr"]
loader.set_description(
(
f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; "
f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; "
f"lr: {lr:.5f}"
)
)
if i % 100 == 0:
model.eval()
sample = img[:sample_size]
with torch.no_grad():
out, _ = model(sample)
utils.save_image(
torch.cat([sample, out], 0),
f"sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png",
nrow=sample_size,
normalize=True,
range=(-1, 1),
)
model.train()
def main(args):
device = "cuda"
args.distributed = dist.get_world_size() > 1
transform = transforms.Compose(
[
transforms.Resize(args.size),
transforms.CenterCrop(args.size),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
dataset = datasets.ImageFolder(args.path, transform=transform)
sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed)
loader = DataLoader(
dataset, batch_size=128 // args.n_gpu, sampler=sampler, num_workers=2
)
model = VQVAE().to(device)
if args.distributed:
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[dist.get_local_rank()],
output_device=dist.get_local_rank(),
)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = None
if args.sched == "cycle":
scheduler = CycleScheduler(
optimizer,
args.lr,
n_iter=len(loader) * args.epoch,
momentum=None,
warmup_proportion=0.05,
)
for i in range(args.epoch):
train(i, loader, model, optimizer, scheduler, device)
if dist.is_primary():
torch.save(model.state_dict(), f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--n_gpu", type=int, default=1)
port = (
2 ** 15
+ 2 ** 14
+ hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
)
parser.add_argument("--dist_url", default=f"tcp://127.0.0.1:{port}")
parser.add_argument("--size", type=int, default=256)
parser.add_argument("--epoch", type=int, default=560)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--sched", type=str)
parser.add_argument("path", type=str)
args = parser.parse_args()
print(args)
dist.launch(main, args.n_gpu, 1, 0, args.dist_url, args=(args,))