diff --git a/clinicadl/monai_networks/__init__.py b/clinicadl/monai_networks/__init__.py index 1d74473d4..ea44f7516 100644 --- a/clinicadl/monai_networks/__init__.py +++ b/clinicadl/monai_networks/__init__.py @@ -1,2 +1,2 @@ -from .config import ImplementedNetworks, NetworkConfig, create_network_config -from .factory import get_network +from .config import ImplementedNetworks, NetworkConfig +from .factory import get_network, get_network_from_config diff --git a/clinicadl/monai_networks/config/__init__.py b/clinicadl/monai_networks/config/__init__.py index 10b8795dc..1c39fa4fa 100644 --- a/clinicadl/monai_networks/config/__init__.py +++ b/clinicadl/monai_networks/config/__init__.py @@ -1,3 +1,2 @@ -from .base import NetworkConfig +from .base import ImplementedNetworks, NetworkConfig, NetworkType from .factory import create_network_config -from .utils.enum import ImplementedNetworks diff --git a/clinicadl/monai_networks/config/autoencoder.py b/clinicadl/monai_networks/config/autoencoder.py index a6df1a20c..b19108573 100644 --- a/clinicadl/monai_networks/config/autoencoder.py +++ b/clinicadl/monai_networks/config/autoencoder.py @@ -1,89 +1,45 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Sequence, Union -from pydantic import ( - NonNegativeInt, - PositiveInt, - computed_field, - model_validator, -) +from pydantic import PositiveInt, computed_field +from clinicadl.monai_networks.nn.layers.utils import ( + ActivationParameters, + UnpoolingMode, +) from clinicadl.utils.factories import DefaultFromLibrary -from .base import VaryingDepthNetworkConfig -from .utils.enum import ImplementedNetworks - -__all__ = ["AutoEncoderConfig", "VarAutoEncoderConfig"] - +from .base import ImplementedNetworks, NetworkConfig +from .conv_encoder import ConvEncoderOptions +from .mlp import MLPOptions -class AutoEncoderConfig(VaryingDepthNetworkConfig): - """Config class for autoencoders.""" - spatial_dims: PositiveInt - in_channels: PositiveInt - out_channels: PositiveInt +class AutoEncoderConfig(NetworkConfig): + """Config class for AutoEncoder.""" - inter_channels: Union[ - Optional[Tuple[PositiveInt, ...]], DefaultFromLibrary - ] = DefaultFromLibrary.YES - inter_dilations: Union[ - Optional[Tuple[PositiveInt, ...]], DefaultFromLibrary + in_shape: Sequence[PositiveInt] + latent_size: PositiveInt + conv_args: ConvEncoderOptions + mlp_args: Union[Optional[MLPOptions], DefaultFromLibrary] = DefaultFromLibrary.YES + out_channels: Union[ + Optional[PositiveInt], DefaultFromLibrary ] = DefaultFromLibrary.YES - num_inter_units: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES - padding: Union[ - Optional[Union[PositiveInt, Tuple[PositiveInt, ...]]], DefaultFromLibrary + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary ] = DefaultFromLibrary.YES + unpooling_mode: Union[UnpoolingMode, DefaultFromLibrary] = DefaultFromLibrary.YES @computed_field @property - def network(self) -> ImplementedNetworks: + def name(self) -> ImplementedNetworks: """The name of the network.""" return ImplementedNetworks.AE - @computed_field - @property - def dim(self) -> int: - """Dimension of the images.""" - return self.spatial_dims - - @model_validator(mode="after") - def model_validator(self): - """Checks coherence between parameters.""" - if self.padding != DefaultFromLibrary.YES: - assert self._check_dimensions( - self.padding - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for padding. You passed {self.padding}." - if isinstance(self.inter_channels, tuple) and isinstance( - self.inter_dilations, tuple - ): - assert len(self.inter_channels) == len( - self.inter_dilations - ), "inter_channels and inter_dilations muust have the same size." - elif isinstance(self.inter_dilations, tuple) and not isinstance( - self.inter_channels, tuple - ): - raise ValueError( - "You passed inter_dilations but didn't pass inter_channels." - ) - return self - -class VarAutoEncoderConfig(AutoEncoderConfig): - """Config class for variational autoencoders.""" - - in_shape: Tuple[PositiveInt, ...] - in_channels: Optional[int] = None - latent_size: PositiveInt - use_sigmoid: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES +class VAEConfig(AutoEncoderConfig): + """Config class for Variational AutoEncoder.""" @computed_field @property - def network(self) -> ImplementedNetworks: + def name(self) -> ImplementedNetworks: """The name of the network.""" return ImplementedNetworks.VAE - - @model_validator(mode="after") - def model_validator_bis(self): - """Checks coherence between parameters.""" - assert ( - len(self.in_shape[1:]) == self.spatial_dims - ), f"You passed {self.spatial_dims} for spatial_dims, but in_shape suggests {len(self.in_shape[1:])} spatial dimensions." diff --git a/clinicadl/monai_networks/config/base.py b/clinicadl/monai_networks/config/base.py index 6e0ff1b6b..d5c0a6f9b 100644 --- a/clinicadl/monai_networks/config/base.py +++ b/clinicadl/monai_networks/config/base.py @@ -1,168 +1,98 @@ -from __future__ import annotations - from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Dict, Optional, Tuple, Union - -from pydantic import ( - BaseModel, - ConfigDict, - NonNegativeFloat, - NonNegativeInt, - PositiveInt, - computed_field, - field_validator, - model_validator, -) +from typing import Optional, Union + +from pydantic import BaseModel, ConfigDict, PositiveInt, computed_field +from clinicadl.monai_networks.nn.layers.utils import ActivationParameters from clinicadl.utils.factories import DefaultFromLibrary -from .utils.enum import ( - ImplementedActFunctions, - ImplementedNetworks, - ImplementedNormLayers, -) + +class ImplementedNetworks(str, Enum): + """Implemented neural networks in ClinicaDL.""" + + MLP = "MLP" + CONV_ENCODER = "ConvEncoder" + CONV_DECODER = "ConvDecoder" + CNN = "CNN" + GENERATOR = "Generator" + AE = "AutoEncoder" + VAE = "VAE" + DENSENET = "DenseNet" + DENSENET_121 = "DenseNet-121" + DENSENET_161 = "DenseNet-161" + DENSENET_169 = "DenseNet-169" + DENSENET_201 = "DenseNet-201" + RESNET = "ResNet" + RESNET_18 = "ResNet-18" + RESNET_34 = "ResNet-34" + RESNET_50 = "ResNet-50" + RESNET_101 = "ResNet-101" + RESNET_152 = "ResNet-152" + SE_RESNET = "SEResNet" + SE_RESNET_50 = "SEResNet-50" + SE_RESNET_101 = "SEResNet-101" + SE_RESNET_152 = "SEResNet-152" + UNET = "UNet" + ATT_UNET = "AttentionUNet" + VIT = "ViT" + VIT_B_16 = "ViT-B/16" + VIT_B_32 = "ViT-B/32" + VIT_L_16 = "ViT-L/16" + VIT_L_32 = "ViT-L/32" + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not implemented. Implemented neural networks are: " + + ", ".join([repr(m.value) for m in cls]) + ) + + +class NetworkType(str, Enum): + """ + Useful to know where to look for the network. + See :py:func:`clinicadl.monai_networks.factory.get_network` + """ + + CUSTOM = "custom" # our own networks + RESNET = "sota-ResNet" + DENSENET = "sota-DenseNet" + SE_RESNET = "sota-SEResNet" + VIT = "sota-ViT" class NetworkConfig(BaseModel, ABC): """Base config class to configure neural networks.""" - kernel_size: Union[ - PositiveInt, Tuple[PositiveInt, ...], DefaultFromLibrary - ] = DefaultFromLibrary.YES - up_kernel_size: Union[ - PositiveInt, Tuple[PositiveInt, ...], DefaultFromLibrary - ] = DefaultFromLibrary.YES - num_res_units: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES - act: Union[ - ImplementedActFunctions, - Tuple[ImplementedActFunctions, Dict[str, Any]], - DefaultFromLibrary, - ] = DefaultFromLibrary.YES - norm: Union[ - ImplementedNormLayers, - Tuple[ImplementedNormLayers, Dict[str, Any]], - DefaultFromLibrary, - ] = DefaultFromLibrary.YES - bias: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - adn_ordering: Union[Optional[str], DefaultFromLibrary] = DefaultFromLibrary.YES # pydantic config model_config = ConfigDict( validate_assignment=True, use_enum_values=True, validate_default=True, - protected_namespaces=(), ) @computed_field @property @abstractmethod - def network(self) -> ImplementedNetworks: + def name(self) -> ImplementedNetworks: """The name of the network.""" @computed_field @property - @abstractmethod - def dim(self) -> int: - """Dimension of the images.""" + def _type(self) -> NetworkType: + """ + To know where to look for the network. + Default to 'custom'. + """ + return NetworkType.CUSTOM - @classmethod - def base_validator_dropout(cls, v): - """Checks that dropout is between 0 and 1.""" - if isinstance(v, float): - assert ( - 0 <= v <= 1 - ), f"dropout must be between 0 and 1 but it has been set to {v}." - return v - - @field_validator("kernel_size", "up_kernel_size") - @classmethod - def base_is_odd(cls, value, field): - """Checks if a field is odd.""" - if value != DefaultFromLibrary.YES: - if isinstance(value, int): - value_ = (value,) - else: - value_ = value - for v in value_: - assert v % 2 == 1, f"{field.field_name} must be odd." - return value - - @field_validator("adn_ordering", mode="after") - @classmethod - def base_adn_validator(cls, v): - """Checks ADN sequence.""" - if v != DefaultFromLibrary.YES: - for letter in v: - assert ( - letter in {"A", "D", "N"} - ), f"adn_ordering must be composed by 'A', 'D' or/and 'N'. You passed {letter}." - assert len(v) == len( - set(v) - ), "adn_ordering cannot contain duplicated letter." - - return v - @classmethod - def base_at_least_2d(cls, v, ctx): - """Checks that a tuple has at least a length of two.""" - if isinstance(v, tuple): - assert ( - len(v) >= 2 - ), f"{ctx.field_name} should have at least two dimensions (with the first one for the channel)." - return v - - @model_validator(mode="after") - def base_model_validator(self): - """Checks coherence between parameters.""" - if self.kernel_size != DefaultFromLibrary.YES: - assert self._check_dimensions( - self.kernel_size - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for kernel_size. You passed {self.kernel_size}." - if self.up_kernel_size != DefaultFromLibrary.YES: - assert self._check_dimensions( - self.up_kernel_size - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for up_kernel_size. You passed {self.up_kernel_size}." - return self - - def _check_dimensions( - self, - value: Union[float, Tuple[float, ...]], - ) -> bool: - """Checks if a tuple has the right dimension.""" - if isinstance(value, tuple): - return len(value) == self.dim - return True - - -class VaryingDepthNetworkConfig(NetworkConfig, ABC): - """ - Base config class to configure neural networks. - More precisely, we refer to MONAI's networks with 'channels' and 'strides' parameters. - """ +class PreTrainedConfig(NetworkConfig): + """Base config class for SOTA networks.""" - channels: Tuple[PositiveInt, ...] - strides: Tuple[Union[PositiveInt, Tuple[PositiveInt, ...]], ...] - dropout: Union[ - Optional[NonNegativeFloat], DefaultFromLibrary + num_outputs: Optional[PositiveInt] + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary ] = DefaultFromLibrary.YES - - @field_validator("dropout") - @classmethod - def validator_dropout(cls, v): - """Checks that dropout is between 0 and 1.""" - return cls.base_validator_dropout(v) - - @model_validator(mode="after") - def channels_strides_validator(self): - """Checks coherence between parameters.""" - n_layers = len(self.channels) - assert ( - len(self.strides) == n_layers - ), f"There are {n_layers} layers but you passed {len(self.strides)} strides." - for s in self.strides: - assert self._check_dimensions( - s - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for strides. You passed {s}." - - return self + pretrained: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES diff --git a/clinicadl/monai_networks/config/classifier.py b/clinicadl/monai_networks/config/classifier.py deleted file mode 100644 index a01bd0efc..000000000 --- a/clinicadl/monai_networks/config/classifier.py +++ /dev/null @@ -1,63 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, Optional, Tuple, Union - -from pydantic import PositiveInt, computed_field - -from clinicadl.utils.factories import DefaultFromLibrary - -from .regressor import RegressorConfig -from .utils.enum import ImplementedActFunctions, ImplementedNetworks - -__all__ = ["ClassifierConfig", "DiscriminatorConfig", "CriticConfig"] - - -class ClassifierConfig(RegressorConfig): - """Config class for classifiers.""" - - classes: PositiveInt - out_shape: Optional[Tuple[PositiveInt, ...]] = None - last_act: Optional[ - Union[ - ImplementedActFunctions, - Tuple[ImplementedActFunctions, Dict[str, Any]], - DefaultFromLibrary, - ] - ] = DefaultFromLibrary.YES - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.CLASSIFIER - - -class DiscriminatorConfig(ClassifierConfig): - """Config class for discriminators.""" - - classes: Optional[PositiveInt] = None - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.DISCRIMINATOR - - -class CriticConfig(ClassifierConfig): - """Config class for discriminators.""" - - classes: Optional[PositiveInt] = None - last_act: Optional[ - Union[ - ImplementedActFunctions, - Tuple[ImplementedActFunctions, Dict[str, Any]], - DefaultFromLibrary, - ] - ] = None - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.CRITIC diff --git a/clinicadl/monai_networks/config/cnn.py b/clinicadl/monai_networks/config/cnn.py new file mode 100644 index 000000000..a7d2043db --- /dev/null +++ b/clinicadl/monai_networks/config/cnn.py @@ -0,0 +1,24 @@ +from typing import Optional, Sequence, Union + +from pydantic import PositiveInt, computed_field + +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import ImplementedNetworks, NetworkConfig +from .conv_encoder import ConvEncoderOptions +from .mlp import MLPOptions + + +class CNNConfig(NetworkConfig): + """Config class for CNN.""" + + in_shape: Sequence[PositiveInt] + num_outputs: PositiveInt + conv_args: ConvEncoderOptions + mlp_args: Union[Optional[MLPOptions], DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.CNN diff --git a/clinicadl/monai_networks/config/conv_decoder.py b/clinicadl/monai_networks/config/conv_decoder.py new file mode 100644 index 000000000..5dc78dfec --- /dev/null +++ b/clinicadl/monai_networks/config/conv_decoder.py @@ -0,0 +1,65 @@ +from typing import Optional, Sequence, Union + +from pydantic import BaseModel, ConfigDict, PositiveFloat, PositiveInt, computed_field + +from clinicadl.monai_networks.nn.layers.utils import ( + ActivationParameters, + ConvNormalizationParameters, + ConvParameters, + UnpoolingParameters, +) +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import ImplementedNetworks, NetworkConfig + + +class ConvDecoderOptions(BaseModel): + """ + Config class for ConvDecoder when it is a submodule. + See for example: :py:class:`clinicadl.monai_networks.nn.generator.Generator` + """ + + channels: Sequence[PositiveInt] + kernel_size: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + stride: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + padding: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + output_padding: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + dilation: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + unpooling: Union[ + Optional[UnpoolingParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + unpooling_indices: Union[ + Optional[Sequence[int]], DefaultFromLibrary + ] = DefaultFromLibrary.YES + act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + norm: Union[ + Optional[ConvNormalizationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + dropout: Union[Optional[PositiveFloat], DefaultFromLibrary] = DefaultFromLibrary.YES + bias: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES + adn_ordering: Union[str, DefaultFromLibrary] = DefaultFromLibrary.YES + + # pydantic config + model_config = ConfigDict( + validate_assignment=True, + use_enum_values=True, + validate_default=True, + ) + + +class ConvDecoderConfig(NetworkConfig, ConvDecoderOptions): + """Config class for ConvDecoder.""" + + spatial_dims: PositiveInt + in_channels: PositiveInt + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.CONV_DECODER diff --git a/clinicadl/monai_networks/config/conv_encoder.py b/clinicadl/monai_networks/config/conv_encoder.py new file mode 100644 index 000000000..499f69b19 --- /dev/null +++ b/clinicadl/monai_networks/config/conv_encoder.py @@ -0,0 +1,64 @@ +from typing import Optional, Sequence, Union + +from pydantic import BaseModel, ConfigDict, PositiveFloat, PositiveInt, computed_field + +from clinicadl.monai_networks.nn.layers.utils import ( + ActivationParameters, + ConvNormalizationParameters, + ConvParameters, + PoolingParameters, +) +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import ImplementedNetworks, NetworkConfig + + +class ConvEncoderOptions(BaseModel): + """ + Config class for ConvEncoder when it is a submodule. + See for example: :py:class:`clinicadl.monai_networks.nn.cnn.CNN` + """ + + channels: Sequence[PositiveInt] + kernel_size: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + stride: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + padding: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + dilation: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + pooling: Union[ + Optional[PoolingParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + pooling_indices: Union[ + Optional[Sequence[int]], DefaultFromLibrary + ] = DefaultFromLibrary.YES + act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + norm: Union[ + Optional[ConvNormalizationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + dropout: Union[Optional[PositiveFloat], DefaultFromLibrary] = DefaultFromLibrary.YES + bias: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES + adn_ordering: Union[str, DefaultFromLibrary] = DefaultFromLibrary.YES + + # pydantic config + model_config = ConfigDict( + validate_assignment=True, + use_enum_values=True, + validate_default=True, + ) + + +class ConvEncoderConfig(NetworkConfig, ConvEncoderOptions): + """Config class for ConvEncoder.""" + + spatial_dims: PositiveInt + in_channels: PositiveInt + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.CONV_ENCODER diff --git a/clinicadl/monai_networks/config/densenet.py b/clinicadl/monai_networks/config/densenet.py index 796d82203..4984d010b 100644 --- a/clinicadl/monai_networks/config/densenet.py +++ b/clinicadl/monai_networks/config/densenet.py @@ -1,20 +1,11 @@ -from __future__ import annotations +from typing import Optional, Sequence, Union -from typing import Tuple, Union - -from pydantic import ( - NonNegativeFloat, - PositiveInt, - computed_field, - field_validator, -) +from pydantic import PositiveFloat, PositiveInt, computed_field +from clinicadl.monai_networks.nn.layers.utils import ActivationParameters from clinicadl.utils.factories import DefaultFromLibrary -from .base import NetworkConfig -from .utils.enum import ImplementedNetworks - -__all__ = ["DenseNetConfig"] +from .base import ImplementedNetworks, NetworkConfig, NetworkType, PreTrainedConfig class DenseNetConfig(NetworkConfig): @@ -22,29 +13,71 @@ class DenseNetConfig(NetworkConfig): spatial_dims: PositiveInt in_channels: PositiveInt - out_channels: PositiveInt + num_outputs: Optional[PositiveInt] + n_dense_layers: Union[ + Sequence[PositiveInt], DefaultFromLibrary + ] = DefaultFromLibrary.YES init_features: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES growth_rate: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - block_config: Union[ - Tuple[PositiveInt, ...], DefaultFromLibrary + bottleneck_factor: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES + act: Union[ActivationParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary ] = DefaultFromLibrary.YES - bn_size: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - dropout_prob: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES + dropout: Union[Optional[PositiveFloat], DefaultFromLibrary] = DefaultFromLibrary.YES @computed_field @property - def network(self) -> ImplementedNetworks: + def name(self) -> ImplementedNetworks: """The name of the network.""" - return ImplementedNetworks.DENSE_NET + return ImplementedNetworks.DENSENET + + +class PreTrainedDenseNetConfig(PreTrainedConfig): + """Base config class for SOTA DenseNets.""" @computed_field @property - def dim(self) -> int: - """Dimension of the images.""" - return self.spatial_dims - - @field_validator("dropout_prob") - @classmethod - def validator_dropout(cls, v): - """Checks that dropout is between 0 and 1.""" - return cls.base_validator_dropout(v) + def _type(self) -> NetworkType: + """To know where to look for the network.""" + return NetworkType.DENSENET + + +class DenseNet121Config(PreTrainedDenseNetConfig): + """Config class for DenseNet-121.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.DENSENET_121 + + +class DenseNet161Config(PreTrainedDenseNetConfig): + """Config class for DenseNet-161.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.DENSENET_161 + + +class DenseNet169Config(PreTrainedDenseNetConfig): + """Config class for DenseNet-169.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.DENSENET_169 + + +class DenseNet201Config(PreTrainedDenseNetConfig): + """Config class for DenseNet-201.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.DENSENET_201 diff --git a/clinicadl/monai_networks/config/factory.py b/clinicadl/monai_networks/config/factory.py index 55e0fad39..2b7e5bdc1 100644 --- a/clinicadl/monai_networks/config/factory.py +++ b/clinicadl/monai_networks/config/factory.py @@ -1,16 +1,36 @@ from typing import Type, Union -from .autoencoder import * -from .base import NetworkConfig -from .classifier import * -from .densenet import * -from .fcn import * -from .generator import * -from .regressor import * -from .resnet import * -from .unet import * -from .utils.enum import ImplementedNetworks -from .vit import * +# pylint: disable=unused-import +from .autoencoder import AutoEncoderConfig, VAEConfig +from .base import ImplementedNetworks, NetworkConfig +from .cnn import CNNConfig +from .conv_decoder import ConvDecoderConfig +from .conv_encoder import ConvEncoderConfig +from .densenet import ( + DenseNet121Config, + DenseNet161Config, + DenseNet169Config, + DenseNet201Config, + DenseNetConfig, +) +from .generator import GeneratorConfig +from .mlp import MLPConfig +from .resnet import ( + ResNet18Config, + ResNet34Config, + ResNet50Config, + ResNet101Config, + ResNet152Config, + ResNetConfig, +) +from .senet import ( + SEResNet50Config, + SEResNet101Config, + SEResNet152Config, + SEResNetConfig, +) +from .unet import AttentionUNetConfig, UNetConfig +from .vit import ViTB16Config, ViTB32Config, ViTConfig, ViTL16Config, ViTL32Config def create_network_config( @@ -29,7 +49,7 @@ def create_network_config( Type[NetworkConfig] The config class. """ - network = ImplementedNetworks(network) + network = ImplementedNetworks(network).value.replace("-", "").replace("/", "") config_name = "".join([network, "Config"]) config = globals()[config_name] diff --git a/clinicadl/monai_networks/config/fcn.py b/clinicadl/monai_networks/config/fcn.py deleted file mode 100644 index 3bb23d6cb..000000000 --- a/clinicadl/monai_networks/config/fcn.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from typing import Optional, Tuple, Union - -from pydantic import ( - NonNegativeFloat, - PositiveInt, - computed_field, - field_validator, -) - -from clinicadl.utils.factories import DefaultFromLibrary - -from .base import NetworkConfig -from .utils.enum import ImplementedNetworks - -__all__ = ["FullyConnectedNetConfig", "VarFullyConnectedNetConfig"] - - -class FullyConnectedNetConfig(NetworkConfig): - """Config class for fully connected networks.""" - - in_channels: PositiveInt - out_channels: PositiveInt - hidden_channels: Tuple[PositiveInt, ...] - - dropout: Union[ - Optional[NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.FCN - - @computed_field - @property - def dim(self) -> Optional[int]: - """Dimension of the images.""" - return None - - @field_validator("dropout") - @classmethod - def validator_dropout(cls, v): - """Checks that dropout is between 0 and 1.""" - return cls.base_validator_dropout(v) - - -class VarFullyConnectedNetConfig(NetworkConfig): - """Config class for fully connected networks.""" - - in_channels: PositiveInt - out_channels: PositiveInt - latent_size: PositiveInt - encode_channels: Tuple[PositiveInt, ...] - decode_channels: Tuple[PositiveInt, ...] - - dropout: Union[ - Optional[NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.VAR_FCN - - @computed_field - @property - def dim(self) -> Optional[int]: - """Dimension of the images.""" - return None - - @field_validator("dropout") - @classmethod - def validator_dropout(cls, v): - """Checks that dropout is between 0 and 1.""" - return cls.base_validator_dropout(v) diff --git a/clinicadl/monai_networks/config/generator.py b/clinicadl/monai_networks/config/generator.py index b864d371d..6c7836474 100644 --- a/clinicadl/monai_networks/config/generator.py +++ b/clinicadl/monai_networks/config/generator.py @@ -1,38 +1,24 @@ -from __future__ import annotations +from typing import Optional, Sequence, Union -from typing import Tuple +from pydantic import PositiveInt, computed_field -from pydantic import ( - PositiveInt, - computed_field, - field_validator, -) +from clinicadl.utils.factories import DefaultFromLibrary -from .base import VaryingDepthNetworkConfig -from .utils.enum import ImplementedNetworks +from .base import ImplementedNetworks, NetworkConfig +from .conv_decoder import ConvDecoderOptions +from .mlp import MLPOptions -__all__ = ["GeneratorConfig"] +class GeneratorConfig(NetworkConfig): + """Config class for Generator.""" -class GeneratorConfig(VaryingDepthNetworkConfig): - """Config class for generators.""" - - latent_shape: Tuple[PositiveInt, ...] - start_shape: Tuple[PositiveInt, ...] + latent_size: PositiveInt + start_shape: Sequence[PositiveInt] + conv_args: ConvDecoderOptions + mlp_args: Union[Optional[MLPOptions], DefaultFromLibrary] = DefaultFromLibrary.YES @computed_field @property - def network(self) -> ImplementedNetworks: + def name(self) -> ImplementedNetworks: """The name of the network.""" return ImplementedNetworks.GENERATOR - - @computed_field - @property - def dim(self) -> int: - """Dimension of the images.""" - return len(self.start_shape[1:]) - - @field_validator("start_shape") - def at_least_2d(cls, v, field): - """Checks that a tuple has at least a length of two.""" - return cls.base_at_least_2d(v, field) diff --git a/clinicadl/monai_networks/config/mlp.py b/clinicadl/monai_networks/config/mlp.py new file mode 100644 index 000000000..5d12f303f --- /dev/null +++ b/clinicadl/monai_networks/config/mlp.py @@ -0,0 +1,52 @@ +from typing import Optional, Sequence, Union + +from pydantic import BaseModel, ConfigDict, PositiveFloat, PositiveInt, computed_field + +from clinicadl.monai_networks.nn.layers.utils import ( + ActivationParameters, + NormalizationParameters, +) +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import ImplementedNetworks, NetworkConfig + + +class MLPOptions(BaseModel): + """ + Config class for MLP when it is a submodule. + See for example: :py:class:`clinicadl.monai_networks.nn.cnn.CNN` + """ + + hidden_channels: Sequence[PositiveInt] + act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + norm: Union[ + Optional[NormalizationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + dropout: Union[Optional[PositiveFloat], DefaultFromLibrary] = DefaultFromLibrary.YES + bias: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES + adn_ordering: Union[str, DefaultFromLibrary] = DefaultFromLibrary.YES + + # pydantic config + model_config = ConfigDict( + validate_assignment=True, + use_enum_values=True, + validate_default=True, + ) + + +class MLPConfig(NetworkConfig, MLPOptions): + """Config class for Multi Layer Perceptron.""" + + in_channels: PositiveInt + out_channels: PositiveInt + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.MLP diff --git a/clinicadl/monai_networks/config/regressor.py b/clinicadl/monai_networks/config/regressor.py deleted file mode 100644 index 5410e31fa..000000000 --- a/clinicadl/monai_networks/config/regressor.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -from typing import Tuple - -from pydantic import PositiveInt, computed_field, field_validator - -from .base import VaryingDepthNetworkConfig -from .utils.enum import ImplementedNetworks - -__all__ = ["RegressorConfig"] - - -class RegressorConfig(VaryingDepthNetworkConfig): - """Config class for regressors.""" - - in_shape: Tuple[PositiveInt, ...] - out_shape: Tuple[PositiveInt, ...] - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.REGRESSOR - - @computed_field - @property - def dim(self) -> int: - """Dimension of the images.""" - return len(self.in_shape[1:]) - - @field_validator("in_shape") - def at_least_2d(cls, v, ctx): - """Checks that a tuple has at least a length of two.""" - return cls.base_at_least_2d(v, ctx) diff --git a/clinicadl/monai_networks/config/resnet.py b/clinicadl/monai_networks/config/resnet.py index 96bb6e193..0f3141dd8 100644 --- a/clinicadl/monai_networks/config/resnet.py +++ b/clinicadl/monai_networks/config/resnet.py @@ -1,148 +1,103 @@ -from __future__ import annotations +from typing import Optional, Sequence, Union -from enum import Enum -from typing import Optional, Tuple, Union - -from pydantic import ( - NonNegativeFloat, - PositiveFloat, - PositiveInt, - computed_field, - field_validator, - model_validator, -) +from pydantic import PositiveInt, computed_field +from clinicadl.monai_networks.nn.layers.utils import ActivationParameters +from clinicadl.monai_networks.nn.resnet import ResNetBlockType from clinicadl.utils.factories import DefaultFromLibrary -from .base import NetworkConfig -from .utils.enum import ( - ImplementedNetworks, - ResNetBlocks, - ResNets, - ShortcutTypes, - UpsampleModes, -) - -__all__ = ["ResNetConfig", "ResNetFeaturesConfig", "SegResNetConfig"] +from .base import ImplementedNetworks, NetworkConfig, NetworkType, PreTrainedConfig class ResNetConfig(NetworkConfig): """Config class for ResNet.""" - block: ResNetBlocks - layers: Tuple[PositiveInt, PositiveInt, PositiveInt, PositiveInt] - block_inplanes: Tuple[PositiveInt, PositiveInt, PositiveInt, PositiveInt] - - spatial_dims: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - n_input_channels: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - conv1_t_size: Union[ - PositiveInt, Tuple[PositiveInt, ...], DefaultFromLibrary + spatial_dims: PositiveInt + in_channels: PositiveInt + num_outputs: Optional[PositiveInt] + block_type: Union[str, ResNetBlockType, DefaultFromLibrary] = DefaultFromLibrary.YES + n_res_blocks: Union[ + Sequence[PositiveInt], DefaultFromLibrary + ] = DefaultFromLibrary.YES + n_features: Union[ + Sequence[PositiveInt], DefaultFromLibrary + ] = DefaultFromLibrary.YES + init_conv_size: Union[ + Sequence[PositiveInt], PositiveInt, DefaultFromLibrary + ] = DefaultFromLibrary.YES + init_conv_stride: Union[ + Sequence[PositiveInt], PositiveInt, DefaultFromLibrary ] = DefaultFromLibrary.YES - conv1_t_stride: Union[ - PositiveInt, Tuple[PositiveInt, ...], DefaultFromLibrary + bottleneck_reduction: Union[ + PositiveInt, DefaultFromLibrary + ] = DefaultFromLibrary.YES + act: Union[ActivationParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary ] = DefaultFromLibrary.YES - no_max_pool: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - shortcut_type: Union[ShortcutTypes, DefaultFromLibrary] = DefaultFromLibrary.YES - widen_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - num_classes: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - feed_forward: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - bias_downsample: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES @computed_field @property - def network(self) -> ImplementedNetworks: + def name(self) -> ImplementedNetworks: """The name of the network.""" - return ImplementedNetworks.RES_NET + return ImplementedNetworks.RESNET + + +class PreTrainedResNetConfig(PreTrainedConfig): + """Base config class for SOTA ResNets.""" @computed_field @property - def dim(self) -> int: - """Dimension of the images.""" - return self.spatial_dims if self.spatial_dims != DefaultFromLibrary.YES else 3 + def _type(self) -> NetworkType: + """To know where to look for the network.""" + return NetworkType.RESNET - @model_validator(mode="after") - def model_validator(self): - """Checks coherence between parameters.""" - if self.conv1_t_size != DefaultFromLibrary.YES: - assert self._check_dimensions( - self.conv1_t_size - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for conv1_t_size. You passed {self.conv1_t_size}." - if self.conv1_t_stride != DefaultFromLibrary.YES: - assert self._check_dimensions( - self.conv1_t_stride - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for conv1_t_stride. You passed {self.conv1_t_stride}." - return self +class ResNet18Config(PreTrainedResNetConfig): + """Config class for ResNet-18.""" + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.RESNET_18 -class ResNetFeaturesConfig(NetworkConfig): - """Config class for ResNet backbones.""" - - model_name: ResNets - pretrained: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - spatial_dims: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - in_channels: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES +class ResNet34Config(PreTrainedResNetConfig): + """Config class for ResNet-34.""" @computed_field @property - def network(self) -> ImplementedNetworks: + def name(self) -> ImplementedNetworks: """The name of the network.""" - return ImplementedNetworks.RES_NET_FEATURES + return ImplementedNetworks.RESNET_34 + + +class ResNet50Config(PreTrainedResNetConfig): + """Config class for ResNet-50.""" @computed_field @property - def dim(self) -> int: - """Dimension of the images.""" - return self.spatial_dims if self.spatial_dims != DefaultFromLibrary.YES else 3 - - @model_validator(mode="after") - def model_validator(self): - """Checks coherence between parameters.""" - if self.pretrained == DefaultFromLibrary.YES or self.pretrained: - assert ( - self.spatial_dims == DefaultFromLibrary.YES or self.spatial_dims == 3 - ), "Pretrained weights are only available with spatial_dims=3. Otherwise, set pretrained to False." - assert ( - self.in_channels == DefaultFromLibrary.YES or self.in_channels == 1 - ), "Pretrained weights are only available with in_channels=1. Otherwise, set pretrained to False." - - return self - - -class SegResNetConfig(NetworkConfig): - """Config class for SegResNet.""" - - spatial_dims: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - init_filters: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - in_channels: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - out_channels: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - dropout_prob: Union[ - Optional[NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - use_conv_final: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - blocks_down: Union[ - Tuple[PositiveInt, ...], DefaultFromLibrary - ] = DefaultFromLibrary.YES - blocks_up: Union[ - Tuple[PositiveInt, ...], DefaultFromLibrary - ] = DefaultFromLibrary.YES - upsample_mode: Union[UpsampleModes, DefaultFromLibrary] = DefaultFromLibrary.YES + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.RESNET_50 + + +class ResNet101Config(PreTrainedResNetConfig): + """Config class for ResNet-101.""" @computed_field @property - def network(self) -> ImplementedNetworks: + def name(self) -> ImplementedNetworks: """The name of the network.""" - return ImplementedNetworks.SEG_RES_NET + return ImplementedNetworks.RESNET_101 + + +class ResNet152Config(PreTrainedResNetConfig): + """Config class for ResNet-152.""" @computed_field @property - def dim(self) -> int: - """Dimension of the images.""" - return self.spatial_dims if self.spatial_dims != DefaultFromLibrary.YES else 3 - - @field_validator("dropout_prob") - @classmethod - def validator_dropout(cls, v): - """Checks that dropout is between 0 and 1.""" - return cls.base_validator_dropout(v) + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.RESNET_152 diff --git a/clinicadl/monai_networks/config/senet.py b/clinicadl/monai_networks/config/senet.py new file mode 100644 index 000000000..79a356726 --- /dev/null +++ b/clinicadl/monai_networks/config/senet.py @@ -0,0 +1,60 @@ +from typing import Union + +from pydantic import PositiveInt, computed_field + +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import ImplementedNetworks, NetworkType, PreTrainedConfig +from .resnet import ResNetConfig + + +class SEResNetConfig(ResNetConfig): + """Config class for Squeeze-and-Excitation ResNet.""" + + se_reduction: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.SE_RESNET + + +class PreTrainedSEResNetConfig(PreTrainedConfig): + """Base config class for SOTA SE-ResNets.""" + + @computed_field + @property + def _type(self) -> NetworkType: + """To know where to look for the network.""" + return NetworkType.SE_RESNET + + +class SEResNet50Config(PreTrainedSEResNetConfig): + """Config class for SE-ResNet-50.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.SE_RESNET_50 + + +class SEResNet101Config(PreTrainedSEResNetConfig): + """Config class for SE-ResNet-101.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.SE_RESNET_101 + + +class SEResNet152Config(PreTrainedSEResNetConfig): + """Config class for SE-ResNet-152.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.SE_RESNET_152 diff --git a/clinicadl/monai_networks/config/unet.py b/clinicadl/monai_networks/config/unet.py index e7fd3498b..abd87817e 100644 --- a/clinicadl/monai_networks/config/unet.py +++ b/clinicadl/monai_networks/config/unet.py @@ -1,64 +1,38 @@ -from __future__ import annotations +from typing import Optional, Sequence, Union -from typing import Union - -from pydantic import ( - PositiveInt, - computed_field, - model_validator, -) +from pydantic import PositiveFloat, PositiveInt, computed_field +from clinicadl.monai_networks.nn.layers.utils import ActivationParameters from clinicadl.utils.factories import DefaultFromLibrary -from .base import VaryingDepthNetworkConfig -from .utils.enum import ImplementedNetworks - -__all__ = ["UNetConfig", "AttentionUnetConfig"] +from .base import ImplementedNetworks, NetworkConfig -class UNetConfig(VaryingDepthNetworkConfig): +class UNetConfig(NetworkConfig): """Config class for UNet.""" spatial_dims: PositiveInt in_channels: PositiveInt out_channels: PositiveInt - adn_ordering: Union[str, DefaultFromLibrary] = DefaultFromLibrary.YES + channels: Union[Sequence[PositiveInt], DefaultFromLibrary] = DefaultFromLibrary.YES + act: Union[ActivationParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + dropout: Union[Optional[PositiveFloat], DefaultFromLibrary] = DefaultFromLibrary.YES @computed_field @property - def network(self) -> ImplementedNetworks: + def name(self) -> ImplementedNetworks: """The name of the network.""" return ImplementedNetworks.UNET - @computed_field - @property - def dim(self) -> int: - """Dimension of the images.""" - return self.spatial_dims - - @model_validator(mode="after") - def channels_strides_validator(self): - """Checks coherence between parameters.""" - n_layers = len(self.channels) - assert ( - n_layers >= 2 - ), f"Channels must be at least of length 2. You passed {self.channels}." - assert ( - len(self.strides) == n_layers - 1 - ), f"Length of strides must be equal to len(channels)-1. You passed channels={self.channels} and strides={self.strides}." - for s in self.strides: - assert self._check_dimensions( - s - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for strides. You passed {s}." - - return self - -class AttentionUnetConfig(UNetConfig): - """Config class for Attention UNet.""" +class AttentionUNetConfig(UNetConfig): + """Config class for AttentionUNet.""" @computed_field @property - def network(self) -> ImplementedNetworks: + def name(self) -> ImplementedNetworks: """The name of the network.""" return ImplementedNetworks.ATT_UNET diff --git a/clinicadl/monai_networks/config/utils/enum.py b/clinicadl/monai_networks/config/utils/enum.py deleted file mode 100644 index 941e34972..000000000 --- a/clinicadl/monai_networks/config/utils/enum.py +++ /dev/null @@ -1,129 +0,0 @@ -from enum import Enum - - -class ImplementedNetworks(str, Enum): - """Implemented neural networks in ClinicaDL.""" - - REGRESSOR = "Regressor" - CLASSIFIER = "Classifier" - DISCRIMINATOR = "Discriminator" - CRITIC = "Critic" - AE = "AutoEncoder" - VAE = "VarAutoEncoder" - DENSE_NET = "DenseNet" - FCN = "FullyConnectedNet" - VAR_FCN = "VarFullyConnectedNet" - GENERATOR = "Generator" - RES_NET = "ResNet" - RES_NET_FEATURES = "ResNetFeatures" - SEG_RES_NET = "SegResNet" - UNET = "UNet" - ATT_UNET = "AttentionUnet" - VIT = "ViT" - VIT_AE = "ViTAutoEnc" - - @classmethod - def _missing_(cls, value): - raise ValueError( - f"{value} is not implemented. Implemented neural networks are: " - + ", ".join([repr(m.value) for m in cls]) - ) - - -class CaseInsensitiveEnum(str, Enum): - @classmethod - def _missing_(cls, value): - if isinstance(value, str): - value = value.lower() - for member in cls: - if member.lower() == value: - return member - return None - - -class ImplementedActFunctions(CaseInsensitiveEnum): - """Supported activation functions in ClinicaDL.""" - - ELU = "elu" - RELU = "relu" - LEAKY_RELU = "leakyrelu" - PRELU = "prelu" - RELU6 = "relu6" - SELU = "selu" - CELU = "celu" - GELU = "gelu" - SIGMOID = "sigmoid" - TANH = "tanh" - SOFTMAX = "softmax" - LOGSOFTMAX = "logsoftmax" - SWISH = "swish" - MEMSWISH = "memswish" - MISH = "mish" - GEGLU = "geglu" - - -class ImplementedNormLayers(CaseInsensitiveEnum): - """Supported normalization layers in ClinicaDL.""" - - GROUP = "group" - LAYER = "layer" - LOCAL_RESPONSE = "localresponse" - SYNCBATCH = "syncbatch" - INSTANCE_NVFUSER = "instance_nvfuser" - BATCH = "batch" - INSTANCE = "instance" - - -class ResNetBlocks(str, Enum): - """Supported ResNet blocks.""" - - BASIC = "basic" - BOTTLENECK = "bottleneck" - - -class ShortcutTypes(str, Enum): - """Supported shortcut types for ResNets.""" - - A = "A" - B = "B" - - -class ResNets(str, Enum): - """Supported ResNet networks.""" - - RESNET_10 = "resnet10" - RESNET_18 = "resnet18" - RESNET_34 = "resnet34" - RESNET_50 = "resnet50" - RESNET_101 = "resnet101" - RESNET_152 = "resnet152" - RESNET_200 = "resnet200" - - -class UpsampleModes(str, Enum): - """Supported upsampling modes for ResNets.""" - - DECONV = "deconv" - NON_TRAINABLE = "nontrainable" - PIXEL_SHUFFLE = "pixelshuffle" - - -class PatchEmbeddingTypes(str, Enum): - """Supported patch embedding types for VITs.""" - - CONV = "conv" - PERCEPTRON = "perceptron" - - -class PosEmbeddingTypes(str, Enum): - """Supported positional embedding types for VITs.""" - - NONE = "none" - LEARNABLE = "learnable" - SINCOS = "sincos" - - -class ClassificationActivation(str, Enum): - """Supported activation layer for classification in ViT.""" - - TANH = "Tanh" diff --git a/clinicadl/monai_networks/config/vit.py b/clinicadl/monai_networks/config/vit.py index 028537612..5059df790 100644 --- a/clinicadl/monai_networks/config/vit.py +++ b/clinicadl/monai_networks/config/vit.py @@ -1,153 +1,84 @@ -from enum import Enum -from typing import Optional, Tuple, Union +from typing import Optional, Sequence, Union -from pydantic import ( - NonNegativeFloat, - PositiveInt, - computed_field, - field_validator, - model_validator, -) +from pydantic import PositiveFloat, PositiveInt, computed_field +from clinicadl.monai_networks.nn.layers.utils import ActivationParameters +from clinicadl.monai_networks.nn.vit import PosEmbedType from clinicadl.utils.factories import DefaultFromLibrary -from .base import NetworkConfig -from .utils.enum import ( - ClassificationActivation, - ImplementedNetworks, - PatchEmbeddingTypes, - PosEmbeddingTypes, -) - -__all__ = ["ViTConfig", "ViTAutoEncConfig"] +from .base import ImplementedNetworks, NetworkConfig, NetworkType, PreTrainedConfig class ViTConfig(NetworkConfig): """Config class for ViT networks.""" - in_channels: PositiveInt - img_size: Union[PositiveInt, Tuple[PositiveInt, ...]] - patch_size: Union[PositiveInt, Tuple[PositiveInt, ...]] - - hidden_size: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - mlp_dim: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES + in_shape: Sequence[PositiveInt] + patch_size: Union[Sequence[PositiveInt], PositiveInt] + num_outputs: Optional[PositiveInt] + embedding_dim: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES num_layers: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES num_heads: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - proj_type: Union[PatchEmbeddingTypes, DefaultFromLibrary] = DefaultFromLibrary.YES + mlp_dim: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES pos_embed_type: Union[ - PosEmbeddingTypes, DefaultFromLibrary + Optional[Union[str, PosEmbedType]], DefaultFromLibrary ] = DefaultFromLibrary.YES - classification: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - num_classes: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - dropout_rate: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - spatial_dims: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - post_activation: Union[ - Optional[ClassificationActivation], DefaultFromLibrary + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary ] = DefaultFromLibrary.YES - qkv_bias: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - save_attn: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES + dropout: Union[Optional[PositiveFloat], DefaultFromLibrary] = DefaultFromLibrary.YES @computed_field @property - def network(self) -> ImplementedNetworks: + def name(self) -> ImplementedNetworks: """The name of the network.""" return ImplementedNetworks.VIT + +class PreTrainedViTConfig(PreTrainedConfig): + """Base config class for SOTA ResNets.""" + + @computed_field + @property + def _type(self) -> NetworkType: + """To know where to look for the network.""" + return NetworkType.VIT + + +class ViTB16Config(PreTrainedViTConfig): + """Config class for ViT-B/16.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.VIT_B_16 + + +class ViTB32Config(PreTrainedViTConfig): + """Config class for ViT-B/32.""" + @computed_field @property - def dim(self) -> int: - """Dimension of the images.""" - return self.spatial_dims if self.spatial_dims != DefaultFromLibrary.YES else 3 - - @field_validator("dropout_rate") - def validator_dropout(cls, v): - """Checks that dropout is between 0 and 1.""" - return cls.base_validator_dropout(v) - - @model_validator(mode="before") - def check_einops(self): - """Checks if the library einops is installed.""" - from importlib import util - - spec = util.find_spec("einops") - if spec is None: - raise ModuleNotFoundError("einops is not installed") - return self - - @model_validator(mode="after") - def model_validator(self): - """Checks coherence between parameters.""" - assert self._check_dimensions( - self.img_size - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for img_size. You passed {self.img_size}." - assert self._check_dimensions( - self.patch_size - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for patch_size. You passed {self.patch_size}." - - if ( - self.hidden_size != DefaultFromLibrary.YES - and self.num_heads != DefaultFromLibrary.YES - ): - assert self._divide( - self.hidden_size, self.num_heads - ), f"hidden_size must be divisible by num_heads. You passed hidden_size={self.hidden_size} and num_heads={self.num_heads}." - elif ( - self.hidden_size != DefaultFromLibrary.YES - and self.num_heads == DefaultFromLibrary.YES - ): - raise ValueError("If you pass hidden_size, please also pass num_heads.") - elif ( - self.hidden_size == DefaultFromLibrary.YES - and self.num_heads != DefaultFromLibrary.YES - ): - raise ValueError("If you pass num_head, please also pass hidden_size.") - - return self - - def _divide( - self, - numerator: Union[int, Tuple[int, ...]], - denominator: Union[int, Tuple[int, ...]], - ) -> bool: - """Checks if numerator is divisible by denominator.""" - if isinstance(numerator, int): - numerator = (numerator,) * self.dim - if isinstance(denominator, int): - denominator = (denominator,) * self.dim - for n, d in zip(numerator, denominator): - if n % d != 0: - return False - return True - - -class ViTAutoEncConfig(ViTConfig): - """Config class for ViT autoencoders.""" - - out_channels: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - deconv_chns: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.VIT_B_32 + + +class ViTL16Config(PreTrainedViTConfig): + """Config class for ViT-L/16.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.VIT_L_16 + + +class ViTL32Config(PreTrainedViTConfig): + """Config class for ViT-L/32.""" @computed_field @property - def network(self) -> ImplementedNetworks: + def name(self) -> ImplementedNetworks: """The name of the network.""" - return ImplementedNetworks.VIT_AE - - @model_validator(mode="after") - def model_validator_bis(self): - """Checks coherence between parameters.""" - assert self._divide( - self.img_size, self.patch_size - ), f"img_size must be divisible by patch_size. You passed hidden_size={self.img_size} and num_heads={self.patch_size}." - assert self._is_sqrt( - self.patch_size - ), f"patch_size must be square number(s). You passed {self.patch_size}." - - return self - - def _is_sqrt(self, value: Union[int, Tuple[int, ...]]) -> bool: - """Checks if value is a square number.""" - import math - - if isinstance(value, int): - value = (value,) * self.dim - return all([int(math.sqrt(v)) == math.sqrt(v) for v in value]) + return ImplementedNetworks.VIT_L_32 diff --git a/clinicadl/monai_networks/factory.py b/clinicadl/monai_networks/factory.py index 1e509f3d1..36a4d1d46 100644 --- a/clinicadl/monai_networks/factory.py +++ b/clinicadl/monai_networks/factory.py @@ -1,38 +1,123 @@ -from typing import Tuple +from copy import deepcopy +from typing import Any, Callable, Tuple, Union -import monai.networks.nets as networks import torch.nn as nn +from pydantic import BaseModel +import clinicadl.monai_networks.nn as nets from clinicadl.utils.factories import DefaultFromLibrary, get_args_and_defaults -from .config.base import NetworkConfig +from .config import ( + ImplementedNetworks, + NetworkConfig, + NetworkType, + create_network_config, +) +from .config.conv_decoder import ConvDecoderOptions +from .config.conv_encoder import ConvEncoderOptions +from .config.mlp import MLPOptions +from .nn import MLP, ConvDecoder, ConvEncoder -def get_network(config: NetworkConfig) -> Tuple[nn.Module, NetworkConfig]: +def get_network( + name: Union[str, ImplementedNetworks], return_config: bool = False, **kwargs: Any +) -> Union[nn.Module, Tuple[nn.Module, NetworkConfig]]: """ - Factory function to get a Neural Network from MONAI. + Factory function to get a neural network from its name and parameters. + + Parameters + ---------- + name : Union[str, ImplementedNetworks] + the name of the neural network. Check our documentation to know + available networks. + return_config : bool (optional, default=False) + if the function should return the config class regrouping the parameters of the + neural network. Useful to keep track of the hyperparameters. + kwargs : Any + the parameters of the neural network. Check our documentation on networks to + know these parameters. + + Returns + ------- + nnn.Module + the neural network. + NetworkConfig + the associated config class. Only returned if `return_config` is True. + """ + config = create_network_config(name)(**kwargs) + network, updated_config = get_network_from_config(config) + + return network if not return_config else (network, updated_config) + + +def get_network_from_config(config: NetworkConfig) -> Tuple[nn.Module, NetworkConfig]: + """ + Factory function to get a neural network from a NetworkConfig instance. Parameters ---------- config : NetworkConfig - The config class with the parameters of the network. + the configuration object. Returns ------- nn.Module - The neural network. + the neural network. NetworkConfig - The updated config class: the arguments set to default will be updated - with their effective values (the default values from the library). + the updated config class: the arguments set to default will be updated + with their effective values (the default values from the network). Useful for reproducibility. """ - network_class = getattr(networks, config.network) - expected_args, config_dict = get_args_and_defaults(network_class.__init__) - for arg, value in config.model_dump().items(): - if arg in expected_args and value != DefaultFromLibrary.YES: - config_dict[arg] = value + config = deepcopy(config) + network_type = config._type # pylint: disable=protected-access + + if network_type == NetworkType.CUSTOM: + network_class: type[nn.Module] = getattr(nets, config.name) + if config.name == ImplementedNetworks.SE_RESNET: + _update_config_with_defaults( + config, getattr(nets, ImplementedNetworks.RESNET.value).__init__ + ) # SEResNet has some default values in ResNet + elif config.name == ImplementedNetworks.ATT_UNET: + _update_config_with_defaults( + config, getattr(nets, ImplementedNetworks.UNET.value).__init__ + ) + _update_config_with_defaults(config, network_class.__init__) + + config_dict = config.model_dump(exclude={"name", "_type"}) + network = network_class(**config_dict) + + else: # sota networks + if network_type == NetworkType.RESNET: + getter: Callable[..., nn.Module] = nets.get_resnet + elif network_type == NetworkType.DENSENET: + getter: Callable[..., nn.Module] = nets.get_densenet + elif network_type == NetworkType.SE_RESNET: + getter: Callable[..., nn.Module] = nets.get_seresnet + elif network_type == NetworkType.VIT: + getter: Callable[..., nn.Module] = nets.get_vit + _update_config_with_defaults(config, getter) # pylint: disable=possibly-used-before-assignment - network = network_class(**config_dict) - updated_config = config.model_copy(update=config_dict) + config_dict = config.model_dump(exclude={"_type"}) + network = getter(**config_dict) + + return network, config + + +def _update_config_with_defaults(config: BaseModel, function: Callable) -> BaseModel: + """ + Updates a config object by setting the parameters left to 'default' to their actual + default values, extracted from 'function'. + """ + _, defaults = get_args_and_defaults(function) - return network, updated_config + for arg, value in config: + if isinstance(value, MLPOptions): + _update_config_with_defaults( + value, MLP.__init__ + ) # we need to update the sub config object + elif isinstance(value, ConvEncoderOptions): + _update_config_with_defaults(value, ConvEncoder.__init__) + elif isinstance(value, ConvDecoderOptions): + _update_config_with_defaults(value, ConvDecoder.__init__) + elif value == DefaultFromLibrary.YES and arg in defaults: + setattr(config, arg, defaults[arg]) diff --git a/clinicadl/monai_networks/nn/__init__.py b/clinicadl/monai_networks/nn/__init__.py new file mode 100644 index 000000000..0e1c7054a --- /dev/null +++ b/clinicadl/monai_networks/nn/__init__.py @@ -0,0 +1,13 @@ +from .att_unet import AttentionUNet +from .autoencoder import AutoEncoder +from .cnn import CNN +from .conv_decoder import ConvDecoder +from .conv_encoder import ConvEncoder +from .densenet import DenseNet, get_densenet +from .generator import Generator +from .mlp import MLP +from .resnet import ResNet, get_resnet +from .senet import SEResNet, get_seresnet +from .unet import UNet +from .vae import VAE +from .vit import ViT, get_vit diff --git a/clinicadl/monai_networks/nn/att_unet.py b/clinicadl/monai_networks/nn/att_unet.py new file mode 100644 index 000000000..77ef02081 --- /dev/null +++ b/clinicadl/monai_networks/nn/att_unet.py @@ -0,0 +1,207 @@ +from typing import Any + +import torch +from monai.networks.nets.attentionunet import AttentionBlock + +from .layers.unet import ConvBlock, UpSample +from .unet import BaseUNet + + +class AttentionUNet(BaseUNet): + """ + Attention-UNet based on [Attention U-Net: Learning Where to Look for the Pancreas](https://arxiv.org/pdf/1804.03999). + + The user can customize the number of encoding blocks, the number of channels in each block, as well as other parameters + like the activation function. + + .. warning:: AttentionUNet works only with images whose dimensions are high enough powers of 2. More precisely, if n is the + number of max pooling operation in your AttentionUNet (which is equal to `len(channels)-1`), the image must have :math:`2^{k}` + pixels in each dimension, with :math:`k \\geq n` (e.g. shape (:math:`2^{n}`, :math:`2^{n+3}`) for a 2D image). + + Parameters + ---------- + spatial_dims : int + number of spatial dimensions of the input image. + in_channels : int + number of channels in the input image. + out_channels : int + number of output channels. + kwargs : Any + any optional argument accepted by (:py:class:`clinicadl.monai_networks.nn.unet.UNet`). + + Examples + -------- + >>> AttentionUNet( + spatial_dims=2, + in_channels=1, + out_channels=2, + channels=(4, 8), + act="elu", + output_act=("softmax", {"dim": 1}), + dropout=0.1, + ) + AttentionUNet( + (doubleconv): ConvBlock( + (0): Convolution( + (conv): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (1): Convolution( + (conv): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + ) + (down1): DownBlock( + (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + (doubleconv): ConvBlock( + (0): Convolution( + (conv): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (1): Convolution( + (conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + ) + ) + (upsample1): UpSample( + (0): Upsample(scale_factor=2.0, mode='nearest') + (1): Convolution( + (conv): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + ) + (attention1): AttentionBlock( + (W_g): Sequential( + (0): Convolution( + (conv): Conv2d(4, 2, kernel_size=(1, 1), stride=(1, 1)) + ) + (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (W_x): Sequential( + (0): Convolution( + (conv): Conv2d(4, 2, kernel_size=(1, 1), stride=(1, 1)) + ) + (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (psi): Sequential( + (0): Convolution( + (conv): Conv2d(2, 1, kernel_size=(1, 1), stride=(1, 1)) + ) + (1): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (2): Sigmoid() + ) + (relu): ReLU() + ) + (doubleconv1): ConvBlock( + (0): Convolution( + (conv): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (1): Convolution( + (conv): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + ) + (reduce_channels): Convolution( + (conv): Conv2d(4, 2, kernel_size=(1, 1), stride=(1, 1)) + ) + (output_act): Softmax(dim=1) + ) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + **kwargs: Any, + ): + super().__init__( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + **kwargs, + ) + + def _build_decoder(self): + for i in range(len(self.channels) - 1, 0, -1): + self.add_module( + f"upsample{i}", + UpSample( + spatial_dims=self.spatial_dims, + in_channels=self.channels[i], + out_channels=self.channels[i - 1], + act=self.act, + dropout=self.dropout, + ), + ) + self.add_module( + f"attention{i}", + AttentionBlock( + spatial_dims=self.spatial_dims, + f_l=self.channels[i - 1], + f_g=self.channels[i - 1], + f_int=self.channels[i - 1] // 2, + dropout=self.dropout, + ), + ) + self.add_module( + f"doubleconv{i}", + ConvBlock( + spatial_dims=self.spatial_dims, + in_channels=self.channels[i - 1] * 2, + out_channels=self.channels[i - 1], + act=self.act, + dropout=self.dropout, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_history = [self.doubleconv(x)] + + for i in range(1, len(self.channels)): + x = self.get_submodule(f"down{i}")(x_history[-1]) + x_history.append(x) + + x_history.pop() # the output of bottelneck is not used as a gating signal + for i in range(len(self.channels) - 1, 0, -1): + up = self.get_submodule(f"upsample{i}")(x) + att_res = self.get_submodule(f"attention{i}")(g=x_history.pop(), x=up) + merged = torch.cat((att_res, up), dim=1) + x = self.get_submodule(f"doubleconv{i}")(merged) + + out = self.reduce_channels(x) + + if self.output_act is not None: + out = self.output_act(out) + + return out diff --git a/clinicadl/monai_networks/nn/autoencoder.py b/clinicadl/monai_networks/nn/autoencoder.py new file mode 100644 index 000000000..5cf823eeb --- /dev/null +++ b/clinicadl/monai_networks/nn/autoencoder.py @@ -0,0 +1,416 @@ +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union + +import numpy as np +import torch.nn as nn + +from .cnn import CNN +from .conv_encoder import ConvEncoder +from .generator import Generator +from .layers.utils import ( + ActivationParameters, + PoolingLayer, + SingleLayerPoolingParameters, + SingleLayerUnpoolingParameters, + UnpoolingLayer, + UnpoolingMode, +) +from .mlp import MLP +from .utils import ( + calculate_conv_out_shape, + calculate_convtranspose_out_shape, + calculate_pool_out_shape, +) + + +class AutoEncoder(nn.Sequential): + """ + An autoencoder with convolutional and fully connected layers. + + The user must pass the arguments to build an encoder, from its convolutional and + fully connected parts, and the decoder will be automatically built by taking the + symmetrical network. + + More precisely, to build the decoder, the order of the encoding layers is reverted, convolutions are + replaced by transposed convolutions and pooling layers are replaced by either upsampling or transposed + convolution layers. + Please note that the order of `Activation`, `Dropout` and `Normalization`, defined with the + argument `adn_ordering` in `conv_args`, is the same for the encoder and the decoder. + + Note that an `AutoEncoder` is an aggregation of a `CNN` (:py:class:`clinicadl.monai_networks.nn. + cnn.CNN`) and a `Generator` (:py:class:`clinicadl.monai_networks.nn.generator.Generator`). + + Parameters + ---------- + in_shape : Sequence[int] + sequence of integers stating the dimension of the input tensor (minus batch dimension). + latent_size : int + size of the latent vector. + conv_args : Dict[str, Any] + the arguments for the convolutional part of the encoder. The arguments are those accepted + by :py:class:`clinicadl.monai_networks.nn.conv_encoder.ConvEncoder`, except `in_shape` that + is specified here. So, the only mandatory argument is `channels`. + mlp_args : Optional[Dict[str, Any]] (optional, default=None) + the arguments for the MLP part of the encoder . The arguments are those accepted by + :py:class:`clinicadl.monai_networks.nn.mlp.MLP`, except `in_channels` that is inferred + from the output of the convolutional part, and `out_channels` that is set to `latent_size`. + So, the only mandatory argument is `hidden_channels`.\n + If None, the MLP part will be reduced to a single linear layer. + out_channels : Optional[int] (optional, default=None) + number of output channels. If None, the output will have the same number of channels as the + input. + output_act : Optional[ActivationParameters] (optional, default=None) + a potential activation layer applied to the output of the network, and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + unpooling_mode : Union[str, UnpoolingMode] (optional, default=UnpoolingMode.NEAREST) + type of unpooling. Can be either `"nearest"`, `"linear"`, `"bilinear"`, `"bicubic"`, `"trilinear"` or + `"convtranspose"`.\n + - `nearest`: unpooling is performed by upsampling with the :italic:`nearest` algorithm (see [PyTorch's Upsample layer] + (https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html)). + - `linear`: unpooling is performed by upsampling with the :italic:`linear` algorithm. Only works with 1D images (excluding the + channel dimension). + - `bilinear`: unpooling is performed by upsampling with the :italic:`bilinear` algorithm. Only works with 2D images. + - `bicubic`: unpooling is performed by upsampling with the :italic:`bicubic` algorithm. Only works with 2D images. + - `trilinear`: unpooling is performed by upsampling with the :italic:`trilinear` algorithm. Only works with 3D images. + - `convtranspose`: unpooling is performed with a transposed convolution, whose parameters (kernel size, stride, etc.) are + computed to reverse the pooling operation. + + Examples + -------- + >>> AutoEncoder( + in_shape=(1, 16, 16), + latent_size=8, + conv_args={ + "channels": [2, 4], + "pooling_indices": [0], + "pooling": ("avg", {"kernel_size": 2}), + }, + mlp_args={"hidden_channels": [32], "output_act": "relu"}, + out_channels=2, + output_act="sigmoid", + unpooling_mode="bilinear", + ) + AutoEncoder( + (encoder): CNN( + (convolutions): ConvEncoder( + (layer0): Convolution( + (conv): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)) + (adn): ADN( + (N): InstanceNorm2d(2, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False) + (A): PReLU(num_parameters=1) + ) + ) + (pool0): AvgPool2d(kernel_size=2, stride=2, padding=0) + (layer1): Convolution( + (conv): Conv2d(2, 4, kernel_size=(3, 3), stride=(1, 1)) + ) + ) + (mlp): MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (hidden0): Sequential( + (linear): Linear(in_features=100, out_features=32, bias=True) + (adn): ADN( + (N): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (A): PReLU(num_parameters=1) + ) + ) + (output): Sequential( + (linear): Linear(in_features=32, out_features=8, bias=True) + (output_act): ReLU() + ) + ) + ) + (decoder): Generator( + (mlp): MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (hidden0): Sequential( + (linear): Linear(in_features=8, out_features=32, bias=True) + (adn): ADN( + (N): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (A): PReLU(num_parameters=1) + ) + ) + (output): Sequential( + (linear): Linear(in_features=32, out_features=100, bias=True) + (output_act): ReLU() + ) + ) + (reshape): Reshape() + (convolutions): ConvDecoder( + (layer0): Convolution( + (conv): ConvTranspose2d(4, 4, kernel_size=(3, 3), stride=(1, 1)) + (adn): ADN( + (N): InstanceNorm2d(4, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False) + (A): PReLU(num_parameters=1) + ) + ) + (unpool0): Upsample(size=(14, 14), mode=) + (layer1): Convolution( + (conv): ConvTranspose2d(4, 2, kernel_size=(3, 3), stride=(1, 1)) + ) + (output_act): Sigmoid() + ) + ) + ) + + """ + + def __init__( + self, + in_shape: Sequence[int], + latent_size: int, + conv_args: Dict[str, Any], + mlp_args: Optional[Dict[str, Any]] = None, + out_channels: Optional[int] = None, + output_act: Optional[ActivationParameters] = None, + unpooling_mode: Union[str, UnpoolingMode] = UnpoolingMode.NEAREST, + ) -> None: + super().__init__() + self.in_shape = in_shape + self.latent_size = latent_size + self.out_channels = out_channels if out_channels else self.in_shape[0] + self._output_act = output_act + self.unpooling_mode = self._check_unpooling_mode(unpooling_mode) + self.spatial_dims = len(in_shape[1:]) + + self.encoder = CNN( + in_shape=self.in_shape, + num_outputs=latent_size, + conv_args=conv_args, + mlp_args=mlp_args, + ) + inter_channels = ( + conv_args["channels"][-1] if len(conv_args["channels"]) > 0 else in_shape[0] + ) + inter_shape = (inter_channels, *self.encoder.convolutions.final_size) + self.decoder = Generator( + latent_size=latent_size, + start_shape=inter_shape, + conv_args=self._invert_conv_args(conv_args, self.encoder.convolutions), + mlp_args=self._invert_mlp_args(mlp_args, self.encoder.mlp), + ) + + @classmethod + def _invert_mlp_args(cls, args: Dict[str, Any], mlp: MLP) -> Dict[str, Any]: + """ + Inverts arguments passed for the MLP part of the encoder, to get the MLP part of + the decoder. + """ + if args is None: + args = {} + args["hidden_channels"] = cls._invert_list_arg(mlp.hidden_channels) + + return args + + def _invert_conv_args( + self, args: Dict[str, Any], conv: ConvEncoder + ) -> Dict[str, Any]: + """ + Inverts arguments passed for the convolutional part of the encoder, to get the convolutional + part of the decoder. + """ + if len(args["channels"]) == 0: + args["channels"] = [] + else: + args["channels"] = self._invert_list_arg(conv.channels[:-1]) + [ + self.out_channels + ] + args["kernel_size"] = self._invert_list_arg(conv.kernel_size) + args["stride"] = self._invert_list_arg(conv.stride) + args["dilation"] = self._invert_list_arg(conv.dilation) + args["padding"], args["output_padding"] = self._get_paddings_list(conv) + + args["unpooling_indices"] = ( + conv.n_layers - np.array(conv.pooling_indices) - 2 + ).astype(int) + args["unpooling"] = [] + sizes_before_pooling = [ + size + for size, (layer_name, _) in zip(conv.size_details, conv.named_children()) + if "pool" in layer_name + ] + for size, pooling in zip(sizes_before_pooling[::-1], conv.pooling[::-1]): + args["unpooling"].append(self._invert_pooling_layer(size, pooling)) + + if "pooling" in args: + del args["pooling"] + if "pooling_indices" in args: + del args["pooling_indices"] + + args["output_act"] = self._output_act if self._output_act else None + + return args + + @classmethod + def _invert_list_arg(cls, arg: Union[Any, List[Any]]) -> Union[Any, List[Any]]: + """ + Reverses lists. + """ + return list(arg[::-1]) if isinstance(arg, Sequence) else arg + + def _invert_pooling_layer( + self, + size_before_pool: Sequence[int], + pooling: SingleLayerPoolingParameters, + ) -> SingleLayerUnpoolingParameters: + """ + Gets the unpooling layer. + """ + if self.unpooling_mode == UnpoolingMode.CONV_TRANS: + return ( + UnpoolingLayer.CONV_TRANS, + self._invert_pooling_with_convtranspose(size_before_pool, pooling), + ) + else: + return ( + UnpoolingLayer.UPSAMPLE, + {"size": size_before_pool, "mode": self.unpooling_mode}, + ) + + @classmethod + def _invert_pooling_with_convtranspose( + cls, + size_before_pool: Sequence[int], + pooling: SingleLayerPoolingParameters, + ) -> Dict[str, Any]: + """ + Computes the arguments of the transposed convolution, based on the pooling layer. + """ + pooling_mode, pooling_args = pooling + if ( + pooling_mode == PoolingLayer.ADAPT_AVG + or pooling_mode == PoolingLayer.ADAPT_MAX + ): + input_size_np = np.array(size_before_pool) + output_size_np = np.array(pooling_args["output_size"]) + stride_np = input_size_np // output_size_np # adaptive pooling formulas + kernel_size_np = ( + input_size_np - (output_size_np - 1) * stride_np + ) # adaptive pooling formulas + args = { + "kernel_size": tuple(int(k) for k in kernel_size_np), + "stride": tuple(int(s) for s in stride_np), + } + padding, output_padding = cls._find_convtranspose_paddings( + pooling_mode, + size_before_pool, + output_size=pooling_args["output_size"], + **args, + ) + + elif pooling_mode == PoolingLayer.MAX or pooling_mode == PoolingLayer.AVG: + if "stride" not in pooling_args: + pooling_args["stride"] = pooling_args["kernel_size"] + args = { + arg: value + for arg, value in pooling_args.items() + if arg in ["kernel_size", "stride", "padding", "dilation"] + } + padding, output_padding = cls._find_convtranspose_paddings( + pooling_mode, + size_before_pool, + **pooling_args, + ) + + args["padding"] = padding # pylint: disable=possibly-used-before-assignment + args["output_padding"] = output_padding # pylint: disable=possibly-used-before-assignment + + return args + + @classmethod + def _get_paddings_list(cls, conv: ConvEncoder) -> List[Tuple[int, ...]]: + """ + Finds output padding list. + """ + padding = [] + output_padding = [] + size_before_convs = [ + size + for size, (layer_name, _) in zip(conv.size_details, conv.named_children()) + if "layer" in layer_name + ] + for size, k, s, p, d in zip( + size_before_convs, + conv.kernel_size, + conv.stride, + conv.padding, + conv.dilation, + ): + p, out_p = cls._find_convtranspose_paddings( + "conv", size, kernel_size=k, stride=s, padding=p, dilation=d + ) + padding.append(p) + output_padding.append(out_p) + + return cls._invert_list_arg(padding), cls._invert_list_arg(output_padding) + + @classmethod + def _find_convtranspose_paddings( + cls, + layer_type: Union[Literal["conv"], PoolingLayer], + in_shape: Union[Sequence[int], int], + padding: Union[Sequence[int], int] = 0, + **kwargs, + ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """ + Finds padding and output padding necessary to recover the right image size after + a transposed convolution. + """ + if layer_type == "conv": + layer_out_shape = calculate_conv_out_shape(in_shape, **kwargs) + elif layer_type in list(PoolingLayer): + layer_out_shape = calculate_pool_out_shape(layer_type, in_shape, **kwargs) + + convt_out_shape = calculate_convtranspose_out_shape(layer_out_shape, **kwargs) # pylint: disable=possibly-used-before-assignment + output_padding = np.atleast_1d(in_shape) - np.atleast_1d(convt_out_shape) + + if ( + output_padding < 0 + ).any(): # can happen with ceil_mode=True for maxpool. Then, add some padding + padding = np.atleast_1d(padding) * np.ones_like( + output_padding + ) # to have the same shape as output_padding + padding[output_padding < 0] += np.maximum(np.abs(output_padding) // 2, 1)[ + output_padding < 0 + ] # //2 because 2*padding pixels are removed + + convt_out_shape = calculate_convtranspose_out_shape( + layer_out_shape, padding=padding, **kwargs + ) + output_padding = np.atleast_1d(in_shape) - np.atleast_1d(convt_out_shape) + padding = tuple(int(s) for s in padding) + + return padding, tuple(int(s) for s in output_padding) + + def _check_unpooling_mode( + self, unpooling_mode: Union[str, UnpoolingMode] + ) -> UnpoolingMode: + """ + Checks consistency between data shape and unpooling mode. + """ + unpooling_mode = UnpoolingMode(unpooling_mode) + if unpooling_mode == UnpoolingMode.LINEAR and len(self.in_shape) != 2: + raise ValueError( + f"unpooling mode `linear` only works with 2D data (counting the channel dimension). " + f"Got in_shape={self.in_shape}, which is understood as {len(self.in_shape)}D data." + ) + elif unpooling_mode == UnpoolingMode.BILINEAR and len(self.in_shape) != 3: + raise ValueError( + f"unpooling mode `bilinear` only works with 3D data (counting the channel dimension). " + f"Got in_shape={self.in_shape}, which is understood as {len(self.in_shape)}D data." + ) + elif unpooling_mode == UnpoolingMode.BICUBIC and len(self.in_shape) != 3: + raise ValueError( + f"unpooling mode `bicubic` only works with 3D data (counting the channel dimension). " + f"Got in_shape={self.in_shape}, which is understood as {len(self.in_shape)}D data." + ) + elif unpooling_mode == UnpoolingMode.TRILINEAR and len(self.in_shape) != 4: + raise ValueError( + f"unpooling mode `trilinear` only works with 4D data (counting the channel dimension). " + f"Got in_shape={self.in_shape}, which is understood as {len(self.in_shape)}D data." + ) + + return unpooling_mode diff --git a/clinicadl/monai_networks/nn/cnn.py b/clinicadl/monai_networks/nn/cnn.py new file mode 100644 index 000000000..1479ecaea --- /dev/null +++ b/clinicadl/monai_networks/nn/cnn.py @@ -0,0 +1,124 @@ +from typing import Any, Dict, Optional, Sequence + +import numpy as np +import torch.nn as nn + +from .conv_encoder import ConvEncoder +from .mlp import MLP +from .utils import check_conv_args, check_mlp_args + + +class CNN(nn.Sequential): + """ + A regressor/classifier with first convolutional layers and then fully connected layers. + + This network is a simple aggregation of a Fully Convolutional Network (:py:class:`clinicadl. + monai_networks.nn.conv_encoder.ConvEncoder`) and a Multi Layer Perceptron (:py:class:`clinicadl. + monai_networks.nn.mlp.MLP`). + + Parameters + ---------- + in_shape : Sequence[int] + sequence of integers stating the dimension of the input tensor (minus batch dimension). + num_outputs : int + number of variables to predict. + conv_args : Dict[str, Any] + the arguments for the convolutional part. The arguments are those accepted by + :py:class:`clinicadl.monai_networks.nn.conv_encoder.ConvEncoder`, except `in_shape` + that is specified here. So, the only mandatory argument is `channels`. + mlp_args : Optional[Dict[str, Any]] (optional, default=None) + the arguments for the MLP part. The arguments are those accepted by + :py:class:`clinicadl.monai_networks.nn.mlp.MLP`, except `in_channels` that is inferred + from the output of the convolutional part, and `out_channels` that is set to `num_outputs`. + So, the only mandatory argument is `hidden_channels`.\n + If None, the MLP part will be reduced to a single linear layer. + + Examples + -------- + # a classifier + >>> CNN( + in_shape=(1, 10, 10), + num_outputs=2, + conv_args={"channels": [2, 4], "norm": None, "act": None}, + mlp_args={"hidden_channels": [5], "act": "elu", "norm": None, "output_act": "softmax"}, + ) + CNN( + (convolutions): ConvEncoder( + (layer0): Convolution( + (conv): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)) + ) + (layer1): Convolution( + (conv): Conv2d(2, 4, kernel_size=(3, 3), stride=(1, 1)) + ) + ) + (mlp): MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (hidden0): Sequential( + (linear): Linear(in_features=144, out_features=5, bias=True) + (adn): ADN( + (A): ELU(alpha=1.0) + ) + ) + (output): Sequential( + (linear): Linear(in_features=5, out_features=2, bias=True) + (output_act): Softmax(dim=None) + ) + ) + ) + + # a regressor + >>> CNN( + in_shape=(1, 10, 10), + num_outputs=2, + conv_args={"channels": [2, 4], "norm": None, "act": None}, + ) + CNN( + (convolutions): ConvEncoder( + (layer0): Convolution( + (conv): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)) + ) + (layer1): Convolution( + (conv): Conv2d(2, 4, kernel_size=(3, 3), stride=(1, 1)) + ) + ) + (mlp): MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (output): Linear(in_features=144, out_features=2, bias=True) + ) + ) + """ + + def __init__( + self, + in_shape: Sequence[int], + num_outputs: int, + conv_args: Dict[str, Any], + mlp_args: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + check_conv_args(conv_args) + check_mlp_args(mlp_args) + self.in_shape = in_shape + self.num_outputs = num_outputs + + in_channels, *input_size = in_shape + spatial_dims = len(input_size) + + self.convolutions = ConvEncoder( + in_channels=in_channels, + spatial_dims=spatial_dims, + _input_size=tuple(input_size), + **conv_args, + ) + + n_channels = ( + conv_args["channels"][-1] if len(conv_args["channels"]) > 0 else in_shape[0] + ) + flatten_shape = int(np.prod(self.convolutions.final_size) * n_channels) + if mlp_args is None: + mlp_args = {"hidden_channels": []} + self.mlp = MLP( + in_channels=flatten_shape, + out_channels=num_outputs, + **mlp_args, + ) diff --git a/clinicadl/monai_networks/nn/conv_decoder.py b/clinicadl/monai_networks/nn/conv_decoder.py new file mode 100644 index 000000000..28c9be96f --- /dev/null +++ b/clinicadl/monai_networks/nn/conv_decoder.py @@ -0,0 +1,388 @@ +from typing import Callable, Optional, Sequence, Tuple + +import torch.nn as nn +from monai.networks.blocks import Convolution +from monai.networks.layers.utils import get_act_layer +from monai.utils.misc import ensure_tuple + +from .layers.unpool import get_unpool_layer +from .layers.utils import ( + ActFunction, + ActivationParameters, + ConvNormalizationParameters, + ConvNormLayer, + ConvParameters, + NormLayer, + SingleLayerUnpoolingParameters, + UnpoolingLayer, + UnpoolingParameters, +) +from .utils import ( + calculate_convtranspose_out_shape, + calculate_unpool_out_shape, + check_adn_ordering, + check_norm_layer, + check_pool_indices, + ensure_list_of_tuples, +) + + +class ConvDecoder(nn.Sequential): + """ + Fully convolutional decoder network with transposed convolutions, unpooling, normalization, activation + and dropout layers. + + Parameters + ---------- + spatial_dims : int + number of spatial dimensions of the input image. + in_channels : int + number of channels in the input image. + channels : Sequence[int] + sequence of integers stating the output channels of each transposed convolution. Thus, this + parameter also controls the number of transposed convolutions. + kernel_size : ConvParameters (optional, default=3) + the kernel size of the transposed convolutions. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the kernel sizes for each layer. + The length of the list must be equal to the number of transposed convolution layers (i.e. + `len(channels)`). + stride : ConvParameters (optional, default=1) + the stride of the transposed convolutions. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the strides for each layer. + The length of the list must be equal to the number of transposed convolution layers (i.e. + `len(channels)`). + padding : ConvParameters (optional, default=0) + the padding of the transposed convolutions. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the paddings for each layer. + The length of the list must be equal to the number of transposed convolution layers (i.e. + `len(channels)`). + output_padding : ConvParameters (optional, default=0) + the output padding of the transposed convolutions. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the output paddings for each layer. + The length of the list must be equal to the number of transposed convolution layers (i.e. + `len(channels)`). + dilation : ConvParameters (optional, default=1) + the dilation factor of the transposed convolutions. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the dilations for each layer. + The length of the list must be equal to the number of transposed convolution layers (i.e. + `len(channels)`). + unpooling : Optional[UnpoolingParameters] (optional, default=(UnpoolingLayer.UPSAMPLE, {"scale_factor": 2})) + the unpooling mode and the arguments of the unpooling layer, passed as `(unpooling_mode, arguments)`. + If None, no unpooling will be performed in the network.\n + `unpooling_mode` can be either `upsample` or `convtranspose`. Please refer to PyTorch's [Upsample] + (https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html) or [ConvTranspose](https:// + pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html) to know the mandatory and optional + arguments.\n + If a list is passed, it will be understood as `(unpooling_mode, arguments)` for each unpooling layer.\n + Note: no need to pass `in_channels` and `out_channels` for `convtranspose` because the unpooling + layers are not intended to modify the number of channels. + unpooling_indices : Optional[Sequence[int]] (optional, default=None) + indices of the transposed convolution layers after which unpooling should be performed. + If None, no unpooling will be performed. An index equal to -1 will be understood as a pooling layer before + the first transposed convolution. + act : Optional[ActivationParameters] (optional, default=ActFunction.PRELU) + the activation function used after a transposed convolution layer, and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + output_act : Optional[ActivationParameters] (optional, default=None) + a potential activation layer applied to the output of the network. Should be pass in the same way as `act`. + If None, no last activation will be applied. + norm : Optional[ConvNormalizationParameters] (optional, default=NormLayer.INSTANCE) + the normalization type used after a transposed convolution layer, and optionally the arguments of the normalization + layer. Should be passed as `norm_type` or `(norm_type, parameters)`. If None, no normalization will be + performed.\n + `norm_type` can be any value in {`batch`, `group`, `instance`, `syncbatch`}. Please refer to PyTorch's + [normalization layers](https://pytorch.org/docs/stable/nn.html#normalization-layers) to know the mandatory and + optional arguments for each of them.\n + Please note that arguments `num_channels`, `num_features` of the normalization layer + should not be passed, as they are automatically inferred from the output of the previous layer in the network. + dropout : Optional[float] (optional, default=None) + dropout ratio. If None, no dropout. + bias : bool (optional, default=True) + whether to have a bias term in transposed convolutions. + adn_ordering : str (optional, default="NDA") + order of operations `Activation`, `Dropout` and `Normalization` after a transposed convolutional layer (except the + last one).\n + For example if "ND" is passed, `Normalization` and then `Dropout` will be performed (without `Activation`).\n + Note: ADN will not be applied after the last convolution. + + Examples + -------- + >>> ConvDecoder( + in_channels=16, + spatial_dims=2, + channels=[8, 4, 1], + kernel_size=(3, 5), + stride=2, + padding=[1, 0, 0], + output_padding=[0, 0, (1, 2)], + dilation=1, + unpooling=[("upsample", {"scale_factor": 2}), ("upsample", {"size": (32, 32)})], + unpooling_indices=[0, 1], + act="elu", + output_act="relu", + norm=("batch", {"eps": 1e-05}), + dropout=0.1, + bias=True, + adn_ordering="NDA", + ) + ConvDecoder( + (layer0): Convolution( + (conv): ConvTranspose2d(16, 8, kernel_size=(3, 5), stride=(2, 2), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (unpool0): Upsample(scale_factor=2.0, mode='nearest') + (layer1): Convolution( + (conv): ConvTranspose2d(8, 4, kernel_size=(3, 5), stride=(2, 2)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (unpool1): Upsample(size=(32, 32), mode='nearest') + (layer2): Convolution( + (conv): ConvTranspose2d(4, 1, kernel_size=(3, 5), stride=(2, 2), output_padding=(1, 2)) + ) + (output_act): ReLU() + ) + + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + channels: Sequence[int], + kernel_size: ConvParameters = 3, + stride: ConvParameters = 1, + padding: ConvParameters = 0, + output_padding: ConvParameters = 0, + dilation: ConvParameters = 1, + unpooling: Optional[UnpoolingParameters] = ( + UnpoolingLayer.UPSAMPLE, + {"scale_factor": 2}, + ), + unpooling_indices: Optional[Sequence[int]] = None, + act: Optional[ActivationParameters] = ActFunction.PRELU, + output_act: Optional[ActivationParameters] = None, + norm: Optional[ConvNormalizationParameters] = ConvNormLayer.INSTANCE, + dropout: Optional[float] = None, + bias: bool = True, + adn_ordering: str = "NDA", + _input_size: Optional[Sequence[int]] = None, + ) -> None: + super().__init__() + + self._current_size = _input_size if _input_size else None + + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.channels = ensure_tuple(channels) + self.n_layers = len(self.channels) + + self.kernel_size = ensure_list_of_tuples( + kernel_size, self.spatial_dims, self.n_layers, "kernel_size" + ) + self.stride = ensure_list_of_tuples( + stride, self.spatial_dims, self.n_layers, "stride" + ) + self.padding = ensure_list_of_tuples( + padding, self.spatial_dims, self.n_layers, "padding" + ) + self.output_padding = ensure_list_of_tuples( + output_padding, self.spatial_dims, self.n_layers, "output_padding" + ) + self.dilation = ensure_list_of_tuples( + dilation, self.spatial_dims, self.n_layers, "dilation" + ) + + self.unpooling_indices = check_pool_indices(unpooling_indices, self.n_layers) + self.unpooling = self._check_unpool_layers(unpooling) + self.act = act + self.norm = check_norm_layer(norm) + if self.norm == NormLayer.LAYER: + raise ValueError("Layer normalization not implemented in ConvDecoder.") + self.dropout = dropout + self.bias = bias + self.adn_ordering = check_adn_ordering(adn_ordering) + + n_unpoolings = 0 + if self.unpooling and -1 in self.unpooling_indices: + unpooling_layer = self._get_unpool_layer( + self.unpooling[n_unpoolings], n_channels=self.in_channels + ) + self.add_module("init_unpool", unpooling_layer) + n_unpoolings += 1 + + echannel = self.in_channels + for i, (c, k, s, p, o_p, d) in enumerate( + zip( + self.channels, + self.kernel_size, + self.stride, + self.padding, + self.output_padding, + self.dilation, + ) + ): + conv_layer = self._get_convtranspose_layer( + in_channels=echannel, + out_channels=c, + kernel_size=k, + stride=s, + padding=p, + output_padding=o_p, + dilation=d, + is_last=(i == len(channels) - 1), + ) + self.add_module(f"layer{i}", conv_layer) + echannel = c # use the output channel number as the input for the next loop + if self.unpooling and i in self.unpooling_indices: + unpooling_layer = self._get_unpool_layer( + self.unpooling[n_unpoolings], n_channels=c + ) + self.add_module(f"unpool{i}", unpooling_layer) + n_unpoolings += 1 + + self.output_act = get_act_layer(output_act) if output_act else None + + @property + def final_size(self): + """ + To know the size of an image at the end of the network. + """ + return self._current_size + + @final_size.setter + def final_size(self, fct: Callable[[Tuple[int, ...]], Tuple[int, ...]]): + """ + Takes as input the function used to update the current image size. + """ + if self._current_size is not None: + self._current_size = fct(self._current_size) + + def _get_convtranspose_layer( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, ...], + stride: Tuple[int, ...], + padding: Tuple[int, ...], + output_padding: Tuple[int, ...], + dilation: Tuple[int, ...], + is_last: bool, + ) -> Convolution: + """ + Gets the parametrized TransposedConvolution-ADN block and updates the current output size. + """ + self.final_size = lambda size: calculate_convtranspose_out_shape( + size, kernel_size, stride, padding, output_padding, dilation + ) + + return Convolution( + is_transposed=True, + conv_only=is_last, + spatial_dims=self.spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + strides=stride, + kernel_size=kernel_size, + padding=padding, + output_padding=output_padding, + dilation=dilation, + act=self.act, + norm=self.norm, + dropout=self.dropout, + bias=self.bias, + adn_ordering=self.adn_ordering, + ) + + def _get_unpool_layer( + self, unpooling: SingleLayerUnpoolingParameters, n_channels: int + ) -> nn.Module: + """ + Gets the parametrized unpooling layer and updates the current output size. + """ + unpool_layer = get_unpool_layer( + unpooling, + spatial_dims=self.spatial_dims, + in_channels=n_channels, + out_channels=n_channels, + ) + self.final_size = lambda size: calculate_unpool_out_shape( + unpool_mode=unpooling[0], + in_shape=size, + **unpool_layer.__dict__, + ) + return unpool_layer + + @classmethod + def _check_single_unpool_layer( + cls, unpooling: SingleLayerUnpoolingParameters + ) -> SingleLayerUnpoolingParameters: + """ + Checks unpooling arguments for a single pooling layer. + """ + if not isinstance(unpooling, tuple) or len(unpooling) != 2: + raise ValueError( + "unpooling must be double (or a list of doubles) with first the type of unpooling and then the parameters of " + f"the unpooling layer in a dict. Got {unpooling}" + ) + _ = UnpoolingLayer(unpooling[0]) # check unpooling mode + args = unpooling[1] + if not isinstance(args, dict): + raise ValueError( + f"The arguments of the unpooling layer must be passed in a dict. Got {args}" + ) + + return unpooling + + def _check_unpool_layers( + self, unpooling: UnpoolingParameters + ) -> UnpoolingParameters: + """ + Checks argument unpooling. + """ + if unpooling is None: + return unpooling + if isinstance(unpooling, list): + for unpool_layer in unpooling: + self._check_single_unpool_layer(unpool_layer) + if len(unpooling) != len(self.unpooling_indices): + raise ValueError( + "If you pass a list for unpooling, the size of that list must match " + f"the size of unpooling_indices. Got: unpooling={unpooling} and " + f"unpooling_indices={self.unpooling_indices}" + ) + elif isinstance(unpooling, tuple): + self._check_single_unpool_layer(unpooling) + unpooling = (unpooling,) * len(self.unpooling_indices) + else: + raise ValueError( + f"unpooling can be either None, a double (string, dictionary) or a list of such doubles. Got {unpooling}" + ) + + return unpooling diff --git a/clinicadl/monai_networks/nn/conv_encoder.py b/clinicadl/monai_networks/nn/conv_encoder.py new file mode 100644 index 000000000..f3ec66484 --- /dev/null +++ b/clinicadl/monai_networks/nn/conv_encoder.py @@ -0,0 +1,392 @@ +from typing import Callable, List, Optional, Sequence, Tuple + +import numpy as np +import torch.nn as nn +from monai.networks.blocks import Convolution +from monai.networks.layers.utils import get_act_layer, get_pool_layer +from monai.utils.misc import ensure_tuple + +from .layers.utils import ( + ActFunction, + ActivationParameters, + ConvNormalizationParameters, + ConvNormLayer, + ConvParameters, + NormLayer, + PoolingLayer, + PoolingParameters, + SingleLayerPoolingParameters, +) +from .utils import ( + calculate_conv_out_shape, + calculate_pool_out_shape, + check_adn_ordering, + check_norm_layer, + check_pool_indices, + ensure_list_of_tuples, +) + + +class ConvEncoder(nn.Sequential): + """ + Fully convolutional encoder network with convolutional, pooling, normalization, activation + and dropout layers. + + Parameters + ---------- + spatial_dims : int + number of spatial dimensions of the input image. + in_channels : int + number of channels in the input image. + channels : Sequence[int] + sequence of integers stating the output channels of each convolutional layer. Thus, this + parameter also controls the number of convolutional layers. + kernel_size : ConvParameters (optional, default=3) + the kernel size of the convolutional layers. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the kernel sizes for each layer. + The length of the list must be equal to the number of convolutional layers (i.e. `len(channels)`). + stride : ConvParameters (optional, default=1) + the stride of the convolutional layers. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the strides for each layer. + The length of the list must be equal to the number of convolutional layers (i.e. `len(channels)`). + padding : ConvParameters (optional, default=0) + the padding of the convolutional layers. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the paddings for each layer. + The length of the list must be equal to the number of convolutional layers (i.e. `len(channels)`). + dilation : ConvParameters (optional, default=1) + the dilation factor of the convolutional layers. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the dilations for each layer. + The length of the list must be equal to the number of convolutional layers (i.e. `len(channels)`). + pooling : Optional[PoolingParameters] (optional, default=(PoolingLayer.MAX, {"kernel_size": 2})) + the pooling mode and the arguments of the pooling layer, passed as `(pooling_mode, arguments)`. + If None, no pooling will be performed in the network.\n + `pooling_mode` can be either `max`, `avg`, `adaptivemax` or `adaptiveavg`. Please refer to PyTorch's [documentation] + (https://pytorch.org/docs/stable/nn.html#pooling-layers) to know the mandatory and optional arguments.\n + If a list is passed, it will be understood as `(pooling_mode, arguments)` for each pooling layer. + pooling_indices : Optional[Sequence[int]] (optional, default=None) + indices of the convolutional layers after which pooling should be performed. + If None, no pooling will be performed. An index equal to -1 will be understood as an unpooling layer before + the first convolution. + act : Optional[ActivationParameters] (optional, default=ActFunction.PRELU) + the activation function used after a convolutional layer, and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + output_act : Optional[ActivationParameters] (optional, default=None) + a potential activation layer applied to the output of the network. Should be pass in the same way as `act`. + If None, no last activation will be applied. + norm : Optional[ConvNormalizationParameters] (optional, default=NormLayer.INSTANCE) + the normalization type used after a convolutional layer, and optionally the arguments of the normalization + layer. Should be passed as `norm_type` or `(norm_type, parameters)`. If None, no normalization will be + performed.\n + `norm_type` can be any value in {`batch`, `group`, `instance`, `syncbatch`}. Please refer to PyTorch's + [normalization layers](https://pytorch.org/docs/stable/nn.html#normalization-layers) to know the mandatory and + optional arguments for each of them.\n + Please note that arguments `num_channels`, `num_features` of the normalization layer + should not be passed, as they are automatically inferred from the output of the previous layer in the network. + dropout : Optional[float] (optional, default=None) + dropout ratio. If None, no dropout. + bias : bool (optional, default=True) + whether to have a bias term in convolutions. + adn_ordering : str (optional, default="NDA") + order of operations `Activation`, `Dropout` and `Normalization` after a convolutional layer (except the last + one). + For example if "ND" is passed, `Normalization` and then `Dropout` will be performed (without `Activation`).\n + Note: ADN will not be applied after the last convolution. + + Examples + -------- + >>> ConvEncoder( + spatial_dims=2, + in_channels=1, + channels=[2, 4, 8], + kernel_size=(3, 5), + stride=1, + padding=[1, (0, 1), 0], + dilation=1, + pooling=[("max", {"kernel_size": 2}), ("avg", {"kernel_size": 2})], + pooling_indices=[0, 1], + act="elu", + output_act="relu", + norm=("batch", {"eps": 1e-05}), + dropout=0.1, + bias=True, + adn_ordering="NDA", + ) + ConvEncoder( + (layer0): Convolution( + (conv): Conv2d(1, 2, kernel_size=(3, 5), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (pool0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + (layer1): Convolution( + (conv): Conv2d(2, 4, kernel_size=(3, 5), stride=(1, 1), padding=(0, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0) + (layer2): Convolution( + (conv): Conv2d(4, 8, kernel_size=(3, 5), stride=(1, 1)) + ) + (output_act): ReLU() + ) + + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + channels: Sequence[int], + kernel_size: ConvParameters = 3, + stride: ConvParameters = 1, + padding: ConvParameters = 0, + dilation: ConvParameters = 1, + pooling: Optional[PoolingParameters] = ( + PoolingLayer.MAX, + {"kernel_size": 2}, + ), + pooling_indices: Optional[Sequence[int]] = None, + act: Optional[ActivationParameters] = ActFunction.PRELU, + output_act: Optional[ActivationParameters] = None, + norm: Optional[ConvNormalizationParameters] = ConvNormLayer.INSTANCE, + dropout: Optional[float] = None, + bias: bool = True, + adn_ordering: str = "NDA", + _input_size: Optional[Sequence[int]] = None, + ) -> None: + super().__init__() + + self._current_size = _input_size if _input_size else None + self._size_details = [self._current_size] if _input_size else None + + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.channels = ensure_tuple(channels) + self.n_layers = len(self.channels) + + self.kernel_size = ensure_list_of_tuples( + kernel_size, self.spatial_dims, self.n_layers, "kernel_size" + ) + self.stride = ensure_list_of_tuples( + stride, self.spatial_dims, self.n_layers, "stride" + ) + self.padding = ensure_list_of_tuples( + padding, self.spatial_dims, self.n_layers, "padding" + ) + self.dilation = ensure_list_of_tuples( + dilation, self.spatial_dims, self.n_layers, "dilation" + ) + + self.pooling_indices = check_pool_indices(pooling_indices, self.n_layers) + self.pooling = self._check_pool_layers(pooling) + self.act = act + self.norm = check_norm_layer(norm) + if self.norm == NormLayer.LAYER: + raise ValueError("Layer normalization not implemented in ConvEncoder.") + self.dropout = dropout + self.bias = bias + self.adn_ordering = check_adn_ordering(adn_ordering) + + n_poolings = 0 + if self.pooling and -1 in self.pooling_indices: + pooling_layer = self._get_pool_layer(self.pooling[n_poolings]) + self.add_module("init_pool", pooling_layer) + n_poolings += 1 + + echannel = self.in_channels + for i, (c, k, s, p, d) in enumerate( + zip( + self.channels, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + ) + ): + conv_layer = self._get_conv_layer( + in_channels=echannel, + out_channels=c, + kernel_size=k, + stride=s, + padding=p, + dilation=d, + is_last=(i == len(channels) - 1), + ) + self.add_module(f"layer{i}", conv_layer) + echannel = c # use the output channel number as the input for the next loop + if self.pooling and i in self.pooling_indices: + pooling_layer = self._get_pool_layer(self.pooling[n_poolings]) + self.add_module(f"pool{i}", pooling_layer) + n_poolings += 1 + + self.output_act = get_act_layer(output_act) if output_act else None + + @property + def final_size(self): + """ + To know the size of an image at the end of the network. + """ + return self._current_size + + @property + def size_details(self): + """ + To know the sizes of intermediate images. + """ + return self._size_details + + @final_size.setter + def final_size(self, fct: Callable[[Tuple[int, ...]], Tuple[int, ...]]): + """ + Takes as input the function used to update the current image size. + """ + if self._current_size is not None: + self._current_size = fct(self._current_size) + self._size_details.append(self._current_size) + self._check_size() + + def _get_conv_layer( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, ...], + stride: Tuple[int, ...], + padding: Tuple[int, ...], + dilation: Tuple[int, ...], + is_last: bool, + ) -> Convolution: + """ + Gets the parametrized Convolution-ADN block and updates the current output size. + """ + self.final_size = lambda size: calculate_conv_out_shape( + size, kernel_size, stride, padding, dilation + ) + + return Convolution( + conv_only=is_last, + spatial_dims=self.spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + strides=stride, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + act=self.act, + norm=self.norm, + dropout=self.dropout, + bias=self.bias, + adn_ordering=self.adn_ordering, + ) + + def _get_pool_layer(self, pooling: SingleLayerPoolingParameters) -> nn.Module: + """ + Gets the parametrized pooling layer and updates the current output size. + """ + pool_layer = get_pool_layer(pooling, spatial_dims=self.spatial_dims) + old_size = self.final_size + self.final_size = lambda size: calculate_pool_out_shape( + pool_mode=pooling[0], in_shape=size, **pool_layer.__dict__ + ) + + if ( + self.final_size is not None + and (np.array(old_size) < np.array(self.final_size)).any() + ): + raise ValueError( + f"You passed {pooling} as a pooling layer. But before this layer, the size of the image " + f"was {old_size}. So, pooling can't be performed." + ) + + return pool_layer + + def _check_size(self) -> None: + """ + Checks that image size never reaches 0. + """ + if self._current_size is not None and (np.array(self._current_size) <= 0).any(): + raise ValueError( + f"Failed to build the network. An image of size 0 or less has been reached. Stopped at:\n {self}" + ) + + @classmethod + def _check_single_pool_layer( + cls, pooling: SingleLayerPoolingParameters + ) -> SingleLayerPoolingParameters: + """ + Checks pooling arguments for a single pooling layer. + """ + if not isinstance(pooling, tuple) or len(pooling) != 2: + raise ValueError( + "pooling must be a double (or a list of doubles) with first the type of pooling and then the parameters " + f"of the pooling layer in a dict. Got {pooling}" + ) + pooling_type = PoolingLayer(pooling[0]) + args = pooling[1] + if not isinstance(args, dict): + raise ValueError( + f"The arguments of the pooling layer must be passed in a dict. Got {args}" + ) + if ( + pooling_type == PoolingLayer.MAX or pooling_type == PoolingLayer.AVG + ) and "kernel_size" not in args: + raise ValueError( + f"For {pooling_type} pooling mode, `kernel_size` argument must be passed. " + f"Got {args}" + ) + elif ( + pooling_type == PoolingLayer.ADAPT_AVG + or pooling_type == PoolingLayer.ADAPT_MAX + ) and "output_size" not in args: + raise ValueError( + f"For {pooling_type} pooling mode, `output_size` argument must be passed. " + f"Got {args}" + ) + + def _check_pool_layers( + self, pooling: PoolingParameters + ) -> List[SingleLayerPoolingParameters]: + """ + Check argument pooling. + """ + if pooling is None: + return pooling + if isinstance(pooling, list): + for pool_layer in pooling: + self._check_single_pool_layer(pool_layer) + if len(pooling) != len(self.pooling_indices): + raise ValueError( + "If you pass a list for pooling, the size of that list must match " + f"the size of pooling_indices. Got: pooling={pooling} and " + f"pooling_indices={self.pooling_indices}" + ) + elif isinstance(pooling, tuple): + self._check_single_pool_layer(pooling) + pooling = [pooling] * len(self.pooling_indices) + else: + raise ValueError( + f"pooling can be either None, a double (string, dictionary) or a list of such doubles. Got {pooling}" + ) + + return pooling diff --git a/clinicadl/monai_networks/nn/densenet.py b/clinicadl/monai_networks/nn/densenet.py new file mode 100644 index 000000000..45d99cc71 --- /dev/null +++ b/clinicadl/monai_networks/nn/densenet.py @@ -0,0 +1,312 @@ +import re +from collections import OrderedDict +from enum import Enum +from typing import Any, Mapping, Optional, Sequence, Union + +import torch.nn as nn +from monai.networks.layers.utils import get_act_layer +from monai.networks.nets import DenseNet as BaseDenseNet +from torch.hub import load_state_dict_from_url +from torchvision.models.densenet import ( + DenseNet121_Weights, + DenseNet161_Weights, + DenseNet169_Weights, + DenseNet201_Weights, +) + +from .layers.utils import ActivationParameters + + +class DenseNet(nn.Sequential): + """ + DenseNet based on the [Densely Connected Convolutional Networks](https://arxiv.org/pdf/1608.06993) paper. + Adapted from [MONAI's implementation](https://docs.monai.io/en/stable/networks.html#densenet). + + The user can customize the number of dense blocks, the number of dense layers in each block, as well as + other parameters like the growth rate. + + DenseNet is a fully convolutional network that can work with input of any size, provided that is it large + enough not to be reduced to a 1-pixel image (before the adaptative average pooling). + + Parameters + ---------- + spatial_dims : int + number of spatial dimensions of the input image. + in_channels : int + number of channels in the input image. + num_outputs : Optional[int] + number of output variables after the last linear layer.\n + If None, the features before the last fully connected layer will be returned. + n_dense_layers : Sequence[int] (optional, default=(6, 12, 24, 16)) + number of dense layers in each dense block. Thus, this parameter also defines the number of dense blocks. + Default is set to DenseNet-121 parameter. + init_features : int (optional, default=64) + number of feature maps after the initial convolution. Default is set to 64, as in the original paper. + growth_rate : int (optional, default=32) + how many feature maps to add at each dense layer. Default is set to 32, as in the original paper. + bottleneck_factor : int (optional, default=4) + multiplicative factor for bottleneck layers (1x1 convolutions). The output of of these bottleneck layers will + have `bottleneck_factor * growth_rate` feature maps. Default is 4, as in the original paper. + act : ActivationParameters (optional, default=("relu", {"inplace": True})) + the activation function used in the convolutional part, and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them.\n + Default is "relu", as in the original paper. + output_act : Optional[ActivationParameters] (optional, default=None) + if `num_outputs` is not None, a potential activation layer applied to the outputs of the network. + Should be pass in the same way as `act`. + If None, no last activation will be applied. + dropout : Optional[float] (optional, default=None) + dropout ratio. If None, no dropout. + + Examples + -------- + >>> DenseNet(spatial_dims=2, in_channels=1, num_outputs=2, output_act="softmax", n_dense_layers=(2, 2)) + DenseNet( + (features): Sequential( + (conv0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act0): ReLU(inplace=True) + (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + (denseblock1): _DenseBlock( + (denselayer1): _DenseLayer( + (layers): Sequential( + (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act2): ReLU(inplace=True) + (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + ) + ) + (denselayer2): _DenseLayer( + (layers): Sequential( + (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act2): ReLU(inplace=True) + (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + ) + ) + ) + (transition1): _Transition( + (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act): ReLU(inplace=True) + (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) + (pool): AvgPool2d(kernel_size=2, stride=2, padding=0) + ) + (denseblock2): _DenseBlock( + (denselayer1): _DenseLayer( + (layers): Sequential( + (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act2): ReLU(inplace=True) + (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + ) + ) + (denselayer2): _DenseLayer( + (layers): Sequential( + (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act2): ReLU(inplace=True) + (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + ) + ) + ) + (norm5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (fc): Sequential( + (act): ReLU(inplace=True) + (pool): AdaptiveAvgPool2d(output_size=1) + (flatten): Flatten(start_dim=1, end_dim=-1) + (out): Linear(in_features=128, out_features=2, bias=True) + (output_act): Softmax(dim=None) + ) + ) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_outputs: Optional[int], + n_dense_layers: Sequence[int] = (6, 12, 24, 16), + init_features: int = 64, + growth_rate: int = 32, + bottleneck_factor: int = 4, + act: ActivationParameters = ("relu", {"inplace": True}), + output_act: Optional[ActivationParameters] = None, + dropout: Optional[float] = None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.num_outputs = num_outputs + self.n_dense_layers = n_dense_layers + self.init_features = init_features + self.growth_rate = growth_rate + self.bottleneck_factor = bottleneck_factor + self.act = act + self.dropout = dropout + + base_densenet = BaseDenseNet( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_outputs if num_outputs else 1, + init_features=init_features, + growth_rate=growth_rate, + block_config=n_dense_layers, + bn_size=bottleneck_factor, + act=act, + dropout_prob=dropout if dropout else 0.0, + ) + self.features = base_densenet.features + self.fc = base_densenet.class_layers if num_outputs else None + if self.fc: + self.fc.output_act = get_act_layer(output_act) if output_act else None + + self._rename_act(self) + + @classmethod + def _rename_act(cls, module: nn.Module) -> None: + """ + Rename activation layers from 'relu' to 'act'. + """ + for name, layer in list(module.named_children()): + if "relu" in name: + module._modules = OrderedDict( # pylint: disable=protected-access + [ + (key.replace("relu", "act"), sub_m) + for key, sub_m in module._modules.items() # pylint: disable=protected-access + ] + ) + else: + cls._rename_act(layer) + + +class SOTADenseNet(str, Enum): + """Supported DenseNet networks.""" + + DENSENET_121 = "DenseNet-121" + DENSENET_161 = "DenseNet-161" + DENSENET_169 = "DenseNet-169" + DENSENET_201 = "DenseNet-201" + + +def get_densenet( + name: Union[str, SOTADenseNet], + num_outputs: Optional[int], + output_act: ActivationParameters = None, + pretrained: bool = False, +) -> DenseNet: + """ + To get a DenseNet implemented in the [Densely Connected Convolutional Networks](https://arxiv.org/pdf/1608.06993) + paper. + + Only the last fully connected layer will be changed to match `num_outputs`. + + The user can also use the pretrained models from `torchvision`. Note that the last fully connected layer will not + used pretrained weights, as it is task specific. + + .. warning:: `DenseNet-121`, `DenseNet-161`, `DenseNet-169` and `DenseNet-201` only works with 2D images with 3 channels. + + Notes: `torchvision` does not provide an implementation for `DenseNet-264` but provides a `DenseNet-161` that is not + mentioned in the paper. + + Parameters + ---------- + name : Union[str, SOTADenseNet] + The name of the DenseNet. Available networks are `DenseNet-121`, `DenseNet-161`, `DenseNet-169` and `DenseNet-201`. + num_outputs : Optional[int] + number of output variables after the last linear layer.\n + If None, the features before the last fully connected layer will be returned. + output_act : ActivationParameters (optional, default=None) + if `num_outputs` is not None, a potential activation layer applied to the outputs of the network, + and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + pretrained : bool (optional, default=False) + whether to use pretrained weights. The pretrained weights used are the default ones from [torchvision](https:// + pytorch.org/vision/main/models/densenet.html). + + Returns + ------- + DenseNet + The network, with potentially pretrained weights. + """ + name = SOTADenseNet(name) + if name == SOTADenseNet.DENSENET_121: + n_dense_layers = (6, 12, 24, 16) + growth_rate = 32 + init_features = 64 + model_url = DenseNet121_Weights.DEFAULT.url + elif name == SOTADenseNet.DENSENET_161: + n_dense_layers = (6, 12, 36, 24) + growth_rate = 48 + init_features = 96 + model_url = DenseNet161_Weights.DEFAULT.url + elif name == SOTADenseNet.DENSENET_169: + n_dense_layers = (6, 12, 32, 32) + growth_rate = 32 + init_features = 64 + model_url = DenseNet169_Weights.DEFAULT.url + elif name == SOTADenseNet.DENSENET_201: + n_dense_layers = (6, 12, 48, 32) + growth_rate = 32 + init_features = 64 + model_url = DenseNet201_Weights.DEFAULT.url + + # pylint: disable=possibly-used-before-assignment + densenet = DenseNet( + spatial_dims=2, + in_channels=3, + num_outputs=num_outputs, + n_dense_layers=n_dense_layers, + growth_rate=growth_rate, + init_features=init_features, + output_act=output_act, + ) + if not pretrained: + return densenet + + pretrained_dict = load_state_dict_from_url(model_url, progress=True) + features_state_dict = { + k.replace("features.", ""): v + for k, v in pretrained_dict.items() + if "classifier" not in k + } + densenet.features.load_state_dict(_state_dict_adapter(features_state_dict)) + + return densenet + + +def _state_dict_adapter(state_dict: Mapping[str, Any]) -> Mapping[str, Any]: + """ + To update the old nomenclature in the pretrained state dict. + Adapted from `_load_state_dict` in [torchvision.models.densenet](https://pytorch.org/vision/main + /_modules/torchvision/models/densenet.html). + """ + pattern = re.compile( + r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + ) + + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + new_key = re.sub(r"^(.*denselayer\d+)\.", r"\1.layers.", new_key) + state_dict[new_key] = state_dict[key] + del state_dict[key] + + return state_dict diff --git a/clinicadl/monai_networks/nn/generator.py b/clinicadl/monai_networks/nn/generator.py new file mode 100644 index 000000000..5f68a2e58 --- /dev/null +++ b/clinicadl/monai_networks/nn/generator.py @@ -0,0 +1,131 @@ +from typing import Any, Dict, Optional, Sequence + +import numpy as np +import torch.nn as nn +from monai.networks.layers.simplelayers import Reshape + +from .conv_decoder import ConvDecoder +from .mlp import MLP +from .utils import check_conv_args, check_mlp_args + + +class Generator(nn.Sequential): + """ + A generator with first fully connected layers and then convolutional layers. + + This network is a simple aggregation of a Multi Layer Perceptron (:py:class: + `clinicadl.monai_networks.nn.mlp.MLP`) and a Fully Convolutional Network + (:py:class:`clinicadl.monai_networks.nn.conv_decoder.ConvDecoder`). + + Parameters + ---------- + latent_size : int + size of the latent vector. + start_shape : Sequence[int] + sequence of integers stating the initial shape of the image, i.e. the shape at the + beginning of the convolutional part (minus batch dimension, but including the number + of channels).\n + Thus, `start_shape` determines the dimension of the output of the generator (the exact + shape depends on the convolutional part and can be accessed via the class attribute + `output_shape`). + conv_args : Dict[str, Any] + the arguments for the convolutional part. The arguments are those accepted by + :py:class:`clinicadl.monai_networks.nn.conv_decoder.ConvDecoder`, except `in_shape` that + is specified here via `start_shape`. So, the only mandatory argument is `channels`. + mlp_args : Optional[Dict[str, Any]] (optional, default=None) + the arguments for the MLP part. The arguments are those accepted by + :py:class:`clinicadl.monai_networks.nn.mlp.MLP`, except `in_channels` that is specified + here via `latent_size`, and `out_channels` that is inferred from `start_shape`. + So, the only mandatory argument is `hidden_channels`.\n + If None, the MLP part will be reduced to a single linear layer. + + Examples + -------- + >>> Generator( + latent_size=8, + start_shape=(8, 2, 2), + conv_args={"channels": [4, 2], "norm": None, "act": None}, + mlp_args={"hidden_channels": [16], "act": "elu", "norm": None}, + ) + Generator( + (mlp): MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (hidden0): Sequential( + (linear): Linear(in_features=8, out_features=16, bias=True) + (adn): ADN( + (A): ELU(alpha=1.0) + ) + ) + (output): Linear(in_features=16, out_features=32, bias=True) + ) + (reshape): Reshape() + (convolutions): ConvDecoder( + (layer0): Convolution( + (conv): ConvTranspose2d(8, 4, kernel_size=(3, 3), stride=(1, 1)) + ) + (layer1): Convolution( + (conv): ConvTranspose2d(4, 2, kernel_size=(3, 3), stride=(1, 1)) + ) + ) + ) + + >>> Generator( + latent_size=8, + start_shape=(8, 2, 2), + conv_args={"channels": [4, 2], "norm": None, "act": None, "output_act": "relu"}, + ) + Generator( + (mlp): MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (output): Linear(in_features=8, out_features=32, bias=True) + ) + (reshape): Reshape() + (convolutions): ConvDecoder( + (layer0): Convolution( + (conv): ConvTranspose2d(8, 4, kernel_size=(3, 3), stride=(1, 1)) + ) + (layer1): Convolution( + (conv): ConvTranspose2d(4, 2, kernel_size=(3, 3), stride=(1, 1)) + ) + (output_act): ReLU() + ) + ) + """ + + def __init__( + self, + latent_size: int, + start_shape: Sequence[int], + conv_args: Dict[str, Any], + mlp_args: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + check_conv_args(conv_args) + check_mlp_args(mlp_args) + self.latent_size = latent_size + self.start_shape = start_shape + + flatten_shape = int(np.prod(start_shape)) + if mlp_args is None: + mlp_args = {"hidden_channels": []} + self.mlp = MLP( + in_channels=latent_size, + out_channels=flatten_shape, + **mlp_args, + ) + + self.reshape = Reshape(*start_shape) + inter_channels, *inter_size = start_shape + self.convolutions = ConvDecoder( + in_channels=inter_channels, + spatial_dims=len(inter_size), + _input_size=inter_size, + **conv_args, + ) + + n_channels = ( + conv_args["channels"][-1] + if len(conv_args["channels"]) > 0 + else start_shape[0] + ) + self.output_shape = (n_channels, *self.convolutions.final_size) diff --git a/clinicadl/monai_networks/nn/layers/__init__.py b/clinicadl/monai_networks/nn/layers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/clinicadl/monai_networks/nn/layers/resnet.py b/clinicadl/monai_networks/nn/layers/resnet.py new file mode 100644 index 000000000..c115da512 --- /dev/null +++ b/clinicadl/monai_networks/nn/layers/resnet.py @@ -0,0 +1,124 @@ +from collections.abc import Callable +from typing import Optional + +import torch +import torch.nn as nn +from monai.networks.layers.factories import Conv, Norm +from monai.networks.layers.utils import get_act_layer + +from .utils import ActivationParameters + + +class ResNetBlock(nn.Module): + """ + ResNet basic block. Adapted from MONAI's implementation: + https://github.com/Project-MONAI/MONAI/blob/59a7211070538586369afd4a01eca0a7fe2e742e/ + monai/networks/nets/resnet.py#L71 + """ + + expansion = 1 + + def __init__( + self, + in_planes: int, + planes: int, + spatial_dims: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + act: ActivationParameters = ("relu", {"inplace": True}), + ) -> None: + super().__init__() + + conv_type: Callable = Conv[Conv.CONV, spatial_dims] + norm_type: Callable = Norm[Norm.BATCH, spatial_dims] + + self.conv1 = conv_type( # pylint: disable=not-callable + in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False + ) + self.norm1 = norm_type(planes) # pylint: disable=not-callable + self.act1 = get_act_layer(name=act) + self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False) # pylint: disable=not-callable + self.norm2 = norm_type(planes) # pylint: disable=not-callable + self.downsample = downsample + self.act2 = get_act_layer(name=act) + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + out: torch.Tensor = self.conv1(x) + out = self.norm1(out) + out = self.act1(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act2(out) + + return out + + +class ResNetBottleneck(nn.Module): + """ + ResNet bottleneck block. Adapted from MONAI's implementation: + https://github.com/Project-MONAI/MONAI/blob/59a7211070538586369afd4a01eca0a7fe2e742e/ + monai/networks/nets/resnet.py#L124 + """ + + expansion = 4 + + def __init__( + self, + in_planes: int, + planes: int, + spatial_dims: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + act: ActivationParameters = ("relu", {"inplace": True}), + ) -> None: + super().__init__() + + conv_type: Callable = Conv[Conv.CONV, spatial_dims] + norm_type: Callable = Norm[Norm.BATCH, spatial_dims] + + self.conv1 = conv_type(in_planes, planes, kernel_size=1, bias=False) # pylint: disable=not-callable + self.norm1 = norm_type(planes) # pylint: disable=not-callable + self.act1 = get_act_layer(name=act) + self.conv2 = conv_type( # pylint: disable=not-callable + planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.norm2 = norm_type(planes) # pylint: disable=not-callable + self.act2 = get_act_layer(name=act) + self.conv3 = conv_type( # pylint: disable=not-callable + planes, planes * self.expansion, kernel_size=1, bias=False + ) + self.norm3 = norm_type(planes * self.expansion) # pylint: disable=not-callable + self.downsample = downsample + self.act3 = get_act_layer(name=act) + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + out: torch.Tensor = self.conv1(x) + out = self.norm1(out) + out = self.act1(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.act2(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act3(out) + + return out diff --git a/clinicadl/monai_networks/nn/layers/senet.py b/clinicadl/monai_networks/nn/layers/senet.py new file mode 100644 index 000000000..8847ef577 --- /dev/null +++ b/clinicadl/monai_networks/nn/layers/senet.py @@ -0,0 +1,142 @@ +from typing import Callable, Optional + +import torch +import torch.nn as nn +from monai.networks.blocks.squeeze_and_excitation import ChannelSELayer +from monai.networks.layers.factories import Conv, Norm +from monai.networks.layers.utils import get_act_layer + +from .utils import ActivationParameters + + +class SEResNetBlock(nn.Module): + """ + ResNet basic block. Adapted from MONAI's ResNetBlock: + https://github.com/Project-MONAI/MONAI/blob/59a7211070538586369afd4a01eca0a7fe2e742e/ + monai/networks/nets/resnet.py#L71 + """ + + expansion = 1 + reduction = 16 + + def __init__( + self, + in_planes: int, + planes: int, + spatial_dims: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + act: ActivationParameters = ("relu", {"inplace": True}), + ) -> None: + super().__init__() + + conv_type: Callable = Conv[Conv.CONV, spatial_dims] + norm_type: Callable = Norm[Norm.BATCH, spatial_dims] + + self.conv1 = conv_type( # pylint: disable=not-callable + in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False + ) + self.norm1 = norm_type(planes) # pylint: disable=not-callable + self.act1 = get_act_layer(name=act) + self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False) # pylint: disable=not-callable + self.norm2 = norm_type(planes) # pylint: disable=not-callable + self.se_layer = ChannelSELayer( + spatial_dims=spatial_dims, + in_channels=planes, + r=self.reduction, + acti_type_1=("relu", {"inplace": True}), + acti_type_2="sigmoid", + ) + self.downsample = downsample + self.act2 = get_act_layer(name=act) + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.act1(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = self.se_layer(out) + out += residual + out = self.act2(out) + + return out + + +class SEResNetBottleneck(nn.Module): + """ + ResNet bottleneck block. Adapted from MONAI's ResNetBottleneck: + https://github.com/Project-MONAI/MONAI/blob/59a7211070538586369afd4a01eca0a7fe2e742e/ + monai/networks/nets/resnet.py#L124 + """ + + expansion = 4 + reduction = 16 + + def __init__( + self, + in_planes: int, + planes: int, + spatial_dims: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + act: ActivationParameters = ("relu", {"inplace": True}), + ) -> None: + super().__init__() + + conv_type: Callable = Conv[Conv.CONV, spatial_dims] + norm_type: Callable = Norm[Norm.BATCH, spatial_dims] + + self.conv1 = conv_type(in_planes, planes, kernel_size=1, bias=False) # pylint: disable=not-callable + self.norm1 = norm_type(planes) # pylint: disable=not-callable + self.act1 = get_act_layer(name=act) + self.conv2 = conv_type( # pylint: disable=not-callable + planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.norm2 = norm_type(planes) # pylint: disable=not-callable + self.act2 = get_act_layer(name=act) + self.conv3 = conv_type( # pylint: disable=not-callable + planes, planes * self.expansion, kernel_size=1, bias=False + ) + self.norm3 = norm_type(planes * self.expansion) # pylint: disable=not-callable + self.se_layer = ChannelSELayer( + spatial_dims=spatial_dims, + in_channels=planes * self.expansion, + r=self.reduction, + acti_type_1=("relu", {"inplace": True}), + acti_type_2="sigmoid", + ) + self.downsample = downsample + self.act3 = get_act_layer(name=act) + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.act1(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.act2(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = self.se_layer(out) + out += residual + out = self.act3(out) + + return out diff --git a/clinicadl/monai_networks/nn/layers/unet.py b/clinicadl/monai_networks/nn/layers/unet.py new file mode 100644 index 000000000..2186425be --- /dev/null +++ b/clinicadl/monai_networks/nn/layers/unet.py @@ -0,0 +1,102 @@ +from typing import Optional + +import torch.nn as nn +from monai.networks.blocks.convolutions import Convolution +from monai.networks.layers.utils import get_pool_layer + +from .utils import ActFunction, ActivationParameters, NormLayer + + +class ConvBlock(nn.Sequential): + """UNet doouble convolution block.""" + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + act: ActivationParameters = ActFunction.RELU, + dropout: Optional[float] = None, + ): + super().__init__() + self.add_module( + "0", + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + strides=1, + padding=None, + adn_ordering="NDA", + act=act, + norm=NormLayer.BATCH, + dropout=dropout, + ), + ) + self.add_module( + "1", + Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + strides=1, + padding=None, + adn_ordering="NDA", + act=act, + norm=NormLayer.BATCH, + dropout=dropout, + ), + ) + + +class UpSample(nn.Sequential): + """UNet up-conv block with first upsampling and then a convolution.""" + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + act: ActivationParameters = ActFunction.RELU, + dropout: Optional[float] = None, + ): + super().__init__() + self.add_module("0", nn.Upsample(scale_factor=2)) + self.add_module( + "1", + Convolution( + spatial_dims, + in_channels, + out_channels, + strides=1, + kernel_size=3, + act=act, + adn_ordering="NDA", + norm=NormLayer.BATCH, + dropout=dropout, + ), + ) + + +class DownBlock(nn.Sequential): + """UNet down block with first max pooling and then two convolutions.""" + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + act: ActivationParameters = ActFunction.RELU, + dropout: Optional[float] = None, + ): + super().__init__() + self.pool = get_pool_layer(("max", {"kernel_size": 2}), spatial_dims) + self.doubleconv = ConvBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + act=act, + dropout=dropout, + ) diff --git a/clinicadl/monai_networks/nn/layers/unpool.py b/clinicadl/monai_networks/nn/layers/unpool.py new file mode 100644 index 000000000..1c90fde90 --- /dev/null +++ b/clinicadl/monai_networks/nn/layers/unpool.py @@ -0,0 +1,87 @@ +from typing import Any, Dict, Optional, Tuple, Type, Union + +import torch.nn as nn +from monai.networks.layers.factories import LayerFactory, split_args +from monai.utils import has_option + +from .utils import UnpoolingLayer + +Unpool = LayerFactory( + name="Unpooling layers", description="Factory for creating unpooling layers." +) + + +@Unpool.factory_function("upsample") +def upsample_factory(dim: int) -> Type[nn.Upsample]: + """ + Upsample layer. + """ + return nn.Upsample + + +@Unpool.factory_function("convtranspose") +def convtranspose_factory( + dim: int, +) -> Type[Union[nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d]]: + """ + Transposed convolutional layers in 1,2,3 dimensions. + + Parameters + ---------- + dim : int + desired dimension of the transposed convolutional layer. + + Returns + ------- + type[Union[nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d]] + ConvTranspose[dim]d + """ + types = (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d) + return types[dim - 1] + + +def get_unpool_layer( + name: Union[UnpoolingLayer, Tuple[UnpoolingLayer, Dict[str, Any]]], + spatial_dims: int, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, +) -> nn.Module: + """ + Creates an unpooling layer instance. + + Parameters + ---------- + name : Union[UnpoolingLayer, Tuple[UnpoolingLayer, Dict[str, Any]]] + the unpooling type, potentially with arguments in a dict. + + Returns + ------- + nn.Module + the parametrized unpooling layer. + + Parameters + ---------- + name : Union[UnpoolingLayer, Tuple[UnpoolingLayer, Dict[str, Any]]] + the unpooling type, potentially with arguments in a dict. + spatial_dims : int + number of spatial dimensions of the input. + in_channels : Optional[int] (optional, default=None) + number of input channels if the unpool layer requires this parameter. + out_channels : Optional[int] (optional, default=None) + number of output channels if the unpool layer requires this parameter. + + Returns + ------- + nn.Module + the parametrized unpooling layer. + """ + unpool_name, unpool_args = split_args(name) + unpool_name = UnpoolingLayer(unpool_name) + unpool_type = Unpool[unpool_name, spatial_dims] + kw_args = dict(unpool_args) + if has_option(unpool_type, "in_channels") and "in_channels" not in kw_args: + kw_args["in_channels"] = in_channels + if has_option(unpool_type, "out_channels") and "out_channels" not in kw_args: + kw_args["out_channels"] = out_channels + + return unpool_type(**kw_args) # pylint: disable=not-callable diff --git a/clinicadl/monai_networks/nn/layers/utils/__init__.py b/clinicadl/monai_networks/nn/layers/utils/__init__.py new file mode 100644 index 000000000..5c080fffd --- /dev/null +++ b/clinicadl/monai_networks/nn/layers/utils/__init__.py @@ -0,0 +1,19 @@ +from .enum import ( + ActFunction, + ConvNormLayer, + NormLayer, + PoolingLayer, + UnpoolingLayer, + UnpoolingMode, +) +from .types import ( + ActivationParameters, + ConvNormalizationParameters, + ConvParameters, + NormalizationParameters, + PoolingParameters, + SingleLayerConvParameter, + SingleLayerPoolingParameters, + SingleLayerUnpoolingParameters, + UnpoolingParameters, +) diff --git a/clinicadl/monai_networks/nn/layers/utils/enum.py b/clinicadl/monai_networks/nn/layers/utils/enum.py new file mode 100644 index 000000000..695776551 --- /dev/null +++ b/clinicadl/monai_networks/nn/layers/utils/enum.py @@ -0,0 +1,65 @@ +from clinicadl.utils.enum import CaseInsensitiveEnum + + +class UnpoolingLayer(CaseInsensitiveEnum): + """Supported unpooling layers in ClinicaDL.""" + + CONV_TRANS = "convtranspose" + UPSAMPLE = "upsample" + + +class ActFunction(CaseInsensitiveEnum): + """Supported activation functions in ClinicaDL.""" + + ELU = "elu" + RELU = "relu" + LEAKY_RELU = "leakyrelu" + PRELU = "prelu" + RELU6 = "relu6" + SELU = "selu" + CELU = "celu" + GELU = "gelu" + SIGMOID = "sigmoid" + TANH = "tanh" + SOFTMAX = "softmax" + LOGSOFTMAX = "logsoftmax" + MISH = "mish" + + +class PoolingLayer(CaseInsensitiveEnum): + """Supported pooling layers in ClinicaDL.""" + + MAX = "max" + AVG = "avg" + ADAPT_AVG = "adaptiveavg" + ADAPT_MAX = "adaptivemax" + + +class NormLayer(CaseInsensitiveEnum): + """Supported normalization layers in ClinicaDL.""" + + GROUP = "group" + LAYER = "layer" + SYNCBATCH = "syncbatch" + BATCH = "batch" + INSTANCE = "instance" + + +class ConvNormLayer(CaseInsensitiveEnum): + """Supported normalization layers with convolutions in ClinicaDL.""" + + GROUP = "group" + SYNCBATCH = "syncbatch" + BATCH = "batch" + INSTANCE = "instance" + + +class UnpoolingMode(CaseInsensitiveEnum): + """Supported unpooling mode for AutoEncoders in ClinicaDL.""" + + NEAREST = "nearest" + LINEAR = "linear" + BILINEAR = "bilinear" + BICUBIC = "bicubic" + TRILINEAR = "trilinear" + CONV_TRANS = "convtranspose" diff --git a/clinicadl/monai_networks/nn/layers/utils/types.py b/clinicadl/monai_networks/nn/layers/utils/types.py new file mode 100644 index 000000000..f5ef18847 --- /dev/null +++ b/clinicadl/monai_networks/nn/layers/utils/types.py @@ -0,0 +1,37 @@ +from typing import Any, Dict, List, Tuple, Union + +from .enum import ( + ActFunction, + ConvNormLayer, + NormLayer, + PoolingLayer, + UnpoolingLayer, +) + +SingleLayerConvParameter = Union[int, Tuple[int, ...]] +ConvParameters = Union[SingleLayerConvParameter, List[SingleLayerConvParameter]] + +PoolingType = Union[str, PoolingLayer] +SingleLayerPoolingParameters = Tuple[PoolingType, Dict[str, Any]] +PoolingParameters = Union[ + SingleLayerPoolingParameters, List[SingleLayerPoolingParameters] +] + +UnpoolingType = Union[str, UnpoolingLayer] +SingleLayerUnpoolingParameters = Tuple[UnpoolingType, Dict[str, Any]] +UnpoolingParameters = Union[ + SingleLayerUnpoolingParameters, List[SingleLayerUnpoolingParameters] +] + +NormalizationType = Union[str, NormLayer] +NormalizationParameters = Union[ + NormalizationType, Tuple[NormalizationType, Dict[str, Any]] +] + +ConvNormalizationType = Union[str, ConvNormLayer] +ConvNormalizationParameters = Union[ + ConvNormalizationType, Tuple[ConvNormalizationType, Dict[str, Any]] +] + +ActivationType = Union[str, ActFunction] +ActivationParameters = Union[ActivationType, Tuple[ActivationType, Dict[str, Any]]] diff --git a/clinicadl/monai_networks/nn/layers/vit.py b/clinicadl/monai_networks/nn/layers/vit.py new file mode 100644 index 000000000..e485d6c6b --- /dev/null +++ b/clinicadl/monai_networks/nn/layers/vit.py @@ -0,0 +1,94 @@ +from functools import partial +from typing import Callable, Optional + +import torch +import torch.nn as nn +from torchvision.models.vision_transformer import MLPBlock + + +class EncoderBlock(nn.Module): + """Transformer encoder block.""" + + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ) -> None: + super().__init__() + self.num_heads = num_heads + + # Attention block + self.norm1 = norm_layer(hidden_dim) + self.self_attention = nn.MultiheadAttention( + hidden_dim, num_heads, dropout=attention_dropout, batch_first=True + ) + self.dropout = nn.Dropout(dropout) + + # MLP block + self.norm2 = norm_layer(hidden_dim) + self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + x = self.norm1(x) + x, _ = self.self_attention(x, x, x, need_weights=False) + x = self.dropout(x) + x += residual + + y = self.norm2(x) + y = self.mlp(y) + return x + y + + +class Encoder(nn.Module): + """Encoder with multiple transformer blocks.""" + + def __init__( + self, + seq_length: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + pos_embedding: Optional[nn.Parameter] = None, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ) -> None: + super().__init__() + + if pos_embedding is not None: + self.pos_embedding = pos_embedding + else: + self.pos_embedding = nn.Parameter( + torch.empty(1, seq_length, hidden_dim).normal_(std=0.02) + ) # from BERT + self.dropout = nn.Dropout(dropout) + self.layers = nn.ModuleList( + [ + EncoderBlock( + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, + ) + for _ in range(num_layers) + ] + ) + self.norm = norm_layer(hidden_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.pos_embedding + + x = self.dropout(x) + for layer in self.layers: + x = layer(x) + + return self.norm(x) diff --git a/clinicadl/monai_networks/nn/mlp.py b/clinicadl/monai_networks/nn/mlp.py new file mode 100644 index 000000000..a27b2ad4e --- /dev/null +++ b/clinicadl/monai_networks/nn/mlp.py @@ -0,0 +1,146 @@ +from collections import OrderedDict +from typing import Optional, Sequence + +import torch.nn as nn +from monai.networks.blocks import ADN +from monai.networks.layers.utils import get_act_layer +from monai.networks.nets import FullyConnectedNet as BaseMLP + +from .layers.utils import ( + ActFunction, + ActivationParameters, + NormalizationParameters, + NormLayer, +) +from .utils import check_adn_ordering, check_norm_layer + + +class MLP(BaseMLP): + """Simple full-connected layer neural network (or Multi-Layer Perceptron) with linear, normalization, activation + and dropout layers. + + Parameters + ---------- + in_channels : int + number of input channels (i.e. number of features). + out_channels : int + number of output channels. + hidden_channels : Sequence[int] + number of output channels for each hidden layer. Thus, this parameter also controls the number of hidden layers. + act : Optional[ActivationParameters] (optional, default=ActFunction.PRELU) + the activation function used after a linear layer, and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + output_act : Optional[ActivationParameters] (optional, default=None) + a potential activation layer applied to the output of the network. Should be pass in the same way as `act`. + If None, no last activation will be applied. + norm : Optional[NormalizationParameters] (optional, default=NormLayer.BATCH) + the normalization type used after a linear layer, and optionally the arguments of the normalization + layer. Should be passed as `norm_type` or `(norm_type, parameters)`. If None, no normalization will be + performed.\n + `norm_type` can be any value in {`batch`, `group`, `instance`, `layer`, `syncbatch`}. Please refer to PyTorch's + [normalization layers](https://pytorch.org/docs/stable/nn.html#normalization-layers) to know the mandatory and + optional arguments for each of them.\n + Please note that arguments `num_channels`, `num_features` and `normalized_shape` of the normalization layer + should not be passed, as they are automatically inferred from the output of the previous layer in the network. + dropout : Optional[float] (optional, default=None) + dropout ratio. If None, no dropout. + bias : bool (optional, default=True) + whether to have a bias term in linear layers. + adn_ordering : str (optional, default="NDA") + order of operations `Activation`, `Dropout` and `Normalization` after a linear layer (except the last + one). + For example if "ND" is passed, `Normalization` and then `Dropout` will be performed (without `Activation`).\n + Note: ADN will not be applied after the last linear layer. + + Examples + -------- + >>> MLP(in_channels=12, out_channels=2, hidden_channels=[8, 4], dropout=0.1, act=("elu", {"alpha": 0.5}), + norm=("group", {"num_groups": 2}), bias=True, adn_ordering="ADN", output_act="softmax") + MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (hidden0): Sequential( + (linear): Linear(in_features=12, out_features=8, bias=True) + (adn): ADN( + (A): ELU(alpha=0.5) + (D): Dropout(p=0.1, inplace=False) + (N): GroupNorm(2, 8, eps=1e-05, affine=True) + ) + ) + (hidden1): Sequential( + (linear): Linear(in_features=8, out_features=4, bias=True) + (adn): ADN( + (A): ELU(alpha=0.5) + (D): Dropout(p=0.1, inplace=False) + (N): GroupNorm(2, 4, eps=1e-05, affine=True) + ) + ) + (output): Sequential( + (linear): Linear(in_features=4, out_features=2, bias=True) + (output_act): Softmax(dim=None) + ) + ) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: Sequence[int], + act: Optional[ActivationParameters] = ActFunction.PRELU, + output_act: Optional[ActivationParameters] = None, + norm: Optional[NormalizationParameters] = NormLayer.BATCH, + dropout: Optional[float] = None, + bias: bool = True, + adn_ordering: str = "NDA", + ) -> None: + self.norm = check_norm_layer(norm) + super().__init__( + in_channels, + out_channels, + hidden_channels, + dropout, + act, + bias, + check_adn_ordering(adn_ordering), + ) + self.output = nn.Sequential(OrderedDict([("linear", self.output)])) + self.output.output_act = get_act_layer(output_act) if output_act else None + # renaming + self._modules = OrderedDict( + [ + (key.replace("hidden_", "hidden"), sub_m) + for key, sub_m in self._modules.items() + ] + ) + + def _get_layer(self, in_channels: int, out_channels: int, bias: bool) -> nn.Module: + """ + Gets the parametrized Linear layer + ADN block. + """ + if self.norm == NormLayer.LAYER: + norm = ("layer", {"normalized_shape": out_channels}) + else: + norm = self.norm + seq = nn.Sequential( + OrderedDict( + [ + ("linear", nn.Linear(in_channels, out_channels, bias)), + ( + "adn", + ADN( + ordering=self.adn_ordering, + act=self.act, + norm=norm, + dropout=self.dropout, + dropout_dim=1, + in_channels=out_channels, + ), + ), + ] + ) + ) + return seq diff --git a/clinicadl/monai_networks/nn/resnet.py b/clinicadl/monai_networks/nn/resnet.py new file mode 100644 index 000000000..1ba90b30c --- /dev/null +++ b/clinicadl/monai_networks/nn/resnet.py @@ -0,0 +1,566 @@ +import re +from collections import OrderedDict +from copy import deepcopy +from enum import Enum +from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union + +import torch +import torch.nn as nn +from monai.networks.layers.factories import Conv, Norm, Pool +from monai.networks.layers.utils import get_act_layer +from monai.utils import ensure_tuple_rep +from torch.hub import load_state_dict_from_url +from torchvision.models.resnet import ( + ResNet18_Weights, + ResNet34_Weights, + ResNet50_Weights, + ResNet101_Weights, + ResNet152_Weights, +) + +from .layers.resnet import ResNetBlock, ResNetBottleneck +from .layers.senet import SEResNetBlock, SEResNetBottleneck +from .layers.utils import ActivationParameters + + +class ResNetBlockType(str, Enum): + """Supported ResNet blocks.""" + + BASIC = "basic" + BOTTLENECK = "bottleneck" + + +class GeneralResNet(nn.Module): + """Common base class for ResNet and SEResNet.""" + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_outputs: Optional[int], + block_type: Union[str, ResNetBlockType], + n_res_blocks: Sequence[int], + n_features: Sequence[int], + init_conv_size: Union[Sequence[int], int], + init_conv_stride: Union[Sequence[int], int], + bottleneck_reduction: int, + se_reduction: Optional[int], + act: ActivationParameters, + output_act: ActivationParameters, + ) -> None: + super().__init__() + + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.num_outputs = num_outputs + self.block_type = block_type + self._check_args_consistency(n_res_blocks, n_features) + self.n_res_blocks = n_res_blocks + self.n_features = n_features + self.bottleneck_reduction = bottleneck_reduction + self.se_reduction = se_reduction + self.act = act + self.squeeze_excitation = True if se_reduction else False + + self.init_conv_size = ensure_tuple_rep(init_conv_size, spatial_dims) + self.init_conv_stride = ensure_tuple_rep(init_conv_stride, spatial_dims) + + block, in_planes = self._get_block(block_type) + + conv_type, norm_type, pool_type, avgp_type = self._get_layers() + + block_avgpool = [0, 1, (1, 1), (1, 1, 1)] + + self.in_planes = in_planes[0] + self.n_layers = len(in_planes) + self.bias_downsample = False + + self.conv0 = conv_type( # pylint: disable=not-callable + in_channels, + self.in_planes, + kernel_size=self.init_conv_size, + stride=self.init_conv_stride, + padding=tuple(k // 2 for k in self.init_conv_size), + bias=False, + ) + self.norm0 = norm_type(self.in_planes) # pylint: disable=not-callable + self.act0 = get_act_layer(name=act) + self.pool0 = pool_type(kernel_size=3, stride=2, padding=1) # pylint: disable=not-callable + self.layer1 = self._make_resnet_layer( + block, in_planes[0], n_res_blocks[0], spatial_dims, act + ) + for i, (n_blocks, n_feats) in enumerate( + zip(n_res_blocks[1:], in_planes[1:]), start=2 + ): + self.add_module( + f"layer{i}", + self._make_resnet_layer( + block, + planes=n_feats, + blocks=n_blocks, + spatial_dims=spatial_dims, + stride=2, + act=act, + ), + ) + self.fc = ( + nn.Sequential( + OrderedDict( + [ + ("pool", avgp_type(block_avgpool[spatial_dims])), # pylint: disable=not-callable + ("flatten", nn.Flatten(1)), + ("out", nn.Linear(n_features[-1], num_outputs)), + ] + ) + ) + if num_outputs + else None + ) + if self.fc: + self.fc.output_act = get_act_layer(output_act) if output_act else None + + self._init_module(conv_type, norm_type) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv0(x) + x = self.norm0(x) + x = self.act0(x) + x = self.pool0(x) + + for i in range(1, self.n_layers + 1): + x = self.get_submodule(f"layer{i}")(x) + + if self.fc is not None: + x = self.fc(x) + + return x + + def _get_block(self, block_type: Union[str, ResNetBlockType]) -> nn.Module: + """ + Gets the residual block, depending on the block choice made by the user and depending + on whether squeeze-excitation mode or not. + """ + block_type = ResNetBlockType(block_type) + if block_type == ResNetBlockType.BASIC: + in_planes = self.n_features + if self.squeeze_excitation: + block = SEResNetBlock + block.reduction = self.se_reduction + else: + block = ResNetBlock + elif block_type == ResNetBlockType.BOTTLENECK: + in_planes = self._bottleneck_reduce( + self.n_features, self.bottleneck_reduction + ) + if self.squeeze_excitation: + block = SEResNetBottleneck + block.reduction = self.se_reduction + else: + block = ResNetBottleneck + block.expansion = self.bottleneck_reduction + + return block, in_planes + + def _get_layers(self): + """ + Gets convolution, normalization, pooling and adaptative average pooling layers. + """ + conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[ + Conv.CONV, self.spatial_dims + ] + norm_type: Type[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[ + Norm.BATCH, self.spatial_dims + ] + pool_type: Type[Union[nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]] = Pool[ + Pool.MAX, self.spatial_dims + ] + avgp_type: Type[ + Union[nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d] + ] = Pool[Pool.ADAPTIVEAVG, self.spatial_dims] + + return conv_type, norm_type, pool_type, avgp_type + + def _make_resnet_layer( + self, + block: Type[Union[ResNetBlock, ResNetBottleneck]], + planes: int, + blocks: int, + spatial_dims: int, + act: ActivationParameters, + stride: int = 1, + ) -> nn.Sequential: + """ + Builds a ResNet layer. + """ + conv_type: Callable = Conv[Conv.CONV, spatial_dims] + norm_type: Callable = Norm[Norm.BATCH, spatial_dims] + + downsample = None + if stride != 1 or self.in_planes != planes * block.expansion: + downsample = nn.Sequential( + conv_type( # pylint: disable=not-callable + self.in_planes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=self.bias_downsample, + ), + norm_type(planes * block.expansion), # pylint: disable=not-callable + ) + + layers = [ + block( + in_planes=self.in_planes, + planes=planes, + spatial_dims=spatial_dims, + stride=stride, + downsample=downsample, + act=act, + ) + ] + + self.in_planes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.in_planes, planes, spatial_dims=spatial_dims, act=act) + ) + + return nn.Sequential(*layers) + + def _init_module( + self, conv_type: Type[nn.Module], norm_type: Type[nn.Module] + ) -> None: + """ + Initializes the parameters. + """ + for m in self.modules(): + if isinstance(m, conv_type): + nn.init.kaiming_normal_( + torch.as_tensor(m.weight), mode="fan_out", nonlinearity="relu" + ) + elif isinstance(m, norm_type): + nn.init.constant_(torch.as_tensor(m.weight), 1) + nn.init.constant_(torch.as_tensor(m.bias), 0) + elif isinstance(m, nn.Linear): + nn.init.constant_(torch.as_tensor(m.bias), 0) + + @classmethod + def _bottleneck_reduce( + cls, n_features: Sequence[int], bottleneck_reduction: int + ) -> Sequence[int]: + """ + Finds number of feature maps for the bottleneck layers. + """ + reduced_features = [] + for n in n_features: + if n % bottleneck_reduction != 0: + raise ValueError( + "All elements of n_features must be divisible by bottleneck_reduction. " + f"Got {n} in n_features and bottleneck_reduction={bottleneck_reduction}" + ) + reduced_features.append(n // bottleneck_reduction) + + return reduced_features + + @classmethod + def _check_args_consistency( + cls, n_res_blocks: Sequence[int], n_features: Sequence[int] + ) -> None: + """ + Checks consistency between `n_res_blocks` and `n_features`. + """ + if not isinstance(n_res_blocks, Sequence): + raise ValueError(f"n_res_blocks must be a sequence, got {n_res_blocks}") + if not isinstance(n_features, Sequence): + raise ValueError(f"n_features must be a sequence, got {n_features}") + if len(n_features) != len(n_res_blocks): + raise ValueError( + f"n_features and n_res_blocks must have the same length, got n_features={n_features} " + f"and n_res_blocks={n_res_blocks}" + ) + + +class ResNet(GeneralResNet): + """ + ResNet based on the [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) paper. + Adapted from [MONAI's implementation](https://docs.monai.io/en/stable/networks.html#resnet). + + The user can customize the number of residual blocks, the number of downsampling blocks, the number of channels + in each block, as well as other parameters like the type of residual block used. + + ResNet is a fully convolutional network that can work with input of any size, provided that is it large + enough not to be reduced to a 1-pixel image (before the adaptative average pooling). + + Parameters + ---------- + spatial_dims : int + number of spatial dimensions of the input image. + in_channels : int + number of channels in the input image. + num_outputs : Optional[int] + number of output variables after the last linear layer.\n + If None, the features before the last fully connected layer (including average pooling) will be returned. + block_type : Union[str, ResNetBlockType] (optional, default=ResNetBlockType.BASIC) + type of residual block. Either `basic` or `bottleneck`. Default to `basic`, as in `ResNet-18`. + n_res_blocks : Sequence[int] (optional, default=(2, 2, 2, 2)) + number of residual block in each ResNet layer. A ResNet layer refers here to the set of residual blocks + between two downsamplings. The length of `n_res_blocks` thus determines the number of ResNet layers. + Default to `(2, 2, 2, 2)`, as in `ResNet-18`. + n_features : Sequence[int] (optional, default=(64, 128, 256, 512)) + number of output feature maps for each ResNet layer. The length of `n_features` must be equal to the length + of `n_res_blocks`. Default to `(64, 128, 256, 512)`, as in `ResNet-18`. + init_conv_size : Union[Sequence[int], int] (optional, default=7) + kernel_size for the first convolution. + If tuple, it will be understood as the values for each dimension. + Default to 7, as in the original paper. + init_conv_stride : Union[Sequence[int], int] (optional, default=2) + stride for the first convolution. + If tuple, it will be understood as the values for each dimension. + Default to 2, as in the original paper. + bottleneck_reduction : int (optional, default=4) + if `block_type='bottleneck'`, `bottleneck_reduction` determines the reduction factor for the number + of feature maps in bottleneck layers (1x1 convolutions). Default to 4, as in the original paper. + act : ActivationParameters (optional, default=("relu", {"inplace": True})) + the activation function used in the convolutional part, and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them.\n + Default is "relu", as in the original paper. + output_act : Optional[ActivationParameters] (optional, default=None) + if `num_outputs` is not None, a potential activation layer applied to the outputs of the network. + Should be pass in the same way as `act`. + If None, no last activation will be applied. + + Examples + -------- + >>> ResNet( + spatial_dims=2, + in_channels=1, + num_outputs=2, + block_type="bottleneck", + bottleneck_reduction=4, + n_features=(8, 16), + n_res_blocks=(2, 2), + output_act="softmax", + init_conv_size=5, + ) + ResNet( + (conv0): Conv2d(1, 2, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False) + (norm0): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act0): ReLU(inplace=True) + (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + (layer1): Sequential( + (0): ResNetBottleneck( + (conv1): Conv2d(2, 2, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv2): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (norm2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act2): ReLU(inplace=True) + (conv3): Conv2d(2, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (downsample): Sequential( + (0): Conv2d(2, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) + (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (act3): ReLU(inplace=True) + ) + (1): ResNetBottleneck( + (conv1): Conv2d(8, 2, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv2): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (norm2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act2): ReLU(inplace=True) + (conv3): Conv2d(2, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act3): ReLU(inplace=True) + ) + ) + (layer2): Sequential( + (0): ResNetBottleneck( + (conv1): Conv2d(8, 4, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv2): Conv2d(4, 4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + (norm2): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act2): ReLU(inplace=True) + (conv3): Conv2d(4, 16, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (downsample): Sequential( + (0): Conv2d(8, 16, kernel_size=(1, 1), stride=(2, 2), bias=False) + (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (act3): ReLU(inplace=True) + ) + (1): ResNetBottleneck( + (conv1): Conv2d(16, 4, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv2): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (norm2): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act2): ReLU(inplace=True) + (conv3): Conv2d(4, 16, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act3): ReLU(inplace=True) + ) + ) + (fc): Sequential( + (pool): AdaptiveAvgPool2d(output_size=(1, 1)) + (flatten): Flatten(start_dim=1, end_dim=-1) + (out): Linear(in_features=16, out_features=2, bias=True) + (output_act): Softmax(dim=None) + ) + ) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_outputs: Optional[int], + block_type: Union[str, ResNetBlockType] = ResNetBlockType.BASIC, + n_res_blocks: Sequence[int] = (2, 2, 2, 2), + n_features: Sequence[int] = (64, 128, 256, 512), + init_conv_size: Union[Sequence[int], int] = 7, + init_conv_stride: Union[Sequence[int], int] = 2, + bottleneck_reduction: int = 4, + act: ActivationParameters = ("relu", {"inplace": True}), + output_act: Optional[ActivationParameters] = None, + ) -> None: + super().__init__( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_outputs=num_outputs, + block_type=block_type, + n_res_blocks=n_res_blocks, + n_features=n_features, + init_conv_size=init_conv_size, + init_conv_stride=init_conv_stride, + bottleneck_reduction=bottleneck_reduction, + se_reduction=None, + act=act, + output_act=output_act, + ) + + +class SOTAResNet(str, Enum): + """Supported ResNet networks.""" + + RESNET_18 = "ResNet-18" + RESNET_34 = "ResNet-34" + RESNET_50 = "ResNet-50" + RESNET_101 = "ResNet-101" + RESNET_152 = "ResNet-152" + + +def get_resnet( + name: Union[str, SOTAResNet], + num_outputs: Optional[int], + output_act: ActivationParameters = None, + pretrained: bool = False, +) -> ResNet: + """ + To get a ResNet implemented in the [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) + paper. + + Only the last fully connected layer will be changed to match `num_outputs`. + + The user can also use the pretrained models from `torchvision`. Note that the last fully connected layer will not + used pretrained weights, as it is task specific. + + .. warning:: `ResNet-18`, `ResNet-34`, `ResNet-50`, `ResNet-101` and `ResNet-152` only works with 2D images with 3 + channels. + + Parameters + ---------- + model : Union[str, SOTAResNet] + The name of the ResNet. Available networks are `ResNet-18`, `ResNet-34`, `ResNet-50`, `ResNet-101` and `ResNet-152`. + num_outputs : Optional[int] + number of output variables after the last linear layer.\n + If None, the features before the last fully connected layer will be returned. + output_act : ActivationParameters (optional, default=None) + if `num_outputs` is not None, a potential activation layer applied to the outputs of the network, + and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + pretrained : bool (optional, default=False) + whether to use pretrained weights. The pretrained weights used are the default ones from [torchvision](https:// + pytorch.org/vision/main/models/resnet.html). + + Returns + ------- + ResNet + The network, with potentially pretrained weights. + """ + name = SOTAResNet(name) + if name == SOTAResNet.RESNET_18: + block_type = ResNetBlockType.BASIC + n_res_blocks = (2, 2, 2, 2) + n_features = (64, 128, 256, 512) + model_url = ResNet18_Weights.DEFAULT.url + elif name == SOTAResNet.RESNET_34: + block_type = ResNetBlockType.BASIC + n_res_blocks = (3, 4, 6, 3) + n_features = (64, 128, 256, 512) + model_url = ResNet34_Weights.DEFAULT.url + elif name == SOTAResNet.RESNET_50: + block_type = ResNetBlockType.BOTTLENECK + n_res_blocks = (3, 4, 6, 3) + n_features = (256, 512, 1024, 2048) + model_url = ResNet50_Weights.DEFAULT.url + elif name == SOTAResNet.RESNET_101: + block_type = ResNetBlockType.BOTTLENECK + n_res_blocks = (3, 4, 23, 3) + n_features = (256, 512, 1024, 2048) + model_url = ResNet101_Weights.DEFAULT.url + elif name == SOTAResNet.RESNET_152: + block_type = ResNetBlockType.BOTTLENECK + n_res_blocks = (3, 8, 36, 3) + n_features = (256, 512, 1024, 2048) + model_url = ResNet152_Weights.DEFAULT.url + + # pylint: disable=possibly-used-before-assignment + resnet = ResNet( + spatial_dims=2, + in_channels=3, + num_outputs=num_outputs, + n_res_blocks=n_res_blocks, + block_type=block_type, + n_features=n_features, + output_act=output_act, + ) + if pretrained: + fc_layers = deepcopy(resnet.fc) + resnet.fc = None + pretrained_dict = load_state_dict_from_url(model_url, progress=True) + resnet.load_state_dict(_state_dict_adapter(pretrained_dict)) + resnet.fc = fc_layers + + return resnet + + +def _state_dict_adapter(state_dict: Mapping[str, Any]) -> Mapping[str, Any]: + """ + A mapping between torchvision's layer names and ours. + """ + state_dict = {k: v for k, v in state_dict.items() if "fc" not in k} + + mappings = [ + (r"(?>> SEResNet( + spatial_dims=2, + in_channels=1, + num_outputs=2, + block_type="basic", + se_reduction=2, + n_features=(8,), + n_res_blocks=(2,), + output_act="softmax", + init_conv_size=5, + ) + SEResNet( + (conv0): Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False) + (norm0): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act0): ReLU(inplace=True) + (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + (layer1): Sequential( + (0): SEResNetBlock( + (conv1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (norm1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (norm2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (se_layer): ChannelSELayer( + (avg_pool): AdaptiveAvgPool2d(output_size=1) + (fc): Sequential( + (0): Linear(in_features=8, out_features=4, bias=True) + (1): ReLU(inplace=True) + (2): Linear(in_features=4, out_features=8, bias=True) + (3): Sigmoid() + ) + ) + (act2): ReLU(inplace=True) + ) + (1): SEResNetBlock( + (conv1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (norm1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (norm2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (se_layer): ChannelSELayer( + (avg_pool): AdaptiveAvgPool2d(output_size=1) + (fc): Sequential( + (0): Linear(in_features=8, out_features=4, bias=True) + (1): ReLU(inplace=True) + (2): Linear(in_features=4, out_features=8, bias=True) + (3): Sigmoid() + ) + ) + (act2): ReLU(inplace=True) + ) + ) + (fc): Sequential( + (pool): AdaptiveAvgPool2d(output_size=(1, 1)) + (flatten): Flatten(start_dim=1, end_dim=-1) + (out): Linear(in_features=8, out_features=2, bias=True) + (output_act): Softmax(dim=None) + ) + ) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_outputs: Optional[int], + se_reduction: int = 16, + **kwargs: Any, + ) -> None: + # get defaults from resnet + _, default_resnet_args = get_args_and_defaults(ResNet.__init__) + for arg, value in default_resnet_args.items(): + if arg not in kwargs: + kwargs[arg] = value + + self._check_se_channels(kwargs["n_features"], se_reduction) + + super().__init__( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_outputs=num_outputs, + se_reduction=se_reduction, + **kwargs, + ) + + @classmethod + def _check_se_channels(cls, n_features: Sequence[int], se_reduction: int) -> None: + """ + Checks that the output of residual blocks always have a number of channels greater + than squeeze-excitation bottleneck reduction factor. + """ + if not isinstance(n_features, Sequence): + raise ValueError(f"n_features must be a sequence. Got {n_features}") + for n in n_features: + if n < se_reduction: + raise ValueError( + f"elements of n_features must be greater or equal to se_reduction. Got {n} in n_features " + f"and se_reduction={se_reduction}" + ) + + +class SOTAResNet(str, Enum): + """Supported SEResNet networks.""" + + SE_RESNET_50 = "SEResNet-50" + SE_RESNET_101 = "SEResNet-101" + SE_RESNET_152 = "SEResNet-152" + + +def get_seresnet( + name: Union[str, SOTAResNet], + num_outputs: Optional[int], + output_act: ActivationParameters = None, + pretrained: bool = False, +) -> SEResNet: + """ + To get a Squeeze-and-Excitation ResNet implemented in the [Squeeze-and-Excitation Networks](https://arxiv.org/pdf/ + 1709.01507) paper. + + Only the last fully connected layer will be changed to match `num_outputs`. + + .. warning:: `SEResNet-50`, `SEResNet-101` and `SEResNet-152` only works with 2D images with 3 channels. + + Note: pretrained weights are not yet available for these networks. + + Parameters + ---------- + model : Union[str, SOTAResNet] + the name of the SEResNet. Available networks are `SEResNet-50`, `SEResNet-101` and `SEResNet-152`. + num_outputs : Optional[int] + number of output variables after the last linear layer.\n + If None, the features before the last fully connected layer will be returned. + output_act : ActivationParameters (optional, default=None) + if `num_outputs` is not None, a potential activation layer applied to the outputs of the network, + and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + pretrained : bool (optional, default=False) + pretrained networks are not yet available for SE-ResNets. Leave this argument to False. + + Returns + ------- + SEResNet + the network. + """ + if pretrained is not False: + raise ValueError( + "Pretrained networks are not yet available for SE-ResNets. Please leave " + "'pretrained' to False." + ) + + name = SOTAResNet(name) + if name == SOTAResNet.SE_RESNET_50: + block_type = ResNetBlockType.BOTTLENECK + n_res_blocks = (3, 4, 6, 3) + n_features = (256, 512, 1024, 2048) + elif name == SOTAResNet.SE_RESNET_101: + block_type = ResNetBlockType.BOTTLENECK + n_res_blocks = (3, 4, 23, 3) + n_features = (256, 512, 1024, 2048) + elif name == SOTAResNet.SE_RESNET_152: + block_type = ResNetBlockType.BOTTLENECK + n_res_blocks = (3, 8, 36, 3) + n_features = (256, 512, 1024, 2048) + + # pylint: disable=possibly-used-before-assignment + resnet = SEResNet( + spatial_dims=2, + in_channels=3, + num_outputs=num_outputs, + n_res_blocks=n_res_blocks, + block_type=block_type, + n_features=n_features, + output_act=output_act, + ) + + return resnet diff --git a/clinicadl/monai_networks/nn/unet.py b/clinicadl/monai_networks/nn/unet.py new file mode 100644 index 000000000..dd1e59141 --- /dev/null +++ b/clinicadl/monai_networks/nn/unet.py @@ -0,0 +1,250 @@ +from abc import ABC, abstractmethod +from typing import Optional, Sequence + +import torch +import torch.nn as nn +from monai.networks.blocks.convolutions import Convolution +from monai.networks.layers.utils import get_act_layer + +from .layers.unet import ConvBlock, DownBlock, UpSample +from .layers.utils import ActFunction, ActivationParameters + + +class BaseUNet(nn.Module, ABC): + """Base class for UNet and AttentionUNet.""" + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int] = (64, 128, 256, 512, 1024), + act: ActivationParameters = ActFunction.RELU, + output_act: Optional[ActivationParameters] = None, + dropout: Optional[float] = None, + ): + super().__init__() + if not isinstance(channels, Sequence) or len(channels) < 2: + raise ValueError( + f"channels should be a sequence, whose length is no less than 2. Got {channels}" + ) + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.act = act + self.dropout = dropout + + self.doubleconv = ConvBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + act=act, + dropout=dropout, + ) + self._build_encoder() + self._build_decoder() + self.reduce_channels = Convolution( + spatial_dims=spatial_dims, + in_channels=channels[0], + out_channels=out_channels, + kernel_size=1, + strides=1, + padding=0, + conv_only=True, + ) + self.output_act = get_act_layer(output_act) if output_act else None + + @abstractmethod + def forward(self, x: torch.Tensor) -> torch.Tensor: + pass + + def _build_encoder(self) -> None: + for i in range(1, len(self.channels)): + self.add_module( + f"down{i}", + DownBlock( + spatial_dims=self.spatial_dims, + in_channels=self.channels[i - 1], + out_channels=self.channels[i], + act=self.act, + dropout=self.dropout, + ), + ) + + @abstractmethod + def _build_decoder(self) -> None: + pass + + +class UNet(BaseUNet): + """ + UNet based on [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/pdf/1505.04597). + + The user can customize the number of encoding blocks, the number of channels in each block, as well as other parameters + like the activation function. + + .. warning:: UNet works only with images whose dimensions are high enough powers of 2. More precisely, if n is the number + of max pooling operation in your UNet (which is equal to `len(channels)-1`), the image must have :math:`2^{k}` + pixels in each dimension, with :math:`k \\geq n` (e.g. shape (:math:`2^{n}`, :math:`2^{n+3}`) for a 2D image). + + Note: the implementation proposed here is not exactly the one described in the original paper. Padding is added to + convolutions so that the feature maps keep a constant size (except when they are passed to `max pool` or `up-sample` + layers), batch normalization is used, and `up-conv` layers are here made with an [Upsample](https://pytorch.org/docs/ + stable/generated/torch.nn.Upsample.html) layer followed by a 3x3 convolution. + + Parameters + ---------- + spatial_dims : int + number of spatial dimensions of the input image. + in_channels : int + number of channels in the input image. + out_channels : int + number of output channels. + channels : Sequence[int] (optional, default=(64, 128, 256, 512, 1024)) + sequence of integers stating the number of channels in each UNet block. Thus, this parameter also controls + the number of UNet blocks. The length `channels` should be nos less than 2.\n + Default to `(64, 128, 256, 512, 1024)`, as in the original paper. + act : ActivationParameters (optional, default=ActFunction.RELU) + the activation function used in the convolutional part, and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them.\n + Default is "relu", as in the original paper. + output_act : Optional[ActivationParameters] (optional, default=None) + a potential activation layer applied to the output of the network. Should be pass in the same way as `act`. + If None, no last activation will be applied. + dropout : Optional[float] (optional, default=None) + dropout ratio. If None, no dropout. + + Examples + -------- + >>> UNet( + spatial_dims=2, + in_channels=1, + out_channels=2, + channels=(4, 8), + act="elu", + output_act=("softmax", {"dim": 1}), + dropout=0.1, + ) + UNet( + (doubleconv): ConvBlock( + (0): Convolution( + (conv): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (1): Convolution( + (conv): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + ) + (down1): DownBlock( + (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + (doubleconv): ConvBlock( + (0): Convolution( + (conv): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (1): Convolution( + (conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + ) + ) + (upsample1): UpSample( + (0): Upsample(scale_factor=2.0, mode='nearest') + (1): Convolution( + (conv): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + ) + (doubleconv1): ConvBlock( + (0): Convolution( + (conv): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (1): Convolution( + (conv): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + ) + (reduce_channels): Convolution( + (conv): Conv2d(4, 2, kernel_size=(1, 1), stride=(1, 1)) + ) + (output_act): Softmax(dim=1) + ) + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_history = [self.doubleconv(x)] + + for i in range(1, len(self.channels)): + x = self.get_submodule(f"down{i}")(x_history[-1]) + x_history.append(x) + + x_history.pop() # the output of bottelneck is not used as a residual + for i in range(len(self.channels) - 1, 0, -1): + up = self.get_submodule(f"upsample{i}")(x) + merged = torch.cat((x_history.pop(), up), dim=1) + x = self.get_submodule(f"doubleconv{i}")(merged) + + out = self.reduce_channels(x) + + if self.output_act is not None: + out = self.output_act(out) + + return out + + def _build_decoder(self): + for i in range(len(self.channels) - 1, 0, -1): + self.add_module( + f"upsample{i}", + UpSample( + spatial_dims=self.spatial_dims, + in_channels=self.channels[i], + out_channels=self.channels[i - 1], + act=self.act, + dropout=self.dropout, + ), + ) + self.add_module( + f"doubleconv{i}", + ConvBlock( + spatial_dims=self.spatial_dims, + in_channels=self.channels[i - 1] * 2, + out_channels=self.channels[i - 1], + act=self.act, + dropout=self.dropout, + ), + ) diff --git a/clinicadl/monai_networks/nn/utils/__init__.py b/clinicadl/monai_networks/nn/utils/__init__.py new file mode 100644 index 000000000..ce603f205 --- /dev/null +++ b/clinicadl/monai_networks/nn/utils/__init__.py @@ -0,0 +1,14 @@ +from .checks import ( + check_adn_ordering, + check_conv_args, + check_mlp_args, + check_norm_layer, + check_pool_indices, + ensure_list_of_tuples, +) +from .shapes import ( + calculate_conv_out_shape, + calculate_convtranspose_out_shape, + calculate_pool_out_shape, + calculate_unpool_out_shape, +) diff --git a/clinicadl/monai_networks/nn/utils/checks.py b/clinicadl/monai_networks/nn/utils/checks.py new file mode 100644 index 000000000..1917a2894 --- /dev/null +++ b/clinicadl/monai_networks/nn/utils/checks.py @@ -0,0 +1,167 @@ +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +from ..layers.utils import ( + ConvParameters, + NormalizationParameters, + NormLayer, + PoolingLayer, +) + +__all__ = [ + "ensure_list_of_tuples", + "check_norm_layer", + "check_conv_args", + "check_mlp_args", + "check_pool_indices", +] + + +def ensure_list_of_tuples( + parameter: ConvParameters, dim: int, n_layers: int, name: str +) -> List[Tuple[int, ...]]: + """ + Checks spatial parameters (e.g. kernel_size) and returns a list of tuples. + Each element of the list corresponds to the parameters of one layer, and + each element of the tuple corresponds to the parameters for one dimension. + """ + parameter = _check_conv_parameter(parameter, dim, n_layers, name) + if isinstance(parameter, tuple): + return [parameter] * n_layers + else: + return parameter + + +def check_norm_layer( + norm: Optional[NormalizationParameters], +) -> Optional[NormalizationParameters]: + """ + Checks that the argument for normalization layers has the right format (i.e. + `norm_type` or (`norm_type`, `norm_layer_parameters`)) and checks potential + mandatory arguments in `norm_layer_parameters`. + """ + if norm is None: + return norm + + if not isinstance(norm, str) and not isinstance(norm, PoolingLayer): + if ( + not isinstance(norm, tuple) + or len(norm) != 2 + or not isinstance(norm[1], dict) + ): + raise ValueError( + "norm must be either the name of the normalization layer or a double with first the name and then the " + f"arguments of the layer in a dict. Got {norm}" + ) + norm_mode = NormLayer(norm[0]) + args = norm[1] + else: + norm_mode = NormLayer(norm) + args = {} + if norm_mode == NormLayer.GROUP and "num_groups" not in args: + raise ValueError( + f"num_groups is a mandatory argument for GroupNorm and must be passed in `norm`. Got `norm`={norm}" + ) + + return norm + + +def check_adn_ordering(adn: str) -> str: + """ + Checks ADN sequence. + """ + if not isinstance(adn, str): + raise ValueError(f"adn_ordering must be a string. Got {adn}") + + for letter in adn: + if letter not in { + "A", + "D", + "N", + }: + raise ValueError( + f"adn_ordering must be composed by 'A', 'D' or/and 'N'. Got {letter}" + ) + if len(adn) != len(set(adn)): + raise ValueError(f"adn_ordering cannot contain duplicated letter. Got {adn}") + + return adn + + +def check_conv_args(conv_args: Dict[str, Any]) -> None: + """ + Checks that `conv_args` is a dict with at least the mandatory argument `channels`. + """ + if not isinstance(conv_args, dict): + raise ValueError( + f"conv_args must be a dict with the arguments for the convolutional part. Got: {conv_args}" + ) + if "channels" not in conv_args: + raise ValueError( + "channels is a mandatory argument for the convolutional part and must therefore be " + f"passed in conv_args. Got conv_args={conv_args}" + ) + + +def check_mlp_args(mlp_args: Optional[Dict[str, Any]]) -> None: + """ + Checks that `mlp_args` is a dict with at least the mandatory argument `hidden_channels`. + """ + if mlp_args is not None: + if not isinstance(mlp_args, dict): + raise ValueError( + f"mlp_args must be a dict with the arguments for the MLP part. Got: {mlp_args}" + ) + if "hidden_channels" not in mlp_args: + raise ValueError( + "hidden_channels is a mandatory argument for the MLP part and must therefore be " + f"passed in mlp_args. Got mlp_args={mlp_args}" + ) + + +def check_pool_indices( + pooling_indices: Optional[Sequence[int]], n_layers: int +) -> Sequence[int]: + """ + Checks that the (un)pooling indices are consistent with the number of layers. + """ + if pooling_indices is not None: + for idx in pooling_indices: + if idx > n_layers - 1: + raise ValueError( + f"indices in (un)pooling_indices must be smaller than len(channels)-1, got (un)pooling_indices={pooling_indices} and len(channels)={n_layers}" + ) + elif idx < -1: + raise ValueError( + f"indices in (un)pooling_indices must be greater or equal to -1, got (un)pooling_indices={pooling_indices}" + ) + return sorted(pooling_indices) + else: + return [] + + +def _check_conv_parameter( + parameter: ConvParameters, dim: int, n_layers: int, name: str +) -> Union[Tuple[int, ...], List[Tuple[int, ...]]]: + """ + Checks spatial parameters (e.g. kernel_size). + """ + if isinstance(parameter, int): + return (parameter,) * dim + elif isinstance(parameter, tuple): + if len(parameter) != dim: + raise ValueError( + f"If a tuple is passed for {name}, its dimension must be {dim}. Got {parameter}" + ) + return parameter + elif isinstance(parameter, list): + if len(parameter) != n_layers: + raise ValueError( + f"If a list is passed, {name} must contain as many elements as there are layers. " + f"There are {n_layers} layers, but got {parameter}" + ) + checked_params = [] + for param in parameter: + checked_params.append(_check_conv_parameter(param, dim, n_layers, name)) + return checked_params + else: + raise ValueError(f"{name} must be an int, a tuple or a list. Got {name}") diff --git a/clinicadl/monai_networks/nn/utils/shapes.py b/clinicadl/monai_networks/nn/utils/shapes.py new file mode 100644 index 000000000..a649af076 --- /dev/null +++ b/clinicadl/monai_networks/nn/utils/shapes.py @@ -0,0 +1,203 @@ +from math import ceil +from typing import Optional, Sequence, Tuple, Union + +import numpy as np + +from ..layers.utils import PoolingLayer, UnpoolingLayer + +__all__ = [ + "calculate_conv_out_shape", + "calculate_convtranspose_out_shape", + "calculate_pool_out_shape", + "calculate_unpool_out_shape", +] + + +def calculate_conv_out_shape( + in_shape: Union[Sequence[int], int], + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int] = 1, + padding: Union[Sequence[int], int] = 0, + dilation: Union[Sequence[int], int] = 1, + **kwargs, # for uniformization +) -> Tuple[int, ...]: + """ + Calculates the output shape of a convolution layer. All arguments can be scalars or multiple + values. Always return a tuple. + """ + in_shape_np = np.atleast_1d(in_shape) + kernel_size_np = np.atleast_1d(kernel_size) + stride_np = np.atleast_1d(stride) + padding_np = np.atleast_1d(padding) + dilation_np = np.atleast_1d(dilation) + + out_shape_np = ( + (in_shape_np + 2 * padding_np - dilation_np * (kernel_size_np - 1) - 1) + / stride_np + ) + 1 + + return tuple(int(s) for s in out_shape_np) + + +def calculate_convtranspose_out_shape( + in_shape: Union[Sequence[int], int], + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int] = 1, + padding: Union[Sequence[int], int] = 0, + output_padding: Union[Sequence[int], int] = 0, + dilation: Union[Sequence[int], int] = 1, + **kwargs, # for uniformization +) -> Tuple[int, ...]: + """ + Calculates the output shape of a transposed convolution layer. All arguments can be scalars or + multiple values. Always return a tuple. + """ + in_shape_np = np.atleast_1d(in_shape) + kernel_size_np = np.atleast_1d(kernel_size) + stride_np = np.atleast_1d(stride) + padding_np = np.atleast_1d(padding) + dilation_np = np.atleast_1d(dilation) + output_padding_np = np.atleast_1d(output_padding) + + out_shape_np = ( + (in_shape_np - 1) * stride_np + - 2 * padding_np + + dilation_np * (kernel_size_np - 1) + + output_padding_np + + 1 + ) + + return tuple(int(s) for s in out_shape_np) + + +def calculate_pool_out_shape( + pool_mode: Union[str, PoolingLayer], + in_shape: Union[Sequence[int], int], + **kwargs, +) -> Tuple[int, ...]: + """ + Calculates the output shape of a pooling layer. The first argument is the type of pooling + performed (`max` or `avg`). All other arguments can be scalars or multiple values, except + `ceil_mode`. + Always return a tuple. + """ + pool_mode = PoolingLayer(pool_mode) + if pool_mode == PoolingLayer.MAX: + return _calculate_maxpool_out_shape(in_shape, **kwargs) + elif pool_mode == PoolingLayer.AVG: + return _calculate_avgpool_out_shape(in_shape, **kwargs) + elif pool_mode == PoolingLayer.ADAPT_MAX or pool_mode == PoolingLayer.ADAPT_AVG: + return _calculate_adaptivepool_out_shape(in_shape, **kwargs) + + +def calculate_unpool_out_shape( + unpool_mode: Union[str, UnpoolingLayer], + in_shape: Union[Sequence[int], int], + **kwargs, +) -> Tuple[int, ...]: + """ + Calculates the output shape of an unpooling layer. The first argument is the type of unpooling + performed (`upsample` or `convtranspose`). + Always return a tuple. + """ + unpool_mode = UnpoolingLayer(unpool_mode) + if unpool_mode == UnpoolingLayer.UPSAMPLE: + return _calculate_upsample_out_shape(in_shape, **kwargs) + elif unpool_mode == UnpoolingLayer.CONV_TRANS: + return calculate_convtranspose_out_shape(in_shape, **kwargs) + + +def _calculate_maxpool_out_shape( + in_shape: Union[Sequence[int], int], + kernel_size: Union[Sequence[int], int], + stride: Optional[Union[Sequence[int], int]] = None, + padding: Union[Sequence[int], int] = 0, + dilation: Union[Sequence[int], int] = 1, + ceil_mode: bool = False, + **kwargs, # for uniformization +) -> Tuple[int, ...]: + """ + Calculates the output shape of a MaxPool layer. + """ + if stride is None: + stride = kernel_size + + in_shape_np = np.atleast_1d(in_shape) + kernel_size_np = np.atleast_1d(kernel_size) + stride_np = np.atleast_1d(stride) + padding_np = np.atleast_1d(padding) + dilation_np = np.atleast_1d(dilation) + + out_shape_np = ( + (in_shape_np + 2 * padding_np - dilation_np * (kernel_size_np - 1) - 1) + / stride_np + ) + 1 + if ceil_mode: + out_shape = tuple(ceil(s) for s in out_shape_np) + else: + out_shape = tuple(int(s) for s in out_shape_np) + + return out_shape + + +def _calculate_avgpool_out_shape( + in_shape: Union[Sequence[int], int], + kernel_size: Union[Sequence[int], int], + stride: Optional[Union[Sequence[int], int]] = None, + padding: Union[Sequence[int], int] = 0, + ceil_mode: bool = False, + **kwargs, # for uniformization +) -> Tuple[int, ...]: + """ + Calculates the output shape of an AvgPool layer. + """ + if stride is None: + stride = kernel_size + + in_shape_np = np.atleast_1d(in_shape) + kernel_size_np = np.atleast_1d(kernel_size) + stride_np = np.atleast_1d(stride) + padding_np = np.atleast_1d(padding) + + out_shape_np = ((in_shape_np + 2 * padding_np - kernel_size_np) / stride_np) + 1 + if ceil_mode: + out_shape_np = np.ceil(out_shape_np) + out_shape_np[(out_shape_np - 1) * stride_np >= in_shape_np + padding_np] -= 1 + + return tuple(int(s) for s in out_shape_np) + + +def _calculate_adaptivepool_out_shape( + in_shape: Union[Sequence[int], int], + output_size: Union[Sequence[int], int], + **kwargs, # for uniformization +) -> Tuple[int, ...]: + """ + Calculates the output shape of an AdaptiveMaxPool or AdaptiveAvgPool layer. + """ + in_shape_np = np.atleast_1d(in_shape) + out_shape_np = np.ones_like(in_shape_np) * np.atleast_1d(output_size) + + return tuple(int(s) for s in out_shape_np) + + +def _calculate_upsample_out_shape( + in_shape: Union[Sequence[int], int], + scale_factor: Optional[Union[Sequence[int], int]] = None, + size: Optional[Union[Sequence[int], int]] = None, + **kwargs, # for uniformization +) -> Tuple[int, ...]: + """ + Calculates the output shape of an Upsample layer. + """ + in_shape_np = np.atleast_1d(in_shape) + if size and scale_factor: + raise ValueError("Pass either size or scale_factor, not both.") + elif size: + out_shape_np = np.ones_like(in_shape_np) * np.atleast_1d(size) + elif scale_factor: + out_shape_np = in_shape_np * scale_factor + else: + raise ValueError("Pass one of size or scale_factor.") + + return tuple(int(s) for s in out_shape_np) diff --git a/clinicadl/monai_networks/nn/vae.py b/clinicadl/monai_networks/nn/vae.py new file mode 100644 index 000000000..9dac6b43b --- /dev/null +++ b/clinicadl/monai_networks/nn/vae.py @@ -0,0 +1,200 @@ +from copy import deepcopy +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn + +from .autoencoder import AutoEncoder +from .layers.utils import ActivationParameters, UnpoolingMode + + +class VAE(nn.Module): + """ + A Variational AutoEncoder with convolutional and fully connected layers. + + The user must pass the arguments to build an encoder, from its convolutional and + fully connected parts, and the decoder will be automatically built by taking the + symmetrical network. + + More precisely, to build the decoder, the order of the encoding layers is reverted, convolutions are + replaced by transposed convolutions and pooling layers are replaced by either upsampling or transposed + convolution layers. + Please note that the order of `Activation`, `Dropout` and `Normalization`, defined with the + argument `adn_ordering` in `conv_args`, is the same for the encoder and the decoder. + + Note that an `AutoEncoder` is an aggregation of a `CNN` (:py:class:`clinicadl.monai_networks.nn. + cnn.CNN`), whose last linear layer is duplicated to infer both the mean and the log variance, + and a `Generator` (:py:class:`clinicadl.monai_networks.nn.generator.Generator`). + + Parameters + ---------- + in_shape : Sequence[int] + sequence of integers stating the dimension of the input tensor (minus batch dimension). + latent_size : int + size of the latent vector. + conv_args : Dict[str, Any] + the arguments for the convolutional part of the encoder. The arguments are those accepted + by :py:class:`clinicadl.monai_networks.nn.conv_encoder.ConvEncoder`, except `in_shape` that + is specified here. So, the only mandatory argument is `channels`. + mlp_args : Optional[Dict[str, Any]] (optional, default=None) + the arguments for the MLP part of the encoder . The arguments are those accepted by + :py:class:`clinicadl.monai_networks.nn.mlp.MLP`, except `in_channels` that is inferred + from the output of the convolutional part, and `out_channels` that is set to `latent_size`. + So, the only mandatory argument is `hidden_channels`.\n + If None, the MLP part will be reduced to a single linear layer.\n + The last linear layer will be duplicated to infer both the mean and the log variance. + out_channels : Optional[int] (optional, default=None) + number of output channels. If None, the output will have the same number of channels as the + input. + output_act : Optional[ActivationParameters] (optional, default=None) + a potential activation layer applied to the output of the network, and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + unpooling_mode : Union[str, UnpoolingMode] (optional, default=UnpoolingMode.NEAREST) + type of unpooling. Can be either `"nearest"`, `"linear"`, `"bilinear"`, `"bicubic"`, `"trilinear"` or + `"convtranspose"`.\n + - `nearest`: unpooling is performed by upsampling with the :italic:`nearest` algorithm (see [PyTorch's Upsample layer] + (https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html)). + - `linear`: unpooling is performed by upsampling with the :italic:`linear` algorithm. Only works with 1D images (excluding the + channel dimension). + - `bilinear`: unpooling is performed by upsampling with the :italic:`bilinear` algorithm. Only works with 2D images. + - `bicubic`: unpooling is performed by upsampling with the :italic:`bicubic` algorithm. Only works with 2D images. + - `trilinear`: unpooling is performed by upsampling with the :italic:`trilinear` algorithm. Only works with 3D images. + - `convtranspose`: unpooling is performed with a transposed convolution, whose parameters (kernel size, stride, etc.) are + computed to reverse the pooling operation. + + Examples + -------- + >>> VAE( + in_shape=(1, 16, 16), + latent_size=4, + conv_args={"channels": [2]}, + mlp_args={"hidden_channels": [16], "output_act": "relu"}, + out_channels=2, + output_act="sigmoid", + unpooling_mode="bilinear", + ) + VAE( + (encoder): CNN( + (convolutions): ConvEncoder( + (layer0): Convolution( + (conv): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)) + ) + ) + (mlp): MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (hidden0): Sequential( + (linear): Linear(in_features=392, out_features=16, bias=True) + (adn): ADN( + (N): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (A): PReLU(num_parameters=1) + ) + ) + (output): Identity() + ) + ) + (mu): Sequential( + (linear): Linear(in_features=16, out_features=4, bias=True) + (output_act): ReLU() + ) + (log_var): Sequential( + (linear): Linear(in_features=16, out_features=4, bias=True) + (output_act): ReLU() + ) + (decoder): Generator( + (mlp): MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (hidden0): Sequential( + (linear): Linear(in_features=4, out_features=16, bias=True) + (adn): ADN( + (N): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (A): PReLU(num_parameters=1) + ) + ) + (output): Sequential( + (linear): Linear(in_features=16, out_features=392, bias=True) + (output_act): ReLU() + ) + ) + (reshape): Reshape() + (convolutions): ConvDecoder( + (layer0): Convolution( + (conv): ConvTranspose2d(2, 2, kernel_size=(3, 3), stride=(1, 1)) + ) + (output_act): Sigmoid() + ) + ) + ) + """ + + def __init__( + self, + in_shape: Sequence[int], + latent_size: int, + conv_args: Dict[str, Any], + mlp_args: Optional[Dict[str, Any]] = None, + out_channels: Optional[int] = None, + output_act: Optional[ActivationParameters] = None, + unpooling_mode: Union[str, UnpoolingMode] = UnpoolingMode.NEAREST, + ) -> None: + super().__init__() + ae = AutoEncoder( + in_shape, + latent_size, + conv_args, + mlp_args, + out_channels, + output_act, + unpooling_mode, + ) + + # replace last mlp layer by two parallel layers + mu_layers = deepcopy(ae.encoder.mlp.output) + log_var_layers = deepcopy(ae.encoder.mlp.output) + self._reset_weights( + log_var_layers + ) # to have different initialization for the two layers + ae.encoder.mlp.output = nn.Identity() + + self.encoder = ae.encoder + self.mu = mu_layers + self.log_var = log_var_layers + self.decoder = ae.decoder + + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Encoding, sampling and decoding. + """ + feature = self.encoder(x) + mu = self.mu(feature) + log_var = self.log_var(feature) + z = self.reparameterize(mu, log_var) + + return self.decoder(z), mu, log_var + + def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: + """ + Samples a random vector from a gaussian distribution, given the mean and log-variance + of this distribution. + """ + std = torch.exp(0.5 * log_var) + + if self.training: # multiply random noise with std only during training + std = torch.randn_like(std).mul(std) + + return std.add_(mu) + + @classmethod + def _reset_weights(cls, layer: Union[nn.Sequential, nn.Linear]) -> None: + """ + Resets the output layer(s) of an MLP. + """ + if isinstance(layer, nn.Linear): + layer.reset_parameters() + else: + layer.linear.reset_parameters() diff --git a/clinicadl/monai_networks/nn/vit.py b/clinicadl/monai_networks/nn/vit.py new file mode 100644 index 000000000..372e1728a --- /dev/null +++ b/clinicadl/monai_networks/nn/vit.py @@ -0,0 +1,420 @@ +import math +import re +from collections import OrderedDict +from copy import deepcopy +from enum import Enum +from typing import Any, Mapping, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding +from monai.networks.layers import Conv +from monai.networks.layers.utils import get_act_layer +from monai.utils import ensure_tuple_rep +from torch.hub import load_state_dict_from_url +from torchvision.models.vision_transformer import ( + ViT_B_16_Weights, + ViT_B_32_Weights, + ViT_L_16_Weights, + ViT_L_32_Weights, +) + +from .layers.utils import ActFunction, ActivationParameters +from .layers.vit import Encoder + + +class PosEmbedType(str, Enum): + """Available position embedding types for ViT.""" + + LEARN = "learnable" + SINCOS = "sincos" + + +class ViT(nn.Module): + """ + Vision Transformer based on the [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale] + (https://arxiv.org/pdf/2010.11929) paper. + Adapted from [torchvision's implementation](https://pytorch.org/vision/main/models/vision_transformer.html). + + The user can customize the patch size, the embedding dimension, the number of transformer blocks, the number of + attention heads, as well as other parameters like the type of position embedding. + + Parameters + ---------- + in_shape : Sequence[int] + sequence of integers stating the dimension of the input tensor (minus batch dimension). + patch_size : Union[Sequence[int], int] + sequence of integers stating the patch size (minus batch and channel dimensions). If int, the same + patch size will be used for all dimensions. + Patch size must divide image size in all dimensions. + num_outputs : Optional[int] + number of output variables after the last linear layer.\n + If None, the patch embeddings after the last transformer block will be returned. + embedding_dim : int (optional, default=768) + size of the embedding vectors. Must be divisible by `num_heads` as each head will be responsible for + a part of the embedding vectors. Default to 768, as for 'ViT-Base' in the original paper. + num_layers : int (optional, default=12) + number of consecutive transformer blocks. Default to 12, as for 'ViT-Base' in the original paper. + num_heads : int (optional, default=12) + number of heads in the self-attention block. Must divide `embedding_size`. + Default to 12, as for 'ViT-Base' in the original paper. + mlp_dim : int (optional, default=3072) + size of the hidden layer in the MLP part of the transformer block. Default to 3072, as for 'ViT-Base' + in the original paper. + pos_embed_type : Optional[Union[str, PosEmbedType]] (optional, default="learnable") + type of position embedding. Can be either `"learnable"`, `"sincos"` or `None`.\n + - `learnable`: the position embeddings are parameters that will be learned during the training + process. + - `sincos`: the position embeddings are fixed and determined with sinus and cosinus formulas (based on Dosovitskiy et al., + 'Attention Is All You Need, https://arxiv.org/pdf/1706.03762). Only implemented for 2D and 3D images. With `sincos` + position embedding, `embedding_dim` must be divisible by 4 for 2D images and by 6 for 3D images. + - `None`: no position embeddings are used.\n + Default to `"learnable"`, as in the original paper. + output_act : Optional[ActivationParameters] (optional, default=ActFunction.TANH) + if `num_outputs` is not None, a potential activation layer applied to the outputs of the network, + and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them.\n + Default to `"tanh"`, as in the original paper. + dropout : Optional[float] (optional, default=None) + dropout ratio. If None, no dropout. + + Examples + -------- + >>> ViT( + in_shape=(3, 60, 64), + patch_size=4, + num_outputs=2, + embedding_dim=32, + num_layers=2, + num_heads=4, + mlp_dim=128, + output_act="softmax", + ) + ViT( + (conv_proj): Conv2d(3, 32, kernel_size=(4, 4), stride=(4, 4)) + (encoder): Encoder( + (dropout): Dropout(p=0.0, inplace=False) + (layers): ModuleList( + (0-1): 2 x EncoderBlock( + (norm1): LayerNorm((32,), eps=1e-06, elementwise_affine=True) + (self_attention): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True) + ) + (dropout): Dropout(p=0.0, inplace=False) + (norm2): LayerNorm((32,), eps=1e-06, elementwise_affine=True) + (mlp): MLPBlock( + (0): Linear(in_features=32, out_features=128, bias=True) + (1): GELU(approximate='none') + (2): Dropout(p=0.0, inplace=False) + (3): Linear(in_features=128, out_features=32, bias=True) + (4): Dropout(p=0.0, inplace=False) + ) + ) + ) + (norm): LayerNorm((32,), eps=1e-06, elementwise_affine=True) + ) + (fc): Sequential( + (out): Linear(in_features=32, out_features=2, bias=True) + (output_act): Softmax(dim=None) + ) + ) + """ + + def __init__( + self, + in_shape: Sequence[int], + patch_size: Union[Sequence[int], int], + num_outputs: Optional[int], + embedding_dim: int = 768, + num_layers: int = 12, + num_heads: int = 12, + mlp_dim: int = 3072, + pos_embed_type: Optional[Union[str, PosEmbedType]] = PosEmbedType.LEARN, + output_act: Optional[ActivationParameters] = ActFunction.TANH, + dropout: Optional[float] = None, + ) -> None: + super().__init__() + + self.in_channels, *self.img_size = in_shape + self.spatial_dims = len(self.img_size) + self.patch_size = ensure_tuple_rep(patch_size, self.spatial_dims) + + self._check_embedding_dim(embedding_dim, num_heads) + self._check_patch_size(self.img_size, self.patch_size) + self.embedding_dim = embedding_dim + self.classification = True if num_outputs else False + dropout = dropout if dropout else 0.0 + + self.conv_proj = Conv[Conv.CONV, self.spatial_dims]( # pylint: disable=not-callable + in_channels=self.in_channels, + out_channels=self.embedding_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + ) + self.seq_length = int( + np.prod(np.array(self.img_size) // np.array(self.patch_size)) + ) + + # Add a class token + if self.classification: + self.class_token = nn.Parameter(torch.zeros(1, 1, self.embedding_dim)) + self.seq_length += 1 + + pos_embedding = self._get_pos_embedding(pos_embed_type) + self.encoder = Encoder( + self.seq_length, + num_layers, + num_heads, + self.embedding_dim, + mlp_dim, + dropout=dropout, + attention_dropout=dropout, + pos_embedding=pos_embedding, + ) + + if self.classification: + self.class_token = nn.Parameter(torch.zeros(1, 1, embedding_dim)) + self.fc = nn.Sequential( + OrderedDict([("out", nn.Linear(embedding_dim, num_outputs))]) + ) + self.fc.output_act = get_act_layer(output_act) if output_act else None + else: + self.fc = None + + self._init_layers() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv_proj(x) + # (n, hidden_dim, n_h, n_w) -> (n, (h * w * d), hidden_dim) + x = x.flatten(2).transpose(-1, -2) + n = x.shape[0] + + # Expand the class token to the full batch + if self.fc: + batch_class_token = self.class_token.expand(n, -1, -1) + x = torch.cat([batch_class_token, x], dim=1) + + x = self.encoder(x) + + # Classifier "token" as used by standard language architectures + if self.fc: + x = x[:, 0] + x = self.fc(x) + + return x + + def _get_pos_embedding( + self, pos_embed_type: Optional[Union[str, PosEmbedType]] + ) -> Optional[nn.Parameter]: + """ + Gets position embeddings. If `pos_embed_type` is "learnable", will return None as it will be handled + by the encoder module. + """ + if pos_embed_type is None: + pos_embed = nn.Parameter( + torch.zeros(1, self.seq_length, self.embedding_dim) + ) + pos_embed.requires_grad = False + return pos_embed + + pos_embed_type = PosEmbedType(pos_embed_type) + + if pos_embed_type == PosEmbedType.LEARN: + return None # will be initialized inside the Encoder + + elif pos_embed_type == PosEmbedType.SINCOS: + if self.spatial_dims != 2 and self.spatial_dims != 3: + raise ValueError( + f"{self.spatial_dims}D sincos position embedding not implemented" + ) + elif self.spatial_dims == 2 and self.embedding_dim % 4: + raise ValueError( + f"embedding_dim must be divisible by 4 for 2D sincos position embedding. Got embedding_dim={self.embedding_dim}" + ) + elif self.spatial_dims == 3 and self.embedding_dim % 6: + raise ValueError( + f"embedding_dim must be divisible by 6 for 3D sincos position embedding. Got embedding_dim={self.embedding_dim}" + ) + grid_size = [] + for in_size, pa_size in zip(self.img_size, self.patch_size): + grid_size.append(in_size // pa_size) + pos_embed = build_sincos_position_embedding( + grid_size, self.embedding_dim, self.spatial_dims + ) + if self.classification: + pos_embed = torch.nn.Parameter( + torch.cat([torch.zeros(1, 1, self.embedding_dim), pos_embed], dim=1) + ) # add 0 for class token pos embedding + pos_embed.requires_grad = False + return pos_embed + + def _init_layers(self): + """ + Initializes some layers, based on torchvision's implementation: https://pytorch.org/vision/main/ + _modules/torchvision/models/vision_transformer.html + """ + fan_in = self.conv_proj.in_channels * np.prod(self.conv_proj.kernel_size) + nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) + nn.init.zeros_(self.conv_proj.bias) + + @classmethod + def _check_embedding_dim(cls, embedding_dim: int, num_heads: int) -> None: + """ + Checks consistency between embedding dimension and number of heads. + """ + if embedding_dim % num_heads != 0: + raise ValueError( + f"embedding_dim should be divisible by num_heads. Got embedding_dim={embedding_dim} " + f" and num_heads={num_heads}" + ) + + @classmethod + def _check_patch_size( + cls, img_size: Tuple[int, ...], patch_size: Tuple[int, ...] + ) -> None: + """ + Checks consistency between image size and patch size. + """ + for i, p in zip(img_size, patch_size): + if i % p != 0: + raise ValueError( + f"img_size should be divisible by patch_size. Got img_size={img_size} " + f" and patch_size={patch_size}" + ) + + +class SOTAViT(str, Enum): + """Supported ViT networks.""" + + B_16 = "ViT-B/16" + B_32 = "ViT-B/32" + L_16 = "ViT-L/16" + L_32 = "ViT-L/32" + + +def get_vit( + name: Union[str, SOTAViT], + num_outputs: Optional[int], + output_act: ActivationParameters = None, + pretrained: bool = False, +) -> ViT: + """ + To get a Vision Transformer implemented in the [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale] + (https://arxiv.org/pdf/2010.11929) paper. + + Only the last fully connected layer will be changed to match `num_outputs`. + + The user can also use the pretrained models from `torchvision`. Note that the last fully connected layer will not + used pretrained weights, as it is task specific. + + .. warning:: `ViT-B/16`, `ViT-B/32`, `ViT-L/16` and `ViT-L/32` work with 2D images of size (224, 224), with 3 channels. + + Parameters + ---------- + model : Union[str, SOTAViT] + The name of the Vision Transformer. Available networks are `ViT-B/16`, `ViT-B/32`, `ViT-L/16` and `ViT-L/32`. + num_outputs : Optional[int] + number of output variables after the last linear layer.\n + If None, the features before the last fully connected layer will be returned. + output_act : ActivationParameters (optional, default=None) + if `num_outputs` is not None, a potential activation layer applied to the outputs of the network, + and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + pretrained : bool (optional, default=False) + whether to use pretrained weights. The pretrained weights used are the default ones from [torchvision](https:// + pytorch.org/vision/main/models/vision_transformer.html). + + Returns + ------- + ViT + The network, with potentially pretrained weights. + """ + name = SOTAViT(name) + if name == SOTAViT.B_16: + in_shape = (3, 224, 224) + patch_size = 16 + embedding_dim = 768 + mlp_dim = 3072 + num_layers = 12 + num_heads = 12 + model_url = ViT_B_16_Weights.DEFAULT.url + elif name == SOTAViT.B_32: + in_shape = (3, 224, 224) + patch_size = 32 + embedding_dim = 768 + mlp_dim = 3072 + num_layers = 12 + num_heads = 12 + model_url = ViT_B_32_Weights.DEFAULT.url + elif name == SOTAViT.L_16: + in_shape = (3, 224, 224) + patch_size = 16 + embedding_dim = 1024 + mlp_dim = 4096 + num_layers = 24 + num_heads = 16 + model_url = ViT_L_16_Weights.DEFAULT.url + elif name == SOTAViT.L_32: + in_shape = (3, 224, 224) + patch_size = 32 + embedding_dim = 1024 + mlp_dim = 4096 + num_layers = 24 + num_heads = 16 + model_url = ViT_L_32_Weights.DEFAULT.url + + # pylint: disable=possibly-used-before-assignment + vit = ViT( + in_shape=in_shape, + patch_size=patch_size, + num_outputs=num_outputs, + embedding_dim=embedding_dim, + mlp_dim=mlp_dim, + num_heads=num_heads, + num_layers=num_layers, + output_act=output_act, + ) + + if pretrained: + pretrained_dict = load_state_dict_from_url(model_url, progress=True) + if num_outputs is None: + del pretrained_dict["class_token"] + pretrained_dict["encoder.pos_embedding"] = pretrained_dict[ + "encoder.pos_embedding" + ][:, 1:] # remove class token position embedding + fc_layers = deepcopy(vit.fc) + vit.fc = None + vit.load_state_dict(_state_dict_adapter(pretrained_dict)) + vit.fc = fc_layers + + return vit + + +def _state_dict_adapter(state_dict: Mapping[str, Any]) -> Mapping[str, Any]: + """ + A mapping between torchvision's layer names and ours. + """ + state_dict = {k: v for k, v in state_dict.items() if "heads" not in k} + + mappings = [ + ("ln_", "norm"), + ("ln", "norm"), + (r"encoder_layer_(\d+)", r"\1"), + ] + + for key in list(state_dict.keys()): + new_key = key + for transform in mappings: + new_key = re.sub(transform[0], transform[1], new_key) + state_dict[new_key] = state_dict.pop(key) + + return state_dict diff --git a/clinicadl/utils/enum.py b/clinicadl/utils/enum.py index 3e9031534..4e5c7721c 100644 --- a/clinicadl/utils/enum.py +++ b/clinicadl/utils/enum.py @@ -1,6 +1,17 @@ from enum import Enum +class CaseInsensitiveEnum(str, Enum): + @classmethod + def _missing_(cls, value): + if isinstance(value, str): + value = value.lower() + for member in cls: + if member.lower() == value: + return member + return None + + class BaseEnum(Enum): """Base Enum object that will print valid inputs if the value passed is not valid.""" diff --git a/tests/unittests/monai_networks/config/__init__.py b/tests/unittests/monai_networks/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unittests/monai_networks/config/test_autoencoder.py b/tests/unittests/monai_networks/config/test_autoencoder.py deleted file mode 100644 index 707695434..000000000 --- a/tests/unittests/monai_networks/config/test_autoencoder.py +++ /dev/null @@ -1,171 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.autoencoder import ( - AutoEncoderConfig, - VarAutoEncoderConfig, -) - - -@pytest.fixture -def dummy_arguments(): - args = { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "channels": [2, 4], - "latent_size": 16, - } - return args - - -@pytest.fixture( - params=[ - {"in_shape": (1, 10, 10), "strides": (1, 1), "dropout": 1.1}, - {"in_shape": (1, 10, 10), "strides": (1, 1), "kernel_size": 4}, - {"in_shape": (1, 10, 10), "strides": (1, 1), "kernel_size": (3,)}, - {"in_shape": (1, 10, 10), "strides": (1, 1), "kernel_size": (3, 3, 3)}, - {"in_shape": (1, 10, 10), "strides": (1, 1), "up_kernel_size": 4}, - {"in_shape": (1, 10, 10), "strides": (1, 1), "up_kernel_size": (3,)}, - {"in_shape": (1, 10, 10), "strides": (1, 1), "up_kernel_size": (3, 3, 3)}, - { - "in_shape": (1, 10, 10), - "strides": (1, 1), - "inter_channels": (2, 2), - "inter_dilations": (2,), - }, - {"in_shape": (1, 10, 10), "strides": (1, 1), "inter_dilations": (2, 2)}, - {"in_shape": (1, 10, 10), "strides": (1, 1), "padding": (1, 1, 1)}, - {"in_shape": (1, 10, 10), "strides": (1, 2, 3)}, - {"in_shape": (1, 10, 10), "strides": (1, (1, 2, 3))}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -@pytest.fixture( - params=[ - {"in_shape": (1,), "strides": (1, 1)}, - {"in_shape": (1, 10), "strides": (1, 1)}, - ] -) -def bad_inputs_vae(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - AutoEncoderConfig(**bad_inputs) - with pytest.raises(ValidationError): - VarAutoEncoderConfig(**bad_inputs) - - -def test_fails_validations_vae(bad_inputs_vae): - with pytest.raises(ValidationError): - VarAutoEncoderConfig(**bad_inputs_vae) - - -@pytest.fixture( - params=[ - { - "in_shape": (1, 10, 10), - "strides": (1, 1), - "dropout": 0.5, - "kernel_size": 5, - "inter_channels": (2, 2), - "inter_dilations": (3, 3), - "padding": (2, 2), - }, - { - "in_shape": (1, 10, 10), - "strides": ((1, 2), 1), - "kernel_size": (3, 3), - "padding": 2, - "up_kernel_size": 5, - }, - ] -) -def good_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_passes_validations(good_inputs): - AutoEncoderConfig(**good_inputs) - VarAutoEncoderConfig(**good_inputs) - - -def test_AutoEncoderConfig(): - config = AutoEncoderConfig( - spatial_dims=2, - in_channels=1, - out_channels=1, - channels=[2, 4], - strides=[1, 1], - kernel_size=(3, 5), - up_kernel_size=(3, 3), - num_res_units=1, - inter_channels=(2, 2), - inter_dilations=(3, 3), - num_inter_units=1, - norm=("BATCh", {"eps": 0.1}), - dropout=0.1, - bias=False, - padding=1, - ) - assert config.network == "AutoEncoder" - assert config.spatial_dims == 2 - assert config.in_channels == 1 - assert config.out_channels == 1 - assert config.channels == (2, 4) - assert config.strides == (1, 1) - assert config.kernel_size == (3, 5) - assert config.num_res_units == 1 - assert config.inter_channels == (2, 2) - assert config.inter_dilations == (3, 3) - assert config.num_inter_units == 1 - assert config.norm == ("batch", {"eps": 0.1}) - assert config.act == "DefaultFromLibrary" - assert config.dropout == 0.1 - assert not config.bias - assert config.padding == 1 - - -def test_VarAutoEncoderConfig(): - config = VarAutoEncoderConfig( - spatial_dims=2, - in_shape=(1, 10, 10), - out_channels=1, - latent_size=16, - channels=[2, 4], - strides=[1, 1], - kernel_size=(3, 5), - up_kernel_size=(3, 3), - num_res_units=1, - inter_channels=(2, 2), - inter_dilations=(3, 3), - num_inter_units=1, - norm=("BATCh", {"eps": 0.1}), - dropout=0.1, - bias=False, - padding=1, - use_sigmoid=False, - ) - assert config.network == "VarAutoEncoder" - assert config.spatial_dims == 2 - assert config.in_shape == (1, 10, 10) - assert config.out_channels == 1 - assert config.latent_size == 16 - assert config.channels == (2, 4) - assert config.strides == (1, 1) - assert config.kernel_size == (3, 5) - assert config.num_res_units == 1 - assert config.inter_channels == (2, 2) - assert config.inter_dilations == (3, 3) - assert config.num_inter_units == 1 - assert config.norm == ("batch", {"eps": 0.1}) - assert config.act == "DefaultFromLibrary" - assert config.dropout == 0.1 - assert not config.bias - assert config.padding == 1 - assert not config.use_sigmoid diff --git a/tests/unittests/monai_networks/config/test_classifier.py b/tests/unittests/monai_networks/config/test_classifier.py deleted file mode 100644 index f63b774d5..000000000 --- a/tests/unittests/monai_networks/config/test_classifier.py +++ /dev/null @@ -1,132 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.classifier import ( - ClassifierConfig, - CriticConfig, - DiscriminatorConfig, -) - - -@pytest.fixture -def dummy_arguments(): - args = { - "classes": 2, - "channels": [2, 4], - } - return args - - -@pytest.fixture( - params=[ - {"in_shape": (3,), "strides": (1, 1)}, - {"in_shape": (1, 3, 3), "strides": (1, 1), "dropout": 1.1}, - {"in_shape": (1, 3, 3), "strides": (1, 1), "kernel_size": 4}, - {"in_shape": (1, 3, 3), "strides": (1, 1), "kernel_size": (3,)}, - {"in_shape": (1, 3, 3), "strides": (1, 1), "kernel_size": (3, 3, 3)}, - {"in_shape": (1, 3, 3), "strides": (1, 2, 3)}, - {"in_shape": (1, 3, 3), "strides": (1, (1, 2, 3))}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - ClassifierConfig(**bad_inputs) - with pytest.raises(ValidationError): - CriticConfig(**bad_inputs) - with pytest.raises(ValidationError): - DiscriminatorConfig(**bad_inputs) - - -@pytest.fixture( - params=[ - {"in_shape": (1, 3, 3), "strides": (1, 1), "dropout": 0.5, "kernel_size": 5}, - {"in_shape": (1, 3, 3), "strides": ((1, 2), 1), "kernel_size": (3, 3)}, - ] -) -def good_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_passes_validations(good_inputs): - ClassifierConfig(**good_inputs) - CriticConfig(**good_inputs) - DiscriminatorConfig(**good_inputs) - - -def test_ClassifierConfig(): - config = ClassifierConfig( - in_shape=(1, 3, 3), - classes=2, - channels=[2, 4], - strides=[1, 1], - kernel_size=(3, 5), - num_res_units=1, - act=("ELU", {"alpha": 2.0}), - dropout=0.1, - bias=False, - last_act=None, - ) - assert config.network == "Classifier" - assert config.in_shape == (1, 3, 3) - assert config.classes == 2 - assert config.channels == (2, 4) - assert config.strides == (1, 1) - assert config.kernel_size == (3, 5) - assert config.num_res_units == 1 - assert config.act == ("elu", {"alpha": 2.0}) - assert config.norm == "DefaultFromLibrary" - assert config.dropout == 0.1 - assert not config.bias - assert config.last_act is None - - -def test_CriticConfig(): - config = CriticConfig( - in_shape=(1, 3, 3), - channels=[2, 4], - strides=[1, 1], - kernel_size=(3, 5), - num_res_units=1, - act=("ELU", {"alpha": 2.0}), - dropout=0.1, - bias=False, - ) - assert config.network == "Critic" - assert config.in_shape == (1, 3, 3) - assert config.channels == (2, 4) - assert config.strides == (1, 1) - assert config.kernel_size == (3, 5) - assert config.num_res_units == 1 - assert config.act == ("elu", {"alpha": 2.0}) - assert config.norm == "DefaultFromLibrary" - assert config.dropout == 0.1 - assert not config.bias - - -def test_DiscriminatorConfig(): - config = DiscriminatorConfig( - in_shape=(1, 3, 3), - channels=[2, 4], - strides=[1, 1], - kernel_size=(3, 5), - num_res_units=1, - act=("ELU", {"alpha": 2.0}), - dropout=0.1, - bias=False, - last_act=("eLu", {"alpha": 0.5}), - ) - assert config.network == "Discriminator" - assert config.in_shape == (1, 3, 3) - assert config.channels == (2, 4) - assert config.strides == (1, 1) - assert config.kernel_size == (3, 5) - assert config.num_res_units == 1 - assert config.act == ("elu", {"alpha": 2.0}) - assert config.norm == "DefaultFromLibrary" - assert config.dropout == 0.1 - assert not config.bias - assert config.last_act == ("elu", {"alpha": 0.5}) diff --git a/tests/unittests/monai_networks/config/test_config.py b/tests/unittests/monai_networks/config/test_config.py new file mode 100644 index 000000000..9da6756f5 --- /dev/null +++ b/tests/unittests/monai_networks/config/test_config.py @@ -0,0 +1,232 @@ +import pytest + +from clinicadl.monai_networks.config.densenet import ( + DenseNet121Config, + DenseNet161Config, + DenseNet169Config, + DenseNet201Config, +) +from clinicadl.monai_networks.config.resnet import ( + ResNet18Config, + ResNet34Config, + ResNet50Config, + ResNet101Config, + ResNet152Config, +) +from clinicadl.monai_networks.config.senet import ( + SEResNet50Config, + SEResNet101Config, + SEResNet152Config, +) +from clinicadl.monai_networks.config.vit import ( + ViTB16Config, + ViTB32Config, + ViTL16Config, + ViTL32Config, +) + + +@pytest.mark.parametrize( + "config_class", + [DenseNet121Config, DenseNet161Config, DenseNet169Config, DenseNet201Config], +) +def test_sota_densenet_config(config_class): + config = config_class(pretrained=True, num_outputs=None) + + assert config.num_outputs is None + assert config.pretrained + assert config.output_act == "DefaultFromLibrary" + assert config._type == "sota-DenseNet" + + +@pytest.mark.parametrize( + "config_class", + [ResNet18Config, ResNet34Config, ResNet50Config, ResNet101Config, ResNet152Config], +) +def test_sota_resnet_config(config_class): + config = config_class(pretrained=False, num_outputs=None) + + assert config.num_outputs is None + assert not config.pretrained + assert config.output_act == "DefaultFromLibrary" + assert config._type == "sota-ResNet" + + +@pytest.mark.parametrize( + "config_class", [SEResNet50Config, SEResNet101Config, SEResNet152Config] +) +def test_sota_senet_config(config_class): + config = config_class(output_act="relu", num_outputs=1) + + assert config.num_outputs == 1 + assert config.pretrained == "DefaultFromLibrary" + assert config.output_act == "relu" + assert config._type == "sota-SEResNet" + + +@pytest.mark.parametrize( + "config_class", [ViTB16Config, ViTB32Config, ViTL16Config, ViTL32Config] +) +def test_sota_vit_config(config_class): + config = config_class(output_act="relu", num_outputs=1) + + assert config.num_outputs == 1 + assert config.pretrained == "DefaultFromLibrary" + assert config.output_act == "relu" + assert config._type == "sota-ViT" + + +def test_autoencoder_config(): + from clinicadl.monai_networks.config.autoencoder import AutoEncoderConfig + + config = AutoEncoderConfig( + in_shape=(1, 10, 10), + latent_size=1, + conv_args={"channels": [1]}, + output_act="softmax", + ) + assert config.in_shape == (1, 10, 10) + assert config.conv_args.channels == [1] + assert config.output_act == "softmax" + assert config.out_channels == "DefaultFromLibrary" + + +def test_vae_config(): + from clinicadl.monai_networks.config.autoencoder import VAEConfig + + config = VAEConfig( + in_shape=(1, 10), + latent_size=1, + conv_args={"channels": [1], "adn_ordering": "NA"}, + output_act=("elu", {"alpha": 0.1}), + ) + assert config.in_shape == (1, 10) + assert config.conv_args.adn_ordering == "NA" + assert config.output_act == ("elu", {"alpha": 0.1}) + assert config.mlp_args == "DefaultFromLibrary" + + +def test_cnn_config(): + from clinicadl.monai_networks.config.cnn import CNNConfig + + config = CNNConfig( + in_shape=(2, 10, 10, 10), num_outputs=1, conv_args={"channels": [1]} + ) + assert config.in_shape == (2, 10, 10, 10) + assert config.conv_args.channels == [1] + assert config.mlp_args == "DefaultFromLibrary" + + +def test_generator_config(): + from clinicadl.monai_networks.config.generator import GeneratorConfig + + config = GeneratorConfig( + start_shape=(2, 10, 10), latent_size=2, conv_args={"channels": [1]} + ) + assert config.start_shape == (2, 10, 10) + assert config.conv_args.channels == [1] + assert config.mlp_args == "DefaultFromLibrary" + + +def test_conv_decoder_config(): + from clinicadl.monai_networks.config.conv_decoder import ConvDecoderConfig + + config = ConvDecoderConfig( + in_channels=1, spatial_dims=2, channels=[1, 2], kernel_size=(3, 4) + ) + assert config.in_channels == 1 + assert config.kernel_size == (3, 4) + assert config.stride == "DefaultFromLibrary" + + +def test_conv_encoder_config(): + from clinicadl.monai_networks.config.conv_encoder import ConvEncoderConfig + + config = ConvEncoderConfig( + in_channels=1, spatial_dims=2, channels=[1, 2], kernel_size=[(3, 4), (4, 5)] + ) + assert config.in_channels == 1 + assert config.kernel_size == [(3, 4), (4, 5)] + assert config.padding == "DefaultFromLibrary" + + +def test_mlp_config(): + from clinicadl.monai_networks.config.mlp import MLPConfig + + config = MLPConfig( + in_channels=1, out_channels=1, hidden_channels=[2, 3], dropout=0.1 + ) + assert config.in_channels == 1 + assert config.dropout == 0.1 + assert config.act == "DefaultFromLibrary" + + +def test_resnet_config(): + from clinicadl.monai_networks.config.resnet import ResNetConfig + + config = ResNetConfig( + spatial_dims=1, in_channels=1, num_outputs=None, block_type="bottleneck" + ) + assert config.num_outputs is None + assert config.block_type == "bottleneck" + assert config.bottleneck_reduction == "DefaultFromLibrary" + + +def test_seresnet_config(): + from clinicadl.monai_networks.config.senet import SEResNetConfig + + config = SEResNetConfig( + spatial_dims=1, + in_channels=1, + num_outputs=None, + block_type="bottleneck", + se_reduction=2, + ) + assert config.num_outputs is None + assert config.block_type == "bottleneck" + assert config.se_reduction == 2 + assert config.bottleneck_reduction == "DefaultFromLibrary" + + +def test_densenet_config(): + from clinicadl.monai_networks.config.densenet import DenseNetConfig + + config = DenseNetConfig( + spatial_dims=1, in_channels=1, num_outputs=2, n_dense_layers=(1, 2) + ) + assert config.num_outputs == 2 + assert config.n_dense_layers == (1, 2) + assert config.growth_rate == "DefaultFromLibrary" + + +def test_vit_config(): + from clinicadl.monai_networks.config.vit import ViTConfig + + config = ViTConfig(in_shape=(1, 10), patch_size=2, num_outputs=1, embedding_dim=42) + assert config.num_outputs == 1 + assert config.embedding_dim == 42 + assert config.mlp_dim == "DefaultFromLibrary" + + +def test_unet_config(): + from clinicadl.monai_networks.config.unet import UNetConfig + + config = UNetConfig(spatial_dims=1, in_channels=1, out_channels=1, channels=(4, 8)) + assert config.out_channels == 1 + assert config.channels == (4, 8) + assert config.output_act == "DefaultFromLibrary" + + +def test_att_unet_config(): + from clinicadl.monai_networks.config.unet import AttentionUNetConfig + + config = AttentionUNetConfig( + spatial_dims=1, + in_channels=1, + out_channels=1, + channels=(4, 8), + output_act="softmax", + ) + assert config.spatial_dims == 1 + assert config.output_act == "softmax" + assert config.dropout == "DefaultFromLibrary" diff --git a/tests/unittests/monai_networks/config/test_densenet.py b/tests/unittests/monai_networks/config/test_densenet.py deleted file mode 100644 index a18b86f09..000000000 --- a/tests/unittests/monai_networks/config/test_densenet.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.densenet import DenseNetConfig - - -@pytest.fixture -def dummy_arguments(): - args = { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - } - return args - - -def test_fails_validations(dummy_arguments): - with pytest.raises(ValidationError): - DenseNetConfig(**{**dummy_arguments, **{"dropout_prob": 1.1}}) - - -def test_passes_validations(dummy_arguments): - DenseNetConfig(**{**dummy_arguments, **{"dropout_prob": 0.1}}) - - -def test_DenseNetConfig(): - config = DenseNetConfig( - spatial_dims=2, - in_channels=1, - out_channels=1, - init_features=16, - growth_rate=2, - block_config=(3, 5), - bn_size=1, - norm=("batch", {"eps": 0.5}), - dropout_prob=0.1, - ) - assert config.network == "DenseNet" - assert config.spatial_dims == 2 - assert config.in_channels == 1 - assert config.out_channels == 1 - assert config.init_features == 16 - assert config.growth_rate == 2 - assert config.block_config == (3, 5) - assert config.bn_size == 1 - assert config.norm == ("batch", {"eps": 0.5}) - assert config.act == "DefaultFromLibrary" - assert config.dropout_prob diff --git a/tests/unittests/monai_networks/config/test_factory.py b/tests/unittests/monai_networks/config/test_factory.py index 07c96e2a9..9dcd7fdc1 100644 --- a/tests/unittests/monai_networks/config/test_factory.py +++ b/tests/unittests/monai_networks/config/test_factory.py @@ -9,9 +9,9 @@ def test_create_training_config(): config = config_class( spatial_dims=1, in_channels=2, - out_channels=3, + num_outputs=None, ) - assert config.network == "DenseNet" + assert config.name == "DenseNet" assert config.spatial_dims == 1 assert config.in_channels == 2 - assert config.out_channels == 3 + assert config.num_outputs is None diff --git a/tests/unittests/monai_networks/config/test_fcn.py b/tests/unittests/monai_networks/config/test_fcn.py deleted file mode 100644 index b7991368e..000000000 --- a/tests/unittests/monai_networks/config/test_fcn.py +++ /dev/null @@ -1,97 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.fcn import ( - FullyConnectedNetConfig, - VarFullyConnectedNetConfig, -) - - -@pytest.fixture -def dummy_arguments(): - args = { - "in_channels": 5, - "out_channels": 1, - "hidden_channels": [3, 2], - "latent_size": 16, - "encode_channels": [2, 3], - "decode_channels": [3, 2], - } - return args - - -@pytest.fixture( - params=[ - {"dropout": 1.1}, - {"adn_ordering": "NDB"}, - {"adn_ordering": "NND"}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - FullyConnectedNetConfig(**bad_inputs) - with pytest.raises(ValidationError): - VarFullyConnectedNetConfig(**bad_inputs) - - -@pytest.fixture( - params=[ - {"dropout": 0.5, "adn_ordering": "DAN"}, - {"adn_ordering": "AN"}, - ] -) -def good_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_passes_validations(good_inputs): - FullyConnectedNetConfig(**good_inputs) - VarFullyConnectedNetConfig(**good_inputs) - - -def test_FullyConnectedNetConfig(): - config = FullyConnectedNetConfig( - in_channels=5, - out_channels=1, - hidden_channels=[3, 2], - dropout=None, - act="prelu", - bias=False, - adn_ordering="ADN", - ) - assert config.network == "FullyConnectedNet" - assert config.in_channels == 5 - assert config.out_channels == 1 - assert config.hidden_channels == (3, 2) - assert config.dropout is None - assert config.act == "prelu" - assert not config.bias - assert config.adn_ordering == "ADN" - - -def test_VarFullyConnectedNetConfig(): - config = VarFullyConnectedNetConfig( - in_channels=5, - out_channels=1, - latent_size=16, - encode_channels=[2, 3], - decode_channels=[3, 2], - dropout=0.1, - act="prelu", - bias=False, - adn_ordering="ADN", - ) - assert config.network == "VarFullyConnectedNet" - assert config.in_channels == 5 - assert config.out_channels == 1 - assert config.latent_size == 16 - assert config.encode_channels == (2, 3) - assert config.decode_channels == (3, 2) - assert config.dropout == 0.1 - assert config.act == "prelu" - assert not config.bias - assert config.adn_ordering == "ADN" diff --git a/tests/unittests/monai_networks/config/test_generator.py b/tests/unittests/monai_networks/config/test_generator.py deleted file mode 100644 index 9ea1cd442..000000000 --- a/tests/unittests/monai_networks/config/test_generator.py +++ /dev/null @@ -1,72 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.generator import GeneratorConfig - - -@pytest.fixture -def dummy_arguments(): - args = {"latent_shape": (5,), "channels": (2, 4)} - return args - - -@pytest.fixture( - params=[ - {"start_shape": (3,), "strides": (1, 1)}, - {"start_shape": (1, 3), "strides": (1, 1), "dropout": 1.1}, - {"start_shape": (1, 3), "strides": (1, 1), "kernel_size": 4}, - {"start_shape": (1, 3), "strides": (1, 1), "kernel_size": (3, 3)}, - {"start_shape": (1, 3), "strides": (1, 2, 3)}, - {"start_shape": (1, 3), "strides": (1, (1, 2))}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - GeneratorConfig(**bad_inputs) - - -@pytest.fixture( - params=[ - {"start_shape": (1, 3), "strides": (1, 1), "dropout": 0.5, "kernel_size": 5}, - { - "start_shape": (1, 3, 3, 3), - "strides": ((1, 2, 3), 1), - "kernel_size": (3, 3, 3), - }, - ] -) -def good_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_passes_validations(good_inputs): - GeneratorConfig(**good_inputs) - - -def test_GeneratorConfig(): - config = GeneratorConfig( - latent_shape=(3,), - start_shape=(1, 3), - channels=[2, 4], - strides=[1, 1], - kernel_size=(3,), - num_res_units=1, - act="SIGMOID", - dropout=0.1, - bias=False, - ) - assert config.network == "Generator" - assert config.latent_shape == (3,) - assert config.start_shape == (1, 3) - assert config.channels == (2, 4) - assert config.strides == (1, 1) - assert config.kernel_size == (3,) - assert config.num_res_units == 1 - assert config.act == "sigmoid" - assert config.norm == "DefaultFromLibrary" - assert config.dropout == 0.1 - assert not config.bias diff --git a/tests/unittests/monai_networks/config/test_regressor.py b/tests/unittests/monai_networks/config/test_regressor.py deleted file mode 100644 index 920464cc2..000000000 --- a/tests/unittests/monai_networks/config/test_regressor.py +++ /dev/null @@ -1,72 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.regressor import RegressorConfig - - -@pytest.fixture -def dummy_arguments(): - args = { - "out_shape": (1,), - "channels": [2, 4], - } - return args - - -@pytest.fixture( - params=[ - {"in_shape": (3,), "strides": (1, 1)}, - {"in_shape": (1, 3, 3), "strides": (1, 1), "dropout": 1.1}, - {"in_shape": (1, 3, 3), "strides": (1, 1), "kernel_size": 4}, - {"in_shape": (1, 3, 3), "strides": (1, 1), "kernel_size": (3,)}, - {"in_shape": (1, 3, 3), "strides": (1, 1), "kernel_size": (3, 3, 3)}, - {"in_shape": (1, 3, 3), "strides": (1, 2, 3)}, - {"in_shape": (1, 3, 3), "strides": (1, (1, 2, 3))}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - RegressorConfig(**bad_inputs) - - -@pytest.fixture( - params=[ - {"in_shape": (1, 3, 3), "strides": (1, 1), "dropout": 0.5, "kernel_size": 5}, - {"in_shape": (1, 3, 3), "strides": ((1, 2), 1), "kernel_size": (3, 3)}, - ] -) -def good_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_passes_validations(good_inputs): - RegressorConfig(**good_inputs) - - -def test_RegressorConfig(): - config = RegressorConfig( - in_shape=(1, 3, 3), - out_shape=(1,), - channels=[2, 4], - strides=[1, 1], - kernel_size=(3, 5), - num_res_units=1, - act=("ELU", {"alpha": 2.0}), - dropout=0.1, - bias=False, - ) - assert config.network == "Regressor" - assert config.in_shape == (1, 3, 3) - assert config.out_shape == (1,) - assert config.channels == (2, 4) - assert config.strides == (1, 1) - assert config.kernel_size == (3, 5) - assert config.num_res_units == 1 - assert config.act == ("elu", {"alpha": 2.0}) - assert config.norm == "DefaultFromLibrary" - assert config.dropout == 0.1 - assert not config.bias diff --git a/tests/unittests/monai_networks/config/test_resnet.py b/tests/unittests/monai_networks/config/test_resnet.py deleted file mode 100644 index b238a3c93..000000000 --- a/tests/unittests/monai_networks/config/test_resnet.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.resnet import ResNetConfig - - -@pytest.fixture -def dummy_arguments(): - args = { - "block": "basic", - "layers": (2, 2, 2, 2), - } - return args - - -@pytest.fixture( - params=[ - {"block_inplanes": (2, 4, 8)}, - {"block_inplanes": (2, 4, 8, 16), "conv1_t_size": (3, 3)}, - {"block_inplanes": (2, 4, 8, 16), "conv1_t_stride": (3, 3)}, - {"block_inplanes": (2, 4, 8, 16), "shortcut_type": "C"}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - ResNetConfig(**bad_inputs) - - -@pytest.fixture( - params=[ - { - "block_inplanes": (2, 4, 8, 16), - "conv1_t_size": (3, 3, 3), - "conv1_t_stride": (3, 3, 3), - "shortcut_type": "B", - }, - {"block_inplanes": (2, 4, 8, 16), "conv1_t_size": 3, "conv1_t_stride": 3}, - ] -) -def good_inputs(request: pytest.FixtureRequest, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_passes_validations(good_inputs): - ResNetConfig(**good_inputs) - - -def test_ResNetConfig(): - config = ResNetConfig( - block="bottleneck", - layers=(2, 2, 2, 2), - block_inplanes=(2, 4, 8, 16), - spatial_dims=3, - n_input_channels=3, - conv1_t_size=3, - conv1_t_stride=4, - no_max_pool=True, - shortcut_type="A", - widen_factor=0.8, - num_classes=3, - feed_forward=False, - bias_downsample=False, - act=("relu", {"inplace": False}), - ) - assert config.network == "ResNet" - assert config.block == "bottleneck" - assert config.layers == (2, 2, 2, 2) - assert config.block_inplanes == (2, 4, 8, 16) - assert config.spatial_dims == 3 - assert config.n_input_channels == 3 - assert config.conv1_t_size == 3 - assert config.conv1_t_stride == 4 - assert config.no_max_pool - assert config.shortcut_type == "A" - assert config.widen_factor == 0.8 - assert config.num_classes == 3 - assert not config.feed_forward - assert not config.bias_downsample - assert config.act == ("relu", {"inplace": False}) diff --git a/tests/unittests/monai_networks/config/test_resnet_features.py b/tests/unittests/monai_networks/config/test_resnet_features.py deleted file mode 100644 index 9f6131974..000000000 --- a/tests/unittests/monai_networks/config/test_resnet_features.py +++ /dev/null @@ -1,56 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.resnet import ResNetFeaturesConfig - - -@pytest.fixture( - params=[ - {"model_name": "abc"}, - {"model_name": "resnet18", "pretrained": True, "spatial_dims": 2}, - {"model_name": "resnet18", "pretrained": True, "in_channels": 2}, - { - "model_name": "resnet18", - "in_channels": 2, - }, # pretrained should be set to False - {"model_name": "resnet18", "spatial_dims": 2}, - ] -) -def bad_inputs(request: pytest.FixtureRequest): - return request.param - - -def test_fails_validations(bad_inputs: dict): - with pytest.raises(ValidationError): - ResNetFeaturesConfig(**bad_inputs) - - -@pytest.fixture( - params=[ - {"model_name": "resnet18", "pretrained": True, "spatial_dims": 3}, - {"model_name": "resnet18", "pretrained": True, "in_channels": 1}, - {"model_name": "resnet18", "pretrained": True}, - {"model_name": "resnet18", "spatial_dims": 3}, - {"model_name": "resnet18", "in_channels": 1}, - ] -) -def good_inputs(request: pytest.FixtureRequest): - return {**request.param} - - -def test_passes_validations(good_inputs: dict): - ResNetFeaturesConfig(**good_inputs) - - -def test_ResNetFeaturesConfig(): - config = ResNetFeaturesConfig( - model_name="resnet200", - pretrained=False, - spatial_dims=2, - in_channels=2, - ) - assert config.network == "ResNetFeatures" - assert config.model_name == "resnet200" - assert not config.pretrained - assert config.spatial_dims == 2 - assert config.in_channels == 2 diff --git a/tests/unittests/monai_networks/config/test_segresnet.py b/tests/unittests/monai_networks/config/test_segresnet.py deleted file mode 100644 index 44b946d49..000000000 --- a/tests/unittests/monai_networks/config/test_segresnet.py +++ /dev/null @@ -1,41 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.resnet import SegResNetConfig - - -def test_fails_validations(): - with pytest.raises(ValidationError): - SegResNetConfig(dropout_prob=1.1) - - -def test_passes_validations(): - SegResNetConfig(dropout_prob=0.5) - - -def test_SegResNetConfig(): - config = SegResNetConfig( - spatial_dims=2, - init_filters=3, - in_channels=1, - out_channels=1, - dropout_prob=0.1, - act=("ELU", {"inplace": False}), - norm=("group", {"num_groups": 4}), - use_conv_final=False, - blocks_down=[1, 2, 3], - blocks_up=[3, 2, 1], - upsample_mode="pixelshuffle", - ) - assert config.network == "SegResNet" - assert config.spatial_dims == 2 - assert config.init_filters == 3 - assert config.in_channels == 1 - assert config.out_channels == 1 - assert config.dropout_prob == 0.1 - assert config.act == ("elu", {"inplace": False}) - assert config.norm == ("group", {"num_groups": 4}) - assert not config.use_conv_final - assert config.blocks_down == (1, 2, 3) - assert config.blocks_up == (3, 2, 1) - assert config.upsample_mode == "pixelshuffle" diff --git a/tests/unittests/monai_networks/config/test_unet.py b/tests/unittests/monai_networks/config/test_unet.py deleted file mode 100644 index d331e0a14..000000000 --- a/tests/unittests/monai_networks/config/test_unet.py +++ /dev/null @@ -1,133 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.unet import AttentionUnetConfig, UNetConfig - - -@pytest.fixture -def dummy_arguments(): - args = { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - } - return args - - -@pytest.fixture( - params=[ - {"strides": (1, 1), "channels": (2, 4, 8), "adn_ordering": "NDB"}, - {"strides": (1, 1), "channels": (2, 4, 8), "adn_ordering": "NND"}, - {"strides": (1, 1), "channels": (2, 4, 8), "dropout": 1.1}, - {"strides": (1, 1), "channels": (2, 4, 8), "kernel_size": 4}, - {"strides": (1, 1), "channels": (2, 4, 8), "kernel_size": (3,)}, - {"strides": (1, 1), "channels": (2, 4, 8), "kernel_size": (3, 3, 3)}, - {"strides": (1, 1), "channels": (2, 4, 8), "up_kernel_size": 4}, - {"strides": (1, 1), "channels": (2, 4, 8), "up_kernel_size": (3,)}, - {"strides": (1, 1), "channels": (2, 4, 8), "up_kernel_size": (3, 3, 3)}, - {"strides": (1, 2, 3), "channels": (2, 4, 8)}, - {"strides": (1, (1, 2, 3)), "channels": (2, 4, 8)}, - {"strides": (), "channels": (2,)}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - UNetConfig(**bad_inputs) - with pytest.raises(ValidationError): - AttentionUnetConfig(**bad_inputs) - - -@pytest.fixture( - params=[ - { - "strides": (1, 1), - "channels": (2, 4, 8), - "adn_ordering": "DAN", - "dropout": 0.5, - "kernel_size": 5, - "up_kernel_size": 5, - }, - { - "strides": ((1, 2),), - "channels": (2, 4), - "adn_ordering": "AN", - "kernel_size": (3, 5), - "up_kernel_size": (3, 5), - }, - ] -) -def good_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_passes_validations(good_inputs): - UNetConfig(**good_inputs) - AttentionUnetConfig(**good_inputs) - - -def test_UNetConfig(): - config = UNetConfig( - spatial_dims=2, - in_channels=1, - out_channels=1, - channels=[2, 4], - strides=[1], - kernel_size=(3, 5), - up_kernel_size=(3, 3), - num_res_units=1, - act="ElU", - norm=("BATCh", {"eps": 0.1}), - dropout=0.1, - bias=False, - adn_ordering="A", - ) - assert config.network == "UNet" - assert config.spatial_dims == 2 - assert config.in_channels == 1 - assert config.out_channels == 1 - assert config.channels == (2, 4) - assert config.strides == (1,) - assert config.kernel_size == (3, 5) - assert config.up_kernel_size == (3, 3) - assert config.num_res_units == 1 - assert config.act == "elu" - assert config.norm == ("batch", {"eps": 0.1}) - assert config.dropout == 0.1 - assert not config.bias - assert config.adn_ordering == "A" - - -def test_AttentionUnetConfig(): - config = AttentionUnetConfig( - spatial_dims=2, - in_channels=1, - out_channels=1, - channels=[2, 4], - strides=[1], - kernel_size=(3, 5), - up_kernel_size=(3, 3), - num_res_units=1, - act="ElU", - norm="inSTance", - dropout=0.1, - bias=False, - adn_ordering="DA", - ) - assert config.network == "AttentionUnet" - assert config.spatial_dims == 2 - assert config.in_channels == 1 - assert config.out_channels == 1 - assert config.channels == (2, 4) - assert config.strides == (1,) - assert config.kernel_size == (3, 5) - assert config.up_kernel_size == (3, 3) - assert config.num_res_units == 1 - assert config.act == "elu" - assert config.norm == "instance" - assert config.dropout == 0.1 - assert not config.bias - assert config.adn_ordering == "DA" diff --git a/tests/unittests/monai_networks/config/test_vit.py b/tests/unittests/monai_networks/config/test_vit.py deleted file mode 100644 index 737caf05e..000000000 --- a/tests/unittests/monai_networks/config/test_vit.py +++ /dev/null @@ -1,162 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.vit import ( - ViTAutoEncConfig, - ViTConfig, -) - - -@pytest.fixture -def dummy_arguments(): - args = { - "in_channels": 2, - } - return args - - -@pytest.fixture( - params=[ - {"img_size": (16, 16, 16), "patch_size": (4, 4, 4), "dropout_rate": 1.1}, - {"img_size": (16, 16), "patch_size": 4}, - {"img_size": 16, "patch_size": (4, 4)}, - {"img_size": 16, "patch_size": (4, 4)}, - { - "img_size": (16, 16, 16), - "patch_size": (4, 4, 4), - "hidden_size": 42, - "num_heads": 5, - }, - {"img_size": (16, 16, 16), "patch_size": (4, 4, 4), "num_heads": 5}, - {"img_size": (16, 16, 16), "patch_size": (4, 4, 4), "hidden_size": 42}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -@pytest.fixture( - params=[ - {"img_size": (20, 20, 20), "patch_size": (4, 4, 5)}, - {"img_size": (20, 20, 20), "patch_size": (4, 4, 9)}, - ] -) -def bad_inputs_ae(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - ViTConfig(**bad_inputs) - with pytest.raises(ValidationError): - ViTAutoEncConfig(**bad_inputs) - - -def test_fails_validations_ae(bad_inputs_ae): - with pytest.raises(ValidationError): - ViTAutoEncConfig(**bad_inputs_ae) - - -@pytest.fixture( - params=[ - { - "img_size": (16, 16, 16), - "patch_size": (4, 4, 4), - "dropout_rate": 0.5, - "hidden_size": 42, - "num_heads": 6, - }, - ] -) -def good_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -@pytest.fixture( - params=[ - {"img_size": 10, "patch_size": 3}, - ] -) -def good_inputs_vit(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_passes_validations(good_inputs): - ViTConfig(**good_inputs) - ViTAutoEncConfig(**good_inputs) - - -def test_passes_validations_vit(good_inputs_vit): - ViTConfig(**good_inputs_vit) - - -def test_ViTConfig(): - config = ViTConfig( - in_channels=2, - img_size=16, - patch_size=4, - hidden_size=32, - mlp_dim=4, - num_layers=3, - num_heads=4, - proj_type="perceptron", - pos_embed_type="sincos", - classification=True, - num_classes=3, - dropout_rate=0.1, - spatial_dims=3, - post_activation=None, - qkv_bias=True, - ) - assert config.network == "ViT" - assert config.in_channels == 2 - assert config.img_size == 16 - assert config.patch_size == 4 - assert config.hidden_size == 32 - assert config.mlp_dim == 4 - assert config.num_layers == 3 - assert config.num_heads == 4 - assert config.proj_type == "perceptron" - assert config.pos_embed_type == "sincos" - assert config.classification - assert config.num_classes == 3 - assert config.dropout_rate == 0.1 - assert config.spatial_dims == 3 - assert config.post_activation is None - assert config.qkv_bias - assert config.save_attn == "DefaultFromLibrary" - - -def test_ViTAutoEncConfig(): - config = ViTAutoEncConfig( - in_channels=2, - img_size=16, - patch_size=4, - out_channels=2, - deconv_chns=7, - hidden_size=32, - mlp_dim=4, - num_layers=3, - num_heads=4, - proj_type="perceptron", - pos_embed_type="sincos", - dropout_rate=0.1, - spatial_dims=3, - qkv_bias=True, - ) - assert config.network == "ViTAutoEnc" - assert config.in_channels == 2 - assert config.img_size == 16 - assert config.patch_size == 4 - assert config.out_channels == 2 - assert config.deconv_chns == 7 - assert config.hidden_size == 32 - assert config.mlp_dim == 4 - assert config.num_layers == 3 - assert config.num_heads == 4 - assert config.proj_type == "perceptron" - assert config.pos_embed_type == "sincos" - assert config.dropout_rate == 0.1 - assert config.spatial_dims == 3 - assert config.qkv_bias - assert config.save_attn == "DefaultFromLibrary" diff --git a/tests/unittests/monai_networks/nn/__init__.py b/tests/unittests/monai_networks/nn/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unittests/monai_networks/nn/test_att_unet.py b/tests/unittests/monai_networks/nn/test_att_unet.py new file mode 100644 index 000000000..711f11142 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_att_unet.py @@ -0,0 +1,134 @@ +import pytest +import torch + +from clinicadl.monai_networks.nn import AttentionUNet +from clinicadl.monai_networks.nn.layers.utils import ActFunction + +INPUT_1D = torch.randn(2, 1, 16) +INPUT_2D = torch.randn(2, 2, 32, 64) +INPUT_3D = torch.randn(2, 3, 16, 32, 8) + + +@pytest.mark.parametrize( + "input_tensor,out_channels,channels,act,output_act,dropout,error", + [ + (INPUT_1D, 1, (2, 3, 4), "relu", "sigmoid", None, False), + (INPUT_2D, 1, (2, 4, 5), "relu", None, 0.0, False), + (INPUT_3D, 2, (2, 3), None, ("softmax", {"dim": 1}), 0.1, False), + ( + INPUT_3D, + 2, + (2,), + None, + ("softmax", {"dim": 1}), + 0.1, + True, + ), # channels length is less than 2 + ], +) +def test_attentionunet( + input_tensor, out_channels, channels, act, output_act, dropout, error +): + batch_size, in_channels, *img_size = input_tensor.shape + spatial_dims = len(img_size) + if error: + with pytest.raises(ValueError): + AttentionUNet( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + channels=channels, + act=act, + output_act=output_act, + dropout=dropout, + ) + else: + net = AttentionUNet( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + channels=channels, + act=act, + output_act=output_act, + dropout=dropout, + ) + + out = net(input_tensor) + assert out.shape == (batch_size, out_channels, *img_size) + + if output_act: + assert net.output_act is not None + else: + assert net.output_act is None + + assert net.doubleconv[1].conv.out_channels == channels[0] + if dropout: + assert net.doubleconv[1].adn.D.p == dropout + else: + with pytest.raises(AttributeError): + net.doubleconv[1].conv.adn.D + + for i in range(1, len(channels)): + down = getattr(net, f"down{i}").doubleconv + up = getattr(net, f"doubleconv{i}") + att = getattr(net, f"attention{i}") + assert down[0].conv.in_channels == channels[i - 1] + assert down[1].conv.out_channels == channels[i] + assert att.W_g[0].out_channels == channels[i - 1] // 2 + assert att.W_x[0].out_channels == channels[i - 1] // 2 + assert up[0].conv.in_channels == channels[i - 1] * 2 + assert up[1].conv.out_channels == channels[i - 1] + for m in (down, up): + if dropout is not None: + assert m[1].adn.D.p == dropout + else: + with pytest.raises(AttributeError): + m[1].adn.D + with pytest.raises(AttributeError): + down = getattr(net, f"down{i+1}") + with pytest.raises(AttributeError): + getattr(net, f"doubleconv{i+1}") + with pytest.raises(AttributeError): + getattr(net, f"attention{i+1}") + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(act): + batch_size, in_channels, *img_size = INPUT_2D.shape + net = AttentionUNet( + spatial_dims=2, + in_channels=in_channels, + out_channels=2, + channels=(2, 4), + act=act, + output_act=act, + ) + assert net(INPUT_2D).shape == (batch_size, 2, *img_size) + + +def test_activation_parameters(): + in_channels = INPUT_2D.shape[1] + act = ("ELU", {"alpha": 0.1}) + output_act = ("ELU", {"alpha": 0.2}) + net = AttentionUNet( + spatial_dims=2, + in_channels=in_channels, + out_channels=2, + channels=(2, 4), + act=act, + output_act=output_act, + ) + assert isinstance(net.doubleconv[0].adn.A, torch.nn.ELU) + assert net.doubleconv[0].adn.A.alpha == 0.1 + + assert isinstance(net.down1.doubleconv[0].adn.A, torch.nn.ELU) + assert net.down1.doubleconv[0].adn.A.alpha == 0.1 + + assert isinstance(net.upsample1[1].adn.A, torch.nn.ELU) + assert net.upsample1[1].adn.A.alpha == 0.1 + + assert isinstance(net.doubleconv1[1].adn.A, torch.nn.ELU) + assert net.doubleconv1[1].adn.A.alpha == 0.1 + + assert isinstance(net.output_act, torch.nn.ELU) + assert net.output_act.alpha == 0.2 diff --git a/tests/unittests/monai_networks/nn/test_autoencoder.py b/tests/unittests/monai_networks/nn/test_autoencoder.py new file mode 100644 index 000000000..c59874353 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_autoencoder.py @@ -0,0 +1,215 @@ +import pytest +import torch +from torch.nn import GELU, Sigmoid, Tanh + +from clinicadl.monai_networks.nn import AutoEncoder +from clinicadl.monai_networks.nn.layers.utils import ActFunction + + +@pytest.mark.parametrize( + "input_tensor,kernel_size,stride,padding,dilation,pooling,pooling_indices,unpooling_mode", + [ + (torch.randn(2, 1, 21), 3, 1, 0, 1, ("max", {"kernel_size": 2}), [0], "linear"), + ( + torch.randn(2, 1, 65, 85), + (3, 5), + (1, 2), + 0, + (1, 2), + ("max", {"kernel_size": 2, "stride": 1}), + [0], + "bilinear", + ), + ( + torch.randn(2, 1, 64, 62, 61), # to test output padding + 4, + 2, + (1, 1, 0), + 1, + ("avg", {"kernel_size": 3, "stride": 2}), + [-1], + "convtranspose", + ), + ( + torch.randn(2, 1, 51, 55, 45), + 3, + 2, + 0, + 1, + ("max", {"kernel_size": 2, "ceil_mode": True}), + [0, 1, 2], + "convtranspose", + ), + ( + torch.randn(2, 1, 51, 55, 45), + 3, + 2, + 0, + 1, + [ + ("max", {"kernel_size": 2, "ceil_mode": True}), + ("adaptivemax", {"output_size": (5, 4, 3)}), + ], + [-1, 1], + "convtranspose", + ), + ], +) +def test_output_shape( + input_tensor, + kernel_size, + stride, + padding, + dilation, + pooling, + pooling_indices, + unpooling_mode, +): + net = AutoEncoder( + in_shape=input_tensor.shape[1:], + latent_size=3, + conv_args={ + "channels": [2, 4, 8], + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + "pooling": pooling, + "pooling_indices": pooling_indices, + }, + unpooling_mode=unpooling_mode, + ) + output = net(input_tensor) + assert output.shape == input_tensor.shape + + +def test_out_channels(): + input_tensor = torch.randn(2, 1, 64, 62, 61) + net = AutoEncoder( + in_shape=input_tensor.shape[1:], + latent_size=3, + conv_args={"channels": [2, 4, 8]}, + mlp_args={"hidden_channels": [8, 4]}, + out_channels=3, + ) + assert net(input_tensor).shape == (2, 3, 64, 62, 61) + assert net.decoder.convolutions.layer2.conv.in_channels == 2 + assert net.decoder.convolutions.layer2.conv.out_channels == 3 + + +@pytest.mark.parametrize( + "pooling,unpooling_mode", + [ + (("adaptivemax", {"output_size": (17, 16, 19)}), "nearest"), + (("adaptivemax", {"output_size": (17, 16, 19)}), "convtranspose"), + (("max", {"kernel_size": 2}), "nearest"), + (("max", {"kernel_size": 2}), "convtranspose"), + ( + ("max", {"kernel_size": 2, "stride": 1, "dilation": 2, "padding": 1}), + "nearest", + ), + ( + ("max", {"kernel_size": 2, "stride": 1, "dilation": 2, "padding": 1}), + "convtranspose", + ), + (("avg", {"kernel_size": 3, "ceil_mode": True}), "nearest"), + (("avg", {"kernel_size": 3, "ceil_mode": True}), "convtranspose"), + ], +) +def test_invert_pooling(pooling, unpooling_mode): + input_tensor = torch.randn(2, 1, 20, 27, 22) + net = AutoEncoder( + in_shape=(1, 20, 27, 22), + latent_size=1, + conv_args={"channels": [], "pooling": pooling, "pooling_indices": [-1]}, + mlp_args=None, + unpooling_mode=unpooling_mode, + ) + output = net(input_tensor) + assert output.shape == input_tensor.shape + + +@pytest.mark.parametrize( + "kernel_size,stride,padding,dilation", + [ + ((3, 2, 1), (1, 1, 2), (1, 1, 0), 1), + ((4, 5, 2), (3, 1, 1), (0, 0, 1), (2, 1, 1)), + ], +) +def test_invert_conv(kernel_size, stride, padding, dilation): + input_tensor = torch.randn(2, 1, 20, 27, 22) + net = AutoEncoder( + in_shape=(1, 20, 27, 22), + latent_size=1, + conv_args={ + "channels": [1], + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + }, + mlp_args=None, + ) + output = net(input_tensor) + assert output.shape == input_tensor.shape + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_out_activation(act): + input_tensor = torch.randn(2, 1, 32, 32) + net = AutoEncoder( + in_shape=(1, 32, 32), + latent_size=3, + conv_args={"channels": [2]}, + output_act=act, + ) + assert net(input_tensor).shape == (2, 1, 32, 32) + + +def test_params(): + net = AutoEncoder( + in_shape=(1, 100, 100), + latent_size=3, + conv_args={"channels": [2], "act": "celu", "output_act": "sigmoid"}, + mlp_args={"hidden_channels": [2], "act": "relu", "output_act": "gelu"}, + output_act="tanh", + out_channels=2, + ) + assert net.encoder.convolutions.act == "celu" + assert net.decoder.convolutions.act == "celu" + assert net.encoder.mlp.act == "relu" + assert net.decoder.mlp.act == "relu" + assert isinstance(net.encoder.mlp.output.output_act, GELU) + assert isinstance(net.encoder.mlp.output.output_act, GELU) + assert isinstance(net.encoder.convolutions.output_act, Sigmoid) + assert isinstance(net.decoder.convolutions.output_act, Tanh) + + +@pytest.mark.parametrize( + "in_shape,upsampling_mode,error", + [ + ((1, 10), "bilinear", True), + ((1, 10, 10), "linear", True), + ((1, 10, 10), "trilinear", True), + ((1, 10, 10, 10), "bicubic", True), + ((1, 10), "linear", False), + ((1, 10, 10), "bilinear", False), + ((1, 10, 10, 10), "trilinear", False), + ], +) +def test_checks(in_shape, upsampling_mode, error): + if error: + with pytest.raises(ValueError): + AutoEncoder( + in_shape=in_shape, + latent_size=3, + conv_args={"channels": []}, + unpooling_mode=upsampling_mode, + ) + else: + AutoEncoder( + in_shape=in_shape, + latent_size=3, + conv_args={"channels": []}, + unpooling_mode=upsampling_mode, + ) diff --git a/tests/unittests/monai_networks/nn/test_cnn.py b/tests/unittests/monai_networks/nn/test_cnn.py new file mode 100644 index 000000000..095c8da5d --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_cnn.py @@ -0,0 +1,62 @@ +import pytest +import torch +from torch.nn import Flatten, Linear, Softmax + +from clinicadl.monai_networks.nn import CNN, MLP, ConvEncoder + +INPUT_1D = torch.randn(3, 1, 16) +INPUT_2D = torch.randn(3, 1, 15, 16) +INPUT_3D = torch.randn(3, 3, 20, 21, 22) + + +@pytest.mark.parametrize("input_tensor", [INPUT_1D, INPUT_2D, INPUT_3D]) +@pytest.mark.parametrize("channels", [(), (2, 4)]) +@pytest.mark.parametrize( + "mlp_args", [None, {"hidden_channels": []}, {"hidden_channels": (2, 4)}] +) +def test_cnn(input_tensor, channels, mlp_args): + in_shape = input_tensor.shape[1:] + net = CNN( + in_shape=in_shape, + num_outputs=2, + conv_args={"channels": channels}, + mlp_args=mlp_args, + ) + output = net(input_tensor) + assert output.shape == (3, 2) + assert isinstance(net.convolutions, ConvEncoder) + assert isinstance(net.mlp, MLP) + + if mlp_args is None or mlp_args["hidden_channels"] == []: + children = net.mlp.children() + assert isinstance(next(children), Flatten) + assert isinstance(next(children).linear, Linear) + with pytest.raises(StopIteration): + next(children) + + if channels == []: + with pytest.raises(StopIteration): + next(net.convolutions.parameters()) + + +@pytest.mark.parametrize( + "conv_args,mlp_args", + [ + (None, {"hidden_channels": [2]}), + ({"channels": [2]}, {}), + ], +) +def test_checks(conv_args, mlp_args): + with pytest.raises(ValueError): + CNN(in_shape=(1, 10, 10), num_outputs=2, conv_args=conv_args, mlp_args=mlp_args) + + +def test_params(): + conv_args = {"channels": [2], "act": "celu"} + mlp_args = {"hidden_channels": [2], "act": "relu", "output_act": "softmax"} + net = CNN( + in_shape=(1, 10, 10), num_outputs=2, conv_args=conv_args, mlp_args=mlp_args + ) + assert net.convolutions.act == "celu" + assert net.mlp.act == "relu" + assert isinstance(net.mlp.output.output_act, Softmax) diff --git a/tests/unittests/monai_networks/nn/test_conv_decoder.py b/tests/unittests/monai_networks/nn/test_conv_decoder.py new file mode 100644 index 000000000..73576918e --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_conv_decoder.py @@ -0,0 +1,407 @@ +import pytest +import torch +from torch.nn import ELU, ConvTranspose2d, Dropout, InstanceNorm2d, Upsample + +from clinicadl.monai_networks.nn import ConvDecoder +from clinicadl.monai_networks.nn.layers.utils import ActFunction + + +@pytest.fixture +def input_tensor(): + return torch.randn(2, 1, 8, 8) + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(input_tensor, act): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + act=act, + output_act=act, + ) + output_shape = net(input_tensor).shape + return len(output_shape) == 4 and output_shape[1] == 1 + + +@pytest.mark.parametrize( + "kernel_size,stride,padding,output_padding,dilation,unpooling,unpooling_indices,norm,dropout,bias,adn_ordering", + [ + ( + 3, + 2, + 0, + 1, + 1, + ("upsample", {"scale_factor": 2}), + [2], + "batch", + None, + True, + "ADN", + ), + ( + (4, 4), + (2, 1), + 2, + (1, 0), + 2, + ("upsample", {"scale_factor": 2}), + [0, 1], + "instance", + 0.5, + False, + "DAN", + ), + ( + 5, + 1, + (2, 1), + 0, + 1, + [("upsample", {"size": (16, 16)}), ("convtranspose", {"kernel_size": 2})], + [0, 1], + "syncbatch", + 0.5, + True, + "NA", + ), + (5, 1, 0, 1, (2, 3), None, [0, 1], "instance", 0.0, True, "DN"), + ( + 5, + 1, + 2, + 0, + 1, + ("convtranspose", {"kernel_size": 2}), + None, + ("group", {"num_groups": 2}), + None, + True, + "N", + ), + ( + 5, + 3, + 2, + (2, 1), + 1, + ("convtranspose", {"kernel_size": 2}), + [0, 1], + None, + None, + True, + "", + ), + ], +) +def test_params( + input_tensor, + kernel_size, + stride, + padding, + output_padding, + dilation, + unpooling, + unpooling_indices, + norm, + dropout, + bias, + adn_ordering, +): + batch_size, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + # test size computation + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + dilation=dilation, + unpooling=unpooling, + unpooling_indices=unpooling_indices, + dropout=dropout, + act=None, + norm=norm, + bias=bias, + adn_ordering=adn_ordering, + _input_size=input_size, + ) + output = net(input_tensor) + assert output.shape == (batch_size, 1, *net.final_size) + + # other checks + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + dilation=dilation, + unpooling=unpooling, + unpooling_indices=unpooling_indices, + dropout=dropout, + act=None, + norm=norm, + bias=bias, + adn_ordering=adn_ordering, + ) + assert isinstance(net.layer2[0], ConvTranspose2d) + with pytest.raises(IndexError): + net.layer2[1] # no adn at the end + + named_layers = list(net.named_children()) + if unpooling and unpooling_indices and unpooling_indices != []: + for i, idx in enumerate(unpooling_indices): + name, layer = named_layers[idx + 1 + i] + if idx == -1: + assert name == "init_unpool" + else: + assert name == f"unpool{idx}" + if net.unpooling[i][0] == "upsample": + assert isinstance(layer, Upsample) + else: + assert isinstance(layer, ConvTranspose2d) + else: + for name, layer in named_layers: + assert not isinstance(layer, Upsample) + assert "unpool" not in name + + assert ( + net.layer0[0].kernel_size == kernel_size + if isinstance(kernel_size, tuple) + else (kernel_size, kernel_size) + ) + assert ( + net.layer0[0].stride == stride + if isinstance(stride, tuple) + else (stride, stride) + ) + assert ( + net.layer0[0].padding == padding + if isinstance(padding, tuple) + else (padding, padding) + ) + assert ( + net.layer0[0].output_padding == output_padding + if isinstance(output_padding, tuple) + else (output_padding, output_padding) + ) + assert ( + net.layer0[0].dilation == dilation + if isinstance(dilation, tuple) + else (dilation, dilation) + ) + + if bias: + assert len(net.layer0[0].bias) > 0 + assert len(net.layer1[0].bias) > 0 + assert len(net.layer2[0].bias) > 0 + else: + assert net.layer0[0].bias is None + assert net.layer1[0].bias is None + assert net.layer2[0].bias is None + if isinstance(dropout, float) and "D" in adn_ordering: + assert net.layer0[1].D.p == dropout + assert net.layer1[1].D.p == dropout + + +def test_activation_parameters(input_tensor): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + act = ("ELU", {"alpha": 0.1}) + output_act = ("ELU", {"alpha": 0.2}) + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + act=act, + output_act=output_act, + ) + assert isinstance(net.layer0[1].A, ELU) + assert net.layer0[1].A.alpha == 0.1 + assert isinstance(net.layer1[1].A, ELU) + assert net.layer1[1].A.alpha == 0.1 + assert isinstance(net.output_act, ELU) + assert net.output_act.alpha == 0.2 + + net = ConvDecoder( + spatial_dims=spatial_dims, in_channels=in_channels, channels=[2, 4, 1], act=None + ) + with pytest.raises(AttributeError): + net.layer0[1].A + with pytest.raises(AttributeError): + net.layer1[1].A + assert net.output_act is None + + +def test_norm_parameters(input_tensor): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + norm = ("instance", {"momentum": 1.0}) + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + norm=norm, + ) + assert isinstance(net.layer0[1].N, InstanceNorm2d) + assert net.layer0[1].N.momentum == 1.0 + assert isinstance(net.layer1[1].N, InstanceNorm2d) + assert net.layer1[1].N.momentum == 1.0 + + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + norm=None, + ) + with pytest.raises(AttributeError): + net.layer0[1].N + with pytest.raises(AttributeError): + net.layer1[1].N + + +def test_unpool_parameters(input_tensor): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + unpooling = ("convtranspose", {"kernel_size": 3, "stride": 2}) + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + unpooling=unpooling, + unpooling_indices=[1], + ) + assert isinstance(net.unpool1, ConvTranspose2d) + assert net.unpool1.stride == (2, 2) + assert net.unpool1.kernel_size == (3, 3) + + +@pytest.mark.parametrize("adn_ordering", ["DAN", "NA", "A"]) +def test_adn_ordering(input_tensor, adn_ordering): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + dropout=0.1, + adn_ordering=adn_ordering, + act="elu", + norm="instance", + ) + objects = {"D": Dropout, "N": InstanceNorm2d, "A": ELU} + for i, letter in enumerate(adn_ordering): + assert isinstance(net.layer0[1][i], objects[letter]) + assert isinstance(net.layer1[1][i], objects[letter]) + for letter in set(["A", "D", "N"]) - set(adn_ordering): + with pytest.raises(AttributeError): + getattr(net.layer0[1], letter) + with pytest.raises(AttributeError): + getattr(net.layer1[1], letter) + + +@pytest.mark.parametrize( + "input_tensor", [torch.randn(2, 1, 16), torch.randn(2, 3, 20, 21, 22)] +) +def test_other_dimensions(input_tensor): + batch_size, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + _input_size=input_size, + ) + output = net(input_tensor) + assert output.shape == (batch_size, 1, *net.final_size) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"kernel_size": (3, 3, 3)}, + {"stride": [1, 1]}, + {"padding": [1, 1]}, + {"dilation": (1,)}, + {"unpooling_indices": [0, 1, 2, 3]}, + {"unpooling": "upsample", "unpooling_indices": [0]}, + {"norm": "group"}, + {"norm": "layer"}, + ], +) +def test_checks(input_tensor, kwargs): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + with pytest.raises(ValueError): + ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + **kwargs, + ) + + +@pytest.mark.parametrize( + "unpooling,error", + [ + (None, False), + ("abc", True), + ("upsample", True), + (("upsample",), True), + (("upsample", 2), True), + (("convtranspose", {"kernel_size": 2}), False), + (("upsample", {"scale_factor": 2}), False), + ( + [("upsample", {"scale_factor": 2}), ("convtranspose", {"kernel_size": 2})], + False, + ), + ([("upsample", {"scale_factor": 2}), None], True), + ([("upsample", {"scale_factor": 2}), "convtranspose"], True), + ([("upsample", {"scale_factor": 2}), ("convtranspose", 2)], True), + ( + [ + ("upsample", {"scale_factor": 2}), + ("convtranspose", {"kernel_size": 2}), + ("convtranspose", {"kernel_size": 2}), + ], + True, + ), + ], +) +def test_check_unpool_layer(input_tensor, unpooling, error): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + if error: + with pytest.raises(ValueError): + ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + unpooling=unpooling, + unpooling_indices=[0, 1], + ) + else: + ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + unpooling=unpooling, + unpooling_indices=[0, 1], + ) diff --git a/tests/unittests/monai_networks/nn/test_conv_encoder.py b/tests/unittests/monai_networks/nn/test_conv_encoder.py new file mode 100644 index 000000000..3a21859b8 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_conv_encoder.py @@ -0,0 +1,400 @@ +import pytest +import torch +from torch.nn import ( + ELU, + AdaptiveAvgPool2d, + AdaptiveMaxPool2d, + AvgPool2d, + Conv2d, + Dropout, + InstanceNorm2d, + MaxPool2d, +) + +from clinicadl.monai_networks.nn import ConvEncoder +from clinicadl.monai_networks.nn.layers.utils import ActFunction + + +@pytest.fixture +def input_tensor(): + return torch.randn(2, 1, 55, 54) + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(input_tensor, act): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + act=act, + output_act=act, + ) + output_shape = net(input_tensor).shape + assert len(output_shape) == 4 and output_shape[1] == 1 + + +@pytest.mark.parametrize( + "kernel_size,stride,padding,dilation,pooling,pooling_indices,norm,dropout,bias,adn_ordering", + [ + ( + 3, + 1, + 0, + 1, + ("adaptivemax", {"output_size": 1}), + [2], + "batch", + None, + True, + "ADN", + ), + ( + (4, 4), + (2, 1), + 2, + 2, + ("max", {"kernel_size": 2}), + [0, 1], + "instance", + 0.5, + False, + "DAN", + ), + ( + 5, + 1, + (2, 1), + 1, + [ + ("avg", {"kernel_size": 2}), + ("max", {"kernel_size": 2}), + ("adaptiveavg", {"output_size": (2, 3)}), + ], + [-1, 1, 2], + "syncbatch", + 0.5, + True, + "NA", + ), + (5, 1, 0, (1, 2), None, [0, 1], "instance", 0.0, True, "DN"), + ( + 5, + 1, + 2, + 1, + ("avg", {"kernel_size": 2}), + None, + ("group", {"num_groups": 2}), + None, + True, + "N", + ), + (5, 1, 2, 1, ("avg", {"kernel_size": 2}), None, None, None, True, ""), + ], +) +def test_params( + input_tensor, + kernel_size, + stride, + padding, + dilation, + pooling, + pooling_indices, + norm, + dropout, + bias, + adn_ordering, +): + batch_size, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + # test output size + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + pooling=pooling, + pooling_indices=pooling_indices, + dropout=dropout, + act=None, + norm=norm, + bias=bias, + adn_ordering=adn_ordering, + _input_size=input_size, + ) + output = net(input_tensor) + assert output.shape == (batch_size, 1, *net.final_size) + + # other checks + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + pooling=pooling, + pooling_indices=pooling_indices, + dropout=dropout, + act=None, + norm=norm, + bias=bias, + adn_ordering=adn_ordering, + ) + assert isinstance(net.layer2.conv, Conv2d) + with pytest.raises(IndexError): + net.layer2[1] # no adn at the end + + named_layers = list(net.named_children()) + if pooling and pooling_indices and pooling_indices != []: + for i, idx in enumerate(pooling_indices): + name, layer = named_layers[idx + 1 + i] + if idx == -1: + assert name == "init_pool" + else: + assert name == f"pool{idx}" + pooling_mode = net.pooling[i][0] + if pooling_mode == "max": + assert isinstance(layer, MaxPool2d) + elif pooling_mode == "avg": + assert isinstance(layer, AvgPool2d) + elif pooling_mode == "adaptivemax": + assert isinstance(layer, AdaptiveMaxPool2d) + else: + assert isinstance(layer, AdaptiveAvgPool2d) + else: + for name, layer in named_layers: + assert not isinstance(layer, AvgPool2d) or isinstance(layer, MaxPool2d) + assert "pool" not in name + + assert ( + net.layer0.conv.kernel_size == kernel_size + if isinstance(kernel_size, tuple) + else (kernel_size, kernel_size) + ) + assert ( + net.layer0.conv.stride == stride + if isinstance(stride, tuple) + else (stride, stride) + ) + assert ( + net.layer0.conv.padding == padding + if isinstance(padding, tuple) + else (padding, padding) + ) + assert ( + net.layer0.conv.dilation == dilation + if isinstance(dilation, tuple) + else (dilation, dilation) + ) + + if bias: + assert len(net.layer0.conv.bias) > 0 + assert len(net.layer1.conv.bias) > 0 + assert len(net.layer2.conv.bias) > 0 + else: + assert net.layer0.conv.bias is None + assert net.layer1.conv.bias is None + assert net.layer2.conv.bias is None + if isinstance(dropout, float) and "D" in adn_ordering: + assert net.layer0.adn.D.p == dropout + assert net.layer1.adn.D.p == dropout + + +def test_activation_parameters(input_tensor): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + act = ("ELU", {"alpha": 0.1}) + output_act = ("ELU", {"alpha": 0.2}) + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + act=act, + output_act=output_act, + ) + assert isinstance(net.layer0.adn.A, ELU) + assert net.layer0.adn.A.alpha == 0.1 + assert isinstance(net.layer1.adn.A, ELU) + assert net.layer1.adn.A.alpha == 0.1 + assert isinstance(net.output_act, ELU) + assert net.output_act.alpha == 0.2 + + net = ConvEncoder( + spatial_dims=spatial_dims, in_channels=in_channels, channels=[2, 4, 1], act=None + ) + with pytest.raises(AttributeError): + net.layer0.adn.A + with pytest.raises(AttributeError): + net.layer1.adn.A + assert net.output_act is None + + +def test_norm_parameters(input_tensor): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + norm = ("instance", {"momentum": 1.0}) + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + norm=norm, + ) + assert isinstance(net.layer0.adn.N, InstanceNorm2d) + assert net.layer0.adn.N.momentum == 1.0 + assert isinstance(net.layer1.adn.N, InstanceNorm2d) + assert net.layer1.adn.N.momentum == 1.0 + + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + norm=None, + ) + with pytest.raises(AttributeError): + net.layer0.adn.N + with pytest.raises(AttributeError): + net.layer1.adn.N + + +def test_pool_parameters(input_tensor): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + pooling = ("avg", {"kernel_size": 3, "stride": 2}) + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + pooling=pooling, + pooling_indices=[1], + ) + assert isinstance(net.pool1, AvgPool2d) + assert net.pool1.stride == 2 + assert net.pool1.kernel_size == 3 + + +@pytest.mark.parametrize("adn_ordering", ["DAN", "NA", "A"]) +def test_adn_ordering(input_tensor, adn_ordering): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + dropout=0.1, + adn_ordering=adn_ordering, + act="elu", + norm="instance", + ) + objects = {"D": Dropout, "N": InstanceNorm2d, "A": ELU} + for i, letter in enumerate(adn_ordering): + assert isinstance(net.layer0.adn[i], objects[letter]) + assert isinstance(net.layer1.adn[i], objects[letter]) + for letter in set(["A", "D", "N"]) - set(adn_ordering): + with pytest.raises(AttributeError): + getattr(net.layer0.adn, letter) + with pytest.raises(AttributeError): + getattr(net.layer1.adn, letter) + + +@pytest.mark.parametrize( + "input_tensor", [torch.randn(2, 1, 16), torch.randn(2, 3, 20, 21, 22)] +) +def test_other_dimensions(input_tensor): + batch_size, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + _input_size=input_size, + ) + output = net(input_tensor) + assert output.shape == (batch_size, 1, *net.final_size) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"kernel_size": (3, 3, 3)}, + {"stride": [1, 1]}, + {"padding": [1, 1]}, + {"dilation": (1,)}, + {"pooling_indices": [0, 1, 2, 3]}, + {"pooling": "avg", "pooling_indices": [0]}, + {"norm": "group"}, + {"_input_size": (1, 10, 10), "stride": 2, "channels": [2, 4, 6, 8]}, + ], +) +def test_checks(kwargs): + if "channels" not in kwargs: + kwargs["channels"] = [2, 4, 1] + if "in_channels" not in kwargs: + kwargs["in_channels"] = 1 + if "spatial_dims" not in kwargs: + kwargs["spatial_dims"] = 2 + with pytest.raises(ValueError): + ConvEncoder(**kwargs) + + +@pytest.mark.parametrize( + "pooling,error", + [ + (None, False), + ("abc", True), + ("max", True), + (("max",), True), + (("max", 3), True), + (("avg", {"stride": 1}), True), + (("avg", {"kernel_size": 1}), False), + (("avg", {"kernel_size": 1, "stride": 1}), False), + (("abc", {"kernel_size": 1, "stride": 1}), True), + ([("avg", {"kernel_size": 1}), ("max", {"kernel_size": 1})], False), + ([("avg", {"kernel_size": 1}), None], True), + ([("avg", {"kernel_size": 1}), "max"], True), + ([("avg", {"kernel_size": 1}), ("max", 3)], True), + ([("avg", {"kernel_size": 1}), ("max", {"stride": 1})], True), + ( + [ + ("avg", {"kernel_size": 1}), + ("max", {"stride": 1}), + ("max", {"stride": 1}), + ], + True, + ), + ], +) +def test_check_pool_layers(input_tensor, pooling, error): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + if error: + with pytest.raises(ValueError): + ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + pooling=pooling, + pooling_indices=[0, 1], + ) + else: + ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + pooling=pooling, + pooling_indices=[0, 1], + ) diff --git a/tests/unittests/monai_networks/nn/test_densenet.py b/tests/unittests/monai_networks/nn/test_densenet.py new file mode 100644 index 000000000..303a22a6e --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_densenet.py @@ -0,0 +1,138 @@ +import pytest +import torch + +from clinicadl.monai_networks.nn import DenseNet, get_densenet +from clinicadl.monai_networks.nn.densenet import SOTADenseNet +from clinicadl.monai_networks.nn.layers.utils import ActFunction + +INPUT_1D = torch.randn(3, 1, 16) +INPUT_2D = torch.randn(3, 2, 15, 16) +INPUT_3D = torch.randn(3, 3, 20, 21, 22) + + +@pytest.mark.parametrize( + "input_tensor,num_outputs,n_dense_layers,init_features,growth_rate,bottleneck_factor,act,output_act,dropout", + [ + (INPUT_1D, 2, (3, 4), 16, 8, 2, "relu", None, 0.1), + (INPUT_2D, None, (3, 4, 2), 9, 5, 3, "elu", "sigmoid", 0.0), + (INPUT_3D, 1, (2,), 4, 4, 2, "tanh", "sigmoid", 0.1), + ], +) +def test_densenet( + input_tensor, + num_outputs, + n_dense_layers, + init_features, + growth_rate, + bottleneck_factor, + act, + output_act, + dropout, +): + batch_size = input_tensor.shape[0] + net = DenseNet( + spatial_dims=len(input_tensor.shape[2:]), + in_channels=input_tensor.shape[1], + num_outputs=num_outputs, + n_dense_layers=n_dense_layers, + init_features=init_features, + growth_rate=growth_rate, + bottleneck_factor=bottleneck_factor, + act=act, + output_act=output_act, + dropout=dropout, + ) + output = net(input_tensor) + + if num_outputs: + assert output.shape == (batch_size, num_outputs) + else: + assert len(output.shape) == len(input_tensor.shape) + + if output_act and num_outputs: + assert net.fc.output_act is not None + elif output_act and num_outputs is None: + with pytest.raises(AttributeError): + net.fc.output_act + + features = net.features + for i, n in enumerate(n_dense_layers, start=1): + dense_block = getattr(features, f"denseblock{i}") + for k in range(1, n + 1): + dense_layer = getattr(dense_block, f"denselayer{k}").layers + assert dense_layer.conv1.out_channels == growth_rate * bottleneck_factor + assert dense_layer.conv2.out_channels == growth_rate + if dropout: + assert dense_layer.dropout.p == dropout + with pytest.raises(AttributeError): + getattr(dense_block, f"denseblock{n+1}") + with pytest.raises(AttributeError): + getattr(dense_block, f"denseblock{i+1}") + + assert features.conv0.out_channels == init_features + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(act): + batch_size = INPUT_2D.shape[0] + net = DenseNet( + spatial_dims=len(INPUT_2D.shape[2:]), + in_channels=INPUT_2D.shape[1], + n_dense_layers=(2, 2), + num_outputs=2, + act=act, + ) + assert net(INPUT_2D).shape == (batch_size, 2) + + +def test_activation_parameters(): + act = ("ELU", {"alpha": 0.1}) + output_act = ("ELU", {"alpha": 0.2}) + net = DenseNet( + spatial_dims=len(INPUT_2D.shape[2:]), + in_channels=INPUT_2D.shape[1], + num_outputs=2, + n_dense_layers=(2, 2), + act=act, + output_act=output_act, + ) + assert isinstance(net.features.denseblock1.denselayer1.layers.act1, torch.nn.ELU) + assert net.features.denseblock1.denselayer1.layers.act1.alpha == 0.1 + assert isinstance(net.fc.output_act, torch.nn.ELU) + assert net.fc.output_act.alpha == 0.2 + + +@pytest.mark.parametrize( + "name,num_outputs,output_act", + [ + (SOTADenseNet.DENSENET_121, 1, "sigmoid"), + (SOTADenseNet.DENSENET_161, 2, None), + (SOTADenseNet.DENSENET_169, None, "sigmoid"), + (SOTADenseNet.DENSENET_201, None, None), + ], +) +def test_get_densenet(name, num_outputs, output_act): + densenet = get_densenet( + name, num_outputs=num_outputs, output_act=output_act, pretrained=True + ) + if num_outputs: + assert densenet.fc.out.out_features == num_outputs + else: + assert densenet.fc is None + + if output_act and num_outputs: + assert densenet.fc.output_act is not None + elif output_act and num_outputs is None: + with pytest.raises(AttributeError): + densenet.fc.output_act + + +def test_get_densenet_output(): + from torchvision.models import densenet121 + + densenet = get_densenet( + SOTADenseNet.DENSENET_121, num_outputs=None, pretrained=True + ).features + gt = densenet121(weights="DEFAULT").features + x = torch.randn(1, 3, 128, 128) + assert (densenet(x) == gt(x)).all() diff --git a/tests/unittests/monai_networks/nn/test_generator.py b/tests/unittests/monai_networks/nn/test_generator.py new file mode 100644 index 000000000..1fa4fffa7 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_generator.py @@ -0,0 +1,67 @@ +import pytest +import torch +from torch.nn import Flatten, Linear + +from clinicadl.monai_networks.nn import MLP, ConvDecoder, Generator + + +@pytest.fixture +def input_tensor(): + return torch.randn(2, 8) + + +@pytest.mark.parametrize("channels", [(), (2, 4)]) +@pytest.mark.parametrize( + "mlp_args", [None, {"hidden_channels": []}, {"hidden_channels": (2, 4)}] +) +@pytest.mark.parametrize("start_shape", [(1, 5), (1, 5, 5), (1, 5, 5)]) +def test_generator(input_tensor, start_shape, channels, mlp_args): + latent_size = input_tensor.shape[1] + net = Generator( + latent_size=latent_size, + start_shape=start_shape, + conv_args={"channels": channels}, + mlp_args=mlp_args, + ) + output = net(input_tensor) + assert output.shape[1:] == net.output_shape + assert isinstance(net.convolutions, ConvDecoder) + assert isinstance(net.mlp, MLP) + + if mlp_args is None or mlp_args["hidden_channels"] == []: + children = net.mlp.children() + assert isinstance(next(children), Flatten) + assert isinstance(next(children).linear, Linear) + with pytest.raises(StopIteration): + next(children) + + if channels == []: + with pytest.raises(StopIteration): + next(net.convolutions.parameters()) + + +@pytest.mark.parametrize( + "conv_args,mlp_args", + [ + (None, {"hidden_channels": [2]}), + ({"channels": [2]}, {}), + ], +) +def test_checks(conv_args, mlp_args): + with pytest.raises(ValueError): + Generator( + latent_size=2, + start_shape=(1, 10, 10), + conv_args=conv_args, + mlp_args=mlp_args, + ) + + +def test_params(): + conv_args = {"channels": [2], "act": "celu"} + mlp_args = {"hidden_channels": [2], "act": "relu"} + net = Generator( + latent_size=2, start_shape=(1, 10, 10), conv_args=conv_args, mlp_args=mlp_args + ) + assert net.convolutions.act == "celu" + assert net.mlp.act == "relu" diff --git a/tests/unittests/monai_networks/nn/test_mlp.py b/tests/unittests/monai_networks/nn/test_mlp.py new file mode 100644 index 000000000..5eb3105a8 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_mlp.py @@ -0,0 +1,125 @@ +import pytest +import torch +from torch.nn import ELU, Dropout, InstanceNorm1d, Linear + +from clinicadl.monai_networks.nn import MLP +from clinicadl.monai_networks.nn.layers.utils import ActFunction + + +@pytest.fixture +def input_tensor(): + return torch.randn(8, 10) + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(input_tensor, act): + net = MLP( + in_channels=10, out_channels=2, hidden_channels=[6, 4], act=act, output_act=act + ) + assert net(input_tensor).shape == (8, 2) + + +@pytest.mark.parametrize( + "dropout,norm,bias,adn_ordering", + [ + (None, "batch", True, "ADN"), + (0.5, "layer", False, "DAN"), + (0.5, "syncbatch", True, "NA"), + (0.0, "instance", True, "DN"), + (None, ("group", {"num_groups": 2}), True, "N"), + (0.5, None, True, "ADN"), + (0.5, "batch", True, ""), + ], +) +def test_params(input_tensor, dropout, norm, bias, adn_ordering): + net = MLP( + in_channels=10, + out_channels=2, + hidden_channels=[6, 4], + dropout=dropout, + norm=norm, + act=None, + bias=bias, + adn_ordering=adn_ordering, + ) + assert net(input_tensor).shape == (8, 2) + assert isinstance(net.output.linear, Linear) + + if bias: + assert len(net.hidden0.linear.bias) > 0 + assert len(net.hidden1.linear.bias) > 0 + assert len(net.output.linear.bias) > 0 + else: + assert net.hidden0.linear.bias is None + assert net.hidden1.linear.bias is None + assert net.output.linear.bias is None + if isinstance(dropout, float) and "D" in adn_ordering: + assert net.hidden0.adn.D.p == dropout + assert net.hidden1.adn.D.p == dropout + + +def test_activation_parameters(): + act = ("ELU", {"alpha": 0.1}) + output_act = ("ELU", {"alpha": 0.2}) + net = MLP( + in_channels=10, + out_channels=2, + hidden_channels=[6, 4], + act=act, + output_act=output_act, + ) + assert isinstance(net.hidden0.adn.A, ELU) + assert net.hidden0.adn.A.alpha == 0.1 + assert isinstance(net.hidden1.adn.A, ELU) + assert net.hidden1.adn.A.alpha == 0.1 + assert isinstance(net.output.output_act, ELU) + assert net.output.output_act.alpha == 0.2 + + net = MLP(in_channels=10, out_channels=2, hidden_channels=[6, 4], act=None) + with pytest.raises(AttributeError): + net.hidden0.adn.A + with pytest.raises(AttributeError): + net.hidden1.adn.A + assert net.output.output_act is None + + +def test_norm_parameters(): + norm = ("instance", {"momentum": 1.0}) + net = MLP(in_channels=10, out_channels=2, hidden_channels=[6, 4], norm=norm) + assert isinstance(net.hidden0.adn.N, InstanceNorm1d) + assert net.hidden0.adn.N.momentum == 1.0 + assert isinstance(net.hidden1.adn.N, InstanceNorm1d) + assert net.hidden1.adn.N.momentum == 1.0 + + net = MLP(in_channels=10, out_channels=2, hidden_channels=[6, 4], act=None) + with pytest.raises(AttributeError): + net.layer_0[1].N + with pytest.raises(AttributeError): + net.layer_1[1].N + + +@pytest.mark.parametrize("adn_ordering", ["DAN", "NA", "A"]) +def test_adn_ordering(adn_ordering): + net = MLP( + in_channels=10, + out_channels=2, + hidden_channels=[6, 4], + dropout=0.1, + adn_ordering=adn_ordering, + act="elu", + norm="instance", + ) + objects = {"D": Dropout, "N": InstanceNorm1d, "A": ELU} + for i, letter in enumerate(adn_ordering): + assert isinstance(net.hidden0.adn[i], objects[letter]) + assert isinstance(net.hidden1.adn[i], objects[letter]) + for letter in set(["A", "D", "N"]) - set(adn_ordering): + with pytest.raises(AttributeError): + getattr(net.hidden0.adn, letter) + with pytest.raises(AttributeError): + getattr(net.hidden1.adn, letter) + + +def test_checks(): + with pytest.raises(ValueError): + MLP(in_channels=10, out_channels=2, hidden_channels=[6, 4], norm="group") diff --git a/tests/unittests/monai_networks/nn/test_resnet.py b/tests/unittests/monai_networks/nn/test_resnet.py new file mode 100644 index 000000000..20ed028d0 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_resnet.py @@ -0,0 +1,173 @@ +import pytest +import torch + +from clinicadl.monai_networks.nn import ResNet, get_resnet +from clinicadl.monai_networks.nn.layers.resnet import ResNetBlock, ResNetBottleneck +from clinicadl.monai_networks.nn.layers.utils import ActFunction +from clinicadl.monai_networks.nn.resnet import SOTAResNet + +INPUT_1D = torch.randn(3, 1, 16) +INPUT_2D = torch.randn(3, 2, 15, 16) +INPUT_3D = torch.randn(3, 3, 20, 21, 22) + + +@pytest.mark.parametrize( + "input_tensor,num_outputs,block_type,n_res_blocks,n_features,init_conv_size,init_conv_stride,bottleneck_reduction,act,output_act", + [ + (INPUT_1D, 2, "basic", (2, 3), (4, 8), 7, 1, 2, "relu", None), + ( + INPUT_2D, + None, + "bottleneck", + (3, 2, 2), + (8, 12, 16), + 5, + (2, 1), + 4, + "elu", + "sigmoid", + ), + (INPUT_3D, 1, "bottleneck", (2,), (3,), (4, 3, 4), 2, 1, "tanh", "sigmoid"), + ], +) +def test_resnet( + input_tensor, + num_outputs, + block_type, + n_res_blocks, + n_features, + init_conv_size, + init_conv_stride, + bottleneck_reduction, + act, + output_act, +): + batch_size = input_tensor.shape[0] + spatial_dims = len(input_tensor.shape[2:]) + net = ResNet( + spatial_dims=spatial_dims, + in_channels=input_tensor.shape[1], + num_outputs=num_outputs, + block_type=block_type, + n_res_blocks=n_res_blocks, + n_features=n_features, + init_conv_size=init_conv_size, + init_conv_stride=init_conv_stride, + bottleneck_reduction=bottleneck_reduction, + act=act, + output_act=output_act, + ) + output = net(input_tensor) + + if num_outputs: + assert output.shape == (batch_size, num_outputs) + else: + assert len(output.shape) == len(input_tensor.shape) + + if output_act and num_outputs: + assert net.fc.output_act is not None + elif output_act and num_outputs is None: + with pytest.raises(AttributeError): + net.fc.output_act + + for i, (n_blocks, n_feats) in enumerate(zip(n_res_blocks, n_features), start=1): + layer = getattr(net, f"layer{i}") + for k in range(n_blocks): + res_block = layer[k] + if block_type == "basic": + assert isinstance(res_block, ResNetBlock) + else: + assert isinstance(res_block, ResNetBottleneck) + if block_type == "basic": + assert res_block.conv2.out_channels == n_feats + else: + assert res_block.conv1.out_channels == n_feats // bottleneck_reduction + assert res_block.conv3.out_channels == n_feats + with pytest.raises(IndexError): + layer[k + 1] + with pytest.raises(AttributeError): + getattr(net, f"layer{i+1}") + + assert ( + net.conv0.kernel_size == init_conv_size + if isinstance(init_conv_size, tuple) + else (init_conv_size,) * spatial_dims + ) + assert ( + net.conv0.stride == init_conv_stride + if isinstance(init_conv_stride, tuple) + else (init_conv_stride,) * spatial_dims + ) + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(act): + batch_size = INPUT_2D.shape[0] + net = ResNet( + spatial_dims=len(INPUT_2D.shape[2:]), + in_channels=INPUT_2D.shape[1], + num_outputs=2, + n_features=(8, 16), + n_res_blocks=(2, 2), + act=act, + ) + assert net(INPUT_2D).shape == (batch_size, 2) + + +def test_activation_parameters(): + act = ("ELU", {"alpha": 0.1}) + output_act = ("ELU", {"alpha": 0.2}) + net = ResNet( + spatial_dims=len(INPUT_2D.shape[2:]), + in_channels=INPUT_2D.shape[1], + num_outputs=2, + n_features=(8, 16), + n_res_blocks=(2, 2), + act=act, + output_act=output_act, + ) + assert isinstance(net.layer1[0].act1, torch.nn.ELU) + assert net.layer1[0].act1.alpha == 0.1 + assert isinstance(net.layer2[1].act2, torch.nn.ELU) + assert net.layer2[1].act2.alpha == 0.1 + assert isinstance(net.act0, torch.nn.ELU) + assert net.act0.alpha == 0.1 + assert isinstance(net.fc.output_act, torch.nn.ELU) + assert net.fc.output_act.alpha == 0.2 + + +@pytest.mark.parametrize( + "name,num_outputs,output_act", + [ + (SOTAResNet.RESNET_18, 1, "sigmoid"), + (SOTAResNet.RESNET_34, 2, None), + (SOTAResNet.RESNET_50, None, "sigmoid"), + (SOTAResNet.RESNET_101, None, None), + (SOTAResNet.RESNET_152, None, None), + ], +) +def test_get_resnet(name, num_outputs, output_act): + resnet = get_resnet( + name, num_outputs=num_outputs, output_act=output_act, pretrained=True + ) + if num_outputs: + assert resnet.fc.out.out_features == num_outputs + else: + assert resnet.fc is None + + if output_act and num_outputs: + assert resnet.fc.output_act is not None + elif output_act and num_outputs is None: + with pytest.raises(AttributeError): + resnet.fc.output_act + + +def test_get_resnet_output(): + from torchvision.models import resnet18 + + resnet = get_resnet(SOTAResNet.RESNET_18, num_outputs=None, pretrained=True) + gt = resnet18(weights="DEFAULT") + gt.avgpool = torch.nn.Identity() + gt.fc = torch.nn.Identity() + x = torch.randn(1, 3, 128, 128) + assert (torch.flatten(resnet(x), start_dim=1) == gt(x)).all() diff --git a/tests/unittests/monai_networks/nn/test_senet.py b/tests/unittests/monai_networks/nn/test_senet.py new file mode 100644 index 000000000..b46eb663a --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_senet.py @@ -0,0 +1,172 @@ +import pytest +import torch + +from clinicadl.monai_networks.nn import SEResNet, get_seresnet +from clinicadl.monai_networks.nn.layers.senet import SEResNetBlock, SEResNetBottleneck +from clinicadl.monai_networks.nn.layers.utils import ActFunction +from clinicadl.monai_networks.nn.senet import SOTAResNet + +INPUT_1D = torch.randn(3, 1, 16) +INPUT_2D = torch.randn(3, 2, 15, 16) +INPUT_3D = torch.randn(3, 3, 20, 21, 22) + + +@pytest.mark.parametrize( + "input_tensor,num_outputs,block_type,n_res_blocks,n_features,init_conv_size,init_conv_stride,bottleneck_reduction,act,output_act,se_reduction", + [ + (INPUT_1D, 2, "basic", (2, 3), (4, 8), 7, 1, 2, "relu", None, 4), + ( + INPUT_2D, + None, + "bottleneck", + (3, 2, 2), + (8, 12, 16), + 5, + (2, 1), + 4, + "elu", + "sigmoid", + 2, + ), + (INPUT_3D, 1, "bottleneck", (2,), (3,), (4, 3, 4), 2, 1, "tanh", "sigmoid", 2), + ], +) +def test_seresnet( + input_tensor, + num_outputs, + block_type, + n_res_blocks, + n_features, + init_conv_size, + init_conv_stride, + bottleneck_reduction, + act, + output_act, + se_reduction, +): + batch_size = input_tensor.shape[0] + spatial_dims = len(input_tensor.shape[2:]) + net = SEResNet( + spatial_dims=spatial_dims, + in_channels=input_tensor.shape[1], + num_outputs=num_outputs, + block_type=block_type, + n_res_blocks=n_res_blocks, + n_features=n_features, + init_conv_size=init_conv_size, + init_conv_stride=init_conv_stride, + bottleneck_reduction=bottleneck_reduction, + act=act, + output_act=output_act, + se_reduction=se_reduction, + ) + output = net(input_tensor) + + if num_outputs: + assert output.shape == (batch_size, num_outputs) + else: + assert len(output.shape) == len(input_tensor.shape) + + if output_act and num_outputs: + assert net.fc.output_act is not None + elif output_act and num_outputs is None: + with pytest.raises(AttributeError): + net.fc.output_act + + for i, (n_blocks, n_feats) in enumerate(zip(n_res_blocks, n_features), start=1): + layer = getattr(net, f"layer{i}") + for k in range(n_blocks): + res_block = layer[k] + if block_type == "basic": + assert isinstance(res_block, SEResNetBlock) + else: + assert isinstance(res_block, SEResNetBottleneck) + if block_type == "basic": + assert res_block.conv2.out_channels == n_feats + else: + assert res_block.conv1.out_channels == n_feats // bottleneck_reduction + assert res_block.conv3.out_channels == n_feats + with pytest.raises(IndexError): + layer[k + 1] + with pytest.raises(AttributeError): + getattr(net, f"layer{i+1}") + + assert ( + net.conv0.kernel_size == init_conv_size + if isinstance(init_conv_size, tuple) + else (init_conv_size,) * spatial_dims + ) + assert ( + net.conv0.stride == init_conv_stride + if isinstance(init_conv_stride, tuple) + else (init_conv_stride,) * spatial_dims + ) + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(act): + batch_size = INPUT_2D.shape[0] + net = SEResNet( + spatial_dims=len(INPUT_2D.shape[2:]), + in_channels=INPUT_2D.shape[1], + num_outputs=2, + n_features=(8, 16), + n_res_blocks=(2, 2), + act=act, + se_reduction=2, + ) + assert net(INPUT_2D).shape == (batch_size, 2) + + +def test_activation_parameters(): + act = ("ELU", {"alpha": 0.1}) + output_act = ("ELU", {"alpha": 0.2}) + net = SEResNet( + spatial_dims=len(INPUT_2D.shape[2:]), + in_channels=INPUT_2D.shape[1], + num_outputs=2, + n_features=(8, 16), + n_res_blocks=(2, 2), + act=act, + output_act=output_act, + se_reduction=2, + ) + assert isinstance(net.layer1[0].act1, torch.nn.ELU) + assert net.layer1[0].act1.alpha == 0.1 + assert isinstance(net.layer2[1].act2, torch.nn.ELU) + assert net.layer2[1].act2.alpha == 0.1 + assert isinstance(net.act0, torch.nn.ELU) + assert net.act0.alpha == 0.1 + assert isinstance(net.fc.output_act, torch.nn.ELU) + assert net.fc.output_act.alpha == 0.2 + + +@pytest.mark.parametrize( + "name,num_outputs,output_act", + [ + (SOTAResNet.SE_RESNET_50, 1, "sigmoid"), + (SOTAResNet.SE_RESNET_101, 2, None), + (SOTAResNet.SE_RESNET_152, None, "sigmoid"), + ], +) +def test_get_seresnet(name, num_outputs, output_act): + seresnet = get_seresnet( + name, + num_outputs=num_outputs, + output_act=output_act, + ) + if num_outputs: + assert seresnet.fc.out.out_features == num_outputs + else: + assert seresnet.fc is None + + if output_act and num_outputs: + assert seresnet.fc.output_act is not None + elif output_act and num_outputs is None: + with pytest.raises(AttributeError): + seresnet.fc.output_act + + +def test_get_seresnet_error(): + with pytest.raises(ValueError): + get_seresnet(SOTAResNet.SE_RESNET_50, num_outputs=1, pretrained=True) diff --git a/tests/unittests/monai_networks/nn/test_unet.py b/tests/unittests/monai_networks/nn/test_unet.py new file mode 100644 index 000000000..b8f06faa8 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_unet.py @@ -0,0 +1,127 @@ +import pytest +import torch + +from clinicadl.monai_networks.nn import UNet +from clinicadl.monai_networks.nn.layers.utils import ActFunction + +INPUT_1D = torch.randn(2, 1, 16) +INPUT_2D = torch.randn(2, 2, 32, 64) +INPUT_3D = torch.randn(2, 3, 16, 32, 8) + + +@pytest.mark.parametrize( + "input_tensor,out_channels,channels,act,output_act,dropout,error", + [ + (INPUT_1D, 1, (2, 3, 4), "relu", "sigmoid", None, False), + (INPUT_2D, 1, (2, 4, 5), "relu", None, 0.0, False), + (INPUT_3D, 2, (2, 3), None, ("softmax", {"dim": 1}), 0.1, False), + ( + INPUT_3D, + 2, + (2,), + None, + ("softmax", {"dim": 1}), + 0.1, + True, + ), # channels length is less than 2 + ], +) +def test_unet(input_tensor, out_channels, channels, act, output_act, dropout, error): + batch_size, in_channels, *img_size = input_tensor.shape + spatial_dims = len(img_size) + if error: + with pytest.raises(ValueError): + UNet( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + channels=channels, + act=act, + output_act=output_act, + dropout=dropout, + ) + else: + net = UNet( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + channels=channels, + act=act, + output_act=output_act, + dropout=dropout, + ) + + out = net(input_tensor) + assert out.shape == (batch_size, out_channels, *img_size) + + if output_act: + assert net.output_act is not None + else: + assert net.output_act is None + + assert net.doubleconv[1].conv.out_channels == channels[0] + if dropout: + assert net.doubleconv[1].adn.D.p == dropout + else: + with pytest.raises(AttributeError): + net.doubleconv[1].conv.adn.D + + for i in range(1, len(channels)): + down = getattr(net, f"down{i}").doubleconv + up = getattr(net, f"doubleconv{i}") + assert down[0].conv.in_channels == channels[i - 1] + assert down[1].conv.out_channels == channels[i] + assert up[0].conv.in_channels == channels[i - 1] * 2 + assert up[1].conv.out_channels == channels[i - 1] + for m in (down, up): + if dropout is not None: + assert m[1].adn.D.p == dropout + else: + with pytest.raises(AttributeError): + m[1].adn.D + with pytest.raises(AttributeError): + down = getattr(net, f"down{i+1}") + with pytest.raises(AttributeError): + getattr(net, f"doubleconv{i+1}") + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(act): + batch_size, in_channels, *img_size = INPUT_2D.shape + net = UNet( + spatial_dims=2, + in_channels=in_channels, + out_channels=2, + channels=(2, 4), + act=act, + output_act=act, + ) + assert net(INPUT_2D).shape == (batch_size, 2, *img_size) + + +def test_activation_parameters(): + in_channels = INPUT_2D.shape[1] + act = ("ELU", {"alpha": 0.1}) + output_act = ("ELU", {"alpha": 0.2}) + net = UNet( + spatial_dims=2, + in_channels=in_channels, + out_channels=2, + channels=(2, 4), + act=act, + output_act=output_act, + ) + assert isinstance(net.doubleconv[0].adn.A, torch.nn.ELU) + assert net.doubleconv[0].adn.A.alpha == 0.1 + + assert isinstance(net.down1.doubleconv[0].adn.A, torch.nn.ELU) + assert net.down1.doubleconv[0].adn.A.alpha == 0.1 + + assert isinstance(net.upsample1[1].adn.A, torch.nn.ELU) + assert net.upsample1[1].adn.A.alpha == 0.1 + + assert isinstance(net.doubleconv1[1].adn.A, torch.nn.ELU) + assert net.doubleconv1[1].adn.A.alpha == 0.1 + + assert isinstance(net.output_act, torch.nn.ELU) + assert net.output_act.alpha == 0.2 diff --git a/tests/unittests/monai_networks/nn/test_vae.py b/tests/unittests/monai_networks/nn/test_vae.py new file mode 100644 index 000000000..ca2fb24b8 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_vae.py @@ -0,0 +1,99 @@ +import pytest +import torch +from numpy import isclose +from torch.nn import ReLU + +from clinicadl.monai_networks.nn import VAE + + +@pytest.mark.parametrize( + "input_tensor,kernel_size,stride,padding,dilation,pooling,pooling_indices", + [ + (torch.randn(2, 1, 21), 3, 1, 0, 1, ("max", {"kernel_size": 2}), [0]), + ( + torch.randn(2, 1, 65, 85), + (3, 5), + (1, 2), + 0, + (1, 2), + ("max", {"kernel_size": 2, "stride": 1}), + [0], + ), + ( + torch.randn(2, 1, 64, 62, 61), # to test output padding + 4, + 2, + (1, 1, 0), + 1, + ("avg", {"kernel_size": 3, "stride": 2}), + [0], + ), + ( + torch.randn(2, 1, 51, 55, 45), + 3, + 2, + 0, + 1, + ("max", {"kernel_size": 2, "ceil_mode": True}), + [0, 1], + ), + ( + torch.randn(2, 1, 51, 55, 45), + 3, + 2, + 0, + 1, + [ + ("max", {"kernel_size": 2, "ceil_mode": True}), + ("max", {"kernel_size": 2, "stride": 1, "ceil_mode": False}), + ], + [0, 1], + ), + ], +) +def test_output_shape( + input_tensor, kernel_size, stride, padding, dilation, pooling, pooling_indices +): + latent_size = 3 + net = VAE( + in_shape=input_tensor.shape[1:], + latent_size=latent_size, + conv_args={ + "channels": [2, 4, 8], + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + "pooling": pooling, + "pooling_indices": pooling_indices, + }, + ) + recon, mu, log_var = net(input_tensor) + assert recon.shape == input_tensor.shape + assert mu.shape == (input_tensor.shape[0], latent_size) + assert log_var.shape == (input_tensor.shape[0], latent_size) + + +def test_mu_log_var(): + net = VAE( + in_shape=(1, 5, 5), + latent_size=4, + conv_args={"channels": []}, + mlp_args={"hidden_channels": [12], "output_act": "relu", "act": "celu"}, + ) + assert net.mu.linear.in_features == 12 + assert net.log_var.linear.in_features == 12 + assert isinstance(net.mu.output_act, ReLU) + assert isinstance(net.log_var.output_act, ReLU) + assert net.encoder(torch.randn(2, 1, 5, 5)).shape == (2, 12) + _, mu, log_var = net(torch.randn(2, 1, 5, 5)) + assert not isclose(mu.detach().numpy(), log_var.detach().numpy()).all() + + net = VAE( + in_shape=(1, 5, 5), + latent_size=4, + conv_args={"channels": []}, + mlp_args={"hidden_channels": [12]}, + ) + assert net.mu.linear.in_features == 12 + assert net.log_var.linear.in_features == 12 diff --git a/tests/unittests/monai_networks/nn/test_vit.py b/tests/unittests/monai_networks/nn/test_vit.py new file mode 100644 index 000000000..741d6e5f8 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_vit.py @@ -0,0 +1,279 @@ +import numpy as np +import pytest +import torch + +from clinicadl.monai_networks.nn import ViT, get_vit +from clinicadl.monai_networks.nn.layers.utils import ActFunction +from clinicadl.monai_networks.nn.vit import SOTAViT + +INPUT_1D = torch.randn(2, 1, 16) +INPUT_2D = torch.randn(2, 2, 15, 16) +INPUT_3D = torch.randn(2, 3, 24, 24, 24) + + +@pytest.mark.parametrize( + "input_tensor,patch_size,num_outputs,embedding_dim,num_layers,num_heads,mlp_dim,pos_embed_type,output_act,dropout,error", + [ + (INPUT_1D, 4, 1, 25, 3, 5, 26, None, "softmax", None, False), + ( + INPUT_1D, + 5, + 1, + 25, + 3, + 5, + 26, + None, + "softmax", + None, + True, + ), # img not divisible by patch + ( + INPUT_1D, + 4, + 1, + 25, + 3, + 4, + 26, + None, + "softmax", + None, + True, + ), # embedding not divisible by num heads + (INPUT_1D, 4, 1, 24, 5, 4, 26, "sincos", "softmax", None, True), # sincos + (INPUT_2D, (3, 4), None, 24, 2, 4, 42, "learnable", "tanh", 0.1, False), + ( + INPUT_2D, + 4, + None, + 24, + 2, + 6, + 42, + "learnable", + "tanh", + 0.1, + True, + ), # img not divisible by patch + ( + INPUT_2D, + (3, 4), + None, + 24, + 2, + 5, + 42, + "learnable", + "tanh", + 0.1, + True, + ), # embedding not divisible by num heads + ( + INPUT_2D, + (3, 4), + None, + 18, + 2, + 6, + 42, + "sincos", + "tanh", + 0.1, + True, + ), # sincos : embedding not divisible by 4 + (INPUT_2D, (3, 4), None, 24, 2, 6, 42, "sincos", "tanh", 0.1, False), + ( + INPUT_3D, + 6, + 2, + 15, + 2, + 3, + 42, + "sincos", + None, + 0.0, + True, + ), # sincos : embedding not divisible by 6 + (INPUT_3D, 6, 2, 18, 2, 3, 42, "sincos", None, 0.0, False), + ], +) +def test_vit( + input_tensor, + patch_size, + num_outputs, + embedding_dim, + num_layers, + num_heads, + mlp_dim, + pos_embed_type, + output_act, + dropout, + error, +): + batch_size = input_tensor.shape[0] + img_size = input_tensor.shape[2:] + spatial_dims = len(img_size) + if error: + with pytest.raises(ValueError): + ViT( + in_shape=input_tensor.shape[1:], + patch_size=patch_size, + num_outputs=num_outputs, + embedding_dim=embedding_dim, + num_layers=num_layers, + num_heads=num_heads, + mlp_dim=mlp_dim, + pos_embed_type=pos_embed_type, + output_act=output_act, + dropout=dropout, + ) + else: + net = ViT( + in_shape=input_tensor.shape[1:], + patch_size=patch_size, + num_outputs=num_outputs, + embedding_dim=embedding_dim, + num_layers=num_layers, + num_heads=num_heads, + mlp_dim=mlp_dim, + pos_embed_type=pos_embed_type, + output_act=output_act, + dropout=dropout, + ) + output = net(input_tensor) + + if num_outputs: + assert output.shape == (batch_size, num_outputs) + else: + n_patches = int( + np.prod( + np.array(img_size) + // np.array( + patch_size + if isinstance(patch_size, tuple) + else (patch_size,) * spatial_dims + ) + ) + ) + assert output.shape == (batch_size, n_patches, embedding_dim) + + if output_act and num_outputs: + assert net.fc.output_act is not None + elif output_act and num_outputs is None: + with pytest.raises(AttributeError): + net.fc.output_act + + assert net.conv_proj.out_channels == embedding_dim + encoder = net.encoder.layers + for transformer_block in encoder: + assert isinstance(transformer_block.norm1, torch.nn.LayerNorm) + assert isinstance(transformer_block.norm2, torch.nn.LayerNorm) + assert transformer_block.self_attention.num_heads == num_heads + assert transformer_block.self_attention.dropout == ( + dropout if dropout is not None else 0.0 + ) + assert transformer_block.self_attention.embed_dim == embedding_dim + assert transformer_block.mlp[0].out_features == mlp_dim + assert transformer_block.mlp[2].p == ( + dropout if dropout is not None else 0.0 + ) + assert transformer_block.mlp[4].p == ( + dropout if dropout is not None else 0.0 + ) + assert net.encoder.dropout.p == (dropout if dropout is not None else 0.0) + assert isinstance(net.encoder.norm, torch.nn.LayerNorm) + + pos_embedding = net.encoder.pos_embedding + if pos_embed_type is None: + assert not pos_embedding.requires_grad + assert (pos_embedding == torch.zeros_like(pos_embedding)).all() + elif pos_embed_type == "sincos": + assert not pos_embedding.requires_grad + if num_outputs: + assert ( + pos_embedding[0, 1, 0] == 0.0 + ) # first element of of sincos embedding of first patch is zero + else: + assert pos_embedding[0, 0, 0] == 0.0 + else: + assert pos_embedding.requires_grad + if num_outputs: + assert pos_embedding[0, 1, 0] != 0.0 + else: + assert pos_embedding[0, 0, 0] != 0.0 + + with pytest.raises(IndexError): + encoder[num_layers] + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(act): + batch_size = INPUT_2D.shape[0] + net = ViT( + in_shape=INPUT_2D.shape[1:], + patch_size=(3, 4), + num_outputs=1, + embedding_dim=12, + num_layers=2, + num_heads=3, + mlp_dim=24, + output_act=act, + ) + assert net(INPUT_2D).shape == (batch_size, 1) + + +def test_activation_parameters(): + output_act = ("ELU", {"alpha": 0.2}) + net = ViT( + in_shape=(1, 12, 12), + patch_size=3, + num_outputs=1, + embedding_dim=12, + num_layers=2, + num_heads=3, + mlp_dim=24, + output_act=output_act, + ) + assert isinstance(net.fc.output_act, torch.nn.ELU) + assert net.fc.output_act.alpha == 0.2 + + +@pytest.mark.parametrize( + "name,num_outputs,output_act,img_size", + [ + (SOTAViT.B_16, 1, "sigmoid", (224, 224)), + (SOTAViT.B_32, 2, None, (224, 224)), + (SOTAViT.L_16, None, "sigmoid", (224, 224)), + (SOTAViT.L_32, None, None, (224, 224)), + ], +) +def test_get_vit(name, num_outputs, output_act, img_size): + input_tensor = torch.randn(1, 3, *img_size) + + vit = get_vit(name, num_outputs=num_outputs, output_act=output_act, pretrained=True) + if num_outputs: + assert vit.fc.out.out_features == num_outputs + else: + assert vit.fc is None + + if output_act and num_outputs: + assert vit.fc.output_act is not None + elif output_act and num_outputs is None: + assert vit.fc is None + + vit(input_tensor) + + +def test_get_vit_output(): + from torchvision.models import vit_b_16 + + gt = vit_b_16(weights="DEFAULT") + gt.heads = torch.nn.Identity() + x = torch.randn(1, 3, 224, 224) + + vit = get_vit(SOTAViT.B_16, num_outputs=1, pretrained=True) + vit.fc = torch.nn.Identity() + with torch.no_grad(): + assert (vit(x) == gt(x)).all() diff --git a/tests/unittests/monai_networks/nn/utils/__init__.py b/tests/unittests/monai_networks/nn/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unittests/monai_networks/nn/utils/test_checks.py b/tests/unittests/monai_networks/nn/utils/test_checks.py new file mode 100644 index 000000000..27cc234f5 --- /dev/null +++ b/tests/unittests/monai_networks/nn/utils/test_checks.py @@ -0,0 +1,127 @@ +import pytest + +from clinicadl.monai_networks.nn.utils.checks import ( + _check_conv_parameter, + check_adn_ordering, + check_conv_args, + check_mlp_args, + check_norm_layer, + check_pool_indices, + ensure_list_of_tuples, +) + + +@pytest.mark.parametrize( + "adn,error", + [("ADN", False), ("ND", False), ("A", False), ("AAD", True), ("ADM", True)], +) +def test_check_adn_ordering(adn, error): + if error: + with pytest.raises(ValueError): + check_adn_ordering(adn) + else: + check_adn_ordering(adn) + + +@pytest.mark.parametrize( + "parameter,expected_output", + [ + (5, (5, 5, 5)), + ((5, 4, 4), (5, 4, 4)), + ([5, 4], [(5, 5, 5), (4, 4, 4)]), + ([5, (4, 3, 3)], [(5, 5, 5), (4, 3, 3)]), + ((5, 5), None), + ([5, 5, 5], None), + ([5, (4, 4)], None), + (5.0, None), + ], +) +def test_check_conv_parameter(parameter, expected_output): + if expected_output: + assert ( + _check_conv_parameter(parameter, dim=3, n_layers=2, name="abc") + == expected_output + ) + else: + with pytest.raises(ValueError): + _check_conv_parameter(parameter, dim=3, n_layers=2, name="abc") + + +@pytest.mark.parametrize( + "parameter,expected_output", + [ + (5, [(5, 5, 5), (5, 5, 5)]), + ((5, 4, 4), [(5, 4, 4), (5, 4, 4)]), + ([5, 4], [(5, 5, 5), (4, 4, 4)]), + ([5, (4, 3, 3)], [(5, 5, 5), (4, 3, 3)]), + ], +) +def test_ensure_list_of_tuples(parameter, expected_output): + assert ( + ensure_list_of_tuples(parameter, dim=3, n_layers=2, name="abc") + == expected_output + ) + + +@pytest.mark.parametrize( + "indices,n_layers,error", + [ + ([0, 1, 2], 4, False), + ([0, 1, 2], 3, False), + ([-1, 1, 2], 3, False), + ([0, 1, 2], 2, True), + ([-2, 1, 2], 3, True), + ], +) +def test_check_pool_indices(indices, n_layers, error): + if error: + with pytest.raises(ValueError): + _ = check_pool_indices(indices, n_layers) + else: + check_pool_indices(indices, n_layers) + + +@pytest.mark.parametrize( + "inputs,error", + [ + (None, False), + ("abc", True), + ("batch", False), + ("group", True), + (("batch",), True), + (("batch", 3), True), + (("batch", {"eps": 0.1}), False), + (("group", {"num_groups": 2}), False), + (("group", {"num_groups": 2, "eps": 0.1}), False), + ], +) +def test_check_norm_layer(inputs, error): + if error: + with pytest.raises(ValueError): + _ = check_norm_layer(inputs) + else: + assert check_norm_layer(inputs) == inputs + + +@pytest.mark.parametrize( + "conv_args,error", + [(None, True), ({"kernel_size": 3}, True), ({"channels": [2]}, False)], +) +def test_check_conv_args(conv_args, error): + if error: + with pytest.raises(ValueError): + check_conv_args(conv_args) + else: + check_conv_args(conv_args) + + +@pytest.mark.parametrize( + "mlp_args,error", + [({"act": "tanh"}, True), ({"hidden_channels": [2]}, False)], +) +def test_check_mlp_args(mlp_args, error): + if error: + with pytest.raises(ValueError): + check_mlp_args(mlp_args) + else: + check_mlp_args(mlp_args) diff --git a/tests/unittests/monai_networks/nn/utils/test_shapes.py b/tests/unittests/monai_networks/nn/utils/test_shapes.py new file mode 100644 index 000000000..b7ae2d444 --- /dev/null +++ b/tests/unittests/monai_networks/nn/utils/test_shapes.py @@ -0,0 +1,281 @@ +import pytest +import torch + +from clinicadl.monai_networks.nn.utils.shapes import ( + _calculate_adaptivepool_out_shape, + _calculate_avgpool_out_shape, + _calculate_maxpool_out_shape, + _calculate_upsample_out_shape, + calculate_conv_out_shape, + calculate_convtranspose_out_shape, + calculate_pool_out_shape, + calculate_unpool_out_shape, +) + +INPUT_1D = torch.randn(2, 1, 10) +INPUT_2D = torch.randn(2, 1, 32, 32) +INPUT_3D = torch.randn(2, 1, 20, 21, 22) + + +@pytest.mark.parametrize( + "input_tensor,kernel_size,stride,padding,dilation", + [ + (INPUT_3D, 7, 2, (1, 2, 3), 3), + (INPUT_2D, (5, 3), 1, 0, (2, 2)), + (INPUT_1D, 3, 1, 2, 1), + ], +) +def test_calculate_conv_out_shape(input_tensor, kernel_size, stride, padding, dilation): + in_shape = input_tensor.shape[2:] + dim = len(input_tensor.shape[2:]) + args = { + "in_channels": 1, + "out_channels": 1, + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + } + if dim == 1: + conv = torch.nn.Conv1d + elif dim == 2: + conv = torch.nn.Conv2d + else: + conv = torch.nn.Conv3d + + output_shape = conv(**args)(input_tensor).shape[2:] + assert ( + calculate_conv_out_shape(in_shape, kernel_size, stride, padding, dilation) + == output_shape + ) + + +@pytest.mark.parametrize( + "input_tensor,kernel_size,stride,padding,dilation,output_padding", + [ + (INPUT_3D, 7, 2, (1, 2, 3), 3, 0), + (INPUT_2D, (5, 3), 1, 0, (2, 2), (1, 0)), + (INPUT_1D, 3, 3, 2, 1, 2), + ], +) +def test_calculate_convtranspose_out_shape( + input_tensor, kernel_size, stride, padding, dilation, output_padding +): + in_shape = input_tensor.shape[2:] + dim = len(input_tensor.shape[2:]) + args = { + "in_channels": 1, + "out_channels": 1, + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + "output_padding": output_padding, + } + if dim == 1: + conv = torch.nn.ConvTranspose1d + elif dim == 2: + conv = torch.nn.ConvTranspose2d + else: + conv = torch.nn.ConvTranspose3d + + output_shape = conv(**args)(input_tensor).shape[2:] + assert ( + calculate_convtranspose_out_shape( + in_shape, kernel_size, stride, padding, output_padding, dilation + ) + == output_shape + ) + + +@pytest.mark.parametrize( + "input_tensor,kernel_size,stride,padding,dilation,ceil_mode", + [ + (INPUT_3D, 7, 2, (1, 2, 3), 3, False), + (INPUT_3D, 7, 2, (1, 2, 3), 3, True), + (INPUT_2D, (5, 3), 1, 0, (2, 2), False), + (INPUT_2D, (5, 3), 1, 0, (2, 2), True), + (INPUT_1D, 2, 1, 1, 1, False), + (INPUT_1D, 2, 1, 1, 1, True), + ], +) +def test_calculate_maxpool_out_shape( + input_tensor, kernel_size, stride, padding, dilation, ceil_mode +): + in_shape = input_tensor.shape[2:] + dim = len(input_tensor.shape[2:]) + args = { + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + "ceil_mode": ceil_mode, + } + if dim == 1: + max_pool = torch.nn.MaxPool1d + elif dim == 2: + max_pool = torch.nn.MaxPool2d + else: + max_pool = torch.nn.MaxPool3d + + output_shape = max_pool(**args)(input_tensor).shape[2:] + assert ( + _calculate_maxpool_out_shape( + in_shape, kernel_size, stride, padding, dilation, ceil_mode=ceil_mode + ) + == output_shape + ) + + +@pytest.mark.parametrize( + "input_tensor,kernel_size,stride,padding,ceil_mode", + [ + (INPUT_3D, 7, 2, (1, 2, 3), False), + (INPUT_3D, 7, 2, (1, 2, 3), True), + (INPUT_2D, (5, 3), 1, 0, False), + (INPUT_2D, (5, 3), 1, 0, True), + (INPUT_1D, 2, 1, 1, False), + (INPUT_1D, 2, 1, 1, True), + ( + INPUT_1D, + 2, + 3, + 1, + True, + ), # special case with ceil_mode (see: https://pytorch.org/docs/stable/generated/torch.nn.AvgPool1d.html) + ], +) +def test_calculate_avgpool_out_shape( + input_tensor, kernel_size, stride, padding, ceil_mode +): + in_shape = input_tensor.shape[2:] + dim = len(in_shape) + args = { + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "ceil_mode": ceil_mode, + } + if dim == 1: + avg_pool = torch.nn.AvgPool1d + elif dim == 2: + avg_pool = torch.nn.AvgPool2d + else: + avg_pool = torch.nn.AvgPool3d + output_shape = avg_pool(**args)(input_tensor).shape[2:] + assert ( + _calculate_avgpool_out_shape( + in_shape, kernel_size, stride, padding, ceil_mode=ceil_mode + ) + == output_shape + ) + + +@pytest.mark.parametrize( + "input_tensor,kwargs", + [ + (INPUT_3D, {"output_size": 1}), + (INPUT_2D, {"output_size": (1, 2)}), + (INPUT_1D, {"output_size": 3}), + ], +) +def test_calculate_adaptivepool_out_shape(input_tensor, kwargs): + in_shape = input_tensor.shape[2:] + dim = len(in_shape) + if dim == 1: + avg_pool = torch.nn.AdaptiveAvgPool1d + max_pool = torch.nn.AdaptiveMaxPool1d + elif dim == 2: + avg_pool = torch.nn.AdaptiveAvgPool2d + max_pool = torch.nn.AdaptiveMaxPool2d + else: + avg_pool = torch.nn.AdaptiveAvgPool3d + max_pool = torch.nn.AdaptiveMaxPool3d + + output_shape = max_pool(**kwargs)(input_tensor).shape[2:] + assert _calculate_adaptivepool_out_shape(in_shape, **kwargs) == output_shape + + output_shape = avg_pool(**kwargs)(input_tensor).shape[2:] + assert _calculate_adaptivepool_out_shape(in_shape, **kwargs) == output_shape + + +def test_calculate_pool_out_shape(): + in_shape = INPUT_3D.shape[2:] + assert calculate_pool_out_shape( + pool_mode="max", + in_shape=in_shape, + kernel_size=7, + stride=2, + padding=(1, 2, 3), + dilation=3, + ceil_mode=True, + ) == (3, 4, 6) + assert calculate_pool_out_shape( + pool_mode="avg", + in_shape=in_shape, + kernel_size=7, + stride=2, + padding=(1, 2, 3), + ceil_mode=True, + ) == (9, 10, 12) + assert calculate_pool_out_shape( + pool_mode="adaptiveavg", + in_shape=in_shape, + output_size=(3, 4, 5), + ) == (3, 4, 5) + assert calculate_pool_out_shape( + pool_mode="adaptivemax", + in_shape=in_shape, + output_size=1, + ) == (1, 1, 1) + with pytest.raises(ValueError): + calculate_pool_out_shape( + pool_mode="abc", + in_shape=in_shape, + kernel_size=7, + stride=2, + padding=(1, 2, 3), + dilation=3, + ceil_mode=True, + ) + + +@pytest.mark.parametrize( + "input_tensor,kwargs", + [ + (INPUT_3D, {"scale_factor": 2}), + (INPUT_2D, {"size": (40, 41)}), + (INPUT_2D, {"size": 40}), + (INPUT_2D, {"scale_factor": (3, 2)}), + (INPUT_1D, {"scale_factor": 2}), + ], +) +def test_calculate_upsample_out_shape(input_tensor, kwargs): + in_shape = input_tensor.shape[2:] + unpool = torch.nn.Upsample(**kwargs) + + output_shape = unpool(input_tensor).shape[2:] + assert _calculate_upsample_out_shape(in_shape, **kwargs) == output_shape + + +def test_calculate_unpool_out_shape(): + in_shape = INPUT_3D.shape[2:] + assert calculate_unpool_out_shape( + unpool_mode="convtranspose", + in_shape=in_shape, + kernel_size=5, + stride=1, + padding=0, + output_padding=0, + dilation=1, + ) == (24, 25, 26) + assert calculate_unpool_out_shape( + unpool_mode="upsample", + in_shape=in_shape, + scale_factor=2, + ) == (40, 42, 44) + with pytest.raises(ValueError): + calculate_unpool_out_shape( + unpool_mode="abc", + in_shape=in_shape, + ) diff --git a/tests/unittests/monai_networks/test_factory.py b/tests/unittests/monai_networks/test_factory.py index 28e8113fe..961238111 100644 --- a/tests/unittests/monai_networks/test_factory.py +++ b/tests/unittests/monai_networks/test_factory.py @@ -1,10 +1,15 @@ import pytest -from monai.networks.nets import ResNet -from monai.networks.nets.resnet import ResNetBottleneck -from torch.nn import Conv2d -from clinicadl.monai_networks import get_network -from clinicadl.monai_networks.config import create_network_config +from clinicadl.monai_networks import ( + ImplementedNetworks, + get_network, + get_network_from_config, +) +from clinicadl.monai_networks.config.autoencoder import AutoEncoderConfig +from clinicadl.monai_networks.factory import _update_config_with_defaults +from clinicadl.monai_networks.nn import AutoEncoder + +tested = [] @pytest.mark.parametrize( @@ -13,124 +18,285 @@ ( "AutoEncoder", { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "channels": [2, 2], - "strides": [1, 1], + "in_shape": (1, 64, 65), + "latent_size": 1, + "conv_args": {"channels": [2, 4]}, }, ), ( - "VarAutoEncoder", + "VAE", { - "spatial_dims": 3, - "in_shape": (1, 16, 16, 16), - "out_channels": 1, - "latent_size": 16, - "channels": [2, 2], - "strides": [1, 1], + "in_shape": (1, 64, 65), + "latent_size": 1, + "conv_args": {"channels": [2, 4]}, }, ), ( - "Regressor", + "CNN", { - "in_shape": (1, 16, 16, 16), - "out_shape": (1, 16, 16, 16), - "channels": [2, 2], - "strides": [1, 1], + "in_shape": (1, 64, 65), + "num_outputs": 1, + "conv_args": {"channels": [2, 4]}, }, ), ( - "Classifier", + "Generator", { - "in_shape": (1, 16, 16, 16), - "classes": 2, - "channels": [2, 2], - "strides": [1, 1], + "latent_size": 1, + "start_shape": (1, 5, 5), + "conv_args": {"channels": [2, 4]}, }, ), ( - "Discriminator", - {"in_shape": (1, 16, 16, 16), "channels": [2, 2], "strides": [1, 1]}, + "ConvDecoder", + { + "spatial_dims": 2, + "in_channels": 1, + "channels": [2, 4], + }, ), ( - "Critic", - {"in_shape": (1, 16, 16, 16), "channels": [2, 2], "strides": [1, 1]}, + "ConvEncoder", + { + "spatial_dims": 2, + "in_channels": 1, + "channels": [2, 4], + }, ), - ("DenseNet", {"spatial_dims": 3, "in_channels": 1, "out_channels": 1}), ( - "FullyConnectedNet", - {"in_channels": 3, "out_channels": 1, "hidden_channels": [2, 3]}, + "MLP", + { + "in_channels": 1, + "out_channels": 2, + "hidden_channels": [2, 4], + }, ), ( - "VarFullyConnectedNet", + "AttentionUNet", { + "spatial_dims": 2, "in_channels": 1, - "out_channels": 1, - "latent_size": 16, - "encode_channels": [2, 2], - "decode_channels": [2, 2], + "out_channels": 2, }, ), ( - "Generator", + "UNet", { - "latent_shape": (3,), - "start_shape": (1, 16, 16, 16), - "channels": [2, 2], - "strides": [1, 1], + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 2, }, ), ( "ResNet", { - "block": "bottleneck", - "layers": (4, 4, 4, 4), - "block_inplanes": (5, 5, 5, 5), "spatial_dims": 2, + "in_channels": 1, + "num_outputs": 1, }, ), - ("ResNetFeatures", {"model_name": "resnet10"}), - ("SegResNet", {}), ( - "UNet", + "DenseNet", { - "spatial_dims": 3, + "spatial_dims": 2, "in_channels": 1, - "out_channels": 1, - "channels": [2, 2, 2], - "strides": [1, 1], + "num_outputs": 1, }, ), ( - "AttentionUnet", + "SEResNet", { - "spatial_dims": 3, + "spatial_dims": 2, "in_channels": 1, - "out_channels": 1, - "channels": [2, 2, 2], - "strides": [1, 1], + "num_outputs": 1, + }, + ), + ( + "ViT", + { + "in_shape": (1, 64, 65), + "patch_size": (4, 5), + "num_outputs": 1, + }, + ), + ( + "ResNet-18", + { + "num_outputs": 1, + }, + ), + ( + "ResNet-34", + { + "num_outputs": 1, + }, + ), + ( + "ResNet-50", + { + "num_outputs": 1, + }, + ), + ( + "ResNet-101", + { + "num_outputs": 1, + }, + ), + ( + "ResNet-152", + { + "num_outputs": 1, + "pretrained": True, + }, + ), + ( + "DenseNet-121", + { + "num_outputs": 1, + }, + ), + ( + "DenseNet-161", + { + "num_outputs": 1, + }, + ), + ( + "DenseNet-169", + { + "num_outputs": 1, + }, + ), + ( + "DenseNet-201", + { + "num_outputs": 1, + "pretrained": True, + }, + ), + ( + "SEResNet-50", + { + "num_outputs": 1, + }, + ), + ( + "SEResNet-101", + { + "num_outputs": 1, + }, + ), + ( + "SEResNet-152", + { + "num_outputs": 1, + }, + ), + ( + "ViT-B/16", + { + "num_outputs": 1, + "pretrained": True, + }, + ), + ( + "ViT-B/32", + { + "num_outputs": 1, + }, + ), + ( + "ViT-L/16", + { + "num_outputs": 1, + }, + ), + ( + "ViT-L/32", + { + "num_outputs": 1, }, ), - ("ViT", {"in_channels": 3, "img_size": 16, "patch_size": 4}), - ("ViTAutoEnc", {"in_channels": 3, "img_size": 16, "patch_size": 4}), ], ) def test_get_network(network_name, params): - config = create_network_config(network_name)(**params) - network, updated_config = get_network(config) + tested.append(network_name) + _ = get_network(name=network_name, **params) + if network_name == "ViT-L/32": # the last one + assert set(tested) == set( + net.value for net in ImplementedNetworks + ) # check we haven't miss a network + + +def test_update_config_with_defaults(): + config = AutoEncoderConfig( + latent_size=1, + in_shape=(1, 10, 10), + conv_args={"channels": [1, 2], "dropout": 0.2}, + mlp_args={"hidden_channels": [5], "act": "relu"}, + ) + _update_config_with_defaults(config, AutoEncoder.__init__) + assert config.in_shape == (1, 10, 10) + assert config.latent_size == 1 + assert config.conv_args.channels == [1, 2] + assert config.conv_args.dropout == 0.2 + assert config.conv_args.act == "prelu" + assert config.mlp_args.hidden_channels == [5] + assert config.mlp_args.act == "relu" + assert config.mlp_args.norm == "batch" + assert config.out_channels is None + + +def test_parameters(): + net, updated_config = get_network( + "AutoEncoder", + return_config=True, + latent_size=1, + in_shape=(1, 10, 10), + conv_args={"channels": [1, 2], "dropout": 0.2}, + mlp_args={"hidden_channels": [5], "act": "relu"}, + ) + assert isinstance(net, AutoEncoder) + assert net.encoder.mlp.out_channels == 1 + assert net.encoder.mlp.hidden_channels == [5] + assert net.encoder.mlp.act == "relu" + assert net.encoder.mlp.norm == "batch" + assert net.in_shape == (1, 10, 10) + assert net.encoder.convolutions.channels == (1, 2) + assert net.encoder.convolutions.dropout == 0.2 + assert net.encoder.convolutions.act == "prelu" + + assert updated_config.in_shape == (1, 10, 10) + assert updated_config.latent_size == 1 + assert updated_config.conv_args.channels == [1, 2] + assert updated_config.conv_args.dropout == 0.2 + assert updated_config.conv_args.act == "prelu" + assert updated_config.mlp_args.hidden_channels == [5] + assert updated_config.mlp_args.act == "relu" + assert updated_config.mlp_args.norm == "batch" + assert updated_config.out_channels is None + + +def test_without_return(): + net = get_network( + "AutoEncoder", + return_config=False, + latent_size=1, + in_shape=(1, 10, 10), + conv_args={"channels": [1, 2]}, + ) + assert isinstance(net, AutoEncoder) - if network_name == "ResNet": - assert isinstance(network, ResNet) - assert isinstance(network.layer1[0], ResNetBottleneck) - assert len(network.layer1) == 4 - assert network.layer1[0].conv1.in_channels == 5 - assert isinstance(network.layer1[0].conv1, Conv2d) - assert updated_config.network == "ResNet" - assert updated_config.block == "bottleneck" - assert updated_config.layers == (4, 4, 4, 4) - assert updated_config.block_inplanes == (5, 5, 5, 5) - assert updated_config.spatial_dims == 2 - assert updated_config.conv1_t_size == 7 - assert updated_config.act == ("relu", {"inplace": True}) +def test_get_network_from_config(): + config = AutoEncoderConfig( + latent_size=1, + in_shape=(1, 10, 10), + conv_args={"channels": [1, 2], "dropout": 0.2}, + mlp_args={"hidden_channels": [5], "act": "relu"}, + ) + net, updated_config = get_network_from_config(config) + assert isinstance(net, AutoEncoder) + assert updated_config.conv_args.act == "prelu" + assert config.conv_args.act == "DefaultFromLibrary"