forked from yang-song/score_sde_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
controllable_generation.py
181 lines (150 loc) · 7.99 KB
/
controllable_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from models import utils as mutils
import torch
import numpy as np
from sampling import NoneCorrector, NonePredictor, shared_corrector_update_fn, shared_predictor_update_fn
import functools
def get_pc_inpainter(sde, predictor, corrector, inverse_scaler, snr,
n_steps=1, probability_flow=False, continuous=False,
denoise=True, eps=1e-5):
"""Create an image inpainting function that uses PC samplers.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
predictor: A subclass of `sampling.Predictor` that represents a predictor algorithm.
corrector: A subclass of `sampling.Corrector` that represents a corrector algorithm.
inverse_scaler: The inverse data normalizer.
snr: A `float` number. The signal-to-noise ratio for the corrector.
n_steps: An integer. The number of corrector steps per update of the corrector.
probability_flow: If `True`, predictor solves the probability flow ODE for sampling.
continuous: `True` indicates that the score-based model was trained with continuous time.
denoise: If `True`, add one-step denoising to final samples.
eps: A `float` number. The reverse-time SDE/ODE is integrated to `eps` for numerical stability.
Returns:
An inpainting function.
"""
# Define predictor & corrector
predictor_update_fn = functools.partial(shared_predictor_update_fn,
sde=sde,
predictor=predictor,
probability_flow=probability_flow,
continuous=continuous)
corrector_update_fn = functools.partial(shared_corrector_update_fn,
sde=sde,
corrector=corrector,
continuous=continuous,
snr=snr,
n_steps=n_steps)
def get_inpaint_update_fn(update_fn):
"""Modify the update function of predictor & corrector to incorporate data information."""
def inpaint_update_fn(model, data, mask, x, t):
with torch.no_grad():
vec_t = torch.ones(data.shape[0], device=data.device) * t
x, x_mean = update_fn(x, vec_t, model=model)
masked_data_mean, std = sde.marginal_prob(data, vec_t)
masked_data = masked_data_mean + torch.randn_like(x) * std[:, None, None, None]
x = x * (1. - mask) + masked_data * mask
x_mean = x * (1. - mask) + masked_data_mean * mask
return x, x_mean
return inpaint_update_fn
projector_inpaint_update_fn = get_inpaint_update_fn(predictor_update_fn)
corrector_inpaint_update_fn = get_inpaint_update_fn(corrector_update_fn)
def pc_inpainter(model, data, mask):
"""Predictor-Corrector (PC) sampler for image inpainting.
Args:
model: A score model.
data: A PyTorch tensor that represents a mini-batch of images to inpaint.
mask: A 0-1 tensor with the same shape of `data`. Value `1` marks known pixels,
and value `0` marks pixels that require inpainting.
Returns:
Inpainted (complete) images.
"""
with torch.no_grad():
# Initial sample
x = data * mask + sde.prior_sampling(data.shape).to(data.device) * (1. - mask)
timesteps = torch.linspace(sde.T, eps, sde.N)
for i in range(sde.N):
t = timesteps[i]
x, x_mean = corrector_inpaint_update_fn(model, data, mask, x, t)
x, x_mean = projector_inpaint_update_fn(model, data, mask, x, t)
return inverse_scaler(x_mean if denoise else x)
return pc_inpainter
def get_pc_colorizer(sde, predictor, corrector, inverse_scaler,
snr, n_steps=1, probability_flow=False, continuous=False,
denoise=True, eps=1e-5):
"""Create a image colorization function based on Predictor-Corrector (PC) sampling.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
predictor: A subclass of `sampling.Predictor` that represents a predictor algorithm.
corrector: A subclass of `sampling.Corrector` that represents a corrector algorithm.
inverse_scaler: The inverse data normalizer.
snr: A `float` number. The signal-to-noise ratio for correctors.
n_steps: An integer. The number of corrector steps per update of the predictor.
probability_flow: If `True`, solve the probability flow ODE for sampling with the predictor.
continuous: `True` indicates that the score-based model was trained with continuous time steps.
denoise: If `True`, add one-step denoising to final samples.
eps: A `float` number. The SDE/ODE will start from `eps` to avoid numerical stabilities.
Returns: A colorization function.
"""
# `M` is an orthonormal matrix to decouple image space to a latent space where the gray-scale image
# occupies a separate channel
M = torch.tensor([[5.7735014e-01, -8.1649649e-01, 4.7008697e-08],
[5.7735026e-01, 4.0824834e-01, 7.0710671e-01],
[5.7735026e-01, 4.0824822e-01, -7.0710683e-01]])
# `invM` is the inverse transformation of `M`
invM = torch.inverse(M)
# Decouple a gray-scale image with `M`
def decouple(inputs):
return torch.einsum('bihw,ij->bjhw', inputs, M.to(inputs.device))
# The inverse function to `decouple`.
def couple(inputs):
return torch.einsum('bihw,ij->bjhw', inputs, invM.to(inputs.device))
predictor_update_fn = functools.partial(shared_predictor_update_fn,
sde=sde,
predictor=predictor,
probability_flow=probability_flow,
continuous=continuous)
corrector_update_fn = functools.partial(shared_corrector_update_fn,
sde=sde,
corrector=corrector,
continuous=continuous,
snr=snr,
n_steps=n_steps)
def get_colorization_update_fn(update_fn):
"""Modify update functions of predictor & corrector to incorporate information of gray-scale images."""
def colorization_update_fn(model, gray_scale_img, x, t):
mask = get_mask(x)
vec_t = torch.ones(x.shape[0], device=x.device) * t
x, x_mean = update_fn(x, vec_t, model=model)
masked_data_mean, std = sde.marginal_prob(decouple(gray_scale_img), vec_t)
masked_data = masked_data_mean + torch.randn_like(x) * std[:, None, None, None]
x = couple(decouple(x) * (1. - mask) + masked_data * mask)
x_mean = couple(decouple(x) * (1. - mask) + masked_data_mean * mask)
return x, x_mean
return colorization_update_fn
def get_mask(image):
mask = torch.cat([torch.ones_like(image[:, :1, ...]),
torch.zeros_like(image[:, 1:, ...])], dim=1)
return mask
predictor_colorize_update_fn = get_colorization_update_fn(predictor_update_fn)
corrector_colorize_update_fn = get_colorization_update_fn(corrector_update_fn)
def pc_colorizer(model, gray_scale_img):
"""Colorize gray-scale images using Predictor-Corrector (PC) sampler.
Args:
model: A score model.
gray_scale_img: A minibatch of gray-scale images. Their R,G,B channels have same values.
Returns:
Colorized images.
"""
with torch.no_grad():
shape = gray_scale_img.shape
mask = get_mask(gray_scale_img)
# Initial sample
x = couple(decouple(gray_scale_img) * mask + \
decouple(sde.prior_sampling(shape).to(gray_scale_img.device)
* (1. - mask)))
timesteps = torch.linspace(sde.T, eps, sde.N)
for i in range(sde.N):
t = timesteps[i]
x, x_mean = corrector_colorize_update_fn(model, gray_scale_img, x, t)
x, x_mean = predictor_colorize_update_fn(model, gray_scale_img, x, t)
return inverse_scaler(x_mean if denoise else x)
return pc_colorizer