-
Notifications
You must be signed in to change notification settings - Fork 3
/
utils.py
37 lines (31 loc) · 1.3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import config
from torchvision.utils import save_image
def save_some_examples(gen, val_loader, epoch, folder):
x, y = next(iter(val_loader))
x, y = x.to(config.DEVICE), y.to(config.DEVICE)
gen.eval()
with torch.no_grad():
y_fake = gen(x)
y_fake = y_fake * 0.5 + 0.5 # remove normalization#
save_image(y_fake, folder + f"/y_gen_{epoch}.png")
save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
if epoch == 1:
save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
gen.train()
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
print("!!! Saving checkpoint !!!")
checkpoint = {
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, filename)
def load_checkpoint(checkpoint_file, model, optimizer, lr):
print("!!! Loading checkpoint !!!")
checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
# If we don't do this then it will just have learning rate of old checkpoint
# and it will lead to many hours of debugging \:
for param_group in optimizer.param_groups:
param_group["lr"] = lr