diff --git a/models/iwae.py b/models/iwae.py index 2b32356f..47ea9d2f 100644 --- a/models/iwae.py +++ b/models/iwae.py @@ -3,7 +3,7 @@ from torch import nn from torch.nn import functional as F from .types_ import * - +import math class IWAE(BaseVAE): @@ -143,21 +143,28 @@ def loss_function(self, eps = args[5] input = input.repeat(self.num_samples, 1, 1, 1, 1).permute(1, 0, 2, 3, 4) #[B x S x C x H x W] - + _, _, c, h, w = input.shape kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset - log_p_x_z = ((recons - input) ** 2).flatten(2).mean(-1) # Reconstruction Loss [B x S] - kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=2) ## [B x S] - # Get importance weights - log_weight = (log_p_x_z + kld_weight * kld_loss) #.detach().data + log_p_x_z = -((recons - input) ** 2).flatten(2).mean(-1) # Reconstruction Loss [B x S] + log_p_z = (- 0.5 * ((z - 0.0) ** 2 / 1.0) - 0.5 * torch.log(2.0 * math.pi * 1.0)).mean(-1) # prior [B x S] + log_q_z_x = (- 0.5 * ((z - mu) ** 2 / log_var.exp()) - 0.5 * (torch.log(2.0 * math.pi) + log_var)).mean(-1) # variational posterior [B x S] + # this kl divergence is estimated, actually its kind of weird to write it in a beta-VAE way + # probably it is better to do \lambda * reconstruction loss + DKL rather than reconstruction loss + * \beta * DKL here + # I am just trying to make it maintains the same meaning as before ... + # Infact it is SGBV 1 (Auto-encoding Variational Bayes Eq 6) used here, not SGVB 2 (Auto-encoding Variational Bayes Eq 7, Eq 10) + kld_loss_est = kld_weight * (log_q_z_x - log_p_z) * (c * h * w) # [B x S] + + # kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=2) ## [B x S] + # Get importance weights, needs to be detached + elbo = (log_p_x_z - kld_loss_est) # Rescale the weights (along the sample dim) to lie in [0, 1] and sum to 1 - weight = F.softmax(log_weight, dim = -1) - # kld_loss = torch.mean(kld_loss, dim = 0) + weight = F.softmax(elbo.detach(), dim = -1) - loss = torch.mean(torch.sum(weight * log_weight, dim=-1), dim = 0) + loss = - torch.mean(torch.sum(weight * elbo, dim=-1), dim = 0) - return {'loss': loss, 'Reconstruction_Loss':log_p_x_z.mean(), 'KLD':-kld_loss.mean()} + return {'loss': loss, 'Reconstruction_Loss':log_p_x_z.mean(), 'KLD':-kld_loss_est.mean()} def sample(self, num_samples:int,