-
Notifications
You must be signed in to change notification settings - Fork 39
/
colorize.py
71 lines (64 loc) · 2.19 KB
/
colorize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from model import generator
from torch.autograd import Variable
from scipy.ndimage import zoom
import cv2
import os
from PIL import Image
import argparse
import numpy as np
from skimage.color import rgb2yuv,yuv2rgb
def parse_args():
parser = argparse.ArgumentParser(description="Colorize images")
parser.add_argument("-i",
"--input",
type=str,
required=True,
help="input image/input dir")
parser.add_argument("-o",
"--output",
type=str,
required=True,
help="output image/output dir")
parser.add_argument("-m",
"--model",
type=str,
required=True,
help="location for model (Generator)")
parser.add_argument("--gpu",
type=int,
default=-1,
help="which GPU to use? [-1 for cpu]")
args = parser.parse_args()
return args
args = parse_args()
G = generator()
if args.gpu>=0:
G=G.cuda(args.gpu)
G.load_state_dict(torch.load(args.model))
else:
G.load_state_dict(torch.load(args.model,map_location={'cuda:0': 'cpu'}))
def inference(G,in_path,out_path):
p=Image.open(in_path).convert('RGB')
img_yuv = rgb2yuv(p)
H,W,_ = img_yuv.shape
infimg = np.expand_dims(np.expand_dims(img_yuv[...,0], axis=0), axis=0)
img_variable = Variable(torch.Tensor(infimg-0.5))
if args.gpu>=0:
img_variable=img_variable.cuda(args.gpu)
res = G(img_variable)
uv=res.cpu().detach().numpy()
uv[:,0,:,:] *= 0.436
uv[:,1,:,:] *= 0.615
(_,_,H1,W1) = uv.shape
uv = zoom(uv,(1,1,H/H1,W/W1))
yuv = np.concatenate([infimg,uv],axis=1)[0]
rgb=yuv2rgb(yuv.transpose(1,2,0))
cv2.imwrite(out_path,(rgb.clip(min=0,max=1)*256)[:,:,[2,1,0]])
if not os.path.isdir(args.input):
inference(G,args.input,args.output)
else:
if not os.path.exists(args.output):
os.makedirs(args.output)
for f in os.listdir(args.input):
inference(G,os.path.join(args.input,f),os.path.join(args.output,f))