Skip to content

Commit

Permalink
vae
Browse files Browse the repository at this point in the history
  • Loading branch information
thibaultdvx committed Oct 1, 2024
1 parent 683570f commit 348754f
Show file tree
Hide file tree
Showing 2 changed files with 290 additions and 0 deletions.
191 changes: 191 additions & 0 deletions clinicadl/monai_networks/nn/vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from copy import deepcopy
from typing import Any, Dict, Optional, Sequence, Union

import torch
import torch.nn as nn

from .autoencoder import AutoEncoder
from .layers import UpsamplingMode
from .utils import ActivationParameters


class VAE(nn.Module):
"""
A Variational AutoEncoder with convolutional and fully connected layers.
The user must pass the arguments to build an encoder, from its convolutional and
fully connected parts, and the decoder will be automatically built by taking the
symmetrical network.
More precisely, to build the decoder, the order of the encoding layers is reverted, convolutions are
replaced by transposed convolutions and pooling layers are replaced by upsampling layers.
Please note that the order of `Activation`, `Dropout` and `Normalization`, defined with the
argument `adn_ordering` in `conv_args`, is the same for the encoder and the decoder.
Note that an `AutoEncoder` is an aggregation of a `CNN` (:py:class:`clinicadl.monai_networks.nn.
cnn.CNN`), whose last linear layer is duplicated to infer both the mean and the log variance,
and a `Generator` (:py:class:`clinicadl.monai_networks.nn.generator.Generator`).
Parameters
----------
in_shape : Sequence[int]
sequence of integers stating the dimension of the input tensor (minus batch dimension).
latent_size : int
size of the latent vector.
conv_args : Dict[str, Any]
the arguments for the convolutional part of the encoder. The arguments are those accepted
by :py:class:`clinicadl.monai_networks.nn.fcn_encoder.FCNEncoder`, except `in_shape` that
is specified here. So, the only mandatory argument is `channels`.
mlp_args : Optional[Dict[str, Any]] (optional, default=None)
the arguments for the MLP part of the encoder . The arguments are those accepted by
:py:class:`clinicadl.monai_networks.nn.mlp.MLP`, except `in_channels` that is inferred
from the output of the convolutional part, and `out_channels` that is set to `latent_size`.
So, the only mandatory argument is `hidden_channels`.\n
If None, the MLP part will be reduced to a single linear layer.\n
The last linear layer will be duplicated to infer both the mean and the log variance.
out_channels : Optional[int] (optional, default=None)
number of output channels. If None, the output will have the same number of channels as the
input.
output_act : ActivationParameters (optional, default=None)
a potential activation layer applied to the output of the network, and optionally its arguments.
Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n
`activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`,
`relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions]
(https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional
arguments for each of them.
upsampling_mode : Union[str, UpsamplingMode] (optional, default=UpsamplingMode.NEAREST)
interpolation mode for upsampling (see: https://pytorch.org/docs/stable/generated/
torch.nn.Upsample.html).
Examples
--------
>>> VAE(
in_shape=(1, 16, 16),
latent_size=4,
conv_args={"channels": [2]},
mlp_args={"hidden_channels": [16], "output_act": "relu"},
out_channels=2,
output_act="sigmoid",
upsampling_mode="bilinear",
)
VAE(
(encoder): CNN(
(convolutions): FCNEncoder(
(layer_0): Convolution(
(conv): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1))
)
)
(mlp): MLP(
(flatten): Flatten(start_dim=1, end_dim=-1)
(hidden_0): Sequential(
(linear): Linear(in_features=392, out_features=16, bias=True)
(adn): ADN(
(N): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(A): PReLU(num_parameters=1)
)
)
(output): Identity()
)
)
(mu): Sequential(
(linear): Linear(in_features=16, out_features=4, bias=True)
(output_act): ReLU()
)
(log_var): Sequential(
(linear): Linear(in_features=16, out_features=4, bias=True)
(output_act): ReLU()
)
(decoder): Generator(
(mlp): MLP(
(flatten): Flatten(start_dim=1, end_dim=-1)
(hidden_0): Sequential(
(linear): Linear(in_features=4, out_features=16, bias=True)
(adn): ADN(
(N): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(A): PReLU(num_parameters=1)
)
)
(output): Sequential(
(linear): Linear(in_features=16, out_features=392, bias=True)
(output_act): ReLU()
)
)
(reshape): Reshape()
(convolutions): FCNDecoder(
(layer_0): Convolution(
(conv): ConvTranspose2d(2, 2, kernel_size=(3, 3), stride=(1, 1))
)
(output_act): Sigmoid()
)
)
)
"""

