forked from daa233/generative-inpainting-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_single.py
132 lines (116 loc) · 5.41 KB
/
test_single.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import random
from argparse import ArgumentParser
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.utils as vutils
from model.networks import Generator
from utils.tools import (default_loader, get_config, get_model_list,
is_image_file, mask_image, normalize, random_bbox)
parser = ArgumentParser()
parser.add_argument('--config', type=str, default='configs/config.yaml',
help="training configuration")
parser.add_argument('--seed', type=int, help='manual seed')
parser.add_argument('--image', type=str)
parser.add_argument('--mask', type=str, default='')
parser.add_argument('--output', type=str, default='output.png')
parser.add_argument('--flow', type=str, default='')
parser.add_argument('--checkpoint_path', type=str, default='')
parser.add_argument('--iter', type=int, default=0)
def main():
args = parser.parse_args()
config = get_config(args.config)
# CUDA configuration
cuda = config['cuda']
device_ids = config['gpu_ids']
if cuda:
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
str(i) for i in device_ids)
device_ids = list(range(len(device_ids)))
config['gpu_ids'] = device_ids
cudnn.benchmark = True
print("Arguments: {}".format(args))
# Set random seed
if args.seed is None:
args.seed = random.randint(1, 10000)
print("Random seed: {}".format(args.seed))
random.seed(args.seed)
torch.manual_seed(args.seed)
if cuda:
torch.cuda.manual_seed_all(args.seed)
print("Configuration: {}".format(config))
try: # for unexpected error logging
with torch.no_grad(): # enter no grad context
if is_image_file(args.image):
if args.mask and is_image_file(args.mask):
# Test a single masked image with a given mask
x = default_loader(args.image)
mask = default_loader(args.mask)
x = transforms.Resize(config['image_shape'][:-1])(x)
x = transforms.CenterCrop(config['image_shape'][:-1])(x)
mask = transforms.Resize(config['image_shape'][:-1])(mask)
mask = transforms.CenterCrop(
config['image_shape'][:-1])(mask)
x = transforms.ToTensor()(x)
mask = transforms.ToTensor()(mask)[0].unsqueeze(dim=0)
x = normalize(x)
x = x * (1. - mask)
x = x.unsqueeze(dim=0)
mask = mask.unsqueeze(dim=0)
elif args.mask:
raise TypeError(
"{} is not an image file.".format(args.mask))
else:
# Test a single ground-truth image with a random mask
ground_truth = default_loader(args.image)
ground_truth = transforms.Resize(
config['image_shape'][:-1])(ground_truth)
ground_truth = transforms.CenterCrop(
config['image_shape'][:-1])(ground_truth)
ground_truth = transforms.ToTensor()(ground_truth)
ground_truth = normalize(ground_truth)
ground_truth = ground_truth.unsqueeze(dim=0)
bboxes = random_bbox(
config, batch_size=ground_truth.size(0))
x, mask = mask_image(ground_truth, bboxes, config)
# Set checkpoint path
if not args.checkpoint_path:
checkpoint_path = os.path.join('checkpoints',
config['dataset_name'],
config['mask_type'] + '_' + config['expname'])
else:
checkpoint_path = args.checkpoint_path
# Define the trainer
netG = Generator(config['netG'], cuda, device_ids)
# Resume weight
last_model_name = get_model_list(
checkpoint_path, "gen", iteration=args.iter)
netG.load_state_dict(torch.load(last_model_name))
model_iteration = int(last_model_name[-11:-3])
print("Resume from {} at iteration {}".format(
checkpoint_path, model_iteration))
if cuda:
netG = nn.parallel.DataParallel(
netG, device_ids=device_ids)
x = x.cuda()
mask = mask.cuda()
# Inference
x1, x2, offset_flow = netG(x, mask)
inpainted_result = x2 * mask + x * (1. - mask)
vutils.save_image(inpainted_result, args.output,
padding=0, normalize=True)
print("Saved the inpainted result to {}".format(args.output))
if args.flow:
vutils.save_image(offset_flow, args.flow,
padding=0, normalize=True)
print("Saved offset flow to {}".format(args.flow))
else:
raise TypeError("{} is not an image file.".format)
# exit no grad context
except Exception as e: # for unexpected error logging
print("Error: {}".format(e))
raise e
if __name__ == '__main__':
main()