-
Notifications
You must be signed in to change notification settings - Fork 0
/
vae_sr.py
258 lines (196 loc) · 9.26 KB
/
vae_sr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
# -*- coding: utf-8 -*-
# Commented out IPython magic to ensure Python compatibility.
from typing import *
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import Image, display, clear_output
import numpy as np
# %matplotlib nbagg
# %matplotlib inline
import pandas as pd
import math
import torch
from torch import nn, Tensor
from torch.nn.functional import softplus
from torch.distributions import Distribution
from torch.distributions import Bernoulli
from torch.distributions import Normal
class ReparameterizedDiagonalGaussian(Distribution):
"""
A distribution `N(y | mu, sigma I)` compatible with the reparameterization trick given `epsilon ~ N(0, 1)`.
"""
def __init__(self, mu: Tensor, log_sigma:Tensor):
assert mu.shape == log_sigma.shape, f"Tensors `mu` : {mu.shape} and ` log_sigma` : {log_sigma.shape} must be of the same shape"
self.mu = mu
self.sigma = log_sigma.exp()
def sample_epsilon(self) -> Tensor:
"""`\eps ~ N(0, I)`"""
return torch.empty_like(self.mu).normal_()
def sample(self) -> Tensor:
"""sample `z ~ N(z | mu, sigma)` (without gradients)"""
with torch.no_grad():
return self.rsample()
def rsample(self) -> Tensor:
"""sample `z ~ N(z | mu, sigma)` (with the reparameterization trick) """
#return self.Normal(self.mu, self.sigma * self.sample_epsilon()) # <- your code
#z = torch.normal(self.mu, (self.sigma * self.sample_epsilon()))
z = self.mu + self.sigma * self.sample_epsilon()
return z
def log_prob(self, z:Tensor) -> Tensor:
"""return the log probability: log `p(z)`"""
#return np.log(self.rsample()).sum() # <- your code
log_scale = torch.log(self.sigma)
return -((z - self.mu)**2 / (2*self.sigma**2)) - log_scale - math.log(math.sqrt(2*math.pi))
#return self.log_prob(z)
class VariationalAutoencoder(nn.Module):
"""A Variational Autoencoder with
* a Bernoulli observation model `p_\theta(x | z) = B(x | g_\theta(z))`
* a Gaussian prior `p(z) = N(z | 0, I)`
* a Gaussian posterior `q_\phi(z|x) = N(z | \mu(x), \sigma(x))`
"""
def __init__(self, input_shape:torch.Size, latent_features:int) -> None:
super(VariationalAutoencoder, self).__init__()
self.input_shape = input_shape
self.latent_features = latent_features
self.observation_features = np.prod(input_shape)
# Inference Network
# Encode the observation `x` into the parameters of the posterior distribution
# `q_\phi(z|x) = N(z | \mu(x), \sigma(x)), \mu(x),\log\sigma(x) = h_\phi(x)`
self.encoder = nn.Sequential(
nn.Linear(in_features=self.observation_features, out_features=512),
nn.ReLU(),
nn.Linear(in_features=512, out_features=256),
nn.ReLU(),
nn.Linear(in_features=256, out_features=128),
nn.ReLU(),
nn.Linear(in_features = 128, out_features = 64),
nn.ReLU(),
nn.Linear(in_features = 64, out_features = 32),
nn.ReLU(),
# A Gaussian is fully characterised by its mean \mu and variance \sigma**2
nn.Linear(in_features=32, out_features=2*latent_features) # <- note the 2*latent_features
)
# Generative Model
# Decode the latent sample `z` into the parameters of the observation model
# `p_\theta(x | z) = \prod_i B(x_i | g_\theta(x))`
self.decoder = nn.Sequential(
nn.Linear(in_features=latent_features, out_features=32),
nn.ReLU(),
nn.Linear(in_features = 32, out_features = 64),
nn.ReLU(),
nn.Linear(in_features = 64, out_features = 128),
nn.ReLU(),
nn.Linear(in_features=128, out_features=256),
nn.ReLU(),
nn.Linear(in_features=256, out_features=512),
nn.ReLU(),
nn.Linear(in_features=512, out_features=self.observation_features)
)
# Prior for SR
self.prior_nn = nn.Sequential(
nn.Linear(in_features=self.observation_features, out_features=512),
nn.ReLU(),
nn.Linear(in_features=512, out_features=256),
nn.ReLU(),
nn.Linear(in_features=256, out_features=128),
nn.ReLU(),
nn.Linear(in_features = 128, out_features = 64),
nn.ReLU(),
nn.Linear(in_features = 64, out_features = 32),
nn.ReLU(),
# A Gaussian is fully characterised by its mean \mu and variance \sigma**2
nn.Linear(in_features=32, out_features=2*latent_features) #
)
# define the parameters of the prior, chosen as p(z) = N(0, I)
self.register_buffer('prior_params', torch.zeros(torch.Size([1, 2*latent_features])))
def posterior(self, x:Tensor) -> Distribution:
"""return the distribution `q(x|x) = N(z | \mu(x), \sigma(x))`"""
# compute the parameters of the posterior
h_x = self.encoder(x)
mu, log_sigma = h_x.chunk(2, dim=-1)
# return a distribution `q(x|x) = N(z | \mu(x), \sigma(x))`
return ReparameterizedDiagonalGaussian(mu, log_sigma)
def prior(self, batch_size:int=1)-> Distribution:
"""return the distribution `p(z)`"""
prior_params = self.prior_params.expand(batch_size, *self.prior_params.shape[-1:])
mu, log_sigma = prior_params.chunk(2, dim=-1)
# return the distribution `p(z)`
return ReparameterizedDiagonalGaussian(mu, log_sigma)
def prior_sr(self, y:Tensor) -> Distribution:
h_y = self.prior_nn(y)
mu, log_sigma = h_y.chunk(2, dim=-1)
# return the distribution `p(z)`
return ReparameterizedDiagonalGaussian(mu, log_sigma)
def observation_model(self, z:Tensor) -> Distribution:
"""return the distribution `p(x|z)`"""
px_logits = self.decoder(z)
px_logits = px_logits.view(-1, *self.input_shape) # reshape the output to input_shape number of columns (rows are unspecified)
return Bernoulli(logits=px_logits)
def observation_model_normal(self, z:Tensor) -> Distribution:
"""return the distribution `p(x|z)`"""
h_z = self.decoder(z)
mu, log_sigma = h_z.chunk(2, dim =-1)
mu = mu.view(-1, *self.input_shape)
log_sigma = log_sigma.view(-1, *self.input_shape)
#sampled = sampled.view(-1,*self.input_shape)
return ReparameterizedDiagonalGaussian(mu, log_sigma)
def forward(self, x, y) -> Dict[str, Any]:
"""compute the posterior q(z|x) (encoder), sample z~q(z|x) and return the distribution p(x|z) (decoder)"""
# flatten the input
x = x.view(x.size(0), -1)
y = y.view(y.size(0), -1)
# define the posterior q(z|x) / encode x into q(z|x)
qz = self.posterior(x)
# define the prior p(z)
#pz = self.prior(batch_size=x.size(0))
# p(z|y)
pz= self.prior_sr(y)
zy = pz.rsample()
# sample the posterior using the reparameterization trick: z ~ q(z | x)
z = qz.rsample()
# define the observation model p(x|z) = B(x | g(z))
px = self.observation_model(z+zy)
return {'px': px, 'pz': pz, 'qz': qz, 'z': z}
def sample_from_prior(self, y):
"""sample z~p(z) and return p(x|z)"""
y = y.view(y.size(0), -1)
# define the prior p(z)
pz = self.prior_sr(y)
# sample the prior
z = pz.rsample()
# define the observation model p(x|z) = B(x | g(z))
px = self.observation_model_normal(z)
return {'px': px, 'pz': pz, 'z': z}
# latent_features = 2
# vae = VariationalAutoencoder(images[0].shape, latent_features)
# print(vae)
def reduce(x:Tensor) -> Tensor:
"""for each datapoint: sum over all dimensions"""
return x.view(x.size(0), -1).sum(dim=1)
class VariationalInference(nn.Module):
def __init__(self, beta:float=0.95):
super().__init__()
self.beta = beta
def forward(self, model:nn.Module, x:Tensor, y:Tensor) -> Tuple[Tensor, Dict]:
# forward pass through the model
outputs = model(y,x)
# unpack outputs
px, pz, qz, z = [outputs[k] for k in ["px", "pz", "qz", "z"]]
# evaluate log probabilities
log_px = reduce(px.log_prob(x))
log_pz = reduce(pz.log_prob(z))
log_qz = reduce(qz.log_prob(z))
# compute the ELBO with and without the beta parameter:
# `L^\beta = E_q [ log p(x|z) - \beta * D_KL(q(z|x) | p(z))`
# where `D_KL(q(z|x) | p(z)) = log q(z|x) - log p(z)`
kl = log_qz - log_pz
# elbo = torch.mean(log_px) - kl # <- your code here
# beta_elbo = torch.mean(log_px) - self.beta* kl # <- your code here
elbo = log_px - kl # <- your code here
beta_elbo = log_px - self.beta* kl # <- your code here
# loss
loss = -beta_elbo.mean()
# prepare the output
with torch.no_grad():
diagnostics = {'elbo': elbo, 'log_px':log_px, 'kl': kl}
return loss, diagnostics, outputs