def __init__(
self,
in_shape: Sequence[int],
latent_size: int,
conv_args: Dict[str, Any],
mlp_args: Optional[Dict[str, Any]] = None,
out_channels: Optional[int] = None,
output_act: ActivationParameters = None,
upsampling_mode: Union[str, UpsamplingMode] = UpsamplingMode.NEAREST,
) -> None:
super().__init__()
ae = AutoEncoder(
in_shape,
latent_size,
conv_args,
mlp_args,
out_channels,
output_act,
upsampling_mode,
)

# replace last mlp layer by two parallel layers
mu_layers = deepcopy(ae.encoder.mlp.output)
log_var_layers = deepcopy(ae.encoder.mlp.output)
self._reset_weights(
log_var_layers
) # to have different initialization for the two layers
ae.encoder.mlp.output = nn.Identity()

self.encoder = ae.encoder
self.mu = mu_layers
self.log_var = log_var_layers
self.decoder = ae.decoder

def forward(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Encoding, sampling and decoding.
"""
feature = self.encoder(x)
mu = self.mu(feature)
log_var = self.log_var(feature)
z = self.reparameterize(mu, log_var)

return self.decoder(z), mu, log_var

def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
"""
Samples a random vector from a gaussian distribution, given the mean and log-variance
of this distribution.
"""
std = torch.exp(0.5 * log_var)

if self.training: # multiply random noise with std only during training
std = torch.randn_like(std).mul(std)

return std.add_(mu)

@classmethod
def _reset_weights(cls, layer: Union[nn.Sequential, nn.Linear]) -> None:
"""
Resets the output layer(s) of an MLP.
"""
if isinstance(layer, nn.Linear):
layer.reset_parameters()
else:
layer.linear.reset_parameters()
99 changes: 99 additions & 0 deletions tests/unittests/monai_networks/nn/test_vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest
import torch
from numpy import isclose
from torch.nn import ReLU

from clinicadl.monai_networks.nn import VAE


@pytest.mark.parametrize(
"input_tensor,kernel_size,stride,padding,dilation,pooling,pooling_indices",
[
(torch.randn(2, 1, 21), 3, 1, 0, 1, ("max", {"kernel_size": 2}), [0]),
(
torch.randn(2, 1, 65, 85),
(3, 5),
(1, 2),
0,
(1, 2),
("max", {"kernel_size": 2, "stride": 1}),
[0],
),
(
torch.randn(2, 1, 64, 62, 61), # to test output padding
4,
2,
(1, 1, 0),
1,
("avg", {"kernel_size": 3, "stride": 2}),
[0],
),
(
torch.randn(2, 1, 51, 55, 45),
3,
2,
0,
1,
("max", {"kernel_size": 2, "ceil_mode": True}),
[0, 1],
),
(
torch.randn(2, 1, 51, 55, 45),
3,
2,
0,
1,
[
("max", {"kernel_size": 2, "ceil_mode": True}),
("max", {"kernel_size": 2, "stride": 1, "ceil_mode": False}),
],
[0, 1],
),
],
)
def test_output_shape(
input_tensor, kernel_size, stride, padding, dilation, pooling, pooling_indices
):
latent_size = 3
net = VAE(
in_shape=input_tensor.shape[1:],
latent_size=latent_size,
conv_args={
"channels": [2, 4, 8],
"kernel_size": kernel_size,
"stride": stride,
"padding": padding,
"dilation": dilation,
"pooling": pooling,
"pooling_indices": pooling_indices,
},
)
recon, mu, log_var = net(input_tensor)
assert recon.shape == input_tensor.shape
assert mu.shape == (input_tensor.shape[0], latent_size)
assert log_var.shape == (input_tensor.shape[0], latent_size)


def test_mu_log_var():
net = VAE(
in_shape=(1, 5, 5),
latent_size=4,
conv_args={"channels": []},
mlp_args={"hidden_channels": [12], "output_act": "relu", "act": "celu"},
)
assert net.mu.linear.in_features == 12
assert net.log_var.linear.in_features == 12
assert isinstance(net.mu.output_act, ReLU)
assert isinstance(net.log_var.output_act, ReLU)
assert net.encoder(torch.randn(2, 1, 5, 5)).shape == (2, 12)
_, mu, log_var = net(torch.randn(2, 1, 5, 5))
assert not isclose(mu.detach().numpy(), log_var.detach().numpy()).all()

net = VAE(
in_shape=(1, 5, 5),
latent_size=4,
conv_args={"channels": []},
mlp_args={"hidden_channels": [12]},
)
assert net.mu.in_features == 12
assert net.log_var.in_features == 12

0 comments on commit 348754f

Please sign in to comment.