From 1cd7597ab1aa26426f97888142deb1a21e0d7494 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Wed, 16 Oct 2024 12:04:07 +0200 Subject: [PATCH] add transposed conv for unpooling --- clinicadl/monai_networks/nn/autoencoder.py | 212 +++++++++++++----- clinicadl/monai_networks/nn/conv_encoder.py | 18 +- .../nn/layers/utils/__init__.py | 2 +- .../monai_networks/nn/layers/utils/enum.py | 5 +- clinicadl/monai_networks/nn/utils/checks.py | 2 +- clinicadl/monai_networks/nn/utils/shapes.py | 30 ++- clinicadl/monai_networks/nn/vae.py | 26 ++- .../monai_networks/nn/test_autoencoder.py | 83 ++++++- 8 files changed, 285 insertions(+), 93 deletions(-) diff --git a/clinicadl/monai_networks/nn/autoencoder.py b/clinicadl/monai_networks/nn/autoencoder.py index 933d0a6b8..d8115760e 100644 --- a/clinicadl/monai_networks/nn/autoencoder.py +++ b/clinicadl/monai_networks/nn/autoencoder.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import numpy as np import torch.nn as nn @@ -7,11 +7,19 @@ from .cnn import CNN from .conv_encoder import ConvEncoder from .generator import Generator -from .layers.utils import ActivationParameters, UnpoolingLayer, UpsamplingMode +from .layers.utils import ( + ActivationParameters, + PoolingLayer, + SingleLayerPoolingParameters, + SingleLayerUnpoolingParameters, + UnpoolingLayer, + UnpoolingMode, +) from .mlp import MLP from .utils import ( calculate_conv_out_shape, calculate_convtranspose_out_shape, + calculate_pool_out_shape, ) @@ -24,7 +32,8 @@ class AutoEncoder(nn.Sequential): symmetrical network. More precisely, to build the decoder, the order of the encoding layers is reverted, convolutions are - replaced by transposed convolutions and pooling layers are replaced by upsampling layers. + replaced by transposed convolutions and pooling layers are replaced by either upsampling or transposed + convolution layers. Please note that the order of `Activation`, `Dropout` and `Normalization`, defined with the argument `adn_ordering` in `conv_args`, is the same for the encoder and the decoder. @@ -57,9 +66,18 @@ class AutoEncoder(nn.Sequential): `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional arguments for each of them. - upsampling_mode : Union[str, UpsamplingMode] (optional, default=UpsamplingMode.NEAREST) - interpolation mode for upsampling (see: https://pytorch.org/docs/stable/generated/ - torch.nn.Upsample.html). + unpooling_mode : Union[str, UnpoolingMode] (optional, default=UnpoolingMode.NEAREST) + type of unpooling. Can be either `"nearest"`, `"linear"`, `"bilinear"`, `"bicubic"`, `"trilinear"` or + `"convtranspose"`.\n + - `nearest`: unpooling is performed by upsampling with the :italic:`nearest` algorithm (see [PyTorch's Upsample layer] + (https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html)). + - `linear`: unpooling is performed by upsampling with the :italic:`linear` algorithm. Only works with 1D images (excluding the + channel dimension). + - `bilinear`: unpooling is performed by upsampling with the :italic:`bilinear` algorithm. Only works with 2D images. + - `bicubic`: unpooling is performed by upsampling with the :italic:`bicubic` algorithm. Only works with 2D images. + - `trilinear`: unpooling is performed by upsampling with the :italic:`trilinear` algorithm. Only works with 3D images. + - `convtranspose`: unpooling is performed with a transposed convolution, whose parameters (kernel size, stride, etc.) are + computed to reverse the pooling operation. Examples -------- @@ -74,7 +92,7 @@ class AutoEncoder(nn.Sequential): mlp_args={"hidden_channels": [32], "output_act": "relu"}, out_channels=2, output_act="sigmoid", - upsampling_mode="bilinear", + unpooling_mode="bilinear", ) AutoEncoder( (encoder): CNN( @@ -149,13 +167,14 @@ def __init__( mlp_args: Optional[Dict[str, Any]] = None, out_channels: Optional[int] = None, output_act: Optional[ActivationParameters] = None, - upsampling_mode: Union[str, UpsamplingMode] = UpsamplingMode.NEAREST, + unpooling_mode: Union[str, UnpoolingMode] = UnpoolingMode.NEAREST, ) -> None: super().__init__() self.in_shape = in_shape - self.upsampling_mode = self._check_upsampling_mode(upsampling_mode) + self.unpooling_mode = self._check_unpooling_mode(unpooling_mode) self.out_channels = out_channels if out_channels else self.in_shape[0] self.output_act = output_act + self.spatial_dims = len(in_shape[1:]) self.encoder = CNN( in_shape=self.in_shape, @@ -194,28 +213,27 @@ def _invert_conv_args( part of the decoder. """ if len(args["channels"]) == 0: - return {"channels": []} - - args["channels"] = self._invert_list_arg(conv.channels[:-1]) + [ - self.out_channels - ] + args["channels"] = [] + else: + args["channels"] = self._invert_list_arg(conv.channels[:-1]) + [ + self.out_channels + ] args["kernel_size"] = self._invert_list_arg(conv.kernel_size) args["stride"] = self._invert_list_arg(conv.stride) - args["padding"] = self._invert_list_arg(conv.padding) args["dilation"] = self._invert_list_arg(conv.dilation) - args["output_padding"] = self._get_output_padding_list(conv) + args["padding"], args["output_padding"] = self._get_paddings_list(conv) args["unpooling_indices"] = ( conv.n_layers - np.array(conv.pooling_indices) - 2 ).astype(int) args["unpooling"] = [] - size_before_pools = [ + sizes_before_pooling = [ size for size, (layer_name, _) in zip(conv.size_details, conv.named_children()) if "pool" in layer_name ] - for size in size_before_pools[::-1]: - args["unpooling"].append(self._invert_pooling_layer(size)) + for size, pooling in zip(sizes_before_pooling[::-1], conv.pooling[::-1]): + args["unpooling"].append(self._invert_pooling_layer(size, pooling)) if "pooling" in args: del args["pooling"] @@ -234,21 +252,80 @@ def _invert_list_arg(cls, arg: Union[Any, List[Any]]) -> Union[Any, List[Any]]: return list(arg[::-1]) if isinstance(arg, Sequence) else arg def _invert_pooling_layer( - self, size_before_pool: Sequence[int] - ) -> Tuple[UnpoolingLayer, Dict[str, Any]]: + self, + size_before_pool: Sequence[int], + pooling: SingleLayerPoolingParameters, + ) -> SingleLayerUnpoolingParameters: """ - Gets the unpooling layer (always upsample). + Gets the unpooling layer. """ - return ( - UnpoolingLayer.UPSAMPLE, - {"size": size_before_pool, "mode": self.upsampling_mode}, - ) + if self.unpooling_mode == UnpoolingMode.CONV_TRANS: + return ( + UnpoolingLayer.CONV_TRANS, + self._invert_pooling_with_convtranspose(size_before_pool, pooling), + ) + else: + return ( + UnpoolingLayer.UPSAMPLE, + {"size": size_before_pool, "mode": self.unpooling_mode}, + ) @classmethod - def _get_output_padding_list(cls, conv: ConvEncoder) -> List[Tuple[int, ...]]: + def _invert_pooling_with_convtranspose( + cls, + size_before_pool: Sequence[int], + pooling: SingleLayerPoolingParameters, + ) -> Dict[str, Any]: + """ + Computes the arguments of the transposed convolution, based on the pooling layer. + """ + pooling_mode, pooling_args = pooling + if ( + pooling_mode == PoolingLayer.ADAPT_AVG + or pooling_mode == PoolingLayer.ADAPT_MAX + ): + input_size_np = np.array(size_before_pool) + output_size_np = np.array(pooling_args["output_size"]) + stride_np = input_size_np // output_size_np # adaptive pooling formulas + kernel_size_np = ( + input_size_np - (output_size_np - 1) * stride_np + ) # adaptive pooling formulas + args = { + "kernel_size": tuple(int(k) for k in kernel_size_np), + "stride": tuple(int(s) for s in stride_np), + } + padding, output_padding = cls._find_convtranspose_paddings( + pooling_mode, + size_before_pool, + output_size=pooling_args["output_size"], + **args, + ) + + elif pooling_mode == PoolingLayer.MAX or pooling_mode == PoolingLayer.AVG: + if "stride" not in pooling_args: + pooling_args["stride"] = pooling_args["kernel_size"] + args = { + arg: value + for arg, value in pooling_args.items() + if arg in ["kernel_size", "stride", "padding", "dilation"] + } + padding, output_padding = cls._find_convtranspose_paddings( + pooling_mode, + size_before_pool, + **pooling_args, + ) + + args["padding"] = padding # pylint: disable=possibly-used-before-assignment + args["output_padding"] = output_padding # pylint: disable=possibly-used-before-assignment + + return args + + @classmethod + def _get_paddings_list(cls, conv: ConvEncoder) -> List[Tuple[int, ...]]: """ Finds output padding list. """ + padding = [] output_padding = [] size_before_convs = [ size @@ -262,61 +339,78 @@ def _get_output_padding_list(cls, conv: ConvEncoder) -> List[Tuple[int, ...]]: conv.padding, conv.dilation, ): - out_p = cls._find_output_padding(size, k, s, p, d) + p, out_p = cls._find_convtranspose_paddings( + "conv", size, kernel_size=k, stride=s, padding=p, dilation=d + ) + padding.append(p) output_padding.append(out_p) - return cls._invert_list_arg(output_padding) + return cls._invert_list_arg(padding), cls._invert_list_arg(output_padding) @classmethod - def _find_output_padding( + def _find_convtranspose_paddings( cls, + layer_type: Union[Literal["conv"], PoolingLayer], in_shape: Union[Sequence[int], int], - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - padding: Union[Sequence[int], int], - dilation: Union[Sequence[int], int], - ) -> Tuple[int, ...]: + padding: Union[Sequence[int], int] = 0, + **kwargs, + ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: """ - Finds output padding necessary to recover the right image size after + Finds padding and output padding necessary to recover the right image size after a transposed convolution. """ - in_shape_np = np.atleast_1d(in_shape) - conv_out_shape = calculate_conv_out_shape( - in_shape_np, kernel_size, stride, padding, dilation - ) - convt_out_shape = calculate_convtranspose_out_shape( - conv_out_shape, kernel_size, stride, padding, 0, dilation - ) - output_padding = in_shape_np - np.atleast_1d(convt_out_shape) + if layer_type == "conv": + layer_out_shape = calculate_conv_out_shape(in_shape, **kwargs) + elif layer_type in list(PoolingLayer): + layer_out_shape = calculate_pool_out_shape(layer_type, in_shape, **kwargs) + + convt_out_shape = calculate_convtranspose_out_shape(layer_out_shape, **kwargs) # pylint: disable=possibly-used-before-assignment + output_padding = np.atleast_1d(in_shape) - np.atleast_1d(convt_out_shape) + + if ( + output_padding < 0 + ).any(): # can happen with ceil_mode=True for maxpool. Then, add some padding + padding = np.atleast_1d(padding) * np.ones_like( + output_padding + ) # to have the same shape as output_padding + padding[output_padding < 0] += np.maximum(np.abs(output_padding) // 2, 1)[ + output_padding < 0 + ] # //2 because 2*padding pixels are removed + + convt_out_shape = calculate_convtranspose_out_shape( + layer_out_shape, padding=padding, **kwargs + ) + output_padding = np.atleast_1d(in_shape) - np.atleast_1d(convt_out_shape) + padding = tuple(int(s) for s in padding) - return tuple(int(s) for s in output_padding) + return padding, tuple(int(s) for s in output_padding) - def _check_upsampling_mode( - self, upsampling_mode: Union[str, UpsamplingMode] - ) -> UpsamplingMode: + def _check_unpooling_mode( + self, unpooling_mode: Union[str, UnpoolingMode] + ) -> UnpoolingMode: """ - Checks consistency between data shape and upsampling mode. + Checks consistency between data shape and unpooling mode. """ - upsampling_mode = UpsamplingMode(upsampling_mode) - if upsampling_mode == "linear" and len(self.in_shape) != 2: + unpooling_mode = UnpoolingMode(unpooling_mode) + if unpooling_mode == UnpoolingMode.LINEAR and len(self.in_shape) != 2: raise ValueError( - f"upsampling mode `linear` only works with 2D data (counting the channel dimension). " + f"unpooling mode `linear` only works with 2D data (counting the channel dimension). " f"Got in_shape={self.in_shape}, which is understood as {len(self.in_shape)}D data." ) - elif upsampling_mode == "bilinear" and len(self.in_shape) != 3: + elif unpooling_mode == UnpoolingMode.BILINEAR and len(self.in_shape) != 3: raise ValueError( - f"upsampling mode `bilinear` only works with 3D data (counting the channel dimension). " + f"unpooling mode `bilinear` only works with 3D data (counting the channel dimension). " f"Got in_shape={self.in_shape}, which is understood as {len(self.in_shape)}D data." ) - elif upsampling_mode == "bicubic" and len(self.in_shape) != 3: + elif unpooling_mode == UnpoolingMode.BICUBIC and len(self.in_shape) != 3: raise ValueError( - f"upsampling mode `bicubic` only works with 3D data (counting the channel dimension). " + f"unpooling mode `bicubic` only works with 3D data (counting the channel dimension). " f"Got in_shape={self.in_shape}, which is understood as {len(self.in_shape)}D data." ) - elif upsampling_mode == "trilinear" and len(self.in_shape) != 4: + elif unpooling_mode == UnpoolingMode.TRILINEAR and len(self.in_shape) != 4: raise ValueError( - f"upsampling mode `trilinear` only works with 4D data (counting the channel dimension). " + f"unpooling mode `trilinear` only works with 4D data (counting the channel dimension). " f"Got in_shape={self.in_shape}, which is understood as {len(self.in_shape)}D data." ) - return upsampling_mode + return unpooling_mode diff --git a/clinicadl/monai_networks/nn/conv_encoder.py b/clinicadl/monai_networks/nn/conv_encoder.py index d138d4011..a6f1537b6 100644 --- a/clinicadl/monai_networks/nn/conv_encoder.py +++ b/clinicadl/monai_networks/nn/conv_encoder.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Callable, Optional, Tuple +from typing import Callable, List, Optional, Tuple import numpy as np import torch.nn as nn @@ -305,10 +305,20 @@ def _get_pool_layer(self, pooling: SingleLayerPoolingParameters) -> nn.Module: Gets the parametrized pooling layer and updates the current output size. """ pool_layer = get_pool_layer(pooling, spatial_dims=self.spatial_dims) + old_size = self.final_size self.final_size = lambda size: calculate_pool_out_shape( pool_mode=pooling[0], in_shape=size, **pool_layer.__dict__ ) + if ( + self.final_size is not None + and (np.array(old_size) < np.array(self.final_size)).any() + ): + raise ValueError( + f"You passed {pooling} as a pooling layer. But before this layer, the size of the image " + f"was {old_size}. So, pooling can't be performed." + ) + return pool_layer def _check_size(self) -> None: @@ -354,7 +364,9 @@ def _check_single_pool_layer( f"Got {args}" ) - def _check_pool_layers(self, pooling: PoolingParameters) -> PoolingParameters: + def _check_pool_layers( + self, pooling: PoolingParameters + ) -> List[SingleLayerPoolingParameters]: """ Check argument pooling. """ @@ -371,7 +383,7 @@ def _check_pool_layers(self, pooling: PoolingParameters) -> PoolingParameters: ) elif isinstance(pooling, tuple): self._check_single_pool_layer(pooling) - pooling = (pooling,) * len(self.pooling_indices) + pooling = [pooling] * len(self.pooling_indices) else: raise ValueError( f"pooling can be either None, a double (string, dictionary) or a list of such doubles. Got {pooling}" diff --git a/clinicadl/monai_networks/nn/layers/utils/__init__.py b/clinicadl/monai_networks/nn/layers/utils/__init__.py index af0446e1a..5c080fffd 100644 --- a/clinicadl/monai_networks/nn/layers/utils/__init__.py +++ b/clinicadl/monai_networks/nn/layers/utils/__init__.py @@ -4,7 +4,7 @@ NormLayer, PoolingLayer, UnpoolingLayer, - UpsamplingMode, + UnpoolingMode, ) from .types import ( ActivationParameters, diff --git a/clinicadl/monai_networks/nn/layers/utils/enum.py b/clinicadl/monai_networks/nn/layers/utils/enum.py index b0af2815b..695776551 100644 --- a/clinicadl/monai_networks/nn/layers/utils/enum.py +++ b/clinicadl/monai_networks/nn/layers/utils/enum.py @@ -54,11 +54,12 @@ class ConvNormLayer(CaseInsensitiveEnum): INSTANCE = "instance" -class UpsamplingMode(CaseInsensitiveEnum): - """Supported interpolation mode for Upsampling in ClinicaDL.""" +class UnpoolingMode(CaseInsensitiveEnum): + """Supported unpooling mode for AutoEncoders in ClinicaDL.""" NEAREST = "nearest" LINEAR = "linear" BILINEAR = "bilinear" BICUBIC = "bicubic" TRILINEAR = "trilinear" + CONV_TRANS = "convtranspose" diff --git a/clinicadl/monai_networks/nn/utils/checks.py b/clinicadl/monai_networks/nn/utils/checks.py index 1cb63fec6..e428080b2 100644 --- a/clinicadl/monai_networks/nn/utils/checks.py +++ b/clinicadl/monai_networks/nn/utils/checks.py @@ -26,7 +26,7 @@ def ensure_list_of_tuples( """ parameter = _check_conv_parameter(parameter, dim, n_layers, name) if isinstance(parameter, tuple): - return [parameter] * n_layers if n_layers > 0 else [parameter] + return [parameter] * n_layers else: return parameter diff --git a/clinicadl/monai_networks/nn/utils/shapes.py b/clinicadl/monai_networks/nn/utils/shapes.py index 66465983c..a649af076 100644 --- a/clinicadl/monai_networks/nn/utils/shapes.py +++ b/clinicadl/monai_networks/nn/utils/shapes.py @@ -16,9 +16,9 @@ def calculate_conv_out_shape( in_shape: Union[Sequence[int], int], kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - padding: Union[Sequence[int], int], - dilation: Union[Sequence[int], int], + stride: Union[Sequence[int], int] = 1, + padding: Union[Sequence[int], int] = 0, + dilation: Union[Sequence[int], int] = 1, **kwargs, # for uniformization ) -> Tuple[int, ...]: """ @@ -42,10 +42,10 @@ def calculate_conv_out_shape( def calculate_convtranspose_out_shape( in_shape: Union[Sequence[int], int], kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - padding: Union[Sequence[int], int], - output_padding: Union[Sequence[int], int], - dilation: Union[Sequence[int], int], + stride: Union[Sequence[int], int] = 1, + padding: Union[Sequence[int], int] = 0, + output_padding: Union[Sequence[int], int] = 0, + dilation: Union[Sequence[int], int] = 1, **kwargs, # for uniformization ) -> Tuple[int, ...]: """ @@ -110,15 +110,18 @@ def calculate_unpool_out_shape( def _calculate_maxpool_out_shape( in_shape: Union[Sequence[int], int], kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - padding: Union[Sequence[int], int], - dilation: Union[Sequence[int], int], + stride: Optional[Union[Sequence[int], int]] = None, + padding: Union[Sequence[int], int] = 0, + dilation: Union[Sequence[int], int] = 1, ceil_mode: bool = False, **kwargs, # for uniformization ) -> Tuple[int, ...]: """ Calculates the output shape of a MaxPool layer. """ + if stride is None: + stride = kernel_size + in_shape_np = np.atleast_1d(in_shape) kernel_size_np = np.atleast_1d(kernel_size) stride_np = np.atleast_1d(stride) @@ -140,14 +143,17 @@ def _calculate_maxpool_out_shape( def _calculate_avgpool_out_shape( in_shape: Union[Sequence[int], int], kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - padding: Union[Sequence[int], int], + stride: Optional[Union[Sequence[int], int]] = None, + padding: Union[Sequence[int], int] = 0, ceil_mode: bool = False, **kwargs, # for uniformization ) -> Tuple[int, ...]: """ Calculates the output shape of an AvgPool layer. """ + if stride is None: + stride = kernel_size + in_shape_np = np.atleast_1d(in_shape) kernel_size_np = np.atleast_1d(kernel_size) stride_np = np.atleast_1d(stride) diff --git a/clinicadl/monai_networks/nn/vae.py b/clinicadl/monai_networks/nn/vae.py index 7f3ddfa83..5a54554fc 100644 --- a/clinicadl/monai_networks/nn/vae.py +++ b/clinicadl/monai_networks/nn/vae.py @@ -5,7 +5,7 @@ import torch.nn as nn from .autoencoder import AutoEncoder -from .layers.utils import ActivationParameters, UpsamplingMode +from .layers.utils import ActivationParameters, UnpoolingMode class VAE(nn.Module): @@ -17,7 +17,8 @@ class VAE(nn.Module): symmetrical network. More precisely, to build the decoder, the order of the encoding layers is reverted, convolutions are - replaced by transposed convolutions and pooling layers are replaced by upsampling layers. + replaced by transposed convolutions and pooling layers are replaced by either upsampling or transposed + convolution layers. Please note that the order of `Activation`, `Dropout` and `Normalization`, defined with the argument `adn_ordering` in `conv_args`, is the same for the encoder and the decoder. @@ -52,9 +53,18 @@ class VAE(nn.Module): `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional arguments for each of them. - upsampling_mode : Union[str, UpsamplingMode] (optional, default=UpsamplingMode.NEAREST) - interpolation mode for upsampling (see: https://pytorch.org/docs/stable/generated/ - torch.nn.Upsample.html). + unpooling_mode : Union[str, UnpoolingMode] (optional, default=UnpoolingMode.NEAREST) + type of unpooling. Can be either `"nearest"`, `"linear"`, `"bilinear"`, `"bicubic"`, `"trilinear"` or + `"convtranspose"`.\n + - `nearest`: unpooling is performed by upsampling with the :italic:`nearest` algorithm (see [PyTorch's Upsample layer] + (https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html)). + - `linear`: unpooling is performed by upsampling with the :italic:`linear` algorithm. Only works with 1D images (excluding the + channel dimension). + - `bilinear`: unpooling is performed by upsampling with the :italic:`bilinear` algorithm. Only works with 2D images. + - `bicubic`: unpooling is performed by upsampling with the :italic:`bicubic` algorithm. Only works with 2D images. + - `trilinear`: unpooling is performed by upsampling with the :italic:`trilinear` algorithm. Only works with 3D images. + - `convtranspose`: unpooling is performed with a transposed convolution, whose parameters (kernel size, stride, etc.) are + computed to reverse the pooling operation. Examples -------- @@ -65,7 +75,7 @@ class VAE(nn.Module): mlp_args={"hidden_channels": [16], "output_act": "relu"}, out_channels=2, output_act="sigmoid", - upsampling_mode="bilinear", + unpooling_mode="bilinear", ) VAE( (encoder): CNN( @@ -128,7 +138,7 @@ def __init__( mlp_args: Optional[Dict[str, Any]] = None, out_channels: Optional[int] = None, output_act: Optional[ActivationParameters] = None, - upsampling_mode: Union[str, UpsamplingMode] = UpsamplingMode.NEAREST, + unpooling_mode: Union[str, UnpoolingMode] = UnpoolingMode.NEAREST, ) -> None: super().__init__() ae = AutoEncoder( @@ -138,7 +148,7 @@ def __init__( mlp_args, out_channels, output_act, - upsampling_mode, + unpooling_mode, ) # replace last mlp layer by two parallel layers diff --git a/tests/unittests/monai_networks/nn/test_autoencoder.py b/tests/unittests/monai_networks/nn/test_autoencoder.py index 6adf217f9..c59874353 100644 --- a/tests/unittests/monai_networks/nn/test_autoencoder.py +++ b/tests/unittests/monai_networks/nn/test_autoencoder.py @@ -7,9 +7,9 @@ @pytest.mark.parametrize( - "input_tensor,kernel_size,stride,padding,dilation,pooling,pooling_indices", + "input_tensor,kernel_size,stride,padding,dilation,pooling,pooling_indices,unpooling_mode", [ - (torch.randn(2, 1, 21), 3, 1, 0, 1, ("max", {"kernel_size": 2}), [0]), + (torch.randn(2, 1, 21), 3, 1, 0, 1, ("max", {"kernel_size": 2}), [0], "linear"), ( torch.randn(2, 1, 65, 85), (3, 5), @@ -18,6 +18,7 @@ (1, 2), ("max", {"kernel_size": 2, "stride": 1}), [0], + "bilinear", ), ( torch.randn(2, 1, 64, 62, 61), # to test output padding @@ -27,6 +28,7 @@ 1, ("avg", {"kernel_size": 3, "stride": 2}), [-1], + "convtranspose", ), ( torch.randn(2, 1, 51, 55, 45), @@ -36,6 +38,7 @@ 1, ("max", {"kernel_size": 2, "ceil_mode": True}), [0, 1, 2], + "convtranspose", ), ( torch.randn(2, 1, 51, 55, 45), @@ -45,14 +48,22 @@ 1, [ ("max", {"kernel_size": 2, "ceil_mode": True}), - ("max", {"kernel_size": 2, "stride": 1, "ceil_mode": False}), + ("adaptivemax", {"output_size": (5, 4, 3)}), ], - [0, 1], + [-1, 1], + "convtranspose", ), ], ) def test_output_shape( - input_tensor, kernel_size, stride, padding, dilation, pooling, pooling_indices + input_tensor, + kernel_size, + stride, + padding, + dilation, + pooling, + pooling_indices, + unpooling_mode, ): net = AutoEncoder( in_shape=input_tensor.shape[1:], @@ -66,6 +77,7 @@ def test_output_shape( "pooling": pooling, "pooling_indices": pooling_indices, }, + unpooling_mode=unpooling_mode, ) output = net(input_tensor) assert output.shape == input_tensor.shape @@ -85,6 +97,63 @@ def test_out_channels(): assert net.decoder.convolutions.layer2.conv.out_channels == 3 +@pytest.mark.parametrize( + "pooling,unpooling_mode", + [ + (("adaptivemax", {"output_size": (17, 16, 19)}), "nearest"), + (("adaptivemax", {"output_size": (17, 16, 19)}), "convtranspose"), + (("max", {"kernel_size": 2}), "nearest"), + (("max", {"kernel_size": 2}), "convtranspose"), + ( + ("max", {"kernel_size": 2, "stride": 1, "dilation": 2, "padding": 1}), + "nearest", + ), + ( + ("max", {"kernel_size": 2, "stride": 1, "dilation": 2, "padding": 1}), + "convtranspose", + ), + (("avg", {"kernel_size": 3, "ceil_mode": True}), "nearest"), + (("avg", {"kernel_size": 3, "ceil_mode": True}), "convtranspose"), + ], +) +def test_invert_pooling(pooling, unpooling_mode): + input_tensor = torch.randn(2, 1, 20, 27, 22) + net = AutoEncoder( + in_shape=(1, 20, 27, 22), + latent_size=1, + conv_args={"channels": [], "pooling": pooling, "pooling_indices": [-1]}, + mlp_args=None, + unpooling_mode=unpooling_mode, + ) + output = net(input_tensor) + assert output.shape == input_tensor.shape + + +@pytest.mark.parametrize( + "kernel_size,stride,padding,dilation", + [ + ((3, 2, 1), (1, 1, 2), (1, 1, 0), 1), + ((4, 5, 2), (3, 1, 1), (0, 0, 1), (2, 1, 1)), + ], +) +def test_invert_conv(kernel_size, stride, padding, dilation): + input_tensor = torch.randn(2, 1, 20, 27, 22) + net = AutoEncoder( + in_shape=(1, 20, 27, 22), + latent_size=1, + conv_args={ + "channels": [1], + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + }, + mlp_args=None, + ) + output = net(input_tensor) + assert output.shape == input_tensor.shape + + @pytest.mark.parametrize("act", [act for act in ActFunction]) def test_out_activation(act): input_tensor = torch.randn(2, 1, 32, 32) @@ -135,12 +204,12 @@ def test_checks(in_shape, upsampling_mode, error): in_shape=in_shape, latent_size=3, conv_args={"channels": []}, - upsampling_mode=upsampling_mode, + unpooling_mode=upsampling_mode, ) else: AutoEncoder( in_shape=in_shape, latent_size=3, conv_args={"channels": []}, - upsampling_mode=upsampling_mode, + unpooling_mode=upsampling_mode, )