-
Notifications
You must be signed in to change notification settings - Fork 1
/
dcgan_module.py
86 lines (78 loc) · 2.71 KB
/
dcgan_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import torch
from pl_bolts.models.gans import DCGAN
from torch.utils.data import DataLoader
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from torchvision.transforms import v2
from lightning.pytorch.callbacks import RichProgressBar
import argparse
parser = argparse.ArgumentParser(
prog='DCGAN Training',
description='Trains a DCGAN to generate Fake images. By default those images will be 64x64 RGB.',
epilog='Text at the bottom of help')
parser.add_argument('--rootdir', default='data/trump/', type=str, help='directory with images to use for training.', required=True)
parser.add_argument('-e', '--epochs', default=500, type=int, required=False, help='Epochs to train for')
args = parser.parse_args()
# TODO: could do with a compute mean here
if __name__ == "__main__":
dataroot = args.rootdir
transform = v2.Compose([
v2.Resize(64),
v2.CenterCrop(64),
v2.RandomHorizontalFlip(p=0.5),
v2.ToTensor(),
v2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
v2.GaussianNoise(),
])
checkpoint_callback = ModelCheckpoint(
dirpath=".models/",
save_top_k=1,
monitor="loss/gen_epoch"
)
dataset = dset.ImageFolder(root=dataroot,
transform=transform)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=64,
shuffle=True,
num_workers=12)
train_dataloader = DataLoader(dataset,
batch_size = 64,
num_workers = 16,
shuffle = True,
persistent_workers = True
)
logger = TensorBoardLogger(
save_dir=os.getcwd(),
name="lightning_logs")
m = DCGAN(
image_channels=3
)
trainer = Trainer(
accelerator="gpu",
devices=1,
max_epochs=args.epochs,
logger=logger,
enable_checkpointing=True,
callbacks=[checkpoint_callback]
)
trainer.fit(
m,
train_dataloader,
)
print(checkpoint_callback.best_model_path)
img_list = []
noise = torch.rand(64, 100)
fake = m(noise)
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
plt.imshow(np.transpose(img_list[-1]))
plt.title("Fake Images")
plt.show()