diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 1a2b6f7d9..949d9efff 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -2,9 +2,9 @@ name: 'Lint codebase' on: pull_request: - branches: [ "dev", "refactoring" ] + branches: [ "dev", "refactoring", "clinicadl_v2" ] push: - branches: [ "dev", "refactoring" ] + branches: [ "dev", "refactoring", "clinicadl_v2" ] permissions: contents: read diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 219e86c2b..21f71cd43 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,9 +2,9 @@ name: Test on: push: - branches: ["dev", "refactoring"] + branches: ["dev", "refactoring", "clinicadl_v2"] pull_request: - branches: ["dev", "refactoring"] + branches: ["dev", "refactoring", "clinicadl_v2"] permissions: contents: read diff --git a/clinicadl/API/complicated_case.py b/clinicadl/API/complicated_case.py new file mode 100644 index 000000000..b71862eff --- /dev/null +++ b/clinicadl/API/complicated_case.py @@ -0,0 +1,135 @@ +from pathlib import Path + +import torchio.transforms as transforms + +from clinicadl.dataset.caps_reader import CapsReader +from clinicadl.dataset.concat import ConcatDataset +from clinicadl.dataset.config.extraction import ExtractionConfig +from clinicadl.dataset.config.preprocessing import ( + PreprocessingConfig, + T1PreprocessingConfig, +) +from clinicadl.dataset.old_caps_dataset import ( + CapsDatasetPatch, + CapsDatasetRoi, + CapsDatasetSlice, +) +from clinicadl.experiment_manager.experiment_manager import ExperimentManager +from clinicadl.losses.config import CrossEntropyLossConfig +from clinicadl.losses.factory import get_loss_function +from clinicadl.model.clinicadl_model import ClinicaDLModel +from clinicadl.networks.config import ImplementedNetworks +from clinicadl.networks.factory import ( + ConvEncoderOptions, + create_network_config, + get_network_from_config, +) +from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig +from clinicadl.optimization.optimizer.factory import get_optimizer +from clinicadl.predictor.predictor import Predictor +from clinicadl.splitter.kfold import KFolder +from clinicadl.splitter.split import get_single_split, split_tsv +from clinicadl.trainer.trainer import Trainer +from clinicadl.transforms.config import TransformsConfig + +# Create the Maps Manager / Read/write manager / +maps_path = Path("/") +manager = ExperimentManager( + maps_path, overwrite=False +) # a ajouter dans le manager: mlflow/ profiler/ etc ... + +caps_directory = Path("caps_directory") # output of clinica pipelines +caps_reader = CapsReader(caps_directory, manager=manager) + +preprocessing_1 = caps_reader.get_preprocessing("t1-linear") +caps_reader.prepare_data( + preprocessing=preprocessing_1, data_tsv=Path(""), n_proc=2 +) # don't return anything -> just extract the image tensor and compute some information for each images + + +transforms_1 = TransformsConfig( + object_augmentation=[transforms.Crop, transforms.Transform], + image_augmentation=[transforms.Crop, transforms.Transform], + extraction=ExtractionPatchConfig(patch_size=3), + image_transforms=[transforms.Blur, transforms.Ghosting], + object_transforms=[transforms.BiasField, transforms.Motion], +) # not mandatory + +preprocessing_2 = caps_reader.get_preprocessing("pet-linear") +transforms_2 = TransformsConfig( + object_augmentation=[transforms.Crop, transforms.Transform], + image_augmentation=[transforms.Crop, transforms.Transform], + extraction=ExtractionSliceConfig(), + image_transforms=[transforms.Blur, transforms.Ghosting], + object_transforms=[transforms.BiasField, transforms.Motion], +) + +sub_ses_tsv = Path("") +split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv + +dataset_t1_roi = caps_reader.get_dataset( + preprocessing=preprocessing_1, + sub_ses_tsv=split_dir / "train.tsv", + transforms=transforms_1, +) # do we give config or object for transforms ? +dataset_pet_patch = caps_reader.get_dataset( + preprocessing=preprocessing_2, + sub_ses_tsv=split_dir / "train.tsv", + transforms=transforms_2, +) + +dataset_multi_modality_multi_extract = ConcatDataset( + [ + dataset_t1_roi, + dataset_pet_patch, + caps_reader.get_dataset_from_json(json_path=Path(""), sub_ses_tsv=sub_ses_tsv), + ] +) # 2 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention + +# TODO : think about adding transforms in extract_json + + +config_file = Path("config_file") +trainer = Trainer.from_json(config_file=config_file, manager=manager) + +# CAS CROSS-VALIDATION +splitter = KFolder( + n_splits=3, caps_dataset=dataset_multi_modality_multi_extract, manager=manager +) + +for split in splitter.split_iterator(split_list=[0, 1]): + # bien définir ce qu'il y a dans l'objet split + + network_config = create_network_config(ImplementedNetworks.CNN)( + in_shape=[2, 2, 2], + num_outputs=1, + conv_args=ConvEncoderOptions(channels=[3, 2, 2]), + ) + model = ClinicaDLModelClassif.from_config( + network_config=network_config, + loss_config=CrossEntropyLossConfig(), + optimizer_config=AdamConfig(), + ) + + trainer.train(model, split) + # le trainer va instancier un predictor/valdiator dans le train ou dans le init + +# TEST + +preprocessing_test = caps_reader.get_preprocessing("pet-linear") +transforms_test = Transforms( + object_augmentation=[transforms.Crop, transforms.Transform], + image_augmentation=[transforms.Crop, transforms.Transform], + extraction=ExtractioImageConfig(), + image_transforms=[transforms.Blur, transforms.Ghosting], + object_transforms=[transforms.BiasField, transforms.Motion], +) + +dataset_test = caps_reader.get_dataset( + preprocessing=preprocessing_test, + sub_ses_tsv=split_dir / "test.tsv", + transforms=transforms_test, +) + +predictor = Predictor(manager=manager) +predictor.predict(dataset_test=dataset_test, split_number=2) diff --git a/clinicadl/API/cross_val.py b/clinicadl/API/cross_val.py new file mode 100644 index 000000000..f2a6b90fe --- /dev/null +++ b/clinicadl/API/cross_val.py @@ -0,0 +1,64 @@ +from pathlib import Path + +import torchio.transforms as transforms + +from clinicadl.dataset.caps_reader import CapsReader +from clinicadl.dataset.concat import ConcatDataset +from clinicadl.dataset.config.extraction import ExtractionConfig +from clinicadl.dataset.config.preprocessing import ( + PreprocessingConfig, + T1PreprocessingConfig, +) +from clinicadl.dataset.old_caps_dataset import ( + CapsDatasetPatch, + CapsDatasetRoi, + CapsDatasetSlice, +) +from clinicadl.experiment_manager.experiment_manager import ExperimentManager +from clinicadl.losses.config import CrossEntropyLossConfig +from clinicadl.losses.factory import get_loss_function +from clinicadl.model.clinicadl_model import ClinicaDLModel +from clinicadl.networks.config import ImplementedNetworks +from clinicadl.networks.factory import ( + ConvEncoderOptions, + create_network_config, + get_network_from_config, +) +from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig +from clinicadl.optimization.optimizer.factory import get_optimizer +from clinicadl.predictor.predictor import Predictor +from clinicadl.splitter.kfold import KFolder +from clinicadl.splitter.split import get_single_split, split_tsv +from clinicadl.trainer.trainer import Trainer +from clinicadl.transforms.config import TransformsConfig + +# SIMPLE EXPERIMENT WITH A CAPS ALREADY EXISTING + +maps_path = Path("/") +manager = ExperimentManager(maps_path, overwrite=False) + +dataset_t1_image = CapsDatasetPatch.from_json( + extraction=Path("test.json"), + sub_ses_tsv=Path("split_dir") / "train.tsv", +) +config_file = Path("config_file") +trainer = Trainer.from_json(config_file=config_file, manager=manager) + +# CAS CROSS-VALIDATION +splitter = KFolder(n_splits=3, caps_dataset=dataset_t1_image, manager=manager) + +for split in splitter.split_iterator(split_list=[0, 1]): + # bien définir ce qu'il y a dans l'objet split + + loss, loss_config = get_loss_function(CrossEntropyLossConfig()) + network_config = create_network_config(ImplementedNetworks.CNN)( + in_shape=[2, 2, 2], + num_outputs=1, + conv_args=ConvEncoderOptions(channels=[3, 2, 2]), + ) + network, _ = get_network_from_config(network_config) + optimizer, _ = get_optimizer(network, AdamConfig()) + model = ClinicaDLModel(network=network, loss=loss, optimizer=optimizer) + + trainer.train(model, split) + # le trainer va instancier un predictor/valdiator dans le train ou dans le init diff --git a/clinicadl/API/single_split.py b/clinicadl/API/single_split.py new file mode 100644 index 000000000..4074a7c0b --- /dev/null +++ b/clinicadl/API/single_split.py @@ -0,0 +1,89 @@ +from pathlib import Path + +import torchio.transforms as transforms + +from clinicadl.dataset.caps_reader import CapsReader +from clinicadl.dataset.concat import ConcatDataset +from clinicadl.dataset.config.extraction import ExtractionConfig, ExtractionPatchConfig +from clinicadl.dataset.config.preprocessing import ( + PreprocessingConfig, + T1PreprocessingConfig, +) +from clinicadl.experiment_manager.experiment_manager import ExperimentManager +from clinicadl.losses.config import CrossEntropyLossConfig +from clinicadl.losses.factory import get_loss_function +from clinicadl.model.clinicadl_model import ClinicaDLModel +from clinicadl.networks.config import ImplementedNetworks +from clinicadl.networks.factory import ( + ConvEncoderOptions, + create_network_config, + get_network_from_config, +) +from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig +from clinicadl.optimization.optimizer.factory import get_optimizer +from clinicadl.predictor.predictor import Predictor +from clinicadl.splitter.kfold import KFolder +from clinicadl.splitter.split import get_single_split, split_tsv +from clinicadl.trainer.trainer import Trainer +from clinicadl.transforms.config import TransformsConfig +from clinicadl.transforms.transforms import Transforms +from clinicadl.utils.enum import ExtractionMethod + +# SIMPLE EXPERIMENT + + +maps_path = Path("/") +manager = ExperimentManager(maps_path, overwrite=False) + +caps_directory = Path("caps_directory") # output of clinica pipelines +caps_reader = CapsReader( + caps_directory, manager=manager +) # un peu bizarre de passer un maps_path a cet endroit via le manager pq on veut pas forcmeent faire un entrainement ?? + +preprocessing_t1 = caps_reader.get_preprocessing("t1-linear") +caps_reader.prepare_data( + preprocessing=preprocessing_t1, + data_tsv=Path(""), + n_proc=2, + use_uncropped_images=False, +) +transforms_1 = Transforms( + object_augmentation=[transforms.RandomMotion()], # default = no transforms + image_augmentation=[transforms.RandomMotion()], # default = no transforms + object_transforms=[transforms.Blur((0.4, 0.5, 0.6))], # default = none + image_transforms=[transforms.Noise(0.2, 0.5, 3)], # default = MiniMax + extraction=ExtractionPatchConfig(patch_size=30, stride_size=20), # default = Image +) # not mandatory + +sub_ses_tsv = Path("") +split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv + +# for the cli I think we should have a splitter for the single split case so we can have the same behaviour for single and kfold + +dataset_t1_image = caps_reader.get_dataset( + preprocessing=preprocessing_t1, + sub_ses_tsv=split_dir / "train.tsv", + transforms=transforms_1, +) # do we give config or ob + + +# CAS SINGLE SPLIT +split = get_single_split( + n_subject_validation=0, + caps_dataset=dataset_t1_image, + manager=manager, +) +config_file = Path("config_file") +trainer = Trainer.from_json(config_file=config_file, manager=manager) + + +loss, loss_config = get_loss_function(CrossEntropyLossConfig()) +network_config = create_network_config(ImplementedNetworks.CNN)( + in_shape=[2, 2, 2], num_outputs=1, conv_args=ConvEncoderOptions(channels=[3, 2, 2]) +) +network, _ = get_network_from_config(network_config) +optimizer, _ = get_optimizer(network, AdamConfig()) +model = ClinicaDLModelClassif(network=network, loss=loss, optimizer=optimizer) + +trainer.train(model, split) +# le trainer va instancier un predictor/valdiator dans le train ou dans le init diff --git a/clinicadl/API_test.py b/clinicadl/API_test.py index 0581b879a..a6eb9fa72 100644 --- a/clinicadl/API_test.py +++ b/clinicadl/API_test.py @@ -1,83 +1,225 @@ from pathlib import Path -from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.caps_dataset.data import return_dataset -from clinicadl.prepare_data.prepare_data import DeepLearningPrepareData -from clinicadl.trainer.config.classification import ClassificationConfig +import torchio + +from clinicadl.dataset.caps_dataset import ( + CapsDatasetPatch, + CapsDatasetRoi, + CapsDatasetSlice, +) +from clinicadl.dataset.caps_reader import CapsReader +from clinicadl.dataset.concat import ConcatDataset +from clinicadl.dataset.config.extraction import ExtractionConfig +from clinicadl.dataset.config.preprocessing import ( + PreprocessingConfig, + T1PreprocessingConfig, +) +from clinicadl.experiment_manager.experiment_manager import ExperimentManager +from clinicadl.losses.config import CrossEntropyLossConfig +from clinicadl.losses.factory import get_loss_function +from clinicadl.model.clinicadl_model import ClinicaDLModel +from clinicadl.networks.config import ImplementedNetworks +from clinicadl.networks.factory import ( + ConvEncoderOptions, + create_network_config, + get_network_from_config, +) +from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig +from clinicadl.optimization.optimizer.factory import get_optimizer +from clinicadl.predictor.predictor import Predictor +from clinicadl.splitter.kfold import KFolder +from clinicadl.splitter.split import get_single_split, split_tsv from clinicadl.trainer.trainer import Trainer -from clinicadl.utils.enum import ExtractionMethod, Preprocessing, Task -from clinicadl.utils.iotools.train_utils import merge_cli_and_config_file_options +from clinicadl.transforms.transforms import Transforms + +# Create the Maps Manager / Read/write manager / +maps_path = Path("/") +manager = ExperimentManager(maps_path, overwrite=False) + +caps_directory = Path("caps_directory") # output of clinica pipelines +caps_reader = CapsReader(caps_directory, manager=manager) + +preprocessing_1 = caps_reader.get_preprocessing("t1-linear") +extraction_1 = caps_reader.extract_slice(preprocessing=preprocessing_1, arg_slice=2) +transforms_1 = Transforms( + data_augmentation=[torchio.t1, torchio.t2], + image_transforms=[torchio.t1, torchio.t2], + object_transforms=[torchio.t1, torchio.t2], +) # not mandatory + +preprocessing_2 = caps_reader.get_preprocessing("pet-linear") +extraction_2 = caps_reader.extract_patch(preprocessing=preprocessing_2, arg_patch=2) +transforms_2 = Transforms( + data_augmentation=[torchio.t2], + image_transforms=[torchio.t1], + object_transforms=[torchio.t1, torchio.t2], +) + +sub_ses_tsv = Path("") +split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv + +dataset_t1_roi = caps_reader.get_dataset( + extraction=extraction_1, + preprocessing=preprocessing_1, + sub_ses_tsv=split_dir / "train.tsv", + transforms=transforms_1, +) # do we give config or object for transforms ? +dataset_pet_patch = caps_reader.get_dataset( + extraction=extraction_2, + preprocessing=preprocessing_2, + sub_ses_tsv=split_dir / "train.tsv", + transforms=transforms_2, +) + +dataset_multi_modality_multi_extract = ConcatDataset( + [dataset_t1_roi, dataset_pet_patch] +) # 2 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention + +config_file = Path("config_file") +trainer = Trainer.from_json(config_file=config_file, manager=manager) + +# CAS CROSS-VALIDATION +splitter = KFolder( + n_splits=3, caps_dataset=dataset_multi_modality_multi_extract, manager=manager +) + +for split in splitter.split_iterator(split_list=[0, 1]): + # bien définir ce qu'il y a dans l'objet split + + loss, loss_config = get_loss_function(CrossEntropyLossConfig()) + network_config = create_network_config(ImplementedNetworks.CNN)( + in_shape=[2, 2, 2], + num_outputs=1, + conv_args=ConvEncoderOptions(channels=[3, 2, 2]), + ) + network, _ = get_network_from_config(network_config) + optimizer, _ = get_optimizer(network, AdamConfig()) + model = ClinicaDLModel(network=network, loss=loss, optimizer=optimizer) + + trainer.train(model, split) + # le trainer va instancier un predictor/valdiator dans le train ou dans le init + + +# CAS SINGLE SPLIT +split = get_single_split( + n_subject_validation=0, + caps_dataset=dataset_multi_modality_multi_extract, + manager=manager, +) + +loss, loss_config = get_loss_function(CrossEntropyLossConfig()) +network_config = create_network_config(ImplementedNetworks.CNN)( + in_shape=[2, 2, 2], num_outputs=1, conv_args=ConvEncoderOptions(channels=[3, 2, 2]) +) +network, _ = get_network_from_config(network_config) +optimizer, _ = get_optimizer(network, AdamConfig()) +model = ClinicaDLModel(network=network, loss=loss, optimizer=optimizer) + +trainer.train(model, split) +# le trainer va instancier un predictor/valdiator dans le train ou dans le init + -image_config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=ExtractionMethod.IMAGE, - preprocessing_type=Preprocessing.T1_LINEAR, +# TEST + +preprocessing_test: PreprocessingConfig = caps_reader.get_preprocessing("pet-linear") +extraction_test: ExtractionConfig = caps_reader.extract_patch( + preprocessing=preprocessing_2, arg_patch=2 +) +transforms_test = Transforms( + data_augmentation=[torchio.t2], + image_transforms=[torchio.t1], + object_transforms=[torchio.t1, torchio.t2], +) + +dataset_test = caps_reader.get_dataset( + extraction=extraction_test, + preprocessing=preprocessing_test, + sub_ses_tsv=split_dir / "test.tsv", + transforms=transforms_test, ) -DeepLearningPrepareData(image_config) - -dataset = return_dataset( - input_dir, - data_df, - preprocessing_dict, - transforms_config, - label, - label_code, - cnn_index, - label_presence, - multi_cohort, +predictor = Predictor(manager=manager) +predictor.predict(dataset_test=dataset_test, split=2) + + +# SIMPLE EXPERIMENT + + +maps_path = Path("/") +manager = ExperimentManager(maps_path, overwrite=False) + +caps_directory = Path("caps_directory") # output of clinica pipelines +caps_reader = CapsReader(caps_directory, manager=manager) + +extraction_1 = caps_reader.extract_image(preprocessing=T1PreprocessingConfig()) +transforms_1 = Transforms( + data_augmentation=[torchio.transforms.RandomMotion] +) # not mandatory + +sub_ses_tsv = Path("") +split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv + +dataset_t1_image = caps_reader.get_dataset( + extraction=extraction_1, + preprocessing=T1PreprocessingConfig(), + sub_ses_tsv=split_dir / "train.tsv", + transforms=transforms_1, +) # do we give config or ob + + +config_file = Path("config_file") +trainer = Trainer.from_json(config_file=config_file, manager=manager) + +# CAS CROSS-VALIDATION +splitter = KFolder(n_splits=3, caps_dataset=dataset_t1_image, manager=manager) + +for split in splitter.split_iterator(split_list=[0, 1]): + # bien définir ce qu'il y a dans l'objet split + + loss, loss_config = get_loss_function(CrossEntropyLossConfig()) + network_config = create_network_config(ImplementedNetworks.CNN)( + in_shape=[2, 2, 2], + num_outputs=1, + conv_args=ConvEncoderOptions(channels=[3, 2, 2]), + ) + network, _ = get_network_from_config(network_config) + optimizer, _ = get_optimizer(network, AdamConfig()) + model = ClinicaDLModel(network=network, loss=loss, optimizer=optimizer) + + trainer.train(model, split) + # le trainer va instancier un predictor/valdiator dans le train ou dans le init + + +# SIMPLE EXPERIMENT WITH A CAPS ALREADY EXISTING + +maps_path = Path("/") +manager = ExperimentManager(maps_path, overwrite=False) + +# sub_ses_tsv = Path("") +# split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv + +dataset_t1_image = CapsDatasetPatch.from_json( + extraction=extract_json, + sub_ses_tsv=split_dir / "train.tsv", ) +config_file = Path("config_file") +trainer = Trainer.from_json(config_file=config_file, manager=manager) + +# CAS CROSS-VALIDATION +splitter = KFolder(n_splits=3, caps_dataset=dataset_t1_image, manager=manager) + +for split in splitter.split_iterator(split_list=[0, 1]): + # bien définir ce qu'il y a dans l'objet split + + loss, loss_config = get_loss_function(CrossEntropyLossConfig()) + network_config = create_network_config(ImplementedNetworks.CNN)( + in_shape=[2, 2, 2], + num_outputs=1, + conv_args=ConvEncoderOptions(channels=[3, 2, 2]), + ) + network, _ = get_network_from_config(network_config) + optimizer, _ = get_optimizer(network, AdamConfig()) + model = ClinicaDLModel(network=network, loss=loss, optimizer=optimizer) -split_config = SplitConfig() -splitter = Splitter(split_config) - -validator_config = ValidatorConfig() -validator = Validator(validator_config) - -train_config = ClassificationConfig() -trainer = Trainer(train_config, validator) - -for split in splitter.split_iterator(): - for network in range( - first_network, self.maps_manager.num_networks - ): # for multi_network - ###### actual _train_single method of the Trainer ############ - train_loader = trainer.get_dataloader(dataset, split, network, "train", config) - valid_loader = validator.get_dataloader( - dataset, split, network, "valid", config - ) # ?? validatior, trainer ? - - trainer._train( - train_loader, - valid_loader, - split=split, - network=network, - resume=resume, # in a config class - callbacks=[CodeCarbonTracker], # in a config class ? - ) - - validator._ensemble_prediction( - self.maps_manager, - "train", - split, - self.config.validation.selection_metrics, - ) - validator._ensemble_prediction( - self.maps_manager, - "validation", - split, - self.config.validation.selection_metrics, - ) - ###### end ############ - - -for split in splitter.split_iterator(): - for network in range( - first_network, self.maps_manager.num_networks - ): # for multi_network - ###### actual _train_single method of the Trainer ############ - test_loader = trainer.get_dataloader(dataset, split, network, "test", config) - validator.predict(test_loader) - -interpret_config = InterpretConfig(**kwargs) -predict_manager = PredictManager(interpret_config) -predict_manager.interpret() + trainer.train(model, split) + # le trainer va instancier un predictor/valdiator dans le train ou dans le init diff --git a/clinicadl/API_test_v2.py b/clinicadl/API_test_v2.py new file mode 100644 index 000000000..0bf730d3c --- /dev/null +++ b/clinicadl/API_test_v2.py @@ -0,0 +1,112 @@ +from pathlib import Path +from clinicadl.caps_dataset2.config.preprocessing import PreprocessingConfig +from clinicadl.caps_dataset2.config.extraction import ExtractionConfig +from clinicadl.caps_dataset2.data import CapsDatasetRoi, CapsDatasetPatch, CapsDatasetSlice +from clinicadl.transforms.config import TransformsConfig +import torchio +from clinicadl import tsvtools +from clinicadl.trainer.trainer import Trainer + +class ExperimentManager: + pass + +class CapsReader: + pass + +class Transforms: + pass + +class Predictor: + pass + +class ClinicaDLModel: + pass + +class KFolder: + pass + +def get_loss_function(): + pass + +def get_network_from_config(): + pass + +def create_network_config(): + pass + +def get_single_split(): + pass + +# Create the Maps Manager / Read/write manager / +maps_path = Path("/") +manager = ExperimentManager(maps_path, overwrite = False) + +caps_directory = Path("caps_directory") # output of clinica pipelines +caps_reader = CapsReader(caps_directory, manager= manager) + +preprocessing_1: PreprocessingConfig = caps_reader.get_preprocessing("t1-linear") +extraction_1: ExtractionConfig = caps_reader.extract_slice(preprocessing=preprocessing_1, arg_slice = 2) +transforms_1 = Transforms(data_augmentation = [torchio.t1, torchio.t2], image_transforms= [torchio.t1, torchio.t2], object_transforms=[torchio.t1, torchio.t2]) # not mandatory + +preprocessing_2: PreprocessingConfig = caps_reader.get_preprocessing("pet-linear") +extraction_2: ExtractionConfig= caps_reader.extract_patch(preprocessing=preprocessing_2, arg_patch = 2) +transforms_2 = Transforms(data_augmentation = [torchio.t2], image_transforms= [torchio.t1], object_transforms=[torchio.t1, torchio.t2]) + +sub_ses_tsv = Path("") +split_dir = tsvtools.split_tsv(sub_ses_tsv) #-> creer un test.tsv et un train.tsv + +dataset_t1_roi: CapsDatasetRoi = caps_reader.get_dataset( extraction = extraction_1, preprocessing = preprocessing_1, sub_ses_tsv = split_dir / "train.tsv", transforms = transforms_1) +dataset_pet_patch: CapsDatasetPatch = caps_reader.get_dataset(extraction = extraction_2, preprocessing= preprocessing_2, sub_ses_tsv = split_dir / "train.tsv", transforms = transforms_2) + +dataset_multi_modality_multi_extract = concat_dataset(dataset_t1, dataset_pet) # 2 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention + +config_file = Path("config_file") +trainer = Trainer.from_json(config_file=config_file, manager=manager) + +# CAS CROSS-VALIDATION +splitter = KFolder(n_splits = 3, caps_dataset=dataset_multi_modality_multi_extract, manager = manager) + +for split in splitter.split_iterator(split_list = [0,1]): + # bien définir ce qu'il y a dans l'objet split + + loss, loss_config = get_loss_function(CrossEntropyLossConfig()) + network_config = create_network_config(ImplementedNetworks.CNN)(in_shape = [2,2,2], num_outputs = 1, conv_args = ConvEncoderOptions(channels = [3, 2, 2])) + network = get_network_from_config(network_config) + + model = ClinicaDLModel( + network= network, + loss=loss, + optimizer= AdamConfig() + ) + + trainer.train(model, split) + # le trainer va instancier un predictor/valdiator dans le train ou dans le init + + +# CAS SINGLE SPLIT +split = get_single_split(n_subject_validation = 0, caps_dataset=dataset_multi_modality_multi_extract, manager = manager) + +loss, loss_config = get_loss_function(CrossEntropyLossConfig()) +network_config = create_network_config(ImplementedNetworks.CNN)(in_shape = [2,2,2], num_outputs = 1, conv_args = ConvEncoderOptions(channels = [3, 2, 2])) +network = get_network_from_config(network_config) + +model = ClinicaDLModel( + network= network, + loss=loss, + optimizer= AdamConfig() +) + +trainer.train(model, split) +# le trainer va instancier un predictor/valdiator dans le train ou dans le init + + +# TEST + +preprocessing_test: PreprocessingConfig = caps_reader.get_preprocessing("pet-linear") +extraction_test: ExtractionConfig= caps_reader.extract_patch(preprocessing=preprocessing_2, arg_patch = 2) +transforms_test = Transforms(data_augmentation = [torchio.t2], image_transforms= [torchio.t1], object_transforms=[torchio.t1, torchio.t2]) + +dataset_test: CapsDatasetROI = caps_reader.get_dataset( extraction = extraction_test, preprocessing = preprocessing_test, sub_ses_tsv = split_dir / "test.tsv", transforms = transforms_test) + +predictor = Predictor(manager= manager) +predictor.predict(dataset_test= dataset_test, split = 2) diff --git a/clinicadl/commandline/modules_options/data.py b/clinicadl/commandline/modules_options/data.py index 569cbab6c..a881440c3 100644 --- a/clinicadl/commandline/modules_options/data.py +++ b/clinicadl/commandline/modules_options/data.py @@ -1,8 +1,8 @@ import click -from clinicadl.caps_dataset.data_config import DataConfig from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type +from clinicadl.dataset.data_config import DataConfig # Data baseline = click.option( diff --git a/clinicadl/commandline/modules_options/dataloader.py b/clinicadl/commandline/modules_options/dataloader.py index dcaa66aa9..bf4d4c781 100644 --- a/clinicadl/commandline/modules_options/dataloader.py +++ b/clinicadl/commandline/modules_options/dataloader.py @@ -1,8 +1,8 @@ import click -from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type +from clinicadl.dataset.dataloader_config import DataLoaderConfig # DataLoader batch_size = click.option( diff --git a/clinicadl/commandline/modules_options/extraction.py b/clinicadl/commandline/modules_options/extraction.py index fc0db1f98..e382eecc2 100644 --- a/clinicadl/commandline/modules_options/extraction.py +++ b/clinicadl/commandline/modules_options/extraction.py @@ -1,14 +1,14 @@ import click -from clinicadl.caps_dataset.extraction.config import ( +from clinicadl.config.config_utils import get_default_from_config_class as get_default +from clinicadl.config.config_utils import get_type_from_config_class as get_type +from clinicadl.dataset.config.extraction import ( ExtractionConfig, ExtractionImageConfig, ExtractionPatchConfig, ExtractionROIConfig, ExtractionSliceConfig, ) -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type extract_json = click.option( "-ej", diff --git a/clinicadl/commandline/modules_options/maps_manager.py b/clinicadl/commandline/modules_options/maps_manager.py index f973f441a..69574d42c 100644 --- a/clinicadl/commandline/modules_options/maps_manager.py +++ b/clinicadl/commandline/modules_options/maps_manager.py @@ -1,7 +1,7 @@ import click from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.maps_manager.config import MapsManagerConfig +from clinicadl.experiment_manager.config import MapsManagerConfig maps_dir = click.argument("maps_dir", type=get_type("maps_dir", MapsManagerConfig)) data_group = click.option("data_group", type=get_type("data_group", MapsManagerConfig)) diff --git a/clinicadl/commandline/modules_options/network.py b/clinicadl/commandline/modules_options/network.py index 995ea6ccc..c0b8716e1 100644 --- a/clinicadl/commandline/modules_options/network.py +++ b/clinicadl/commandline/modules_options/network.py @@ -2,7 +2,7 @@ from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.network.config import NetworkConfig +from clinicadl.networks.old_network.config import NetworkConfig # Model multi_network = click.option( diff --git a/clinicadl/commandline/modules_options/optimization.py b/clinicadl/commandline/modules_options/optimization.py index fd88dc06e..66bedebd0 100644 --- a/clinicadl/commandline/modules_options/optimization.py +++ b/clinicadl/commandline/modules_options/optimization.py @@ -2,7 +2,7 @@ from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.optimizer.optimization import OptimizationConfig +from clinicadl.optimization.config import OptimizationConfig # Optimization accumulation_steps = click.option( diff --git a/clinicadl/commandline/modules_options/optimizer.py b/clinicadl/commandline/modules_options/optimizer.py index 57e3903e3..1012adfe0 100644 --- a/clinicadl/commandline/modules_options/optimizer.py +++ b/clinicadl/commandline/modules_options/optimizer.py @@ -2,7 +2,7 @@ from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.optimizer.optimizer import OptimizerConfig +from clinicadl.optimization.optimizer import OptimizerConfig # Optimizer learning_rate = click.option( diff --git a/clinicadl/commandline/modules_options/preprocessing.py b/clinicadl/commandline/modules_options/preprocessing.py index 641e91518..131ba5324 100644 --- a/clinicadl/commandline/modules_options/preprocessing.py +++ b/clinicadl/commandline/modules_options/preprocessing.py @@ -1,13 +1,13 @@ import click -from clinicadl.caps_dataset.preprocessing.config import ( +from clinicadl.config.config_utils import get_default_from_config_class as get_default +from clinicadl.config.config_utils import get_type_from_config_class as get_type +from clinicadl.dataset.config.preprocessing import ( CustomPreprocessingConfig, DTIPreprocessingConfig, PETPreprocessingConfig, PreprocessingConfig, ) -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type tracer = click.option( "--tracer", diff --git a/clinicadl/commandline/modules_options/ssda.py b/clinicadl/commandline/modules_options/ssda.py deleted file mode 100644 index 8119726ef..000000000 --- a/clinicadl/commandline/modules_options/ssda.py +++ /dev/null @@ -1,45 +0,0 @@ -import click - -from clinicadl.config.config.ssda import SSDAConfig -from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.config.config_utils import get_type_from_config_class as get_type - -# SSDA -caps_target = click.option( - "--caps_target", - "-d", - type=get_type("caps_target", SSDAConfig), - default=get_default("caps_target", SSDAConfig), - help="CAPS of target data.", - show_default=True, -) -preprocessing_json_target = click.option( - "--preprocessing_json_target", - "-d", - type=get_type("preprocessing_json_target", SSDAConfig), - default=get_default("preprocessing_json_target", SSDAConfig), - help="Path to json target.", - show_default=True, -) -ssda_network = click.option( - "--ssda_network/--single_network", - default=get_default("ssda_network", SSDAConfig), - help="If provided uses a ssda-network framework.", - show_default=True, -) -tsv_target_lab = click.option( - "--tsv_target_lab", - "-d", - type=get_type("tsv_target_lab", SSDAConfig), - default=get_default("tsv_target_lab", SSDAConfig), - help="TSV of labeled target data.", - show_default=True, -) -tsv_target_unlab = click.option( - "--tsv_target_unlab", - "-d", - type=get_type("tsv_target_unlab", SSDAConfig), - default=get_default("tsv_target_unlab", SSDAConfig), - help="TSV of unllabeled target data.", - show_default=True, -) diff --git a/clinicadl/commandline/modules_options/validation.py b/clinicadl/commandline/modules_options/validation.py index 858dd956e..089357866 100644 --- a/clinicadl/commandline/modules_options/validation.py +++ b/clinicadl/commandline/modules_options/validation.py @@ -2,7 +2,7 @@ from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.splitter.validation import ValidationConfig +from clinicadl.predictor.validation import ValidationConfig # Validation valid_longitudinal = click.option( diff --git a/clinicadl/commandline/pipelines/generate/artifacts/cli.py b/clinicadl/commandline/pipelines/generate/artifacts/cli.py index b4a98b40a..68d1ec869 100644 --- a/clinicadl/commandline/pipelines/generate/artifacts/cli.py +++ b/clinicadl/commandline/pipelines/generate/artifacts/cli.py @@ -6,8 +6,6 @@ import torchio as tio from joblib import Parallel, delayed -from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.caps_dataset.caps_dataset_utils import find_file_type from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import ( data, @@ -15,6 +13,8 @@ preprocessing, ) from clinicadl.commandline.pipelines.generate.artifacts import options as artifacts +from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.dataset.caps_dataset_utils import find_file_type from clinicadl.generate.generate_config import GenerateArtifactsConfig from clinicadl.generate.generate_utils import ( load_and_check_tsv, diff --git a/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py b/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py index cb68269ca..82c4e5cb3 100644 --- a/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py +++ b/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py @@ -7,13 +7,13 @@ from joblib import Parallel, delayed from nilearn.image import resample_to_img -from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.caps_dataset.caps_dataset_utils import find_file_type from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import data, dataloader, preprocessing from clinicadl.commandline.pipelines.generate.hypometabolic import ( options as hypometabolic, ) +from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.dataset.caps_dataset_utils import find_file_type from clinicadl.generate.generate_config import GenerateHypometabolicConfig from clinicadl.generate.generate_utils import ( load_and_check_tsv, diff --git a/clinicadl/commandline/pipelines/generate/random/cli.py b/clinicadl/commandline/pipelines/generate/random/cli.py index cf8e8d9e8..8ea26a5d0 100644 --- a/clinicadl/commandline/pipelines/generate/random/cli.py +++ b/clinicadl/commandline/pipelines/generate/random/cli.py @@ -7,8 +7,6 @@ import pandas as pd from joblib import Parallel, delayed -from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.caps_dataset.caps_dataset_utils import find_file_type from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import ( data, @@ -16,6 +14,8 @@ preprocessing, ) from clinicadl.commandline.pipelines.generate.random import options as random +from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.dataset.caps_dataset_utils import find_file_type from clinicadl.generate.generate_config import GenerateRandomConfig from clinicadl.generate.generate_utils import ( load_and_check_tsv, diff --git a/clinicadl/commandline/pipelines/generate/trivial/cli.py b/clinicadl/commandline/pipelines/generate/trivial/cli.py index b48651811..d683865f2 100644 --- a/clinicadl/commandline/pipelines/generate/trivial/cli.py +++ b/clinicadl/commandline/pipelines/generate/trivial/cli.py @@ -6,8 +6,6 @@ import pandas as pd from joblib import Parallel, delayed -from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.caps_dataset.caps_dataset_utils import find_file_type from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import ( data, @@ -15,6 +13,8 @@ preprocessing, ) from clinicadl.commandline.pipelines.generate.trivial import options as trivial +from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.dataset.caps_dataset_utils import find_file_type from clinicadl.generate.generate_config import GenerateTrivialConfig from clinicadl.generate.generate_utils import ( im_loss_roi_gaussian_distribution, @@ -118,7 +118,6 @@ def create_trivial_image(subject_id: int) -> pd.DataFrame: if caps_config.data.mask_path is None: caps_config.data.mask_path = get_mask_path() path_to_mask = caps_config.data.mask_path / f"mask-{label + 1}.nii" - print(path_to_mask) if path_to_mask.is_file(): atlas_to_mask = nib.loadsave.load(path_to_mask).get_fdata() else: diff --git a/clinicadl/commandline/pipelines/interpret/cli.py b/clinicadl/commandline/pipelines/interpret/cli.py index 3509eaf23..9f4fb8a87 100644 --- a/clinicadl/commandline/pipelines/interpret/cli.py +++ b/clinicadl/commandline/pipelines/interpret/cli.py @@ -1,3 +1,5 @@ +from pathlib import Path + import click from clinicadl.commandline import arguments @@ -10,7 +12,7 @@ ) from clinicadl.commandline.pipelines.interpret import options from clinicadl.interpret.config import InterpretConfig -from clinicadl.predict.predict_manager import PredictManager +from clinicadl.predictor.old_predictor import Predictor @click.command("interpret", no_args_is_help=True) @@ -40,9 +42,13 @@ def cli(**kwargs): NAME is the name of the saliency map task. METHOD is the method used to extract an attribution map. """ + from clinicadl.utils.iotools.train_utils import merge_cli_and_maps_json_options - interpret_config = InterpretConfig(**kwargs) - predict_manager = PredictManager(interpret_config) + dict_ = merge_cli_and_maps_json_options( + Path(kwargs["input_maps"]) / "maps.json", **kwargs + ) + interpret_config = InterpretConfig(**dict_) + predict_manager = Predictor(interpret_config) predict_manager.interpret() diff --git a/clinicadl/commandline/pipelines/interpret/options.py b/clinicadl/commandline/pipelines/interpret/options.py index 5313b4a90..43cada4c4 100644 --- a/clinicadl/commandline/pipelines/interpret/options.py +++ b/clinicadl/commandline/pipelines/interpret/options.py @@ -2,28 +2,28 @@ from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type -from clinicadl.interpret.config import InterpretConfig +from clinicadl.interpret.config import InterpretBaseConfig # interpret specific name = click.argument( "name", - type=get_type("name", InterpretConfig), + type=get_type("name", InterpretBaseConfig), ) method = click.argument( "method", - type=get_type("method", InterpretConfig), # ["gradients", "grad-cam"] + type=get_type("method", InterpretBaseConfig), # ["gradients", "grad-cam"] ) level = click.option( "--level_grad_cam", - type=get_type("level", InterpretConfig), - default=get_default("level", InterpretConfig), + type=get_type("level", InterpretBaseConfig), + default=get_default("level", InterpretBaseConfig), help="level of the feature map (after the layer corresponding to the number) chosen for grad-cam.", show_default=True, ) target_node = click.option( "--target_node", - type=get_type("target_node", InterpretConfig), - default=get_default("target_node", InterpretConfig), + type=get_type("target_node", InterpretBaseConfig), + default=get_default("target_node", InterpretBaseConfig), help="Which target node the gradients explain. Default takes the first output node.", show_default=True, ) diff --git a/clinicadl/commandline/pipelines/predict/cli.py b/clinicadl/commandline/pipelines/predict/cli.py index fa7303008..119c12678 100644 --- a/clinicadl/commandline/pipelines/predict/cli.py +++ b/clinicadl/commandline/pipelines/predict/cli.py @@ -10,8 +10,8 @@ validation, ) from clinicadl.commandline.pipelines.predict import options -from clinicadl.predict.config import PredictConfig -from clinicadl.predict.predict_manager import PredictManager +from clinicadl.predictor.config import PredictConfig +from clinicadl.predictor.old_predictor import Predictor @click.command(name="predict", no_args_is_help=True) @@ -61,7 +61,7 @@ def cli(input_maps_directory, data_group, **kwargs): """ predict_config = PredictConfig(**kwargs) - predict_manager = PredictManager(predict_config) + predict_manager = Predictor(predict_config) predict_manager.predict() diff --git a/clinicadl/commandline/pipelines/predict/options.py b/clinicadl/commandline/pipelines/predict/options.py index 003dfe275..cbb8980ca 100644 --- a/clinicadl/commandline/pipelines/predict/options.py +++ b/clinicadl/commandline/pipelines/predict/options.py @@ -1,13 +1,11 @@ import click from clinicadl.config.config_utils import get_default_from_config_class as get_default -from clinicadl.predict.config import PredictConfig +from clinicadl.predictor.config import PredictConfig # predict specific use_labels = click.option( "--use_labels/--no_labels", - show_default=True, - default=get_default("use_labels", PredictConfig), help="Set this option to --no_labels if your dataset does not contain ground truth labels.", ) save_tensor = click.option( diff --git a/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py b/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py index c9630c507..d162dcf97 100644 --- a/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py +++ b/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py @@ -1,6 +1,5 @@ import click -from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import ( data, @@ -8,7 +7,8 @@ extraction, preprocessing, ) -from clinicadl.prepare_data.prepare_data import DeepLearningPrepareData +from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.dataset.prepare_data.prepare_data import DeepLearningPrepareData from clinicadl.utils.enum import ExtractionMethod diff --git a/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py b/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py index 4c43df851..f4f888a71 100644 --- a/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py +++ b/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py @@ -1,6 +1,5 @@ import click -from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import ( data, @@ -8,7 +7,8 @@ extraction, preprocessing, ) -from clinicadl.prepare_data.prepare_data import DeepLearningPrepareData +from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.dataset.prepare_data.prepare_data import DeepLearningPrepareData from clinicadl.utils.enum import ExtractionMethod diff --git a/clinicadl/commandline/pipelines/quality_check/pet_linear/cli.py b/clinicadl/commandline/pipelines/quality_check/pet_linear/cli.py index 455bb5299..938895b82 100644 --- a/clinicadl/commandline/pipelines/quality_check/pet_linear/cli.py +++ b/clinicadl/commandline/pipelines/quality_check/pet_linear/cli.py @@ -44,7 +44,7 @@ def cli( SUVR_REFERENCE_REGION is the reference region used to perform intensity normalization {pons|cerebellumPons|pons2|cerebellumPons2}. """ - from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig + from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig from .....quality_check.pet_linear.quality_check import ( quality_check as pet_linear_qc, diff --git a/clinicadl/commandline/pipelines/quality_check/t1_linear/cli.py b/clinicadl/commandline/pipelines/quality_check/t1_linear/cli.py index f73971a63..6c55b3586 100755 --- a/clinicadl/commandline/pipelines/quality_check/t1_linear/cli.py +++ b/clinicadl/commandline/pipelines/quality_check/t1_linear/cli.py @@ -1,8 +1,8 @@ import click -from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import computational, data, dataloader +from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.enum import ExtractionMethod, Preprocessing diff --git a/clinicadl/commandline/pipelines/train/classification/cli.py b/clinicadl/commandline/pipelines/train/classification/cli.py index 539f6cd42..21d57f365 100644 --- a/clinicadl/commandline/pipelines/train/classification/cli.py +++ b/clinicadl/commandline/pipelines/train/classification/cli.py @@ -13,7 +13,6 @@ optimizer, reproducibility, split, - ssda, transforms, validation, ) @@ -24,7 +23,7 @@ options as transfer_learning, ) from clinicadl.trainer.config.classification import ClassificationConfig -from clinicadl.trainer.trainer import Trainer +from clinicadl.trainer.old_trainer import Trainer from clinicadl.utils.enum import Task from clinicadl.utils.iotools.train_utils import merge_cli_and_config_file_options @@ -63,12 +62,6 @@ @dataloader.batch_size @dataloader.sampler @dataloader.n_proc -# ssda option -@ssda.ssda_network -@ssda.caps_target -@ssda.tsv_target_lab -@ssda.tsv_target_unlab -@ssda.preprocessing_json_target # Cross validation @split.n_splits @split.split @@ -115,4 +108,5 @@ def cli(**kwargs): options = merge_cli_and_config_file_options(Task.CLASSIFICATION, **kwargs) config = ClassificationConfig(**options) trainer = Trainer(config) + trainer.train(split_list=config.split.split, overwrite=True) diff --git a/clinicadl/commandline/pipelines/train/from_json/cli.py b/clinicadl/commandline/pipelines/train/from_json/cli.py index c0130a9b9..5e6771258 100644 --- a/clinicadl/commandline/pipelines/train/from_json/cli.py +++ b/clinicadl/commandline/pipelines/train/from_json/cli.py @@ -6,7 +6,7 @@ from clinicadl.commandline.modules_options import ( split, ) -from clinicadl.trainer.trainer import Trainer +from clinicadl.trainer.old_trainer import Trainer @click.command(name="from_json", no_args_is_help=True) @@ -27,6 +27,8 @@ def cli(**kwargs): logger.info(f"Reading JSON file at path {kwargs['config_file']}...") trainer = Trainer.from_json( - config_file=kwargs["config_file"], maps_path=kwargs["output_maps_directory"] + config_file=kwargs["config_file"], + maps_path=kwargs["output_maps_directory"], + split=kwargs["split"], ) trainer.train(split_list=kwargs["split"], overwrite=True) diff --git a/clinicadl/commandline/pipelines/train/reconstruction/cli.py b/clinicadl/commandline/pipelines/train/reconstruction/cli.py index d63bf63f8..1bad88443 100644 --- a/clinicadl/commandline/pipelines/train/reconstruction/cli.py +++ b/clinicadl/commandline/pipelines/train/reconstruction/cli.py @@ -13,7 +13,6 @@ optimizer, reproducibility, split, - ssda, transforms, validation, ) @@ -24,7 +23,7 @@ options as transfer_learning, ) from clinicadl.trainer.config.reconstruction import ReconstructionConfig -from clinicadl.trainer.trainer import Trainer +from clinicadl.trainer.old_trainer import Trainer from clinicadl.utils.enum import Task from clinicadl.utils.iotools.train_utils import merge_cli_and_config_file_options @@ -63,12 +62,6 @@ @dataloader.batch_size @dataloader.sampler @dataloader.n_proc -# ssda option -@ssda.ssda_network -@ssda.caps_target -@ssda.tsv_target_lab -@ssda.tsv_target_unlab -@ssda.preprocessing_json_target # Cross validation @split.n_splits @split.split diff --git a/clinicadl/commandline/pipelines/train/regression/cli.py b/clinicadl/commandline/pipelines/train/regression/cli.py index ff6dd68ca..95a623604 100644 --- a/clinicadl/commandline/pipelines/train/regression/cli.py +++ b/clinicadl/commandline/pipelines/train/regression/cli.py @@ -13,7 +13,6 @@ optimizer, reproducibility, split, - ssda, transforms, validation, ) @@ -22,7 +21,7 @@ options as transfer_learning, ) from clinicadl.trainer.config.regression import RegressionConfig -from clinicadl.trainer.trainer import Trainer +from clinicadl.trainer.old_trainer import Trainer from clinicadl.utils.enum import Task from clinicadl.utils.iotools.train_utils import merge_cli_and_config_file_options @@ -61,12 +60,6 @@ @dataloader.batch_size @dataloader.sampler @dataloader.n_proc -# ssda o -@ssda.ssda_network -@ssda.caps_target -@ssda.tsv_target_lab -@ssda.tsv_target_unlab -@ssda.preprocessing_json_target # Cross validation @split.n_splits @split.split diff --git a/clinicadl/commandline/pipelines/train/resume/cli.py b/clinicadl/commandline/pipelines/train/resume/cli.py index 1fc34a0f4..90efa4244 100644 --- a/clinicadl/commandline/pipelines/train/resume/cli.py +++ b/clinicadl/commandline/pipelines/train/resume/cli.py @@ -4,7 +4,7 @@ from clinicadl.commandline.modules_options import ( split, ) -from clinicadl.trainer.trainer import Trainer +from clinicadl.trainer.old_trainer import Trainer @click.command(name="resume", no_args_is_help=True) @@ -16,4 +16,4 @@ def cli(input_maps_directory, split): INPUT_MAPS_DIRECTORY is the path to the MAPS folder where training job has started. """ trainer = Trainer.from_maps(input_maps_directory) - trainer.resume(split) + trainer.resume() diff --git a/clinicadl/config/config/ssda.py b/clinicadl/config/config/ssda.py deleted file mode 100644 index caf52634d..000000000 --- a/clinicadl/config/config/ssda.py +++ /dev/null @@ -1,41 +0,0 @@ -from logging import getLogger -from pathlib import Path -from typing import Any, Dict - -from pydantic import BaseModel, ConfigDict, computed_field - -from clinicadl.utils.iotools.utils import read_preprocessing - -logger = getLogger("clinicadl.ssda_config") - - -class SSDAConfig(BaseModel): - """Config class to perform SSDA.""" - - caps_target: Path = Path("") - preprocessing_json_target: Path = Path("") - ssda_network: bool = False - tsv_target_lab: Path = Path("") - tsv_target_unlab: Path = Path("") - # pydantic config - model_config = ConfigDict(validate_assignment=True) - - @computed_field - @property - def preprocessing_dict_target(self) -> Dict[str, Any]: - """ - Gets the preprocessing dictionary from a target preprocessing json file. - - Returns - ------- - Dict[str, Any] - The preprocessing dictionary. - """ - if not self.ssda_network: - return {} - - preprocessing_json_target = ( - self.caps_target / "tensor_extraction" / self.preprocessing_json_target - ) - - return read_preprocessing(preprocessing_json_target) diff --git a/clinicadl/caps_dataset/__init__.py b/clinicadl/dataset/__init__.py similarity index 100% rename from clinicadl/caps_dataset/__init__.py rename to clinicadl/dataset/__init__.py diff --git a/clinicadl/caps_dataset/data.py b/clinicadl/dataset/caps_dataset.py similarity index 99% rename from clinicadl/caps_dataset/data.py rename to clinicadl/dataset/caps_dataset.py index 638f49e9d..d45dc5aa6 100644 --- a/clinicadl/caps_dataset/data.py +++ b/clinicadl/dataset/caps_dataset.py @@ -10,14 +10,14 @@ import torch from torch.utils.data import Dataset -from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.caps_dataset.extraction.config import ( +from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.dataset.config.extraction import ( ExtractionImageConfig, ExtractionPatchConfig, ExtractionROIConfig, ExtractionSliceConfig, ) -from clinicadl.prepare_data.prepare_data_utils import ( +from clinicadl.dataset.prepare_data.prepare_data_utils import ( compute_discarded_slices, extract_patch_path, extract_patch_tensor, diff --git a/clinicadl/caps_dataset/caps_dataset_config.py b/clinicadl/dataset/caps_dataset_config.py similarity index 93% rename from clinicadl/caps_dataset/caps_dataset_config.py rename to clinicadl/dataset/caps_dataset_config.py index b7086944c..0eac3ffd3 100644 --- a/clinicadl/caps_dataset/caps_dataset_config.py +++ b/clinicadl/dataset/caps_dataset_config.py @@ -3,10 +3,8 @@ from pydantic import BaseModel, ConfigDict -from clinicadl.caps_dataset.data_config import DataConfig -from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig -from clinicadl.caps_dataset.extraction import config as extraction -from clinicadl.caps_dataset.preprocessing.config import ( +from clinicadl.dataset.config import extraction +from clinicadl.dataset.config.preprocessing import ( CustomPreprocessingConfig, DTIPreprocessingConfig, FlairPreprocessingConfig, @@ -14,7 +12,9 @@ PreprocessingConfig, T1PreprocessingConfig, ) -from clinicadl.caps_dataset.preprocessing.utils import ( +from clinicadl.dataset.data_config import DataConfig +from clinicadl.dataset.dataloader_config import DataLoaderConfig +from clinicadl.dataset.utils import ( bids_nii, dwi_dti, linear_nii, diff --git a/clinicadl/caps_dataset/caps_dataset_utils.py b/clinicadl/dataset/caps_dataset_utils.py similarity index 96% rename from clinicadl/caps_dataset/caps_dataset_utils.py rename to clinicadl/dataset/caps_dataset_utils.py index b87c6ed22..b54ba373d 100644 --- a/clinicadl/caps_dataset/caps_dataset_utils.py +++ b/clinicadl/dataset/caps_dataset_utils.py @@ -2,15 +2,15 @@ from pathlib import Path from typing import Any, Dict, Optional, Tuple -from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.caps_dataset.preprocessing.config import ( +from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.dataset.config.preprocessing import ( CustomPreprocessingConfig, DTIPreprocessingConfig, FlairPreprocessingConfig, PETPreprocessingConfig, T1PreprocessingConfig, ) -from clinicadl.caps_dataset.preprocessing.utils import ( +from clinicadl.dataset.utils import ( bids_nii, dwi_dti, linear_nii, @@ -179,7 +179,7 @@ def read_json(json_path: Path) -> Dict[str, Any]: if "preprocessing" not in parameters: parameters["preprocessing"] = parameters["preprocessing_dict"]["preprocessing"] - from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig + from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig config = CapsDatasetConfig.from_preprocessing_and_extraction_method( extraction=parameters["mode"], diff --git a/clinicadl/dataset/caps_reader.py b/clinicadl/dataset/caps_reader.py new file mode 100644 index 000000000..14199616e --- /dev/null +++ b/clinicadl/dataset/caps_reader.py @@ -0,0 +1,62 @@ +from pathlib import Path +from typing import Optional + +from clinicadl.dataset.caps_dataset import CapsDataset +from clinicadl.dataset.config.extraction import ( + ExtractionConfig, + ExtractionImageConfig, + ExtractionPatchConfig, + ExtractionROIConfig, + ExtractionSliceConfig, +) +from clinicadl.dataset.config.preprocessing import PreprocessingConfig +from clinicadl.experiment_manager.experiment_manager import ExperimentManager +from clinicadl.transforms.config import TransformsConfig + + +class CapsReader: + def __init__(self, caps_directory: Path, manager: ExperimentManager): + """TO COMPLETE""" + pass + + def get_dataset( + self, + extraction: ExtractionConfig, + preprocessing: PreprocessingConfig, + sub_ses_tsv: Path, + transforms: TransformsConfig, + ) -> CapsDataset: + return CapsDataset(extraction, preprocessing, sub_ses_tsv, transforms) + + def get_preprocessing(self, preprocessing: str) -> PreprocessingConfig: + """TO COMPLETE""" + + return PreprocessingConfig() + + def extract_slice( + self, preprocessing: PreprocessingConfig, arg_slice: Optional[int] = None + ) -> ExtractionSliceConfig: + """TO COMPLETE""" + + return ExtractionSliceConfig() + + def extract_patch( + self, preprocessing: PreprocessingConfig, arg_patch: Optional[int] = None + ) -> ExtractionPatchConfig: + """TO COMPLETE""" + + return ExtractionPatchConfig() + + def extract_roi( + self, preprocessing: PreprocessingConfig, arg_roi: Optional[int] = None + ) -> ExtractionROIConfig: + """TO COMPLETE""" + + return ExtractionROIConfig() + + def extract_image( + self, preprocessing: PreprocessingConfig, arg_image: Optional[int] = None + ) -> ExtractionImageConfig: + """TO COMPLETE""" + + return ExtractionImageConfig() diff --git a/clinicadl/dataset/concat.py b/clinicadl/dataset/concat.py new file mode 100644 index 000000000..f0b420dfe --- /dev/null +++ b/clinicadl/dataset/concat.py @@ -0,0 +1,6 @@ +from clinicadl.dataset.caps_dataset import CapsDataset + + +class ConcatDataset(CapsDataset): + def __init__(self, list_: list[CapsDataset]): + """TO COMPLETE""" diff --git a/clinicadl/maps_manager/__init__.py b/clinicadl/dataset/config/__init__.py similarity index 100% rename from clinicadl/maps_manager/__init__.py rename to clinicadl/dataset/config/__init__.py diff --git a/clinicadl/caps_dataset/extraction/config.py b/clinicadl/dataset/config/extraction.py similarity index 100% rename from clinicadl/caps_dataset/extraction/config.py rename to clinicadl/dataset/config/extraction.py diff --git a/clinicadl/caps_dataset/preprocessing/config.py b/clinicadl/dataset/config/preprocessing.py similarity index 100% rename from clinicadl/caps_dataset/preprocessing/config.py rename to clinicadl/dataset/config/preprocessing.py diff --git a/clinicadl/caps_dataset/data_config.py b/clinicadl/dataset/data_config.py similarity index 92% rename from clinicadl/caps_dataset/data_config.py rename to clinicadl/dataset/data_config.py index 35aed91b5..39e6a6254 100644 --- a/clinicadl/caps_dataset/data_config.py +++ b/clinicadl/dataset/data_config.py @@ -24,7 +24,7 @@ class DataConfig(BaseModel): # TODO : put in data module that must be passed by the user. """ - caps_directory: Path + caps_directory: Optional[Path] = None baseline: bool = False diagnoses: Tuple[str, ...] = ("AD", "CN") data_df: Optional[pd.DataFrame] = None @@ -122,7 +122,6 @@ def preprocessing_dict(self) -> Dict[str, Any]: ValueError In case of multi-cohort dataset, if no preprocessing file is found in any CAPS. """ - from clinicadl.caps_dataset.data import CapsDataset if self.preprocessing_json is not None: if not self.multi_cohort: @@ -147,15 +146,17 @@ def preprocessing_dict(self) -> Dict[str, Any]: f"in {caps_dict}." ) - preprocessing_dict = read_preprocessing(preprocessing_json) + preprocessing_dict = read_preprocessing(preprocessing_json) - if ( - preprocessing_dict["mode"] == "roi" - and "roi_background_value" not in preprocessing_dict - ): - preprocessing_dict["roi_background_value"] = 0 + if ( + preprocessing_dict["mode"] == "roi" + and "roi_background_value" not in preprocessing_dict + ): + preprocessing_dict["roi_background_value"] = 0 - return preprocessing_dict + return preprocessing_dict + else: + return None @computed_field @property diff --git a/clinicadl/caps_dataset/dataloader_config.py b/clinicadl/dataset/dataloader_config.py similarity index 100% rename from clinicadl/caps_dataset/dataloader_config.py rename to clinicadl/dataset/dataloader_config.py diff --git a/clinicadl/network/autoencoder/__init__.py b/clinicadl/dataset/prepare_data/__init__.py similarity index 100% rename from clinicadl/network/autoencoder/__init__.py rename to clinicadl/dataset/prepare_data/__init__.py diff --git a/clinicadl/prepare_data/prepare_data.py b/clinicadl/dataset/prepare_data/prepare_data.py similarity index 97% rename from clinicadl/prepare_data/prepare_data.py rename to clinicadl/dataset/prepare_data/prepare_data.py index e9b7fc073..d9ef1c412 100644 --- a/clinicadl/prepare_data/prepare_data.py +++ b/clinicadl/dataset/prepare_data/prepare_data.py @@ -5,9 +5,9 @@ from joblib import Parallel, delayed from torch import save as save_tensor -from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.caps_dataset.caps_dataset_utils import compute_folder_and_file_type -from clinicadl.caps_dataset.extraction.config import ( +from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.dataset.caps_dataset_utils import compute_folder_and_file_type +from clinicadl.dataset.config.extraction import ( ExtractionConfig, ExtractionImageConfig, ExtractionPatchConfig, diff --git a/clinicadl/prepare_data/prepare_data_utils.py b/clinicadl/dataset/prepare_data/prepare_data_utils.py similarity index 100% rename from clinicadl/prepare_data/prepare_data_utils.py rename to clinicadl/dataset/prepare_data/prepare_data_utils.py diff --git a/clinicadl/caps_dataset/preprocessing/utils.py b/clinicadl/dataset/utils.py similarity index 76% rename from clinicadl/caps_dataset/preprocessing/utils.py rename to clinicadl/dataset/utils.py index 0aa93004d..7af1da539 100644 --- a/clinicadl/caps_dataset/preprocessing/utils.py +++ b/clinicadl/dataset/utils.py @@ -1,6 +1,6 @@ from typing import Optional -from clinicadl.caps_dataset.preprocessing import config as preprocessing_config +from clinicadl.dataset.config import preprocessing from clinicadl.utils.enum import ( LinearModality, Preprocessing, @@ -11,7 +11,7 @@ def bids_nii( - config: preprocessing_config.PreprocessingConfig, + config: preprocessing.PreprocessingConfig, reconstruction: Optional[str] = None, ) -> FileType: """Return the query dict required to capture PET scans. @@ -41,7 +41,7 @@ def bids_nii( f"ClinicaDL is Unable to read this modality ({config.preprocessing}) of images, please chose one from this list: {list[Preprocessing]}" ) - if isinstance(config, preprocessing_config.PETPreprocessingConfig): + if isinstance(config, preprocessing.PETPreprocessingConfig): trc = "" if config.tracer is None else f"_trc-{Tracer(config.tracer).value}" rec = "" if reconstruction is None else f"_rec-{reconstruction}" description = "PET data" @@ -56,13 +56,13 @@ def bids_nii( ) return file_type - elif isinstance(config, preprocessing_config.T1PreprocessingConfig): + elif isinstance(config, preprocessing.T1PreprocessingConfig): return FileType(pattern="anat/sub-*_ses-*_T1w.nii*", description="T1w MRI") - elif isinstance(config, preprocessing_config.FlairPreprocessingConfig): + elif isinstance(config, preprocessing.FlairPreprocessingConfig): return FileType(pattern="sub-*_ses-*_flair.nii*", description="FLAIR T2w MRI") - elif isinstance(config, preprocessing_config.DTIPreprocessingConfig): + elif isinstance(config, preprocessing.DTIPreprocessingConfig): return FileType(pattern="dwi/sub-*_ses-*_dwi.nii*", description="DWI NIfTI") else: @@ -70,15 +70,15 @@ def bids_nii( def linear_nii( - config: preprocessing_config.PreprocessingConfig, + config: preprocessing, ) -> FileType: - if isinstance(config, preprocessing_config.T1PreprocessingConfig): + if isinstance(config, preprocessing.T1PreprocessingConfig): needed_pipeline = Preprocessing.T1_LINEAR modality = LinearModality.T1W - elif isinstance(config, preprocessing_config.T2PreprocessingConfig): + elif isinstance(config, preprocessing.T2PreprocessingConfig): needed_pipeline = Preprocessing.T2_LINEAR modality = LinearModality.T2W - elif isinstance(config, preprocessing_config.FlairPreprocessingConfig): + elif isinstance(config, preprocessing.FlairPreprocessingConfig): needed_pipeline = Preprocessing.FLAIR_LINEAR modality = LinearModality.FLAIR else: @@ -102,7 +102,7 @@ def linear_nii( return file_type -def dwi_dti(config: preprocessing_config.DTIPreprocessingConfig) -> FileType: +def dwi_dti(config: preprocessing.DTIPreprocessingConfig) -> FileType: """Return the query dict required to capture DWI DTI images. Parameters @@ -113,12 +113,12 @@ def dwi_dti(config: preprocessing_config.DTIPreprocessingConfig) -> FileType: ------- FileType : """ - if isinstance(config, preprocessing_config.DTIPreprocessingConfig): + if isinstance(config, preprocessing.DTIPreprocessingConfig): measure = config.dti_measure space = config.dti_space else: raise ClinicaDLArgumentError( - f"PreprocessingConfig is of type {config} but should be of type{preprocessing_config.DTIPreprocessingConfig}" + f"preprocessing is of type {config} but should be of type{preprocessing.DTIPreprocessingConfig}" ) return FileType( @@ -128,10 +128,10 @@ def dwi_dti(config: preprocessing_config.DTIPreprocessingConfig) -> FileType: ) -def pet_linear_nii(config: preprocessing_config.PETPreprocessingConfig) -> FileType: - if not isinstance(config, preprocessing_config.PETPreprocessingConfig): +def pet_linear_nii(config: preprocessing.PETPreprocessingConfig) -> FileType: + if not isinstance(config, preprocessing.PETPreprocessingConfig): raise ClinicaDLArgumentError( - f"PreprocessingConfig is of type {config} but should be of type{preprocessing_config.PETPreprocessingConfig}" + f"preprocessing is of type {config} but should be of type{preprocessing.PETPreprocessingConfig}" ) if config.use_uncropped_image: diff --git a/clinicadl/network/cnn/__init__.py b/clinicadl/experiment_manager/__init__.py similarity index 100% rename from clinicadl/network/cnn/__init__.py rename to clinicadl/experiment_manager/__init__.py diff --git a/clinicadl/maps_manager/config.py b/clinicadl/experiment_manager/config.py similarity index 100% rename from clinicadl/maps_manager/config.py rename to clinicadl/experiment_manager/config.py diff --git a/clinicadl/experiment_manager/experiment_manager.py b/clinicadl/experiment_manager/experiment_manager.py new file mode 100644 index 000000000..f3e4aaac8 --- /dev/null +++ b/clinicadl/experiment_manager/experiment_manager.py @@ -0,0 +1,7 @@ +from pathlib import Path + + +class ExperimentManager: + def __init__(self, maps_path: Path, overwrite: bool) -> None: + """TO COMPLETE""" + pass diff --git a/clinicadl/maps_manager/maps_manager.py b/clinicadl/experiment_manager/maps_manager.py similarity index 98% rename from clinicadl/maps_manager/maps_manager.py rename to clinicadl/experiment_manager/maps_manager.py index 76cb544fe..84c7806ca 100644 --- a/clinicadl/maps_manager/maps_manager.py +++ b/clinicadl/experiment_manager/maps_manager.py @@ -9,17 +9,17 @@ import pandas as pd import torch -from clinicadl.caps_dataset.caps_dataset_utils import read_json -from clinicadl.caps_dataset.data import ( +from clinicadl.dataset.caps_dataset import ( return_dataset, ) -from clinicadl.metrics.metric_module import MetricModule -from clinicadl.metrics.utils import ( +from clinicadl.dataset.caps_dataset_utils import read_json +from clinicadl.metrics.old_metrics.metric_module import MetricModule +from clinicadl.metrics.old_metrics.utils import ( check_selection_metric, ) -from clinicadl.predict.utils import get_prediction +from clinicadl.predictor.utils import get_prediction from clinicadl.splitter.config import SplitterConfig -from clinicadl.splitter.splitter import Splitter +from clinicadl.splitter.old_splitter import Splitter from clinicadl.trainer.tasks_utils import ( ensemble_prediction, evaluation_metrics, @@ -346,7 +346,7 @@ def _write_information(self): """ from datetime import datetime - import clinicadl.network as network_package + import clinicadl.networks.old_network as network_package model_class = getattr(network_package, self.architecture) args = list( @@ -589,7 +589,7 @@ def _init_model( gpu (bool): If given, a new value for the device of the model will be computed. network (int): Index of the network trained (used in multi-network setting only). """ - import clinicadl.network as network_package + import clinicadl.networks.old_network as network_package logger.debug(f"Initialization of model {self.architecture}") # or choose to implement a dictionary diff --git a/clinicadl/hugging_face/hugging_face.py b/clinicadl/hugging_face/hugging_face.py index 00f729e35..22b6bbb02 100644 --- a/clinicadl/hugging_face/hugging_face.py +++ b/clinicadl/hugging_face/hugging_face.py @@ -5,7 +5,7 @@ import toml -from clinicadl.caps_dataset.caps_dataset_utils import read_json +from clinicadl.dataset.caps_dataset_utils import read_json from clinicadl.utils.exceptions import ClinicaDLArgumentError from clinicadl.utils.iotools.maps_manager_utils import ( remove_unused_tasks, diff --git a/clinicadl/interpret/config.py b/clinicadl/interpret/config.py index 41c8dcea9..d679a82e4 100644 --- a/clinicadl/interpret/config.py +++ b/clinicadl/interpret/config.py @@ -1,23 +1,35 @@ from logging import getLogger from pathlib import Path -from typing import Optional +from typing import Any, Dict, Optional from pydantic import BaseModel, field_validator -from clinicadl.caps_dataset.data_config import DataConfig as DataBaseConfig -from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig +from clinicadl.dataset.data_config import DataConfig +from clinicadl.dataset.dataloader_config import DataLoaderConfig +from clinicadl.experiment_manager.config import ( + MapsManagerConfig as MapsManagerConfigBase, +) +from clinicadl.experiment_manager.maps_manager import MapsManager from clinicadl.interpret.gradients import GradCam, Gradients, VanillaBackProp -from clinicadl.maps_manager.config import MapsManagerConfig +from clinicadl.predictor.validation import ValidationConfig from clinicadl.splitter.config import SplitConfig -from clinicadl.splitter.validation import ValidationConfig +from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.enum import InterpretationMethod +from clinicadl.utils.exceptions import ClinicaDLArgumentError logger = getLogger("clinicadl.interpret_config") -class DataConfig(DataBaseConfig): - caps_directory: Optional[Path] = None +class MapsManagerConfig(MapsManagerConfigBase): + save_tensor: bool = False + + def check_output_saving_tensor(self, network_task: str) -> None: + # Check if task is reconstruction for "save_tensor" and "save_nifti" + if self.save_tensor and network_task != "reconstruction": + raise ClinicaDLArgumentError( + "Cannot save tensors if the network task is not reconstruction. Please remove --save_tensor option." + ) class InterpretBaseConfig(BaseModel): @@ -44,13 +56,57 @@ def get_method(self) -> Gradients: raise ValueError(f"The method {self.method.value} is not implemented") -class InterpretConfig( - MapsManagerConfig, - InterpretBaseConfig, - DataConfig, - ValidationConfig, - ComputationalConfig, - DataLoaderConfig, - SplitConfig, -): +class InterpretConfig(BaseModel): """Config class to perform Transfer Learning.""" + + maps_manager: MapsManagerConfig + data: DataConfig + validation: ValidationConfig + computational: ComputationalConfig + dataloader: DataLoaderConfig + split: SplitConfig + interpret: InterpretBaseConfig + + def __init__(self, **kwargs): + super().__init__( + maps_manager=kwargs, + computational=kwargs, + dataloader=kwargs, + data=kwargs, + split=kwargs, + validation=kwargs, + interpret=kwargs, + ) + + def _update(self, config_dict: Dict[str, Any]) -> None: + """Updates the configs with a dict given by the user.""" + self.data.__dict__.update(config_dict) + self.split.__dict__.update(config_dict) + self.validation.__dict__.update(config_dict) + self.maps_manager.__dict__.update(config_dict) + self.split.__dict__.update(config_dict) + self.computational.__dict__.update(config_dict) + self.dataloader.__dict__.update(config_dict) + self.interpret.__dict__.update(config_dict) + + def adapt_with_maps_manager_info(self, maps_manager: MapsManager): + self.maps_manager.check_output_saving_nifti(maps_manager.network_task) + self.data.diagnoses = ( + maps_manager.diagnoses + if self.data.diagnoses is None or len(self.data.diagnoses) == 0 + else self.data.diagnoses + ) + + self.dataloader.batch_size = ( + maps_manager.batch_size + if not self.dataloader.batch_size + else self.dataloader.batch_size + ) + self.dataloader.n_proc = ( + maps_manager.n_proc + if not self.dataloader.n_proc + else self.dataloader.n_proc + ) + + self.split.adapt_cross_val_with_maps_manager_info(maps_manager) + self.maps_manager.check_output_saving_tensor(maps_manager.network_task) diff --git a/clinicadl/interpret/gradients.py b/clinicadl/interpret/gradients.py index b62308f38..d6e11815f 100644 --- a/clinicadl/interpret/gradients.py +++ b/clinicadl/interpret/gradients.py @@ -50,7 +50,7 @@ class GradCam(Gradients): """ def __init__(self, model): - from clinicadl.network.sub_network import CNN + from clinicadl.networks.old_network.sub_network import CNN super().__init__(model=model) if not isinstance(model, CNN): diff --git a/clinicadl/monai_metrics/__init__.py b/clinicadl/metrics/__init__.py similarity index 100% rename from clinicadl/monai_metrics/__init__.py rename to clinicadl/metrics/__init__.py diff --git a/clinicadl/monai_metrics/config/__init__.py b/clinicadl/metrics/config/__init__.py similarity index 100% rename from clinicadl/monai_metrics/config/__init__.py rename to clinicadl/metrics/config/__init__.py diff --git a/clinicadl/monai_metrics/config/base.py b/clinicadl/metrics/config/base.py similarity index 100% rename from clinicadl/monai_metrics/config/base.py rename to clinicadl/metrics/config/base.py diff --git a/clinicadl/monai_metrics/config/classification.py b/clinicadl/metrics/config/classification.py similarity index 100% rename from clinicadl/monai_metrics/config/classification.py rename to clinicadl/metrics/config/classification.py diff --git a/clinicadl/monai_metrics/config/enum.py b/clinicadl/metrics/config/enum.py similarity index 100% rename from clinicadl/monai_metrics/config/enum.py rename to clinicadl/metrics/config/enum.py diff --git a/clinicadl/monai_metrics/config/factory.py b/clinicadl/metrics/config/factory.py similarity index 100% rename from clinicadl/monai_metrics/config/factory.py rename to clinicadl/metrics/config/factory.py diff --git a/clinicadl/monai_metrics/config/generation.py b/clinicadl/metrics/config/generation.py similarity index 100% rename from clinicadl/monai_metrics/config/generation.py rename to clinicadl/metrics/config/generation.py diff --git a/clinicadl/monai_metrics/config/reconstruction.py b/clinicadl/metrics/config/reconstruction.py similarity index 100% rename from clinicadl/monai_metrics/config/reconstruction.py rename to clinicadl/metrics/config/reconstruction.py diff --git a/clinicadl/monai_metrics/config/regression.py b/clinicadl/metrics/config/regression.py similarity index 100% rename from clinicadl/monai_metrics/config/regression.py rename to clinicadl/metrics/config/regression.py diff --git a/clinicadl/monai_metrics/config/segmentation.py b/clinicadl/metrics/config/segmentation.py similarity index 100% rename from clinicadl/monai_metrics/config/segmentation.py rename to clinicadl/metrics/config/segmentation.py diff --git a/clinicadl/monai_metrics/factory.py b/clinicadl/metrics/factory.py similarity index 100% rename from clinicadl/monai_metrics/factory.py rename to clinicadl/metrics/factory.py diff --git a/clinicadl/metrics/metric_module.py b/clinicadl/metrics/old_metrics/metric_module.py similarity index 100% rename from clinicadl/metrics/metric_module.py rename to clinicadl/metrics/old_metrics/metric_module.py diff --git a/clinicadl/metrics/utils.py b/clinicadl/metrics/old_metrics/utils.py similarity index 100% rename from clinicadl/metrics/utils.py rename to clinicadl/metrics/old_metrics/utils.py diff --git a/clinicadl/network/unet/__init__.py b/clinicadl/model/__init__.py similarity index 100% rename from clinicadl/network/unet/__init__.py rename to clinicadl/model/__init__.py diff --git a/clinicadl/model/clinicadl_model.py b/clinicadl/model/clinicadl_model.py new file mode 100644 index 000000000..8785ed97b --- /dev/null +++ b/clinicadl/model/clinicadl_model.py @@ -0,0 +1,8 @@ +import torch.nn as nn +import torch.optim as optim + + +class ClinicaDLModel: + def __init__(self, network: nn.Module, loss: nn.Module, optimizer=optim.optimizer): + """TO COMPLETE""" + pass diff --git a/clinicadl/monai_networks/__init__.py b/clinicadl/monai_networks/__init__.py deleted file mode 100644 index 1d74473d4..000000000 --- a/clinicadl/monai_networks/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .config import ImplementedNetworks, NetworkConfig, create_network_config -from .factory import get_network diff --git a/clinicadl/monai_networks/config/__init__.py b/clinicadl/monai_networks/config/__init__.py deleted file mode 100644 index 10b8795dc..000000000 --- a/clinicadl/monai_networks/config/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .base import NetworkConfig -from .factory import create_network_config -from .utils.enum import ImplementedNetworks diff --git a/clinicadl/monai_networks/config/autoencoder.py b/clinicadl/monai_networks/config/autoencoder.py deleted file mode 100644 index a6df1a20c..000000000 --- a/clinicadl/monai_networks/config/autoencoder.py +++ /dev/null @@ -1,89 +0,0 @@ -from typing import Optional, Tuple, Union - -from pydantic import ( - NonNegativeInt, - PositiveInt, - computed_field, - model_validator, -) - -from clinicadl.utils.factories import DefaultFromLibrary - -from .base import VaryingDepthNetworkConfig -from .utils.enum import ImplementedNetworks - -__all__ = ["AutoEncoderConfig", "VarAutoEncoderConfig"] - - -class AutoEncoderConfig(VaryingDepthNetworkConfig): - """Config class for autoencoders.""" - - spatial_dims: PositiveInt - in_channels: PositiveInt - out_channels: PositiveInt - - inter_channels: Union[ - Optional[Tuple[PositiveInt, ...]], DefaultFromLibrary - ] = DefaultFromLibrary.YES - inter_dilations: Union[ - Optional[Tuple[PositiveInt, ...]], DefaultFromLibrary - ] = DefaultFromLibrary.YES - num_inter_units: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES - padding: Union[ - Optional[Union[PositiveInt, Tuple[PositiveInt, ...]]], DefaultFromLibrary - ] = DefaultFromLibrary.YES - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.AE - - @computed_field - @property - def dim(self) -> int: - """Dimension of the images.""" - return self.spatial_dims - - @model_validator(mode="after") - def model_validator(self): - """Checks coherence between parameters.""" - if self.padding != DefaultFromLibrary.YES: - assert self._check_dimensions( - self.padding - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for padding. You passed {self.padding}." - if isinstance(self.inter_channels, tuple) and isinstance( - self.inter_dilations, tuple - ): - assert len(self.inter_channels) == len( - self.inter_dilations - ), "inter_channels and inter_dilations muust have the same size." - elif isinstance(self.inter_dilations, tuple) and not isinstance( - self.inter_channels, tuple - ): - raise ValueError( - "You passed inter_dilations but didn't pass inter_channels." - ) - return self - - -class VarAutoEncoderConfig(AutoEncoderConfig): - """Config class for variational autoencoders.""" - - in_shape: Tuple[PositiveInt, ...] - in_channels: Optional[int] = None - latent_size: PositiveInt - use_sigmoid: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.VAE - - @model_validator(mode="after") - def model_validator_bis(self): - """Checks coherence between parameters.""" - assert ( - len(self.in_shape[1:]) == self.spatial_dims - ), f"You passed {self.spatial_dims} for spatial_dims, but in_shape suggests {len(self.in_shape[1:])} spatial dimensions." diff --git a/clinicadl/monai_networks/config/base.py b/clinicadl/monai_networks/config/base.py deleted file mode 100644 index 6e0ff1b6b..000000000 --- a/clinicadl/monai_networks/config/base.py +++ /dev/null @@ -1,168 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from enum import Enum -from typing import Any, Dict, Optional, Tuple, Union - -from pydantic import ( - BaseModel, - ConfigDict, - NonNegativeFloat, - NonNegativeInt, - PositiveInt, - computed_field, - field_validator, - model_validator, -) - -from clinicadl.utils.factories import DefaultFromLibrary - -from .utils.enum import ( - ImplementedActFunctions, - ImplementedNetworks, - ImplementedNormLayers, -) - - -class NetworkConfig(BaseModel, ABC): - """Base config class to configure neural networks.""" - - kernel_size: Union[ - PositiveInt, Tuple[PositiveInt, ...], DefaultFromLibrary - ] = DefaultFromLibrary.YES - up_kernel_size: Union[ - PositiveInt, Tuple[PositiveInt, ...], DefaultFromLibrary - ] = DefaultFromLibrary.YES - num_res_units: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES - act: Union[ - ImplementedActFunctions, - Tuple[ImplementedActFunctions, Dict[str, Any]], - DefaultFromLibrary, - ] = DefaultFromLibrary.YES - norm: Union[ - ImplementedNormLayers, - Tuple[ImplementedNormLayers, Dict[str, Any]], - DefaultFromLibrary, - ] = DefaultFromLibrary.YES - bias: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - adn_ordering: Union[Optional[str], DefaultFromLibrary] = DefaultFromLibrary.YES - # pydantic config - model_config = ConfigDict( - validate_assignment=True, - use_enum_values=True, - validate_default=True, - protected_namespaces=(), - ) - - @computed_field - @property - @abstractmethod - def network(self) -> ImplementedNetworks: - """The name of the network.""" - - @computed_field - @property - @abstractmethod - def dim(self) -> int: - """Dimension of the images.""" - - @classmethod - def base_validator_dropout(cls, v): - """Checks that dropout is between 0 and 1.""" - if isinstance(v, float): - assert ( - 0 <= v <= 1 - ), f"dropout must be between 0 and 1 but it has been set to {v}." - return v - - @field_validator("kernel_size", "up_kernel_size") - @classmethod - def base_is_odd(cls, value, field): - """Checks if a field is odd.""" - if value != DefaultFromLibrary.YES: - if isinstance(value, int): - value_ = (value,) - else: - value_ = value - for v in value_: - assert v % 2 == 1, f"{field.field_name} must be odd." - return value - - @field_validator("adn_ordering", mode="after") - @classmethod - def base_adn_validator(cls, v): - """Checks ADN sequence.""" - if v != DefaultFromLibrary.YES: - for letter in v: - assert ( - letter in {"A", "D", "N"} - ), f"adn_ordering must be composed by 'A', 'D' or/and 'N'. You passed {letter}." - assert len(v) == len( - set(v) - ), "adn_ordering cannot contain duplicated letter." - - return v - - @classmethod - def base_at_least_2d(cls, v, ctx): - """Checks that a tuple has at least a length of two.""" - if isinstance(v, tuple): - assert ( - len(v) >= 2 - ), f"{ctx.field_name} should have at least two dimensions (with the first one for the channel)." - return v - - @model_validator(mode="after") - def base_model_validator(self): - """Checks coherence between parameters.""" - if self.kernel_size != DefaultFromLibrary.YES: - assert self._check_dimensions( - self.kernel_size - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for kernel_size. You passed {self.kernel_size}." - if self.up_kernel_size != DefaultFromLibrary.YES: - assert self._check_dimensions( - self.up_kernel_size - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for up_kernel_size. You passed {self.up_kernel_size}." - return self - - def _check_dimensions( - self, - value: Union[float, Tuple[float, ...]], - ) -> bool: - """Checks if a tuple has the right dimension.""" - if isinstance(value, tuple): - return len(value) == self.dim - return True - - -class VaryingDepthNetworkConfig(NetworkConfig, ABC): - """ - Base config class to configure neural networks. - More precisely, we refer to MONAI's networks with 'channels' and 'strides' parameters. - """ - - channels: Tuple[PositiveInt, ...] - strides: Tuple[Union[PositiveInt, Tuple[PositiveInt, ...]], ...] - dropout: Union[ - Optional[NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - - @field_validator("dropout") - @classmethod - def validator_dropout(cls, v): - """Checks that dropout is between 0 and 1.""" - return cls.base_validator_dropout(v) - - @model_validator(mode="after") - def channels_strides_validator(self): - """Checks coherence between parameters.""" - n_layers = len(self.channels) - assert ( - len(self.strides) == n_layers - ), f"There are {n_layers} layers but you passed {len(self.strides)} strides." - for s in self.strides: - assert self._check_dimensions( - s - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for strides. You passed {s}." - - return self diff --git a/clinicadl/monai_networks/config/classifier.py b/clinicadl/monai_networks/config/classifier.py deleted file mode 100644 index a01bd0efc..000000000 --- a/clinicadl/monai_networks/config/classifier.py +++ /dev/null @@ -1,63 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, Optional, Tuple, Union - -from pydantic import PositiveInt, computed_field - -from clinicadl.utils.factories import DefaultFromLibrary - -from .regressor import RegressorConfig -from .utils.enum import ImplementedActFunctions, ImplementedNetworks - -__all__ = ["ClassifierConfig", "DiscriminatorConfig", "CriticConfig"] - - -class ClassifierConfig(RegressorConfig): - """Config class for classifiers.""" - - classes: PositiveInt - out_shape: Optional[Tuple[PositiveInt, ...]] = None - last_act: Optional[ - Union[ - ImplementedActFunctions, - Tuple[ImplementedActFunctions, Dict[str, Any]], - DefaultFromLibrary, - ] - ] = DefaultFromLibrary.YES - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.CLASSIFIER - - -class DiscriminatorConfig(ClassifierConfig): - """Config class for discriminators.""" - - classes: Optional[PositiveInt] = None - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.DISCRIMINATOR - - -class CriticConfig(ClassifierConfig): - """Config class for discriminators.""" - - classes: Optional[PositiveInt] = None - last_act: Optional[ - Union[ - ImplementedActFunctions, - Tuple[ImplementedActFunctions, Dict[str, Any]], - DefaultFromLibrary, - ] - ] = None - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.CRITIC diff --git a/clinicadl/monai_networks/config/densenet.py b/clinicadl/monai_networks/config/densenet.py deleted file mode 100644 index 796d82203..000000000 --- a/clinicadl/monai_networks/config/densenet.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import annotations - -from typing import Tuple, Union - -from pydantic import ( - NonNegativeFloat, - PositiveInt, - computed_field, - field_validator, -) - -from clinicadl.utils.factories import DefaultFromLibrary - -from .base import NetworkConfig -from .utils.enum import ImplementedNetworks - -__all__ = ["DenseNetConfig"] - - -class DenseNetConfig(NetworkConfig): - """Config class for DenseNet.""" - - spatial_dims: PositiveInt - in_channels: PositiveInt - out_channels: PositiveInt - init_features: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - growth_rate: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - block_config: Union[ - Tuple[PositiveInt, ...], DefaultFromLibrary - ] = DefaultFromLibrary.YES - bn_size: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - dropout_prob: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.DENSE_NET - - @computed_field - @property - def dim(self) -> int: - """Dimension of the images.""" - return self.spatial_dims - - @field_validator("dropout_prob") - @classmethod - def validator_dropout(cls, v): - """Checks that dropout is between 0 and 1.""" - return cls.base_validator_dropout(v) diff --git a/clinicadl/monai_networks/config/factory.py b/clinicadl/monai_networks/config/factory.py deleted file mode 100644 index 55e0fad39..000000000 --- a/clinicadl/monai_networks/config/factory.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Type, Union - -from .autoencoder import * -from .base import NetworkConfig -from .classifier import * -from .densenet import * -from .fcn import * -from .generator import * -from .regressor import * -from .resnet import * -from .unet import * -from .utils.enum import ImplementedNetworks -from .vit import * - - -def create_network_config( - network: Union[str, ImplementedNetworks], -) -> Type[NetworkConfig]: - """ - A factory function to create a config class suited for the network. - - Parameters - ---------- - network : Union[str, ImplementedNetworks] - The name of the neural network. - - Returns - ------- - Type[NetworkConfig] - The config class. - """ - network = ImplementedNetworks(network) - config_name = "".join([network, "Config"]) - config = globals()[config_name] - - return config diff --git a/clinicadl/monai_networks/config/fcn.py b/clinicadl/monai_networks/config/fcn.py deleted file mode 100644 index 3bb23d6cb..000000000 --- a/clinicadl/monai_networks/config/fcn.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from typing import Optional, Tuple, Union - -from pydantic import ( - NonNegativeFloat, - PositiveInt, - computed_field, - field_validator, -) - -from clinicadl.utils.factories import DefaultFromLibrary - -from .base import NetworkConfig -from .utils.enum import ImplementedNetworks - -__all__ = ["FullyConnectedNetConfig", "VarFullyConnectedNetConfig"] - - -class FullyConnectedNetConfig(NetworkConfig): - """Config class for fully connected networks.""" - - in_channels: PositiveInt - out_channels: PositiveInt - hidden_channels: Tuple[PositiveInt, ...] - - dropout: Union[ - Optional[NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.FCN - - @computed_field - @property - def dim(self) -> Optional[int]: - """Dimension of the images.""" - return None - - @field_validator("dropout") - @classmethod - def validator_dropout(cls, v): - """Checks that dropout is between 0 and 1.""" - return cls.base_validator_dropout(v) - - -class VarFullyConnectedNetConfig(NetworkConfig): - """Config class for fully connected networks.""" - - in_channels: PositiveInt - out_channels: PositiveInt - latent_size: PositiveInt - encode_channels: Tuple[PositiveInt, ...] - decode_channels: Tuple[PositiveInt, ...] - - dropout: Union[ - Optional[NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.VAR_FCN - - @computed_field - @property - def dim(self) -> Optional[int]: - """Dimension of the images.""" - return None - - @field_validator("dropout") - @classmethod - def validator_dropout(cls, v): - """Checks that dropout is between 0 and 1.""" - return cls.base_validator_dropout(v) diff --git a/clinicadl/monai_networks/config/generator.py b/clinicadl/monai_networks/config/generator.py deleted file mode 100644 index b864d371d..000000000 --- a/clinicadl/monai_networks/config/generator.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import annotations - -from typing import Tuple - -from pydantic import ( - PositiveInt, - computed_field, - field_validator, -) - -from .base import VaryingDepthNetworkConfig -from .utils.enum import ImplementedNetworks - -__all__ = ["GeneratorConfig"] - - -class GeneratorConfig(VaryingDepthNetworkConfig): - """Config class for generators.""" - - latent_shape: Tuple[PositiveInt, ...] - start_shape: Tuple[PositiveInt, ...] - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.GENERATOR - - @computed_field - @property - def dim(self) -> int: - """Dimension of the images.""" - return len(self.start_shape[1:]) - - @field_validator("start_shape") - def at_least_2d(cls, v, field): - """Checks that a tuple has at least a length of two.""" - return cls.base_at_least_2d(v, field) diff --git a/clinicadl/monai_networks/config/regressor.py b/clinicadl/monai_networks/config/regressor.py deleted file mode 100644 index 5410e31fa..000000000 --- a/clinicadl/monai_networks/config/regressor.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -from typing import Tuple - -from pydantic import PositiveInt, computed_field, field_validator - -from .base import VaryingDepthNetworkConfig -from .utils.enum import ImplementedNetworks - -__all__ = ["RegressorConfig"] - - -class RegressorConfig(VaryingDepthNetworkConfig): - """Config class for regressors.""" - - in_shape: Tuple[PositiveInt, ...] - out_shape: Tuple[PositiveInt, ...] - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.REGRESSOR - - @computed_field - @property - def dim(self) -> int: - """Dimension of the images.""" - return len(self.in_shape[1:]) - - @field_validator("in_shape") - def at_least_2d(cls, v, ctx): - """Checks that a tuple has at least a length of two.""" - return cls.base_at_least_2d(v, ctx) diff --git a/clinicadl/monai_networks/config/resnet.py b/clinicadl/monai_networks/config/resnet.py deleted file mode 100644 index 96bb6e193..000000000 --- a/clinicadl/monai_networks/config/resnet.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import annotations - -from enum import Enum -from typing import Optional, Tuple, Union - -from pydantic import ( - NonNegativeFloat, - PositiveFloat, - PositiveInt, - computed_field, - field_validator, - model_validator, -) - -from clinicadl.utils.factories import DefaultFromLibrary - -from .base import NetworkConfig -from .utils.enum import ( - ImplementedNetworks, - ResNetBlocks, - ResNets, - ShortcutTypes, - UpsampleModes, -) - -__all__ = ["ResNetConfig", "ResNetFeaturesConfig", "SegResNetConfig"] - - -class ResNetConfig(NetworkConfig): - """Config class for ResNet.""" - - block: ResNetBlocks - layers: Tuple[PositiveInt, PositiveInt, PositiveInt, PositiveInt] - block_inplanes: Tuple[PositiveInt, PositiveInt, PositiveInt, PositiveInt] - - spatial_dims: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - n_input_channels: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - conv1_t_size: Union[ - PositiveInt, Tuple[PositiveInt, ...], DefaultFromLibrary - ] = DefaultFromLibrary.YES - conv1_t_stride: Union[ - PositiveInt, Tuple[PositiveInt, ...], DefaultFromLibrary - ] = DefaultFromLibrary.YES - no_max_pool: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - shortcut_type: Union[ShortcutTypes, DefaultFromLibrary] = DefaultFromLibrary.YES - widen_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - num_classes: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - feed_forward: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - bias_downsample: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.RES_NET - - @computed_field - @property - def dim(self) -> int: - """Dimension of the images.""" - return self.spatial_dims if self.spatial_dims != DefaultFromLibrary.YES else 3 - - @model_validator(mode="after") - def model_validator(self): - """Checks coherence between parameters.""" - if self.conv1_t_size != DefaultFromLibrary.YES: - assert self._check_dimensions( - self.conv1_t_size - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for conv1_t_size. You passed {self.conv1_t_size}." - if self.conv1_t_stride != DefaultFromLibrary.YES: - assert self._check_dimensions( - self.conv1_t_stride - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for conv1_t_stride. You passed {self.conv1_t_stride}." - - return self - - -class ResNetFeaturesConfig(NetworkConfig): - """Config class for ResNet backbones.""" - - model_name: ResNets - - pretrained: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - spatial_dims: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - in_channels: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.RES_NET_FEATURES - - @computed_field - @property - def dim(self) -> int: - """Dimension of the images.""" - return self.spatial_dims if self.spatial_dims != DefaultFromLibrary.YES else 3 - - @model_validator(mode="after") - def model_validator(self): - """Checks coherence between parameters.""" - if self.pretrained == DefaultFromLibrary.YES or self.pretrained: - assert ( - self.spatial_dims == DefaultFromLibrary.YES or self.spatial_dims == 3 - ), "Pretrained weights are only available with spatial_dims=3. Otherwise, set pretrained to False." - assert ( - self.in_channels == DefaultFromLibrary.YES or self.in_channels == 1 - ), "Pretrained weights are only available with in_channels=1. Otherwise, set pretrained to False." - - return self - - -class SegResNetConfig(NetworkConfig): - """Config class for SegResNet.""" - - spatial_dims: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - init_filters: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - in_channels: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - out_channels: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - dropout_prob: Union[ - Optional[NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - use_conv_final: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - blocks_down: Union[ - Tuple[PositiveInt, ...], DefaultFromLibrary - ] = DefaultFromLibrary.YES - blocks_up: Union[ - Tuple[PositiveInt, ...], DefaultFromLibrary - ] = DefaultFromLibrary.YES - upsample_mode: Union[UpsampleModes, DefaultFromLibrary] = DefaultFromLibrary.YES - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.SEG_RES_NET - - @computed_field - @property - def dim(self) -> int: - """Dimension of the images.""" - return self.spatial_dims if self.spatial_dims != DefaultFromLibrary.YES else 3 - - @field_validator("dropout_prob") - @classmethod - def validator_dropout(cls, v): - """Checks that dropout is between 0 and 1.""" - return cls.base_validator_dropout(v) diff --git a/clinicadl/monai_networks/config/unet.py b/clinicadl/monai_networks/config/unet.py deleted file mode 100644 index e7fd3498b..000000000 --- a/clinicadl/monai_networks/config/unet.py +++ /dev/null @@ -1,64 +0,0 @@ -from __future__ import annotations - -from typing import Union - -from pydantic import ( - PositiveInt, - computed_field, - model_validator, -) - -from clinicadl.utils.factories import DefaultFromLibrary - -from .base import VaryingDepthNetworkConfig -from .utils.enum import ImplementedNetworks - -__all__ = ["UNetConfig", "AttentionUnetConfig"] - - -class UNetConfig(VaryingDepthNetworkConfig): - """Config class for UNet.""" - - spatial_dims: PositiveInt - in_channels: PositiveInt - out_channels: PositiveInt - adn_ordering: Union[str, DefaultFromLibrary] = DefaultFromLibrary.YES - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.UNET - - @computed_field - @property - def dim(self) -> int: - """Dimension of the images.""" - return self.spatial_dims - - @model_validator(mode="after") - def channels_strides_validator(self): - """Checks coherence between parameters.""" - n_layers = len(self.channels) - assert ( - n_layers >= 2 - ), f"Channels must be at least of length 2. You passed {self.channels}." - assert ( - len(self.strides) == n_layers - 1 - ), f"Length of strides must be equal to len(channels)-1. You passed channels={self.channels} and strides={self.strides}." - for s in self.strides: - assert self._check_dimensions( - s - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for strides. You passed {s}." - - return self - - -class AttentionUnetConfig(UNetConfig): - """Config class for Attention UNet.""" - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.ATT_UNET diff --git a/clinicadl/monai_networks/config/utils/enum.py b/clinicadl/monai_networks/config/utils/enum.py deleted file mode 100644 index 941e34972..000000000 --- a/clinicadl/monai_networks/config/utils/enum.py +++ /dev/null @@ -1,129 +0,0 @@ -from enum import Enum - - -class ImplementedNetworks(str, Enum): - """Implemented neural networks in ClinicaDL.""" - - REGRESSOR = "Regressor" - CLASSIFIER = "Classifier" - DISCRIMINATOR = "Discriminator" - CRITIC = "Critic" - AE = "AutoEncoder" - VAE = "VarAutoEncoder" - DENSE_NET = "DenseNet" - FCN = "FullyConnectedNet" - VAR_FCN = "VarFullyConnectedNet" - GENERATOR = "Generator" - RES_NET = "ResNet" - RES_NET_FEATURES = "ResNetFeatures" - SEG_RES_NET = "SegResNet" - UNET = "UNet" - ATT_UNET = "AttentionUnet" - VIT = "ViT" - VIT_AE = "ViTAutoEnc" - - @classmethod - def _missing_(cls, value): - raise ValueError( - f"{value} is not implemented. Implemented neural networks are: " - + ", ".join([repr(m.value) for m in cls]) - ) - - -class CaseInsensitiveEnum(str, Enum): - @classmethod - def _missing_(cls, value): - if isinstance(value, str): - value = value.lower() - for member in cls: - if member.lower() == value: - return member - return None - - -class ImplementedActFunctions(CaseInsensitiveEnum): - """Supported activation functions in ClinicaDL.""" - - ELU = "elu" - RELU = "relu" - LEAKY_RELU = "leakyrelu" - PRELU = "prelu" - RELU6 = "relu6" - SELU = "selu" - CELU = "celu" - GELU = "gelu" - SIGMOID = "sigmoid" - TANH = "tanh" - SOFTMAX = "softmax" - LOGSOFTMAX = "logsoftmax" - SWISH = "swish" - MEMSWISH = "memswish" - MISH = "mish" - GEGLU = "geglu" - - -class ImplementedNormLayers(CaseInsensitiveEnum): - """Supported normalization layers in ClinicaDL.""" - - GROUP = "group" - LAYER = "layer" - LOCAL_RESPONSE = "localresponse" - SYNCBATCH = "syncbatch" - INSTANCE_NVFUSER = "instance_nvfuser" - BATCH = "batch" - INSTANCE = "instance" - - -class ResNetBlocks(str, Enum): - """Supported ResNet blocks.""" - - BASIC = "basic" - BOTTLENECK = "bottleneck" - - -class ShortcutTypes(str, Enum): - """Supported shortcut types for ResNets.""" - - A = "A" - B = "B" - - -class ResNets(str, Enum): - """Supported ResNet networks.""" - - RESNET_10 = "resnet10" - RESNET_18 = "resnet18" - RESNET_34 = "resnet34" - RESNET_50 = "resnet50" - RESNET_101 = "resnet101" - RESNET_152 = "resnet152" - RESNET_200 = "resnet200" - - -class UpsampleModes(str, Enum): - """Supported upsampling modes for ResNets.""" - - DECONV = "deconv" - NON_TRAINABLE = "nontrainable" - PIXEL_SHUFFLE = "pixelshuffle" - - -class PatchEmbeddingTypes(str, Enum): - """Supported patch embedding types for VITs.""" - - CONV = "conv" - PERCEPTRON = "perceptron" - - -class PosEmbeddingTypes(str, Enum): - """Supported positional embedding types for VITs.""" - - NONE = "none" - LEARNABLE = "learnable" - SINCOS = "sincos" - - -class ClassificationActivation(str, Enum): - """Supported activation layer for classification in ViT.""" - - TANH = "Tanh" diff --git a/clinicadl/monai_networks/config/vit.py b/clinicadl/monai_networks/config/vit.py deleted file mode 100644 index 206d0d881..000000000 --- a/clinicadl/monai_networks/config/vit.py +++ /dev/null @@ -1,154 +0,0 @@ -from enum import Enum -from typing import Optional, Tuple, Union - -from pydantic import ( - NonNegativeFloat, - PositiveInt, - computed_field, - field_validator, - model_validator, -) - -from clinicadl.utils.factories import DefaultFromLibrary - -from .base import NetworkConfig -from .utils.enum import ( - ClassificationActivation, - ImplementedNetworks, - PatchEmbeddingTypes, - PosEmbeddingTypes, -) - -__all__ = ["ViTConfig", "ViTAutoEncConfig"] - - -class ViTConfig(NetworkConfig): - """Config class for ViT networks.""" - - in_channels: PositiveInt - img_size: Union[PositiveInt, Tuple[PositiveInt, ...]] - patch_size: Union[PositiveInt, Tuple[PositiveInt, ...]] - - hidden_size: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - mlp_dim: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - num_layers: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - num_heads: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - proj_type: Union[PatchEmbeddingTypes, DefaultFromLibrary] = DefaultFromLibrary.YES - pos_embed_type: Union[ - PosEmbeddingTypes, DefaultFromLibrary - ] = DefaultFromLibrary.YES - classification: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - num_classes: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - dropout_rate: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - spatial_dims: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - post_activation: Union[ - Optional[ClassificationActivation], DefaultFromLibrary - ] = DefaultFromLibrary.YES - qkv_bias: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - save_attn: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.VIT - - @computed_field - @property - def dim(self) -> int: - """Dimension of the images.""" - return self.spatial_dims if self.spatial_dims != DefaultFromLibrary.YES else 3 - - @field_validator("dropout_rate") - def validator_dropout(cls, v): - """Checks that dropout is between 0 and 1.""" - return cls.base_validator_dropout(v) - - @model_validator(mode="before") - def check_einops(self): - """Checks if the library einops is installed.""" - from importlib import util - - spec = util.find_spec("einops") - if spec is None: - raise ModuleNotFoundError("einops is not installed") - return self - - @model_validator(mode="after") - def model_validator(self): - """Checks coherence between parameters.""" - assert self._check_dimensions( - self.img_size - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for img_size. You passed {self.img_size}." - assert self._check_dimensions( - self.patch_size - ), f"You must passed an int or a sequence of {self.dim} ints (the dimensionality of your images) for patch_size. You passed {self.patch_size}." - - if ( - self.hidden_size != DefaultFromLibrary.YES - and self.num_heads != DefaultFromLibrary.YES - ): - assert self._divide( - self.hidden_size, self.num_heads - ), f"hidden_size must be divisible by num_heads. You passed hidden_size={self.hidden_size} and num_heads={self.num_heads}." - elif ( - self.hidden_size != DefaultFromLibrary.YES - and self.num_heads == DefaultFromLibrary.YES - ): - raise ValueError("If you pass hidden_size, please also pass num_heads.") - elif ( - self.hidden_size == DefaultFromLibrary.YES - and self.num_heads != DefaultFromLibrary.YES - ): - raise ValueError("If you pass num_head, please also pass hidden_size.") - - return self - - def _divide( - self, - numerator: Union[int, Tuple[int, ...]], - denominator: Union[int, Tuple[int, ...]], - ) -> bool: - print(self.dim) - """Checks if numerator is divisible by denominator.""" - if isinstance(numerator, int): - numerator = (numerator,) * self.dim - if isinstance(denominator, int): - denominator = (denominator,) * self.dim - for n, d in zip(numerator, denominator): - if n % d != 0: - return False - return True - - -class ViTAutoEncConfig(ViTConfig): - """Config class for ViT autoencoders.""" - - out_channels: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - deconv_chns: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - - @computed_field - @property - def network(self) -> ImplementedNetworks: - """The name of the network.""" - return ImplementedNetworks.VIT_AE - - @model_validator(mode="after") - def model_validator_bis(self): - """Checks coherence between parameters.""" - assert self._divide( - self.img_size, self.patch_size - ), f"img_size must be divisible by patch_size. You passed hidden_size={self.img_size} and num_heads={self.patch_size}." - assert self._is_sqrt( - self.patch_size - ), f"patch_size must be square number(s). You passed {self.patch_size}." - - return self - - def _is_sqrt(self, value: Union[int, Tuple[int, ...]]) -> bool: - """Checks if value is a square number.""" - import math - - if isinstance(value, int): - value = (value,) * self.dim - return all([int(math.sqrt(v)) == math.sqrt(v) for v in value]) diff --git a/clinicadl/monai_networks/factory.py b/clinicadl/monai_networks/factory.py deleted file mode 100644 index 1e509f3d1..000000000 --- a/clinicadl/monai_networks/factory.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Tuple - -import monai.networks.nets as networks -import torch.nn as nn - -from clinicadl.utils.factories import DefaultFromLibrary, get_args_and_defaults - -from .config.base import NetworkConfig - - -def get_network(config: NetworkConfig) -> Tuple[nn.Module, NetworkConfig]: - """ - Factory function to get a Neural Network from MONAI. - - Parameters - ---------- - config : NetworkConfig - The config class with the parameters of the network. - - Returns - ------- - nn.Module - The neural network. - NetworkConfig - The updated config class: the arguments set to default will be updated - with their effective values (the default values from the library). - Useful for reproducibility. - """ - network_class = getattr(networks, config.network) - expected_args, config_dict = get_args_and_defaults(network_class.__init__) - for arg, value in config.model_dump().items(): - if arg in expected_args and value != DefaultFromLibrary.YES: - config_dict[arg] = value - - network = network_class(**config_dict) - updated_config = config.model_copy(update=config_dict) - - return network, updated_config diff --git a/clinicadl/networks/__init__.py b/clinicadl/networks/__init__.py new file mode 100644 index 000000000..ea44f7516 --- /dev/null +++ b/clinicadl/networks/__init__.py @@ -0,0 +1,2 @@ +from .config import ImplementedNetworks, NetworkConfig +from .factory import get_network, get_network_from_config diff --git a/clinicadl/networks/config/__init__.py b/clinicadl/networks/config/__init__.py new file mode 100644 index 000000000..1c39fa4fa --- /dev/null +++ b/clinicadl/networks/config/__init__.py @@ -0,0 +1,2 @@ +from .base import ImplementedNetworks, NetworkConfig, NetworkType +from .factory import create_network_config diff --git a/clinicadl/networks/config/autoencoder.py b/clinicadl/networks/config/autoencoder.py new file mode 100644 index 000000000..50ef4bcca --- /dev/null +++ b/clinicadl/networks/config/autoencoder.py @@ -0,0 +1,45 @@ +from typing import Optional, Sequence, Union + +from pydantic import PositiveInt, computed_field + +from clinicadl.networks.nn.layers.utils import ( + ActivationParameters, + UnpoolingMode, +) +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import ImplementedNetworks, NetworkConfig +from .conv_encoder import ConvEncoderOptions +from .mlp import MLPOptions + + +class AutoEncoderConfig(NetworkConfig): + """Config class for AutoEncoder.""" + + in_shape: Sequence[PositiveInt] + latent_size: PositiveInt + conv_args: ConvEncoderOptions + mlp_args: Union[Optional[MLPOptions], DefaultFromLibrary] = DefaultFromLibrary.YES + out_channels: Union[ + Optional[PositiveInt], DefaultFromLibrary + ] = DefaultFromLibrary.YES + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + unpooling_mode: Union[UnpoolingMode, DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.AE + + +class VAEConfig(AutoEncoderConfig): + """Config class for Variational AutoEncoder.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.VAE diff --git a/clinicadl/networks/config/base.py b/clinicadl/networks/config/base.py new file mode 100644 index 000000000..6d61d16fd --- /dev/null +++ b/clinicadl/networks/config/base.py @@ -0,0 +1,98 @@ +from abc import ABC, abstractmethod +from enum import Enum +from typing import Optional, Union + +from pydantic import BaseModel, ConfigDict, PositiveInt, computed_field + +from clinicadl.networks.nn.layers.utils import ActivationParameters +from clinicadl.utils.factories import DefaultFromLibrary + + +class ImplementedNetworks(str, Enum): + """Implemented neural networks in ClinicaDL.""" + + MLP = "MLP" + CONV_ENCODER = "ConvEncoder" + CONV_DECODER = "ConvDecoder" + CNN = "CNN" + GENERATOR = "Generator" + AE = "AutoEncoder" + VAE = "VAE" + DENSENET = "DenseNet" + DENSENET_121 = "DenseNet-121" + DENSENET_161 = "DenseNet-161" + DENSENET_169 = "DenseNet-169" + DENSENET_201 = "DenseNet-201" + RESNET = "ResNet" + RESNET_18 = "ResNet-18" + RESNET_34 = "ResNet-34" + RESNET_50 = "ResNet-50" + RESNET_101 = "ResNet-101" + RESNET_152 = "ResNet-152" + SE_RESNET = "SEResNet" + SE_RESNET_50 = "SEResNet-50" + SE_RESNET_101 = "SEResNet-101" + SE_RESNET_152 = "SEResNet-152" + UNET = "UNet" + ATT_UNET = "AttentionUNet" + VIT = "ViT" + VIT_B_16 = "ViT-B/16" + VIT_B_32 = "ViT-B/32" + VIT_L_16 = "ViT-L/16" + VIT_L_32 = "ViT-L/32" + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not implemented. Implemented neural networks are: " + + ", ".join([repr(m.value) for m in cls]) + ) + + +class NetworkType(str, Enum): + """ + Useful to know where to look for the network. + See :py:func:`clinicadl.monai_networks.factory.get_network` + """ + + CUSTOM = "custom" # our own networks + RESNET = "sota-ResNet" + DENSENET = "sota-DenseNet" + SE_RESNET = "sota-SEResNet" + VIT = "sota-ViT" + + +class NetworkConfig(BaseModel, ABC): + """Base config class to configure neural networks.""" + + # pydantic config + model_config = ConfigDict( + validate_assignment=True, + use_enum_values=True, + validate_default=True, + ) + + @computed_field + @property + @abstractmethod + def name(self) -> ImplementedNetworks: + """The name of the network.""" + + @computed_field + @property + def _type(self) -> NetworkType: + """ + To know where to look for the network. + Default to 'custom'. + """ + return NetworkType.CUSTOM + + +class PreTrainedConfig(NetworkConfig): + """Base config class for SOTA networks.""" + + num_outputs: Optional[PositiveInt] + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + pretrained: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES diff --git a/clinicadl/networks/config/cnn.py b/clinicadl/networks/config/cnn.py new file mode 100644 index 000000000..a7d2043db --- /dev/null +++ b/clinicadl/networks/config/cnn.py @@ -0,0 +1,24 @@ +from typing import Optional, Sequence, Union + +from pydantic import PositiveInt, computed_field + +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import ImplementedNetworks, NetworkConfig +from .conv_encoder import ConvEncoderOptions +from .mlp import MLPOptions + + +class CNNConfig(NetworkConfig): + """Config class for CNN.""" + + in_shape: Sequence[PositiveInt] + num_outputs: PositiveInt + conv_args: ConvEncoderOptions + mlp_args: Union[Optional[MLPOptions], DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.CNN diff --git a/clinicadl/networks/config/conv_decoder.py b/clinicadl/networks/config/conv_decoder.py new file mode 100644 index 000000000..91547e052 --- /dev/null +++ b/clinicadl/networks/config/conv_decoder.py @@ -0,0 +1,65 @@ +from typing import Optional, Sequence, Union + +from pydantic import BaseModel, ConfigDict, PositiveFloat, PositiveInt, computed_field + +from clinicadl.networks.nn.layers.utils import ( + ActivationParameters, + ConvNormalizationParameters, + ConvParameters, + UnpoolingParameters, +) +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import ImplementedNetworks, NetworkConfig + + +class ConvDecoderOptions(BaseModel): + """ + Config class for ConvDecoder when it is a submodule. + See for example: :py:class:`clinicadl.monai_networks.nn.generator.Generator` + """ + + channels: Sequence[PositiveInt] + kernel_size: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + stride: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + padding: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + output_padding: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + dilation: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + unpooling: Union[ + Optional[UnpoolingParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + unpooling_indices: Union[ + Optional[Sequence[int]], DefaultFromLibrary + ] = DefaultFromLibrary.YES + act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + norm: Union[ + Optional[ConvNormalizationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + dropout: Union[Optional[PositiveFloat], DefaultFromLibrary] = DefaultFromLibrary.YES + bias: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES + adn_ordering: Union[str, DefaultFromLibrary] = DefaultFromLibrary.YES + + # pydantic config + model_config = ConfigDict( + validate_assignment=True, + use_enum_values=True, + validate_default=True, + ) + + +class ConvDecoderConfig(NetworkConfig, ConvDecoderOptions): + """Config class for ConvDecoder.""" + + spatial_dims: PositiveInt + in_channels: PositiveInt + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.CONV_DECODER diff --git a/clinicadl/networks/config/conv_encoder.py b/clinicadl/networks/config/conv_encoder.py new file mode 100644 index 000000000..1bddbc947 --- /dev/null +++ b/clinicadl/networks/config/conv_encoder.py @@ -0,0 +1,64 @@ +from typing import Optional, Sequence, Union + +from pydantic import BaseModel, ConfigDict, PositiveFloat, PositiveInt, computed_field + +from clinicadl.networks.nn.layers.utils import ( + ActivationParameters, + ConvNormalizationParameters, + ConvParameters, + PoolingParameters, +) +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import ImplementedNetworks, NetworkConfig + + +class ConvEncoderOptions(BaseModel): + """ + Config class for ConvEncoder when it is a submodule. + See for example: :py:class:`clinicadl.monai_networks.nn.cnn.CNN` + """ + + channels: Sequence[PositiveInt] + kernel_size: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + stride: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + padding: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + dilation: Union[ConvParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + pooling: Union[ + Optional[PoolingParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + pooling_indices: Union[ + Optional[Sequence[int]], DefaultFromLibrary + ] = DefaultFromLibrary.YES + act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + norm: Union[ + Optional[ConvNormalizationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + dropout: Union[Optional[PositiveFloat], DefaultFromLibrary] = DefaultFromLibrary.YES + bias: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES + adn_ordering: Union[str, DefaultFromLibrary] = DefaultFromLibrary.YES + + # pydantic config + model_config = ConfigDict( + validate_assignment=True, + use_enum_values=True, + validate_default=True, + ) + + +class ConvEncoderConfig(NetworkConfig, ConvEncoderOptions): + """Config class for ConvEncoder.""" + + spatial_dims: PositiveInt + in_channels: PositiveInt + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.CONV_ENCODER diff --git a/clinicadl/networks/config/densenet.py b/clinicadl/networks/config/densenet.py new file mode 100644 index 000000000..022f26cca --- /dev/null +++ b/clinicadl/networks/config/densenet.py @@ -0,0 +1,83 @@ +from typing import Optional, Sequence, Union + +from pydantic import PositiveFloat, PositiveInt, computed_field + +from clinicadl.networks.nn.layers.utils import ActivationParameters +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import ImplementedNetworks, NetworkConfig, NetworkType, PreTrainedConfig + + +class DenseNetConfig(NetworkConfig): + """Config class for DenseNet.""" + + spatial_dims: PositiveInt + in_channels: PositiveInt + num_outputs: Optional[PositiveInt] + n_dense_layers: Union[ + Sequence[PositiveInt], DefaultFromLibrary + ] = DefaultFromLibrary.YES + init_features: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES + growth_rate: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES + bottleneck_factor: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES + act: Union[ActivationParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + dropout: Union[Optional[PositiveFloat], DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.DENSENET + + +class PreTrainedDenseNetConfig(PreTrainedConfig): + """Base config class for SOTA DenseNets.""" + + @computed_field + @property + def _type(self) -> NetworkType: + """To know where to look for the network.""" + return NetworkType.DENSENET + + +class DenseNet121Config(PreTrainedDenseNetConfig): + """Config class for DenseNet-121.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.DENSENET_121 + + +class DenseNet161Config(PreTrainedDenseNetConfig): + """Config class for DenseNet-161.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.DENSENET_161 + + +class DenseNet169Config(PreTrainedDenseNetConfig): + """Config class for DenseNet-169.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.DENSENET_169 + + +class DenseNet201Config(PreTrainedDenseNetConfig): + """Config class for DenseNet-201.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.DENSENET_201 diff --git a/clinicadl/networks/config/factory.py b/clinicadl/networks/config/factory.py new file mode 100644 index 000000000..2b7e5bdc1 --- /dev/null +++ b/clinicadl/networks/config/factory.py @@ -0,0 +1,56 @@ +from typing import Type, Union + +# pylint: disable=unused-import +from .autoencoder import AutoEncoderConfig, VAEConfig +from .base import ImplementedNetworks, NetworkConfig +from .cnn import CNNConfig +from .conv_decoder import ConvDecoderConfig +from .conv_encoder import ConvEncoderConfig +from .densenet import ( + DenseNet121Config, + DenseNet161Config, + DenseNet169Config, + DenseNet201Config, + DenseNetConfig, +) +from .generator import GeneratorConfig +from .mlp import MLPConfig +from .resnet import ( + ResNet18Config, + ResNet34Config, + ResNet50Config, + ResNet101Config, + ResNet152Config, + ResNetConfig, +) +from .senet import ( + SEResNet50Config, + SEResNet101Config, + SEResNet152Config, + SEResNetConfig, +) +from .unet import AttentionUNetConfig, UNetConfig +from .vit import ViTB16Config, ViTB32Config, ViTConfig, ViTL16Config, ViTL32Config + + +def create_network_config( + network: Union[str, ImplementedNetworks], +) -> Type[NetworkConfig]: + """ + A factory function to create a config class suited for the network. + + Parameters + ---------- + network : Union[str, ImplementedNetworks] + The name of the neural network. + + Returns + ------- + Type[NetworkConfig] + The config class. + """ + network = ImplementedNetworks(network).value.replace("-", "").replace("/", "") + config_name = "".join([network, "Config"]) + config = globals()[config_name] + + return config diff --git a/clinicadl/networks/config/generator.py b/clinicadl/networks/config/generator.py new file mode 100644 index 000000000..6c7836474 --- /dev/null +++ b/clinicadl/networks/config/generator.py @@ -0,0 +1,24 @@ +from typing import Optional, Sequence, Union + +from pydantic import PositiveInt, computed_field + +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import ImplementedNetworks, NetworkConfig +from .conv_decoder import ConvDecoderOptions +from .mlp import MLPOptions + + +class GeneratorConfig(NetworkConfig): + """Config class for Generator.""" + + latent_size: PositiveInt + start_shape: Sequence[PositiveInt] + conv_args: ConvDecoderOptions + mlp_args: Union[Optional[MLPOptions], DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.GENERATOR diff --git a/clinicadl/networks/config/mlp.py b/clinicadl/networks/config/mlp.py new file mode 100644 index 000000000..2f72eda88 --- /dev/null +++ b/clinicadl/networks/config/mlp.py @@ -0,0 +1,52 @@ +from typing import Optional, Sequence, Union + +from pydantic import BaseModel, ConfigDict, PositiveFloat, PositiveInt, computed_field + +from clinicadl.networks.nn.layers.utils import ( + ActivationParameters, + NormalizationParameters, +) +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import ImplementedNetworks, NetworkConfig + + +class MLPOptions(BaseModel): + """ + Config class for MLP when it is a submodule. + See for example: :py:class:`clinicadl.monai_networks.nn.cnn.CNN` + """ + + hidden_channels: Sequence[PositiveInt] + act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + norm: Union[ + Optional[NormalizationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + dropout: Union[Optional[PositiveFloat], DefaultFromLibrary] = DefaultFromLibrary.YES + bias: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES + adn_ordering: Union[str, DefaultFromLibrary] = DefaultFromLibrary.YES + + # pydantic config + model_config = ConfigDict( + validate_assignment=True, + use_enum_values=True, + validate_default=True, + ) + + +class MLPConfig(NetworkConfig, MLPOptions): + """Config class for Multi Layer Perceptron.""" + + in_channels: PositiveInt + out_channels: PositiveInt + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.MLP diff --git a/clinicadl/networks/config/resnet.py b/clinicadl/networks/config/resnet.py new file mode 100644 index 000000000..ddc53a125 --- /dev/null +++ b/clinicadl/networks/config/resnet.py @@ -0,0 +1,103 @@ +from typing import Optional, Sequence, Union + +from pydantic import PositiveInt, computed_field + +from clinicadl.networks.nn.layers.utils import ActivationParameters +from clinicadl.networks.nn.resnet import ResNetBlockType +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import ImplementedNetworks, NetworkConfig, NetworkType, PreTrainedConfig + + +class ResNetConfig(NetworkConfig): + """Config class for ResNet.""" + + spatial_dims: PositiveInt + in_channels: PositiveInt + num_outputs: Optional[PositiveInt] + block_type: Union[str, ResNetBlockType, DefaultFromLibrary] = DefaultFromLibrary.YES + n_res_blocks: Union[ + Sequence[PositiveInt], DefaultFromLibrary + ] = DefaultFromLibrary.YES + n_features: Union[ + Sequence[PositiveInt], DefaultFromLibrary + ] = DefaultFromLibrary.YES + init_conv_size: Union[ + Sequence[PositiveInt], PositiveInt, DefaultFromLibrary + ] = DefaultFromLibrary.YES + init_conv_stride: Union[ + Sequence[PositiveInt], PositiveInt, DefaultFromLibrary + ] = DefaultFromLibrary.YES + bottleneck_reduction: Union[ + PositiveInt, DefaultFromLibrary + ] = DefaultFromLibrary.YES + act: Union[ActivationParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.RESNET + + +class PreTrainedResNetConfig(PreTrainedConfig): + """Base config class for SOTA ResNets.""" + + @computed_field + @property + def _type(self) -> NetworkType: + """To know where to look for the network.""" + return NetworkType.RESNET + + +class ResNet18Config(PreTrainedResNetConfig): + """Config class for ResNet-18.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.RESNET_18 + + +class ResNet34Config(PreTrainedResNetConfig): + """Config class for ResNet-34.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.RESNET_34 + + +class ResNet50Config(PreTrainedResNetConfig): + """Config class for ResNet-50.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.RESNET_50 + + +class ResNet101Config(PreTrainedResNetConfig): + """Config class for ResNet-101.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.RESNET_101 + + +class ResNet152Config(PreTrainedResNetConfig): + """Config class for ResNet-152.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.RESNET_152 diff --git a/clinicadl/networks/config/senet.py b/clinicadl/networks/config/senet.py new file mode 100644 index 000000000..79a356726 --- /dev/null +++ b/clinicadl/networks/config/senet.py @@ -0,0 +1,60 @@ +from typing import Union + +from pydantic import PositiveInt, computed_field + +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import ImplementedNetworks, NetworkType, PreTrainedConfig +from .resnet import ResNetConfig + + +class SEResNetConfig(ResNetConfig): + """Config class for Squeeze-and-Excitation ResNet.""" + + se_reduction: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.SE_RESNET + + +class PreTrainedSEResNetConfig(PreTrainedConfig): + """Base config class for SOTA SE-ResNets.""" + + @computed_field + @property + def _type(self) -> NetworkType: + """To know where to look for the network.""" + return NetworkType.SE_RESNET + + +class SEResNet50Config(PreTrainedSEResNetConfig): + """Config class for SE-ResNet-50.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.SE_RESNET_50 + + +class SEResNet101Config(PreTrainedSEResNetConfig): + """Config class for SE-ResNet-101.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.SE_RESNET_101 + + +class SEResNet152Config(PreTrainedSEResNetConfig): + """Config class for SE-ResNet-152.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.SE_RESNET_152 diff --git a/clinicadl/networks/config/unet.py b/clinicadl/networks/config/unet.py new file mode 100644 index 000000000..b1faf542e --- /dev/null +++ b/clinicadl/networks/config/unet.py @@ -0,0 +1,38 @@ +from typing import Optional, Sequence, Union + +from pydantic import PositiveFloat, PositiveInt, computed_field + +from clinicadl.networks.nn.layers.utils import ActivationParameters +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import ImplementedNetworks, NetworkConfig + + +class UNetConfig(NetworkConfig): + """Config class for UNet.""" + + spatial_dims: PositiveInt + in_channels: PositiveInt + out_channels: PositiveInt + channels: Union[Sequence[PositiveInt], DefaultFromLibrary] = DefaultFromLibrary.YES + act: Union[ActivationParameters, DefaultFromLibrary] = DefaultFromLibrary.YES + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + dropout: Union[Optional[PositiveFloat], DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.UNET + + +class AttentionUNetConfig(UNetConfig): + """Config class for AttentionUNet.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.ATT_UNET diff --git a/clinicadl/networks/config/vit.py b/clinicadl/networks/config/vit.py new file mode 100644 index 000000000..ea4103f5d --- /dev/null +++ b/clinicadl/networks/config/vit.py @@ -0,0 +1,84 @@ +from typing import Optional, Sequence, Union + +from pydantic import PositiveFloat, PositiveInt, computed_field + +from clinicadl.networks.nn.layers.utils import ActivationParameters +from clinicadl.networks.nn.vit import PosEmbedType +from clinicadl.utils.factories import DefaultFromLibrary + +from .base import ImplementedNetworks, NetworkConfig, NetworkType, PreTrainedConfig + + +class ViTConfig(NetworkConfig): + """Config class for ViT networks.""" + + in_shape: Sequence[PositiveInt] + patch_size: Union[Sequence[PositiveInt], PositiveInt] + num_outputs: Optional[PositiveInt] + embedding_dim: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES + num_layers: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES + num_heads: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES + mlp_dim: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES + pos_embed_type: Union[ + Optional[Union[str, PosEmbedType]], DefaultFromLibrary + ] = DefaultFromLibrary.YES + output_act: Union[ + Optional[ActivationParameters], DefaultFromLibrary + ] = DefaultFromLibrary.YES + dropout: Union[Optional[PositiveFloat], DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.VIT + + +class PreTrainedViTConfig(PreTrainedConfig): + """Base config class for SOTA ResNets.""" + + @computed_field + @property + def _type(self) -> NetworkType: + """To know where to look for the network.""" + return NetworkType.VIT + + +class ViTB16Config(PreTrainedViTConfig): + """Config class for ViT-B/16.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.VIT_B_16 + + +class ViTB32Config(PreTrainedViTConfig): + """Config class for ViT-B/32.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.VIT_B_32 + + +class ViTL16Config(PreTrainedViTConfig): + """Config class for ViT-L/16.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.VIT_L_16 + + +class ViTL32Config(PreTrainedViTConfig): + """Config class for ViT-L/32.""" + + @computed_field + @property + def name(self) -> ImplementedNetworks: + """The name of the network.""" + return ImplementedNetworks.VIT_L_32 diff --git a/clinicadl/networks/factory.py b/clinicadl/networks/factory.py new file mode 100644 index 000000000..a1822af52 --- /dev/null +++ b/clinicadl/networks/factory.py @@ -0,0 +1,123 @@ +from copy import deepcopy +from typing import Any, Callable, Tuple, Union + +import torch.nn as nn +from pydantic import BaseModel + +import clinicadl.networks.nn as nets +from clinicadl.utils.factories import DefaultFromLibrary, get_args_and_defaults + +from .config import ( + ImplementedNetworks, + NetworkConfig, + NetworkType, + create_network_config, +) +from .config.conv_decoder import ConvDecoderOptions +from .config.conv_encoder import ConvEncoderOptions +from .config.mlp import MLPOptions +from .nn import MLP, ConvDecoder, ConvEncoder + + +def get_network( + name: Union[str, ImplementedNetworks], return_config: bool = False, **kwargs: Any +) -> Union[nn.Module, Tuple[nn.Module, NetworkConfig]]: + """ + Factory function to get a neural network from its name and parameters. + + Parameters + ---------- + name : Union[str, ImplementedNetworks] + the name of the neural network. Check our documentation to know + available networks. + return_config : bool (optional, default=False) + if the function should return the config class regrouping the parameters of the + neural network. Useful to keep track of the hyperparameters. + kwargs : Any + the parameters of the neural network. Check our documentation on networks to + know these parameters. + + Returns + ------- + nnn.Module + the neural network. + NetworkConfig + the associated config class. Only returned if `return_config` is True. + """ + config = create_network_config(name)(**kwargs) + network, updated_config = get_network_from_config(config) + + return network if not return_config else (network, updated_config) + + +def get_network_from_config(config: NetworkConfig) -> Tuple[nn.Module, NetworkConfig]: + """ + Factory function to get a neural network from a NetworkConfig instance. + + Parameters + ---------- + config : NetworkConfig + the configuration object. + + Returns + ------- + nn.Module + the neural network. + NetworkConfig + the updated config class: the arguments set to default will be updated + with their effective values (the default values from the network). + Useful for reproducibility. + """ + config = deepcopy(config) + network_type = config._type # pylint: disable=protected-access + + if network_type == NetworkType.CUSTOM: + network_class: type[nn.Module] = getattr(nets, config.name) + if config.name == ImplementedNetworks.SE_RESNET: + _update_config_with_defaults( + config, getattr(nets, ImplementedNetworks.RESNET.value).__init__ + ) # SEResNet has some default values in ResNet + elif config.name == ImplementedNetworks.ATT_UNET: + _update_config_with_defaults( + config, getattr(nets, ImplementedNetworks.UNET.value).__init__ + ) + _update_config_with_defaults(config, network_class.__init__) + + config_dict = config.model_dump(exclude={"name", "_type"}) + network = network_class(**config_dict) + + else: # sota networks + if network_type == NetworkType.RESNET: + getter: Callable[..., nn.Module] = nets.get_resnet + elif network_type == NetworkType.DENSENET: + getter: Callable[..., nn.Module] = nets.get_densenet + elif network_type == NetworkType.SE_RESNET: + getter: Callable[..., nn.Module] = nets.get_seresnet + elif network_type == NetworkType.VIT: + getter: Callable[..., nn.Module] = nets.get_vit + _update_config_with_defaults(config, getter) # pylint: disable=possibly-used-before-assignment + + config_dict = config.model_dump(exclude={"_type"}) + network = getter(**config_dict) + + return network, config + + +def _update_config_with_defaults(config: BaseModel, function: Callable) -> BaseModel: + """ + Updates a config object by setting the parameters left to 'default' to their actual + default values, extracted from 'function'. + """ + _, defaults = get_args_and_defaults(function) + + for arg, value in config: + if isinstance(value, MLPOptions): + _update_config_with_defaults( + value, MLP.__init__ + ) # we need to update the sub config object + elif isinstance(value, ConvEncoderOptions): + _update_config_with_defaults(value, ConvEncoder.__init__) + elif isinstance(value, ConvDecoderOptions): + _update_config_with_defaults(value, ConvDecoder.__init__) + elif value == DefaultFromLibrary.YES and arg in defaults: + setattr(config, arg, defaults[arg]) diff --git a/clinicadl/networks/nn/__init__.py b/clinicadl/networks/nn/__init__.py new file mode 100644 index 000000000..0e1c7054a --- /dev/null +++ b/clinicadl/networks/nn/__init__.py @@ -0,0 +1,13 @@ +from .att_unet import AttentionUNet +from .autoencoder import AutoEncoder +from .cnn import CNN +from .conv_decoder import ConvDecoder +from .conv_encoder import ConvEncoder +from .densenet import DenseNet, get_densenet +from .generator import Generator +from .mlp import MLP +from .resnet import ResNet, get_resnet +from .senet import SEResNet, get_seresnet +from .unet import UNet +from .vae import VAE +from .vit import ViT, get_vit diff --git a/clinicadl/networks/nn/att_unet.py b/clinicadl/networks/nn/att_unet.py new file mode 100644 index 000000000..77ef02081 --- /dev/null +++ b/clinicadl/networks/nn/att_unet.py @@ -0,0 +1,207 @@ +from typing import Any + +import torch +from monai.networks.nets.attentionunet import AttentionBlock + +from .layers.unet import ConvBlock, UpSample +from .unet import BaseUNet + + +class AttentionUNet(BaseUNet): + """ + Attention-UNet based on [Attention U-Net: Learning Where to Look for the Pancreas](https://arxiv.org/pdf/1804.03999). + + The user can customize the number of encoding blocks, the number of channels in each block, as well as other parameters + like the activation function. + + .. warning:: AttentionUNet works only with images whose dimensions are high enough powers of 2. More precisely, if n is the + number of max pooling operation in your AttentionUNet (which is equal to `len(channels)-1`), the image must have :math:`2^{k}` + pixels in each dimension, with :math:`k \\geq n` (e.g. shape (:math:`2^{n}`, :math:`2^{n+3}`) for a 2D image). + + Parameters + ---------- + spatial_dims : int + number of spatial dimensions of the input image. + in_channels : int + number of channels in the input image. + out_channels : int + number of output channels. + kwargs : Any + any optional argument accepted by (:py:class:`clinicadl.monai_networks.nn.unet.UNet`). + + Examples + -------- + >>> AttentionUNet( + spatial_dims=2, + in_channels=1, + out_channels=2, + channels=(4, 8), + act="elu", + output_act=("softmax", {"dim": 1}), + dropout=0.1, + ) + AttentionUNet( + (doubleconv): ConvBlock( + (0): Convolution( + (conv): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (1): Convolution( + (conv): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + ) + (down1): DownBlock( + (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + (doubleconv): ConvBlock( + (0): Convolution( + (conv): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (1): Convolution( + (conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + ) + ) + (upsample1): UpSample( + (0): Upsample(scale_factor=2.0, mode='nearest') + (1): Convolution( + (conv): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + ) + (attention1): AttentionBlock( + (W_g): Sequential( + (0): Convolution( + (conv): Conv2d(4, 2, kernel_size=(1, 1), stride=(1, 1)) + ) + (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (W_x): Sequential( + (0): Convolution( + (conv): Conv2d(4, 2, kernel_size=(1, 1), stride=(1, 1)) + ) + (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (psi): Sequential( + (0): Convolution( + (conv): Conv2d(2, 1, kernel_size=(1, 1), stride=(1, 1)) + ) + (1): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (2): Sigmoid() + ) + (relu): ReLU() + ) + (doubleconv1): ConvBlock( + (0): Convolution( + (conv): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (1): Convolution( + (conv): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + ) + (reduce_channels): Convolution( + (conv): Conv2d(4, 2, kernel_size=(1, 1), stride=(1, 1)) + ) + (output_act): Softmax(dim=1) + ) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + **kwargs: Any, + ): + super().__init__( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + **kwargs, + ) + + def _build_decoder(self): + for i in range(len(self.channels) - 1, 0, -1): + self.add_module( + f"upsample{i}", + UpSample( + spatial_dims=self.spatial_dims, + in_channels=self.channels[i], + out_channels=self.channels[i - 1], + act=self.act, + dropout=self.dropout, + ), + ) + self.add_module( + f"attention{i}", + AttentionBlock( + spatial_dims=self.spatial_dims, + f_l=self.channels[i - 1], + f_g=self.channels[i - 1], + f_int=self.channels[i - 1] // 2, + dropout=self.dropout, + ), + ) + self.add_module( + f"doubleconv{i}", + ConvBlock( + spatial_dims=self.spatial_dims, + in_channels=self.channels[i - 1] * 2, + out_channels=self.channels[i - 1], + act=self.act, + dropout=self.dropout, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_history = [self.doubleconv(x)] + + for i in range(1, len(self.channels)): + x = self.get_submodule(f"down{i}")(x_history[-1]) + x_history.append(x) + + x_history.pop() # the output of bottelneck is not used as a gating signal + for i in range(len(self.channels) - 1, 0, -1): + up = self.get_submodule(f"upsample{i}")(x) + att_res = self.get_submodule(f"attention{i}")(g=x_history.pop(), x=up) + merged = torch.cat((att_res, up), dim=1) + x = self.get_submodule(f"doubleconv{i}")(merged) + + out = self.reduce_channels(x) + + if self.output_act is not None: + out = self.output_act(out) + + return out diff --git a/clinicadl/networks/nn/autoencoder.py b/clinicadl/networks/nn/autoencoder.py new file mode 100644 index 000000000..5cf823eeb --- /dev/null +++ b/clinicadl/networks/nn/autoencoder.py @@ -0,0 +1,416 @@ +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union + +import numpy as np +import torch.nn as nn + +from .cnn import CNN +from .conv_encoder import ConvEncoder +from .generator import Generator +from .layers.utils import ( + ActivationParameters, + PoolingLayer, + SingleLayerPoolingParameters, + SingleLayerUnpoolingParameters, + UnpoolingLayer, + UnpoolingMode, +) +from .mlp import MLP +from .utils import ( + calculate_conv_out_shape, + calculate_convtranspose_out_shape, + calculate_pool_out_shape, +) + + +class AutoEncoder(nn.Sequential): + """ + An autoencoder with convolutional and fully connected layers. + + The user must pass the arguments to build an encoder, from its convolutional and + fully connected parts, and the decoder will be automatically built by taking the + symmetrical network. + + More precisely, to build the decoder, the order of the encoding layers is reverted, convolutions are + replaced by transposed convolutions and pooling layers are replaced by either upsampling or transposed + convolution layers. + Please note that the order of `Activation`, `Dropout` and `Normalization`, defined with the + argument `adn_ordering` in `conv_args`, is the same for the encoder and the decoder. + + Note that an `AutoEncoder` is an aggregation of a `CNN` (:py:class:`clinicadl.monai_networks.nn. + cnn.CNN`) and a `Generator` (:py:class:`clinicadl.monai_networks.nn.generator.Generator`). + + Parameters + ---------- + in_shape : Sequence[int] + sequence of integers stating the dimension of the input tensor (minus batch dimension). + latent_size : int + size of the latent vector. + conv_args : Dict[str, Any] + the arguments for the convolutional part of the encoder. The arguments are those accepted + by :py:class:`clinicadl.monai_networks.nn.conv_encoder.ConvEncoder`, except `in_shape` that + is specified here. So, the only mandatory argument is `channels`. + mlp_args : Optional[Dict[str, Any]] (optional, default=None) + the arguments for the MLP part of the encoder . The arguments are those accepted by + :py:class:`clinicadl.monai_networks.nn.mlp.MLP`, except `in_channels` that is inferred + from the output of the convolutional part, and `out_channels` that is set to `latent_size`. + So, the only mandatory argument is `hidden_channels`.\n + If None, the MLP part will be reduced to a single linear layer. + out_channels : Optional[int] (optional, default=None) + number of output channels. If None, the output will have the same number of channels as the + input. + output_act : Optional[ActivationParameters] (optional, default=None) + a potential activation layer applied to the output of the network, and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + unpooling_mode : Union[str, UnpoolingMode] (optional, default=UnpoolingMode.NEAREST) + type of unpooling. Can be either `"nearest"`, `"linear"`, `"bilinear"`, `"bicubic"`, `"trilinear"` or + `"convtranspose"`.\n + - `nearest`: unpooling is performed by upsampling with the :italic:`nearest` algorithm (see [PyTorch's Upsample layer] + (https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html)). + - `linear`: unpooling is performed by upsampling with the :italic:`linear` algorithm. Only works with 1D images (excluding the + channel dimension). + - `bilinear`: unpooling is performed by upsampling with the :italic:`bilinear` algorithm. Only works with 2D images. + - `bicubic`: unpooling is performed by upsampling with the :italic:`bicubic` algorithm. Only works with 2D images. + - `trilinear`: unpooling is performed by upsampling with the :italic:`trilinear` algorithm. Only works with 3D images. + - `convtranspose`: unpooling is performed with a transposed convolution, whose parameters (kernel size, stride, etc.) are + computed to reverse the pooling operation. + + Examples + -------- + >>> AutoEncoder( + in_shape=(1, 16, 16), + latent_size=8, + conv_args={ + "channels": [2, 4], + "pooling_indices": [0], + "pooling": ("avg", {"kernel_size": 2}), + }, + mlp_args={"hidden_channels": [32], "output_act": "relu"}, + out_channels=2, + output_act="sigmoid", + unpooling_mode="bilinear", + ) + AutoEncoder( + (encoder): CNN( + (convolutions): ConvEncoder( + (layer0): Convolution( + (conv): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)) + (adn): ADN( + (N): InstanceNorm2d(2, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False) + (A): PReLU(num_parameters=1) + ) + ) + (pool0): AvgPool2d(kernel_size=2, stride=2, padding=0) + (layer1): Convolution( + (conv): Conv2d(2, 4, kernel_size=(3, 3), stride=(1, 1)) + ) + ) + (mlp): MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (hidden0): Sequential( + (linear): Linear(in_features=100, out_features=32, bias=True) + (adn): ADN( + (N): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (A): PReLU(num_parameters=1) + ) + ) + (output): Sequential( + (linear): Linear(in_features=32, out_features=8, bias=True) + (output_act): ReLU() + ) + ) + ) + (decoder): Generator( + (mlp): MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (hidden0): Sequential( + (linear): Linear(in_features=8, out_features=32, bias=True) + (adn): ADN( + (N): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (A): PReLU(num_parameters=1) + ) + ) + (output): Sequential( + (linear): Linear(in_features=32, out_features=100, bias=True) + (output_act): ReLU() + ) + ) + (reshape): Reshape() + (convolutions): ConvDecoder( + (layer0): Convolution( + (conv): ConvTranspose2d(4, 4, kernel_size=(3, 3), stride=(1, 1)) + (adn): ADN( + (N): InstanceNorm2d(4, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False) + (A): PReLU(num_parameters=1) + ) + ) + (unpool0): Upsample(size=(14, 14), mode=) + (layer1): Convolution( + (conv): ConvTranspose2d(4, 2, kernel_size=(3, 3), stride=(1, 1)) + ) + (output_act): Sigmoid() + ) + ) + ) + + """ + + def __init__( + self, + in_shape: Sequence[int], + latent_size: int, + conv_args: Dict[str, Any], + mlp_args: Optional[Dict[str, Any]] = None, + out_channels: Optional[int] = None, + output_act: Optional[ActivationParameters] = None, + unpooling_mode: Union[str, UnpoolingMode] = UnpoolingMode.NEAREST, + ) -> None: + super().__init__() + self.in_shape = in_shape + self.latent_size = latent_size + self.out_channels = out_channels if out_channels else self.in_shape[0] + self._output_act = output_act + self.unpooling_mode = self._check_unpooling_mode(unpooling_mode) + self.spatial_dims = len(in_shape[1:]) + + self.encoder = CNN( + in_shape=self.in_shape, + num_outputs=latent_size, + conv_args=conv_args, + mlp_args=mlp_args, + ) + inter_channels = ( + conv_args["channels"][-1] if len(conv_args["channels"]) > 0 else in_shape[0] + ) + inter_shape = (inter_channels, *self.encoder.convolutions.final_size) + self.decoder = Generator( + latent_size=latent_size, + start_shape=inter_shape, + conv_args=self._invert_conv_args(conv_args, self.encoder.convolutions), + mlp_args=self._invert_mlp_args(mlp_args, self.encoder.mlp), + ) + + @classmethod + def _invert_mlp_args(cls, args: Dict[str, Any], mlp: MLP) -> Dict[str, Any]: + """ + Inverts arguments passed for the MLP part of the encoder, to get the MLP part of + the decoder. + """ + if args is None: + args = {} + args["hidden_channels"] = cls._invert_list_arg(mlp.hidden_channels) + + return args + + def _invert_conv_args( + self, args: Dict[str, Any], conv: ConvEncoder + ) -> Dict[str, Any]: + """ + Inverts arguments passed for the convolutional part of the encoder, to get the convolutional + part of the decoder. + """ + if len(args["channels"]) == 0: + args["channels"] = [] + else: + args["channels"] = self._invert_list_arg(conv.channels[:-1]) + [ + self.out_channels + ] + args["kernel_size"] = self._invert_list_arg(conv.kernel_size) + args["stride"] = self._invert_list_arg(conv.stride) + args["dilation"] = self._invert_list_arg(conv.dilation) + args["padding"], args["output_padding"] = self._get_paddings_list(conv) + + args["unpooling_indices"] = ( + conv.n_layers - np.array(conv.pooling_indices) - 2 + ).astype(int) + args["unpooling"] = [] + sizes_before_pooling = [ + size + for size, (layer_name, _) in zip(conv.size_details, conv.named_children()) + if "pool" in layer_name + ] + for size, pooling in zip(sizes_before_pooling[::-1], conv.pooling[::-1]): + args["unpooling"].append(self._invert_pooling_layer(size, pooling)) + + if "pooling" in args: + del args["pooling"] + if "pooling_indices" in args: + del args["pooling_indices"] + + args["output_act"] = self._output_act if self._output_act else None + + return args + + @classmethod + def _invert_list_arg(cls, arg: Union[Any, List[Any]]) -> Union[Any, List[Any]]: + """ + Reverses lists. + """ + return list(arg[::-1]) if isinstance(arg, Sequence) else arg + + def _invert_pooling_layer( + self, + size_before_pool: Sequence[int], + pooling: SingleLayerPoolingParameters, + ) -> SingleLayerUnpoolingParameters: + """ + Gets the unpooling layer. + """ + if self.unpooling_mode == UnpoolingMode.CONV_TRANS: + return ( + UnpoolingLayer.CONV_TRANS, + self._invert_pooling_with_convtranspose(size_before_pool, pooling), + ) + else: + return ( + UnpoolingLayer.UPSAMPLE, + {"size": size_before_pool, "mode": self.unpooling_mode}, + ) + + @classmethod + def _invert_pooling_with_convtranspose( + cls, + size_before_pool: Sequence[int], + pooling: SingleLayerPoolingParameters, + ) -> Dict[str, Any]: + """ + Computes the arguments of the transposed convolution, based on the pooling layer. + """ + pooling_mode, pooling_args = pooling + if ( + pooling_mode == PoolingLayer.ADAPT_AVG + or pooling_mode == PoolingLayer.ADAPT_MAX + ): + input_size_np = np.array(size_before_pool) + output_size_np = np.array(pooling_args["output_size"]) + stride_np = input_size_np // output_size_np # adaptive pooling formulas + kernel_size_np = ( + input_size_np - (output_size_np - 1) * stride_np + ) # adaptive pooling formulas + args = { + "kernel_size": tuple(int(k) for k in kernel_size_np), + "stride": tuple(int(s) for s in stride_np), + } + padding, output_padding = cls._find_convtranspose_paddings( + pooling_mode, + size_before_pool, + output_size=pooling_args["output_size"], + **args, + ) + + elif pooling_mode == PoolingLayer.MAX or pooling_mode == PoolingLayer.AVG: + if "stride" not in pooling_args: + pooling_args["stride"] = pooling_args["kernel_size"] + args = { + arg: value + for arg, value in pooling_args.items() + if arg in ["kernel_size", "stride", "padding", "dilation"] + } + padding, output_padding = cls._find_convtranspose_paddings( + pooling_mode, + size_before_pool, + **pooling_args, + ) + + args["padding"] = padding # pylint: disable=possibly-used-before-assignment + args["output_padding"] = output_padding # pylint: disable=possibly-used-before-assignment + + return args + + @classmethod + def _get_paddings_list(cls, conv: ConvEncoder) -> List[Tuple[int, ...]]: + """ + Finds output padding list. + """ + padding = [] + output_padding = [] + size_before_convs = [ + size + for size, (layer_name, _) in zip(conv.size_details, conv.named_children()) + if "layer" in layer_name + ] + for size, k, s, p, d in zip( + size_before_convs, + conv.kernel_size, + conv.stride, + conv.padding, + conv.dilation, + ): + p, out_p = cls._find_convtranspose_paddings( + "conv", size, kernel_size=k, stride=s, padding=p, dilation=d + ) + padding.append(p) + output_padding.append(out_p) + + return cls._invert_list_arg(padding), cls._invert_list_arg(output_padding) + + @classmethod + def _find_convtranspose_paddings( + cls, + layer_type: Union[Literal["conv"], PoolingLayer], + in_shape: Union[Sequence[int], int], + padding: Union[Sequence[int], int] = 0, + **kwargs, + ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """ + Finds padding and output padding necessary to recover the right image size after + a transposed convolution. + """ + if layer_type == "conv": + layer_out_shape = calculate_conv_out_shape(in_shape, **kwargs) + elif layer_type in list(PoolingLayer): + layer_out_shape = calculate_pool_out_shape(layer_type, in_shape, **kwargs) + + convt_out_shape = calculate_convtranspose_out_shape(layer_out_shape, **kwargs) # pylint: disable=possibly-used-before-assignment + output_padding = np.atleast_1d(in_shape) - np.atleast_1d(convt_out_shape) + + if ( + output_padding < 0 + ).any(): # can happen with ceil_mode=True for maxpool. Then, add some padding + padding = np.atleast_1d(padding) * np.ones_like( + output_padding + ) # to have the same shape as output_padding + padding[output_padding < 0] += np.maximum(np.abs(output_padding) // 2, 1)[ + output_padding < 0 + ] # //2 because 2*padding pixels are removed + + convt_out_shape = calculate_convtranspose_out_shape( + layer_out_shape, padding=padding, **kwargs + ) + output_padding = np.atleast_1d(in_shape) - np.atleast_1d(convt_out_shape) + padding = tuple(int(s) for s in padding) + + return padding, tuple(int(s) for s in output_padding) + + def _check_unpooling_mode( + self, unpooling_mode: Union[str, UnpoolingMode] + ) -> UnpoolingMode: + """ + Checks consistency between data shape and unpooling mode. + """ + unpooling_mode = UnpoolingMode(unpooling_mode) + if unpooling_mode == UnpoolingMode.LINEAR and len(self.in_shape) != 2: + raise ValueError( + f"unpooling mode `linear` only works with 2D data (counting the channel dimension). " + f"Got in_shape={self.in_shape}, which is understood as {len(self.in_shape)}D data." + ) + elif unpooling_mode == UnpoolingMode.BILINEAR and len(self.in_shape) != 3: + raise ValueError( + f"unpooling mode `bilinear` only works with 3D data (counting the channel dimension). " + f"Got in_shape={self.in_shape}, which is understood as {len(self.in_shape)}D data." + ) + elif unpooling_mode == UnpoolingMode.BICUBIC and len(self.in_shape) != 3: + raise ValueError( + f"unpooling mode `bicubic` only works with 3D data (counting the channel dimension). " + f"Got in_shape={self.in_shape}, which is understood as {len(self.in_shape)}D data." + ) + elif unpooling_mode == UnpoolingMode.TRILINEAR and len(self.in_shape) != 4: + raise ValueError( + f"unpooling mode `trilinear` only works with 4D data (counting the channel dimension). " + f"Got in_shape={self.in_shape}, which is understood as {len(self.in_shape)}D data." + ) + + return unpooling_mode diff --git a/clinicadl/networks/nn/cnn.py b/clinicadl/networks/nn/cnn.py new file mode 100644 index 000000000..1479ecaea --- /dev/null +++ b/clinicadl/networks/nn/cnn.py @@ -0,0 +1,124 @@ +from typing import Any, Dict, Optional, Sequence + +import numpy as np +import torch.nn as nn + +from .conv_encoder import ConvEncoder +from .mlp import MLP +from .utils import check_conv_args, check_mlp_args + + +class CNN(nn.Sequential): + """ + A regressor/classifier with first convolutional layers and then fully connected layers. + + This network is a simple aggregation of a Fully Convolutional Network (:py:class:`clinicadl. + monai_networks.nn.conv_encoder.ConvEncoder`) and a Multi Layer Perceptron (:py:class:`clinicadl. + monai_networks.nn.mlp.MLP`). + + Parameters + ---------- + in_shape : Sequence[int] + sequence of integers stating the dimension of the input tensor (minus batch dimension). + num_outputs : int + number of variables to predict. + conv_args : Dict[str, Any] + the arguments for the convolutional part. The arguments are those accepted by + :py:class:`clinicadl.monai_networks.nn.conv_encoder.ConvEncoder`, except `in_shape` + that is specified here. So, the only mandatory argument is `channels`. + mlp_args : Optional[Dict[str, Any]] (optional, default=None) + the arguments for the MLP part. The arguments are those accepted by + :py:class:`clinicadl.monai_networks.nn.mlp.MLP`, except `in_channels` that is inferred + from the output of the convolutional part, and `out_channels` that is set to `num_outputs`. + So, the only mandatory argument is `hidden_channels`.\n + If None, the MLP part will be reduced to a single linear layer. + + Examples + -------- + # a classifier + >>> CNN( + in_shape=(1, 10, 10), + num_outputs=2, + conv_args={"channels": [2, 4], "norm": None, "act": None}, + mlp_args={"hidden_channels": [5], "act": "elu", "norm": None, "output_act": "softmax"}, + ) + CNN( + (convolutions): ConvEncoder( + (layer0): Convolution( + (conv): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)) + ) + (layer1): Convolution( + (conv): Conv2d(2, 4, kernel_size=(3, 3), stride=(1, 1)) + ) + ) + (mlp): MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (hidden0): Sequential( + (linear): Linear(in_features=144, out_features=5, bias=True) + (adn): ADN( + (A): ELU(alpha=1.0) + ) + ) + (output): Sequential( + (linear): Linear(in_features=5, out_features=2, bias=True) + (output_act): Softmax(dim=None) + ) + ) + ) + + # a regressor + >>> CNN( + in_shape=(1, 10, 10), + num_outputs=2, + conv_args={"channels": [2, 4], "norm": None, "act": None}, + ) + CNN( + (convolutions): ConvEncoder( + (layer0): Convolution( + (conv): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)) + ) + (layer1): Convolution( + (conv): Conv2d(2, 4, kernel_size=(3, 3), stride=(1, 1)) + ) + ) + (mlp): MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (output): Linear(in_features=144, out_features=2, bias=True) + ) + ) + """ + + def __init__( + self, + in_shape: Sequence[int], + num_outputs: int, + conv_args: Dict[str, Any], + mlp_args: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + check_conv_args(conv_args) + check_mlp_args(mlp_args) + self.in_shape = in_shape + self.num_outputs = num_outputs + + in_channels, *input_size = in_shape + spatial_dims = len(input_size) + + self.convolutions = ConvEncoder( + in_channels=in_channels, + spatial_dims=spatial_dims, + _input_size=tuple(input_size), + **conv_args, + ) + + n_channels = ( + conv_args["channels"][-1] if len(conv_args["channels"]) > 0 else in_shape[0] + ) + flatten_shape = int(np.prod(self.convolutions.final_size) * n_channels) + if mlp_args is None: + mlp_args = {"hidden_channels": []} + self.mlp = MLP( + in_channels=flatten_shape, + out_channels=num_outputs, + **mlp_args, + ) diff --git a/clinicadl/networks/nn/conv_decoder.py b/clinicadl/networks/nn/conv_decoder.py new file mode 100644 index 000000000..28c9be96f --- /dev/null +++ b/clinicadl/networks/nn/conv_decoder.py @@ -0,0 +1,388 @@ +from typing import Callable, Optional, Sequence, Tuple + +import torch.nn as nn +from monai.networks.blocks import Convolution +from monai.networks.layers.utils import get_act_layer +from monai.utils.misc import ensure_tuple + +from .layers.unpool import get_unpool_layer +from .layers.utils import ( + ActFunction, + ActivationParameters, + ConvNormalizationParameters, + ConvNormLayer, + ConvParameters, + NormLayer, + SingleLayerUnpoolingParameters, + UnpoolingLayer, + UnpoolingParameters, +) +from .utils import ( + calculate_convtranspose_out_shape, + calculate_unpool_out_shape, + check_adn_ordering, + check_norm_layer, + check_pool_indices, + ensure_list_of_tuples, +) + + +class ConvDecoder(nn.Sequential): + """ + Fully convolutional decoder network with transposed convolutions, unpooling, normalization, activation + and dropout layers. + + Parameters + ---------- + spatial_dims : int + number of spatial dimensions of the input image. + in_channels : int + number of channels in the input image. + channels : Sequence[int] + sequence of integers stating the output channels of each transposed convolution. Thus, this + parameter also controls the number of transposed convolutions. + kernel_size : ConvParameters (optional, default=3) + the kernel size of the transposed convolutions. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the kernel sizes for each layer. + The length of the list must be equal to the number of transposed convolution layers (i.e. + `len(channels)`). + stride : ConvParameters (optional, default=1) + the stride of the transposed convolutions. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the strides for each layer. + The length of the list must be equal to the number of transposed convolution layers (i.e. + `len(channels)`). + padding : ConvParameters (optional, default=0) + the padding of the transposed convolutions. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the paddings for each layer. + The length of the list must be equal to the number of transposed convolution layers (i.e. + `len(channels)`). + output_padding : ConvParameters (optional, default=0) + the output padding of the transposed convolutions. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the output paddings for each layer. + The length of the list must be equal to the number of transposed convolution layers (i.e. + `len(channels)`). + dilation : ConvParameters (optional, default=1) + the dilation factor of the transposed convolutions. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the dilations for each layer. + The length of the list must be equal to the number of transposed convolution layers (i.e. + `len(channels)`). + unpooling : Optional[UnpoolingParameters] (optional, default=(UnpoolingLayer.UPSAMPLE, {"scale_factor": 2})) + the unpooling mode and the arguments of the unpooling layer, passed as `(unpooling_mode, arguments)`. + If None, no unpooling will be performed in the network.\n + `unpooling_mode` can be either `upsample` or `convtranspose`. Please refer to PyTorch's [Upsample] + (https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html) or [ConvTranspose](https:// + pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html) to know the mandatory and optional + arguments.\n + If a list is passed, it will be understood as `(unpooling_mode, arguments)` for each unpooling layer.\n + Note: no need to pass `in_channels` and `out_channels` for `convtranspose` because the unpooling + layers are not intended to modify the number of channels. + unpooling_indices : Optional[Sequence[int]] (optional, default=None) + indices of the transposed convolution layers after which unpooling should be performed. + If None, no unpooling will be performed. An index equal to -1 will be understood as a pooling layer before + the first transposed convolution. + act : Optional[ActivationParameters] (optional, default=ActFunction.PRELU) + the activation function used after a transposed convolution layer, and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + output_act : Optional[ActivationParameters] (optional, default=None) + a potential activation layer applied to the output of the network. Should be pass in the same way as `act`. + If None, no last activation will be applied. + norm : Optional[ConvNormalizationParameters] (optional, default=NormLayer.INSTANCE) + the normalization type used after a transposed convolution layer, and optionally the arguments of the normalization + layer. Should be passed as `norm_type` or `(norm_type, parameters)`. If None, no normalization will be + performed.\n + `norm_type` can be any value in {`batch`, `group`, `instance`, `syncbatch`}. Please refer to PyTorch's + [normalization layers](https://pytorch.org/docs/stable/nn.html#normalization-layers) to know the mandatory and + optional arguments for each of them.\n + Please note that arguments `num_channels`, `num_features` of the normalization layer + should not be passed, as they are automatically inferred from the output of the previous layer in the network. + dropout : Optional[float] (optional, default=None) + dropout ratio. If None, no dropout. + bias : bool (optional, default=True) + whether to have a bias term in transposed convolutions. + adn_ordering : str (optional, default="NDA") + order of operations `Activation`, `Dropout` and `Normalization` after a transposed convolutional layer (except the + last one).\n + For example if "ND" is passed, `Normalization` and then `Dropout` will be performed (without `Activation`).\n + Note: ADN will not be applied after the last convolution. + + Examples + -------- + >>> ConvDecoder( + in_channels=16, + spatial_dims=2, + channels=[8, 4, 1], + kernel_size=(3, 5), + stride=2, + padding=[1, 0, 0], + output_padding=[0, 0, (1, 2)], + dilation=1, + unpooling=[("upsample", {"scale_factor": 2}), ("upsample", {"size": (32, 32)})], + unpooling_indices=[0, 1], + act="elu", + output_act="relu", + norm=("batch", {"eps": 1e-05}), + dropout=0.1, + bias=True, + adn_ordering="NDA", + ) + ConvDecoder( + (layer0): Convolution( + (conv): ConvTranspose2d(16, 8, kernel_size=(3, 5), stride=(2, 2), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (unpool0): Upsample(scale_factor=2.0, mode='nearest') + (layer1): Convolution( + (conv): ConvTranspose2d(8, 4, kernel_size=(3, 5), stride=(2, 2)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (unpool1): Upsample(size=(32, 32), mode='nearest') + (layer2): Convolution( + (conv): ConvTranspose2d(4, 1, kernel_size=(3, 5), stride=(2, 2), output_padding=(1, 2)) + ) + (output_act): ReLU() + ) + + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + channels: Sequence[int], + kernel_size: ConvParameters = 3, + stride: ConvParameters = 1, + padding: ConvParameters = 0, + output_padding: ConvParameters = 0, + dilation: ConvParameters = 1, + unpooling: Optional[UnpoolingParameters] = ( + UnpoolingLayer.UPSAMPLE, + {"scale_factor": 2}, + ), + unpooling_indices: Optional[Sequence[int]] = None, + act: Optional[ActivationParameters] = ActFunction.PRELU, + output_act: Optional[ActivationParameters] = None, + norm: Optional[ConvNormalizationParameters] = ConvNormLayer.INSTANCE, + dropout: Optional[float] = None, + bias: bool = True, + adn_ordering: str = "NDA", + _input_size: Optional[Sequence[int]] = None, + ) -> None: + super().__init__() + + self._current_size = _input_size if _input_size else None + + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.channels = ensure_tuple(channels) + self.n_layers = len(self.channels) + + self.kernel_size = ensure_list_of_tuples( + kernel_size, self.spatial_dims, self.n_layers, "kernel_size" + ) + self.stride = ensure_list_of_tuples( + stride, self.spatial_dims, self.n_layers, "stride" + ) + self.padding = ensure_list_of_tuples( + padding, self.spatial_dims, self.n_layers, "padding" + ) + self.output_padding = ensure_list_of_tuples( + output_padding, self.spatial_dims, self.n_layers, "output_padding" + ) + self.dilation = ensure_list_of_tuples( + dilation, self.spatial_dims, self.n_layers, "dilation" + ) + + self.unpooling_indices = check_pool_indices(unpooling_indices, self.n_layers) + self.unpooling = self._check_unpool_layers(unpooling) + self.act = act + self.norm = check_norm_layer(norm) + if self.norm == NormLayer.LAYER: + raise ValueError("Layer normalization not implemented in ConvDecoder.") + self.dropout = dropout + self.bias = bias + self.adn_ordering = check_adn_ordering(adn_ordering) + + n_unpoolings = 0 + if self.unpooling and -1 in self.unpooling_indices: + unpooling_layer = self._get_unpool_layer( + self.unpooling[n_unpoolings], n_channels=self.in_channels + ) + self.add_module("init_unpool", unpooling_layer) + n_unpoolings += 1 + + echannel = self.in_channels + for i, (c, k, s, p, o_p, d) in enumerate( + zip( + self.channels, + self.kernel_size, + self.stride, + self.padding, + self.output_padding, + self.dilation, + ) + ): + conv_layer = self._get_convtranspose_layer( + in_channels=echannel, + out_channels=c, + kernel_size=k, + stride=s, + padding=p, + output_padding=o_p, + dilation=d, + is_last=(i == len(channels) - 1), + ) + self.add_module(f"layer{i}", conv_layer) + echannel = c # use the output channel number as the input for the next loop + if self.unpooling and i in self.unpooling_indices: + unpooling_layer = self._get_unpool_layer( + self.unpooling[n_unpoolings], n_channels=c + ) + self.add_module(f"unpool{i}", unpooling_layer) + n_unpoolings += 1 + + self.output_act = get_act_layer(output_act) if output_act else None + + @property + def final_size(self): + """ + To know the size of an image at the end of the network. + """ + return self._current_size + + @final_size.setter + def final_size(self, fct: Callable[[Tuple[int, ...]], Tuple[int, ...]]): + """ + Takes as input the function used to update the current image size. + """ + if self._current_size is not None: + self._current_size = fct(self._current_size) + + def _get_convtranspose_layer( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, ...], + stride: Tuple[int, ...], + padding: Tuple[int, ...], + output_padding: Tuple[int, ...], + dilation: Tuple[int, ...], + is_last: bool, + ) -> Convolution: + """ + Gets the parametrized TransposedConvolution-ADN block and updates the current output size. + """ + self.final_size = lambda size: calculate_convtranspose_out_shape( + size, kernel_size, stride, padding, output_padding, dilation + ) + + return Convolution( + is_transposed=True, + conv_only=is_last, + spatial_dims=self.spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + strides=stride, + kernel_size=kernel_size, + padding=padding, + output_padding=output_padding, + dilation=dilation, + act=self.act, + norm=self.norm, + dropout=self.dropout, + bias=self.bias, + adn_ordering=self.adn_ordering, + ) + + def _get_unpool_layer( + self, unpooling: SingleLayerUnpoolingParameters, n_channels: int + ) -> nn.Module: + """ + Gets the parametrized unpooling layer and updates the current output size. + """ + unpool_layer = get_unpool_layer( + unpooling, + spatial_dims=self.spatial_dims, + in_channels=n_channels, + out_channels=n_channels, + ) + self.final_size = lambda size: calculate_unpool_out_shape( + unpool_mode=unpooling[0], + in_shape=size, + **unpool_layer.__dict__, + ) + return unpool_layer + + @classmethod + def _check_single_unpool_layer( + cls, unpooling: SingleLayerUnpoolingParameters + ) -> SingleLayerUnpoolingParameters: + """ + Checks unpooling arguments for a single pooling layer. + """ + if not isinstance(unpooling, tuple) or len(unpooling) != 2: + raise ValueError( + "unpooling must be double (or a list of doubles) with first the type of unpooling and then the parameters of " + f"the unpooling layer in a dict. Got {unpooling}" + ) + _ = UnpoolingLayer(unpooling[0]) # check unpooling mode + args = unpooling[1] + if not isinstance(args, dict): + raise ValueError( + f"The arguments of the unpooling layer must be passed in a dict. Got {args}" + ) + + return unpooling + + def _check_unpool_layers( + self, unpooling: UnpoolingParameters + ) -> UnpoolingParameters: + """ + Checks argument unpooling. + """ + if unpooling is None: + return unpooling + if isinstance(unpooling, list): + for unpool_layer in unpooling: + self._check_single_unpool_layer(unpool_layer) + if len(unpooling) != len(self.unpooling_indices): + raise ValueError( + "If you pass a list for unpooling, the size of that list must match " + f"the size of unpooling_indices. Got: unpooling={unpooling} and " + f"unpooling_indices={self.unpooling_indices}" + ) + elif isinstance(unpooling, tuple): + self._check_single_unpool_layer(unpooling) + unpooling = (unpooling,) * len(self.unpooling_indices) + else: + raise ValueError( + f"unpooling can be either None, a double (string, dictionary) or a list of such doubles. Got {unpooling}" + ) + + return unpooling diff --git a/clinicadl/networks/nn/conv_encoder.py b/clinicadl/networks/nn/conv_encoder.py new file mode 100644 index 000000000..f3ec66484 --- /dev/null +++ b/clinicadl/networks/nn/conv_encoder.py @@ -0,0 +1,392 @@ +from typing import Callable, List, Optional, Sequence, Tuple + +import numpy as np +import torch.nn as nn +from monai.networks.blocks import Convolution +from monai.networks.layers.utils import get_act_layer, get_pool_layer +from monai.utils.misc import ensure_tuple + +from .layers.utils import ( + ActFunction, + ActivationParameters, + ConvNormalizationParameters, + ConvNormLayer, + ConvParameters, + NormLayer, + PoolingLayer, + PoolingParameters, + SingleLayerPoolingParameters, +) +from .utils import ( + calculate_conv_out_shape, + calculate_pool_out_shape, + check_adn_ordering, + check_norm_layer, + check_pool_indices, + ensure_list_of_tuples, +) + + +class ConvEncoder(nn.Sequential): + """ + Fully convolutional encoder network with convolutional, pooling, normalization, activation + and dropout layers. + + Parameters + ---------- + spatial_dims : int + number of spatial dimensions of the input image. + in_channels : int + number of channels in the input image. + channels : Sequence[int] + sequence of integers stating the output channels of each convolutional layer. Thus, this + parameter also controls the number of convolutional layers. + kernel_size : ConvParameters (optional, default=3) + the kernel size of the convolutional layers. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the kernel sizes for each layer. + The length of the list must be equal to the number of convolutional layers (i.e. `len(channels)`). + stride : ConvParameters (optional, default=1) + the stride of the convolutional layers. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the strides for each layer. + The length of the list must be equal to the number of convolutional layers (i.e. `len(channels)`). + padding : ConvParameters (optional, default=0) + the padding of the convolutional layers. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the paddings for each layer. + The length of the list must be equal to the number of convolutional layers (i.e. `len(channels)`). + dilation : ConvParameters (optional, default=1) + the dilation factor of the convolutional layers. Can be an integer, a tuple or a list.\n + If integer, the value will be used for all layers and all dimensions.\n + If tuple (of integers), it will be interpreted as the values for each dimension. These values + will be used for all the layers.\n + If list (of tuples or integers), it will be interpreted as the dilations for each layer. + The length of the list must be equal to the number of convolutional layers (i.e. `len(channels)`). + pooling : Optional[PoolingParameters] (optional, default=(PoolingLayer.MAX, {"kernel_size": 2})) + the pooling mode and the arguments of the pooling layer, passed as `(pooling_mode, arguments)`. + If None, no pooling will be performed in the network.\n + `pooling_mode` can be either `max`, `avg`, `adaptivemax` or `adaptiveavg`. Please refer to PyTorch's [documentation] + (https://pytorch.org/docs/stable/nn.html#pooling-layers) to know the mandatory and optional arguments.\n + If a list is passed, it will be understood as `(pooling_mode, arguments)` for each pooling layer. + pooling_indices : Optional[Sequence[int]] (optional, default=None) + indices of the convolutional layers after which pooling should be performed. + If None, no pooling will be performed. An index equal to -1 will be understood as an unpooling layer before + the first convolution. + act : Optional[ActivationParameters] (optional, default=ActFunction.PRELU) + the activation function used after a convolutional layer, and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + output_act : Optional[ActivationParameters] (optional, default=None) + a potential activation layer applied to the output of the network. Should be pass in the same way as `act`. + If None, no last activation will be applied. + norm : Optional[ConvNormalizationParameters] (optional, default=NormLayer.INSTANCE) + the normalization type used after a convolutional layer, and optionally the arguments of the normalization + layer. Should be passed as `norm_type` or `(norm_type, parameters)`. If None, no normalization will be + performed.\n + `norm_type` can be any value in {`batch`, `group`, `instance`, `syncbatch`}. Please refer to PyTorch's + [normalization layers](https://pytorch.org/docs/stable/nn.html#normalization-layers) to know the mandatory and + optional arguments for each of them.\n + Please note that arguments `num_channels`, `num_features` of the normalization layer + should not be passed, as they are automatically inferred from the output of the previous layer in the network. + dropout : Optional[float] (optional, default=None) + dropout ratio. If None, no dropout. + bias : bool (optional, default=True) + whether to have a bias term in convolutions. + adn_ordering : str (optional, default="NDA") + order of operations `Activation`, `Dropout` and `Normalization` after a convolutional layer (except the last + one). + For example if "ND" is passed, `Normalization` and then `Dropout` will be performed (without `Activation`).\n + Note: ADN will not be applied after the last convolution. + + Examples + -------- + >>> ConvEncoder( + spatial_dims=2, + in_channels=1, + channels=[2, 4, 8], + kernel_size=(3, 5), + stride=1, + padding=[1, (0, 1), 0], + dilation=1, + pooling=[("max", {"kernel_size": 2}), ("avg", {"kernel_size": 2})], + pooling_indices=[0, 1], + act="elu", + output_act="relu", + norm=("batch", {"eps": 1e-05}), + dropout=0.1, + bias=True, + adn_ordering="NDA", + ) + ConvEncoder( + (layer0): Convolution( + (conv): Conv2d(1, 2, kernel_size=(3, 5), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (pool0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + (layer1): Convolution( + (conv): Conv2d(2, 4, kernel_size=(3, 5), stride=(1, 1), padding=(0, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0) + (layer2): Convolution( + (conv): Conv2d(4, 8, kernel_size=(3, 5), stride=(1, 1)) + ) + (output_act): ReLU() + ) + + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + channels: Sequence[int], + kernel_size: ConvParameters = 3, + stride: ConvParameters = 1, + padding: ConvParameters = 0, + dilation: ConvParameters = 1, + pooling: Optional[PoolingParameters] = ( + PoolingLayer.MAX, + {"kernel_size": 2}, + ), + pooling_indices: Optional[Sequence[int]] = None, + act: Optional[ActivationParameters] = ActFunction.PRELU, + output_act: Optional[ActivationParameters] = None, + norm: Optional[ConvNormalizationParameters] = ConvNormLayer.INSTANCE, + dropout: Optional[float] = None, + bias: bool = True, + adn_ordering: str = "NDA", + _input_size: Optional[Sequence[int]] = None, + ) -> None: + super().__init__() + + self._current_size = _input_size if _input_size else None + self._size_details = [self._current_size] if _input_size else None + + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.channels = ensure_tuple(channels) + self.n_layers = len(self.channels) + + self.kernel_size = ensure_list_of_tuples( + kernel_size, self.spatial_dims, self.n_layers, "kernel_size" + ) + self.stride = ensure_list_of_tuples( + stride, self.spatial_dims, self.n_layers, "stride" + ) + self.padding = ensure_list_of_tuples( + padding, self.spatial_dims, self.n_layers, "padding" + ) + self.dilation = ensure_list_of_tuples( + dilation, self.spatial_dims, self.n_layers, "dilation" + ) + + self.pooling_indices = check_pool_indices(pooling_indices, self.n_layers) + self.pooling = self._check_pool_layers(pooling) + self.act = act + self.norm = check_norm_layer(norm) + if self.norm == NormLayer.LAYER: + raise ValueError("Layer normalization not implemented in ConvEncoder.") + self.dropout = dropout + self.bias = bias + self.adn_ordering = check_adn_ordering(adn_ordering) + + n_poolings = 0 + if self.pooling and -1 in self.pooling_indices: + pooling_layer = self._get_pool_layer(self.pooling[n_poolings]) + self.add_module("init_pool", pooling_layer) + n_poolings += 1 + + echannel = self.in_channels + for i, (c, k, s, p, d) in enumerate( + zip( + self.channels, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + ) + ): + conv_layer = self._get_conv_layer( + in_channels=echannel, + out_channels=c, + kernel_size=k, + stride=s, + padding=p, + dilation=d, + is_last=(i == len(channels) - 1), + ) + self.add_module(f"layer{i}", conv_layer) + echannel = c # use the output channel number as the input for the next loop + if self.pooling and i in self.pooling_indices: + pooling_layer = self._get_pool_layer(self.pooling[n_poolings]) + self.add_module(f"pool{i}", pooling_layer) + n_poolings += 1 + + self.output_act = get_act_layer(output_act) if output_act else None + + @property + def final_size(self): + """ + To know the size of an image at the end of the network. + """ + return self._current_size + + @property + def size_details(self): + """ + To know the sizes of intermediate images. + """ + return self._size_details + + @final_size.setter + def final_size(self, fct: Callable[[Tuple[int, ...]], Tuple[int, ...]]): + """ + Takes as input the function used to update the current image size. + """ + if self._current_size is not None: + self._current_size = fct(self._current_size) + self._size_details.append(self._current_size) + self._check_size() + + def _get_conv_layer( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, ...], + stride: Tuple[int, ...], + padding: Tuple[int, ...], + dilation: Tuple[int, ...], + is_last: bool, + ) -> Convolution: + """ + Gets the parametrized Convolution-ADN block and updates the current output size. + """ + self.final_size = lambda size: calculate_conv_out_shape( + size, kernel_size, stride, padding, dilation + ) + + return Convolution( + conv_only=is_last, + spatial_dims=self.spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + strides=stride, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + act=self.act, + norm=self.norm, + dropout=self.dropout, + bias=self.bias, + adn_ordering=self.adn_ordering, + ) + + def _get_pool_layer(self, pooling: SingleLayerPoolingParameters) -> nn.Module: + """ + Gets the parametrized pooling layer and updates the current output size. + """ + pool_layer = get_pool_layer(pooling, spatial_dims=self.spatial_dims) + old_size = self.final_size + self.final_size = lambda size: calculate_pool_out_shape( + pool_mode=pooling[0], in_shape=size, **pool_layer.__dict__ + ) + + if ( + self.final_size is not None + and (np.array(old_size) < np.array(self.final_size)).any() + ): + raise ValueError( + f"You passed {pooling} as a pooling layer. But before this layer, the size of the image " + f"was {old_size}. So, pooling can't be performed." + ) + + return pool_layer + + def _check_size(self) -> None: + """ + Checks that image size never reaches 0. + """ + if self._current_size is not None and (np.array(self._current_size) <= 0).any(): + raise ValueError( + f"Failed to build the network. An image of size 0 or less has been reached. Stopped at:\n {self}" + ) + + @classmethod + def _check_single_pool_layer( + cls, pooling: SingleLayerPoolingParameters + ) -> SingleLayerPoolingParameters: + """ + Checks pooling arguments for a single pooling layer. + """ + if not isinstance(pooling, tuple) or len(pooling) != 2: + raise ValueError( + "pooling must be a double (or a list of doubles) with first the type of pooling and then the parameters " + f"of the pooling layer in a dict. Got {pooling}" + ) + pooling_type = PoolingLayer(pooling[0]) + args = pooling[1] + if not isinstance(args, dict): + raise ValueError( + f"The arguments of the pooling layer must be passed in a dict. Got {args}" + ) + if ( + pooling_type == PoolingLayer.MAX or pooling_type == PoolingLayer.AVG + ) and "kernel_size" not in args: + raise ValueError( + f"For {pooling_type} pooling mode, `kernel_size` argument must be passed. " + f"Got {args}" + ) + elif ( + pooling_type == PoolingLayer.ADAPT_AVG + or pooling_type == PoolingLayer.ADAPT_MAX + ) and "output_size" not in args: + raise ValueError( + f"For {pooling_type} pooling mode, `output_size` argument must be passed. " + f"Got {args}" + ) + + def _check_pool_layers( + self, pooling: PoolingParameters + ) -> List[SingleLayerPoolingParameters]: + """ + Check argument pooling. + """ + if pooling is None: + return pooling + if isinstance(pooling, list): + for pool_layer in pooling: + self._check_single_pool_layer(pool_layer) + if len(pooling) != len(self.pooling_indices): + raise ValueError( + "If you pass a list for pooling, the size of that list must match " + f"the size of pooling_indices. Got: pooling={pooling} and " + f"pooling_indices={self.pooling_indices}" + ) + elif isinstance(pooling, tuple): + self._check_single_pool_layer(pooling) + pooling = [pooling] * len(self.pooling_indices) + else: + raise ValueError( + f"pooling can be either None, a double (string, dictionary) or a list of such doubles. Got {pooling}" + ) + + return pooling diff --git a/clinicadl/networks/nn/densenet.py b/clinicadl/networks/nn/densenet.py new file mode 100644 index 000000000..45d99cc71 --- /dev/null +++ b/clinicadl/networks/nn/densenet.py @@ -0,0 +1,312 @@ +import re +from collections import OrderedDict +from enum import Enum +from typing import Any, Mapping, Optional, Sequence, Union + +import torch.nn as nn +from monai.networks.layers.utils import get_act_layer +from monai.networks.nets import DenseNet as BaseDenseNet +from torch.hub import load_state_dict_from_url +from torchvision.models.densenet import ( + DenseNet121_Weights, + DenseNet161_Weights, + DenseNet169_Weights, + DenseNet201_Weights, +) + +from .layers.utils import ActivationParameters + + +class DenseNet(nn.Sequential): + """ + DenseNet based on the [Densely Connected Convolutional Networks](https://arxiv.org/pdf/1608.06993) paper. + Adapted from [MONAI's implementation](https://docs.monai.io/en/stable/networks.html#densenet). + + The user can customize the number of dense blocks, the number of dense layers in each block, as well as + other parameters like the growth rate. + + DenseNet is a fully convolutional network that can work with input of any size, provided that is it large + enough not to be reduced to a 1-pixel image (before the adaptative average pooling). + + Parameters + ---------- + spatial_dims : int + number of spatial dimensions of the input image. + in_channels : int + number of channels in the input image. + num_outputs : Optional[int] + number of output variables after the last linear layer.\n + If None, the features before the last fully connected layer will be returned. + n_dense_layers : Sequence[int] (optional, default=(6, 12, 24, 16)) + number of dense layers in each dense block. Thus, this parameter also defines the number of dense blocks. + Default is set to DenseNet-121 parameter. + init_features : int (optional, default=64) + number of feature maps after the initial convolution. Default is set to 64, as in the original paper. + growth_rate : int (optional, default=32) + how many feature maps to add at each dense layer. Default is set to 32, as in the original paper. + bottleneck_factor : int (optional, default=4) + multiplicative factor for bottleneck layers (1x1 convolutions). The output of of these bottleneck layers will + have `bottleneck_factor * growth_rate` feature maps. Default is 4, as in the original paper. + act : ActivationParameters (optional, default=("relu", {"inplace": True})) + the activation function used in the convolutional part, and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them.\n + Default is "relu", as in the original paper. + output_act : Optional[ActivationParameters] (optional, default=None) + if `num_outputs` is not None, a potential activation layer applied to the outputs of the network. + Should be pass in the same way as `act`. + If None, no last activation will be applied. + dropout : Optional[float] (optional, default=None) + dropout ratio. If None, no dropout. + + Examples + -------- + >>> DenseNet(spatial_dims=2, in_channels=1, num_outputs=2, output_act="softmax", n_dense_layers=(2, 2)) + DenseNet( + (features): Sequential( + (conv0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act0): ReLU(inplace=True) + (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + (denseblock1): _DenseBlock( + (denselayer1): _DenseLayer( + (layers): Sequential( + (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act2): ReLU(inplace=True) + (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + ) + ) + (denselayer2): _DenseLayer( + (layers): Sequential( + (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act2): ReLU(inplace=True) + (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + ) + ) + ) + (transition1): _Transition( + (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act): ReLU(inplace=True) + (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) + (pool): AvgPool2d(kernel_size=2, stride=2, padding=0) + ) + (denseblock2): _DenseBlock( + (denselayer1): _DenseLayer( + (layers): Sequential( + (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act2): ReLU(inplace=True) + (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + ) + ) + (denselayer2): _DenseLayer( + (layers): Sequential( + (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act2): ReLU(inplace=True) + (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + ) + ) + ) + (norm5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (fc): Sequential( + (act): ReLU(inplace=True) + (pool): AdaptiveAvgPool2d(output_size=1) + (flatten): Flatten(start_dim=1, end_dim=-1) + (out): Linear(in_features=128, out_features=2, bias=True) + (output_act): Softmax(dim=None) + ) + ) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_outputs: Optional[int], + n_dense_layers: Sequence[int] = (6, 12, 24, 16), + init_features: int = 64, + growth_rate: int = 32, + bottleneck_factor: int = 4, + act: ActivationParameters = ("relu", {"inplace": True}), + output_act: Optional[ActivationParameters] = None, + dropout: Optional[float] = None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.num_outputs = num_outputs + self.n_dense_layers = n_dense_layers + self.init_features = init_features + self.growth_rate = growth_rate + self.bottleneck_factor = bottleneck_factor + self.act = act + self.dropout = dropout + + base_densenet = BaseDenseNet( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_outputs if num_outputs else 1, + init_features=init_features, + growth_rate=growth_rate, + block_config=n_dense_layers, + bn_size=bottleneck_factor, + act=act, + dropout_prob=dropout if dropout else 0.0, + ) + self.features = base_densenet.features + self.fc = base_densenet.class_layers if num_outputs else None + if self.fc: + self.fc.output_act = get_act_layer(output_act) if output_act else None + + self._rename_act(self) + + @classmethod + def _rename_act(cls, module: nn.Module) -> None: + """ + Rename activation layers from 'relu' to 'act'. + """ + for name, layer in list(module.named_children()): + if "relu" in name: + module._modules = OrderedDict( # pylint: disable=protected-access + [ + (key.replace("relu", "act"), sub_m) + for key, sub_m in module._modules.items() # pylint: disable=protected-access + ] + ) + else: + cls._rename_act(layer) + + +class SOTADenseNet(str, Enum): + """Supported DenseNet networks.""" + + DENSENET_121 = "DenseNet-121" + DENSENET_161 = "DenseNet-161" + DENSENET_169 = "DenseNet-169" + DENSENET_201 = "DenseNet-201" + + +def get_densenet( + name: Union[str, SOTADenseNet], + num_outputs: Optional[int], + output_act: ActivationParameters = None, + pretrained: bool = False, +) -> DenseNet: + """ + To get a DenseNet implemented in the [Densely Connected Convolutional Networks](https://arxiv.org/pdf/1608.06993) + paper. + + Only the last fully connected layer will be changed to match `num_outputs`. + + The user can also use the pretrained models from `torchvision`. Note that the last fully connected layer will not + used pretrained weights, as it is task specific. + + .. warning:: `DenseNet-121`, `DenseNet-161`, `DenseNet-169` and `DenseNet-201` only works with 2D images with 3 channels. + + Notes: `torchvision` does not provide an implementation for `DenseNet-264` but provides a `DenseNet-161` that is not + mentioned in the paper. + + Parameters + ---------- + name : Union[str, SOTADenseNet] + The name of the DenseNet. Available networks are `DenseNet-121`, `DenseNet-161`, `DenseNet-169` and `DenseNet-201`. + num_outputs : Optional[int] + number of output variables after the last linear layer.\n + If None, the features before the last fully connected layer will be returned. + output_act : ActivationParameters (optional, default=None) + if `num_outputs` is not None, a potential activation layer applied to the outputs of the network, + and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + pretrained : bool (optional, default=False) + whether to use pretrained weights. The pretrained weights used are the default ones from [torchvision](https:// + pytorch.org/vision/main/models/densenet.html). + + Returns + ------- + DenseNet + The network, with potentially pretrained weights. + """ + name = SOTADenseNet(name) + if name == SOTADenseNet.DENSENET_121: + n_dense_layers = (6, 12, 24, 16) + growth_rate = 32 + init_features = 64 + model_url = DenseNet121_Weights.DEFAULT.url + elif name == SOTADenseNet.DENSENET_161: + n_dense_layers = (6, 12, 36, 24) + growth_rate = 48 + init_features = 96 + model_url = DenseNet161_Weights.DEFAULT.url + elif name == SOTADenseNet.DENSENET_169: + n_dense_layers = (6, 12, 32, 32) + growth_rate = 32 + init_features = 64 + model_url = DenseNet169_Weights.DEFAULT.url + elif name == SOTADenseNet.DENSENET_201: + n_dense_layers = (6, 12, 48, 32) + growth_rate = 32 + init_features = 64 + model_url = DenseNet201_Weights.DEFAULT.url + + # pylint: disable=possibly-used-before-assignment + densenet = DenseNet( + spatial_dims=2, + in_channels=3, + num_outputs=num_outputs, + n_dense_layers=n_dense_layers, + growth_rate=growth_rate, + init_features=init_features, + output_act=output_act, + ) + if not pretrained: + return densenet + + pretrained_dict = load_state_dict_from_url(model_url, progress=True) + features_state_dict = { + k.replace("features.", ""): v + for k, v in pretrained_dict.items() + if "classifier" not in k + } + densenet.features.load_state_dict(_state_dict_adapter(features_state_dict)) + + return densenet + + +def _state_dict_adapter(state_dict: Mapping[str, Any]) -> Mapping[str, Any]: + """ + To update the old nomenclature in the pretrained state dict. + Adapted from `_load_state_dict` in [torchvision.models.densenet](https://pytorch.org/vision/main + /_modules/torchvision/models/densenet.html). + """ + pattern = re.compile( + r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + ) + + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + new_key = re.sub(r"^(.*denselayer\d+)\.", r"\1.layers.", new_key) + state_dict[new_key] = state_dict[key] + del state_dict[key] + + return state_dict diff --git a/clinicadl/networks/nn/generator.py b/clinicadl/networks/nn/generator.py new file mode 100644 index 000000000..5f68a2e58 --- /dev/null +++ b/clinicadl/networks/nn/generator.py @@ -0,0 +1,131 @@ +from typing import Any, Dict, Optional, Sequence + +import numpy as np +import torch.nn as nn +from monai.networks.layers.simplelayers import Reshape + +from .conv_decoder import ConvDecoder +from .mlp import MLP +from .utils import check_conv_args, check_mlp_args + + +class Generator(nn.Sequential): + """ + A generator with first fully connected layers and then convolutional layers. + + This network is a simple aggregation of a Multi Layer Perceptron (:py:class: + `clinicadl.monai_networks.nn.mlp.MLP`) and a Fully Convolutional Network + (:py:class:`clinicadl.monai_networks.nn.conv_decoder.ConvDecoder`). + + Parameters + ---------- + latent_size : int + size of the latent vector. + start_shape : Sequence[int] + sequence of integers stating the initial shape of the image, i.e. the shape at the + beginning of the convolutional part (minus batch dimension, but including the number + of channels).\n + Thus, `start_shape` determines the dimension of the output of the generator (the exact + shape depends on the convolutional part and can be accessed via the class attribute + `output_shape`). + conv_args : Dict[str, Any] + the arguments for the convolutional part. The arguments are those accepted by + :py:class:`clinicadl.monai_networks.nn.conv_decoder.ConvDecoder`, except `in_shape` that + is specified here via `start_shape`. So, the only mandatory argument is `channels`. + mlp_args : Optional[Dict[str, Any]] (optional, default=None) + the arguments for the MLP part. The arguments are those accepted by + :py:class:`clinicadl.monai_networks.nn.mlp.MLP`, except `in_channels` that is specified + here via `latent_size`, and `out_channels` that is inferred from `start_shape`. + So, the only mandatory argument is `hidden_channels`.\n + If None, the MLP part will be reduced to a single linear layer. + + Examples + -------- + >>> Generator( + latent_size=8, + start_shape=(8, 2, 2), + conv_args={"channels": [4, 2], "norm": None, "act": None}, + mlp_args={"hidden_channels": [16], "act": "elu", "norm": None}, + ) + Generator( + (mlp): MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (hidden0): Sequential( + (linear): Linear(in_features=8, out_features=16, bias=True) + (adn): ADN( + (A): ELU(alpha=1.0) + ) + ) + (output): Linear(in_features=16, out_features=32, bias=True) + ) + (reshape): Reshape() + (convolutions): ConvDecoder( + (layer0): Convolution( + (conv): ConvTranspose2d(8, 4, kernel_size=(3, 3), stride=(1, 1)) + ) + (layer1): Convolution( + (conv): ConvTranspose2d(4, 2, kernel_size=(3, 3), stride=(1, 1)) + ) + ) + ) + + >>> Generator( + latent_size=8, + start_shape=(8, 2, 2), + conv_args={"channels": [4, 2], "norm": None, "act": None, "output_act": "relu"}, + ) + Generator( + (mlp): MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (output): Linear(in_features=8, out_features=32, bias=True) + ) + (reshape): Reshape() + (convolutions): ConvDecoder( + (layer0): Convolution( + (conv): ConvTranspose2d(8, 4, kernel_size=(3, 3), stride=(1, 1)) + ) + (layer1): Convolution( + (conv): ConvTranspose2d(4, 2, kernel_size=(3, 3), stride=(1, 1)) + ) + (output_act): ReLU() + ) + ) + """ + + def __init__( + self, + latent_size: int, + start_shape: Sequence[int], + conv_args: Dict[str, Any], + mlp_args: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + check_conv_args(conv_args) + check_mlp_args(mlp_args) + self.latent_size = latent_size + self.start_shape = start_shape + + flatten_shape = int(np.prod(start_shape)) + if mlp_args is None: + mlp_args = {"hidden_channels": []} + self.mlp = MLP( + in_channels=latent_size, + out_channels=flatten_shape, + **mlp_args, + ) + + self.reshape = Reshape(*start_shape) + inter_channels, *inter_size = start_shape + self.convolutions = ConvDecoder( + in_channels=inter_channels, + spatial_dims=len(inter_size), + _input_size=inter_size, + **conv_args, + ) + + n_channels = ( + conv_args["channels"][-1] + if len(conv_args["channels"]) > 0 + else start_shape[0] + ) + self.output_shape = (n_channels, *self.convolutions.final_size) diff --git a/clinicadl/network/vae/__init__.py b/clinicadl/networks/nn/layers/__init__.py similarity index 100% rename from clinicadl/network/vae/__init__.py rename to clinicadl/networks/nn/layers/__init__.py diff --git a/clinicadl/networks/nn/layers/resnet.py b/clinicadl/networks/nn/layers/resnet.py new file mode 100644 index 000000000..c115da512 --- /dev/null +++ b/clinicadl/networks/nn/layers/resnet.py @@ -0,0 +1,124 @@ +from collections.abc import Callable +from typing import Optional + +import torch +import torch.nn as nn +from monai.networks.layers.factories import Conv, Norm +from monai.networks.layers.utils import get_act_layer + +from .utils import ActivationParameters + + +class ResNetBlock(nn.Module): + """ + ResNet basic block. Adapted from MONAI's implementation: + https://github.com/Project-MONAI/MONAI/blob/59a7211070538586369afd4a01eca0a7fe2e742e/ + monai/networks/nets/resnet.py#L71 + """ + + expansion = 1 + + def __init__( + self, + in_planes: int, + planes: int, + spatial_dims: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + act: ActivationParameters = ("relu", {"inplace": True}), + ) -> None: + super().__init__() + + conv_type: Callable = Conv[Conv.CONV, spatial_dims] + norm_type: Callable = Norm[Norm.BATCH, spatial_dims] + + self.conv1 = conv_type( # pylint: disable=not-callable + in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False + ) + self.norm1 = norm_type(planes) # pylint: disable=not-callable + self.act1 = get_act_layer(name=act) + self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False) # pylint: disable=not-callable + self.norm2 = norm_type(planes) # pylint: disable=not-callable + self.downsample = downsample + self.act2 = get_act_layer(name=act) + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + out: torch.Tensor = self.conv1(x) + out = self.norm1(out) + out = self.act1(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act2(out) + + return out + + +class ResNetBottleneck(nn.Module): + """ + ResNet bottleneck block. Adapted from MONAI's implementation: + https://github.com/Project-MONAI/MONAI/blob/59a7211070538586369afd4a01eca0a7fe2e742e/ + monai/networks/nets/resnet.py#L124 + """ + + expansion = 4 + + def __init__( + self, + in_planes: int, + planes: int, + spatial_dims: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + act: ActivationParameters = ("relu", {"inplace": True}), + ) -> None: + super().__init__() + + conv_type: Callable = Conv[Conv.CONV, spatial_dims] + norm_type: Callable = Norm[Norm.BATCH, spatial_dims] + + self.conv1 = conv_type(in_planes, planes, kernel_size=1, bias=False) # pylint: disable=not-callable + self.norm1 = norm_type(planes) # pylint: disable=not-callable + self.act1 = get_act_layer(name=act) + self.conv2 = conv_type( # pylint: disable=not-callable + planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.norm2 = norm_type(planes) # pylint: disable=not-callable + self.act2 = get_act_layer(name=act) + self.conv3 = conv_type( # pylint: disable=not-callable + planes, planes * self.expansion, kernel_size=1, bias=False + ) + self.norm3 = norm_type(planes * self.expansion) # pylint: disable=not-callable + self.downsample = downsample + self.act3 = get_act_layer(name=act) + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + out: torch.Tensor = self.conv1(x) + out = self.norm1(out) + out = self.act1(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.act2(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act3(out) + + return out diff --git a/clinicadl/networks/nn/layers/senet.py b/clinicadl/networks/nn/layers/senet.py new file mode 100644 index 000000000..8847ef577 --- /dev/null +++ b/clinicadl/networks/nn/layers/senet.py @@ -0,0 +1,142 @@ +from typing import Callable, Optional + +import torch +import torch.nn as nn +from monai.networks.blocks.squeeze_and_excitation import ChannelSELayer +from monai.networks.layers.factories import Conv, Norm +from monai.networks.layers.utils import get_act_layer + +from .utils import ActivationParameters + + +class SEResNetBlock(nn.Module): + """ + ResNet basic block. Adapted from MONAI's ResNetBlock: + https://github.com/Project-MONAI/MONAI/blob/59a7211070538586369afd4a01eca0a7fe2e742e/ + monai/networks/nets/resnet.py#L71 + """ + + expansion = 1 + reduction = 16 + + def __init__( + self, + in_planes: int, + planes: int, + spatial_dims: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + act: ActivationParameters = ("relu", {"inplace": True}), + ) -> None: + super().__init__() + + conv_type: Callable = Conv[Conv.CONV, spatial_dims] + norm_type: Callable = Norm[Norm.BATCH, spatial_dims] + + self.conv1 = conv_type( # pylint: disable=not-callable + in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False + ) + self.norm1 = norm_type(planes) # pylint: disable=not-callable + self.act1 = get_act_layer(name=act) + self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False) # pylint: disable=not-callable + self.norm2 = norm_type(planes) # pylint: disable=not-callable + self.se_layer = ChannelSELayer( + spatial_dims=spatial_dims, + in_channels=planes, + r=self.reduction, + acti_type_1=("relu", {"inplace": True}), + acti_type_2="sigmoid", + ) + self.downsample = downsample + self.act2 = get_act_layer(name=act) + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.act1(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = self.se_layer(out) + out += residual + out = self.act2(out) + + return out + + +class SEResNetBottleneck(nn.Module): + """ + ResNet bottleneck block. Adapted from MONAI's ResNetBottleneck: + https://github.com/Project-MONAI/MONAI/blob/59a7211070538586369afd4a01eca0a7fe2e742e/ + monai/networks/nets/resnet.py#L124 + """ + + expansion = 4 + reduction = 16 + + def __init__( + self, + in_planes: int, + planes: int, + spatial_dims: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + act: ActivationParameters = ("relu", {"inplace": True}), + ) -> None: + super().__init__() + + conv_type: Callable = Conv[Conv.CONV, spatial_dims] + norm_type: Callable = Norm[Norm.BATCH, spatial_dims] + + self.conv1 = conv_type(in_planes, planes, kernel_size=1, bias=False) # pylint: disable=not-callable + self.norm1 = norm_type(planes) # pylint: disable=not-callable + self.act1 = get_act_layer(name=act) + self.conv2 = conv_type( # pylint: disable=not-callable + planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.norm2 = norm_type(planes) # pylint: disable=not-callable + self.act2 = get_act_layer(name=act) + self.conv3 = conv_type( # pylint: disable=not-callable + planes, planes * self.expansion, kernel_size=1, bias=False + ) + self.norm3 = norm_type(planes * self.expansion) # pylint: disable=not-callable + self.se_layer = ChannelSELayer( + spatial_dims=spatial_dims, + in_channels=planes * self.expansion, + r=self.reduction, + acti_type_1=("relu", {"inplace": True}), + acti_type_2="sigmoid", + ) + self.downsample = downsample + self.act3 = get_act_layer(name=act) + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.act1(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.act2(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = self.se_layer(out) + out += residual + out = self.act3(out) + + return out diff --git a/clinicadl/networks/nn/layers/unet.py b/clinicadl/networks/nn/layers/unet.py new file mode 100644 index 000000000..2186425be --- /dev/null +++ b/clinicadl/networks/nn/layers/unet.py @@ -0,0 +1,102 @@ +from typing import Optional + +import torch.nn as nn +from monai.networks.blocks.convolutions import Convolution +from monai.networks.layers.utils import get_pool_layer + +from .utils import ActFunction, ActivationParameters, NormLayer + + +class ConvBlock(nn.Sequential): + """UNet doouble convolution block.""" + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + act: ActivationParameters = ActFunction.RELU, + dropout: Optional[float] = None, + ): + super().__init__() + self.add_module( + "0", + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + strides=1, + padding=None, + adn_ordering="NDA", + act=act, + norm=NormLayer.BATCH, + dropout=dropout, + ), + ) + self.add_module( + "1", + Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + strides=1, + padding=None, + adn_ordering="NDA", + act=act, + norm=NormLayer.BATCH, + dropout=dropout, + ), + ) + + +class UpSample(nn.Sequential): + """UNet up-conv block with first upsampling and then a convolution.""" + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + act: ActivationParameters = ActFunction.RELU, + dropout: Optional[float] = None, + ): + super().__init__() + self.add_module("0", nn.Upsample(scale_factor=2)) + self.add_module( + "1", + Convolution( + spatial_dims, + in_channels, + out_channels, + strides=1, + kernel_size=3, + act=act, + adn_ordering="NDA", + norm=NormLayer.BATCH, + dropout=dropout, + ), + ) + + +class DownBlock(nn.Sequential): + """UNet down block with first max pooling and then two convolutions.""" + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + act: ActivationParameters = ActFunction.RELU, + dropout: Optional[float] = None, + ): + super().__init__() + self.pool = get_pool_layer(("max", {"kernel_size": 2}), spatial_dims) + self.doubleconv = ConvBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + act=act, + dropout=dropout, + ) diff --git a/clinicadl/networks/nn/layers/unpool.py b/clinicadl/networks/nn/layers/unpool.py new file mode 100644 index 000000000..1c90fde90 --- /dev/null +++ b/clinicadl/networks/nn/layers/unpool.py @@ -0,0 +1,87 @@ +from typing import Any, Dict, Optional, Tuple, Type, Union + +import torch.nn as nn +from monai.networks.layers.factories import LayerFactory, split_args +from monai.utils import has_option + +from .utils import UnpoolingLayer + +Unpool = LayerFactory( + name="Unpooling layers", description="Factory for creating unpooling layers." +) + + +@Unpool.factory_function("upsample") +def upsample_factory(dim: int) -> Type[nn.Upsample]: + """ + Upsample layer. + """ + return nn.Upsample + + +@Unpool.factory_function("convtranspose") +def convtranspose_factory( + dim: int, +) -> Type[Union[nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d]]: + """ + Transposed convolutional layers in 1,2,3 dimensions. + + Parameters + ---------- + dim : int + desired dimension of the transposed convolutional layer. + + Returns + ------- + type[Union[nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d]] + ConvTranspose[dim]d + """ + types = (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d) + return types[dim - 1] + + +def get_unpool_layer( + name: Union[UnpoolingLayer, Tuple[UnpoolingLayer, Dict[str, Any]]], + spatial_dims: int, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, +) -> nn.Module: + """ + Creates an unpooling layer instance. + + Parameters + ---------- + name : Union[UnpoolingLayer, Tuple[UnpoolingLayer, Dict[str, Any]]] + the unpooling type, potentially with arguments in a dict. + + Returns + ------- + nn.Module + the parametrized unpooling layer. + + Parameters + ---------- + name : Union[UnpoolingLayer, Tuple[UnpoolingLayer, Dict[str, Any]]] + the unpooling type, potentially with arguments in a dict. + spatial_dims : int + number of spatial dimensions of the input. + in_channels : Optional[int] (optional, default=None) + number of input channels if the unpool layer requires this parameter. + out_channels : Optional[int] (optional, default=None) + number of output channels if the unpool layer requires this parameter. + + Returns + ------- + nn.Module + the parametrized unpooling layer. + """ + unpool_name, unpool_args = split_args(name) + unpool_name = UnpoolingLayer(unpool_name) + unpool_type = Unpool[unpool_name, spatial_dims] + kw_args = dict(unpool_args) + if has_option(unpool_type, "in_channels") and "in_channels" not in kw_args: + kw_args["in_channels"] = in_channels + if has_option(unpool_type, "out_channels") and "out_channels" not in kw_args: + kw_args["out_channels"] = out_channels + + return unpool_type(**kw_args) # pylint: disable=not-callable diff --git a/clinicadl/networks/nn/layers/utils/__init__.py b/clinicadl/networks/nn/layers/utils/__init__.py new file mode 100644 index 000000000..5c080fffd --- /dev/null +++ b/clinicadl/networks/nn/layers/utils/__init__.py @@ -0,0 +1,19 @@ +from .enum import ( + ActFunction, + ConvNormLayer, + NormLayer, + PoolingLayer, + UnpoolingLayer, + UnpoolingMode, +) +from .types import ( + ActivationParameters, + ConvNormalizationParameters, + ConvParameters, + NormalizationParameters, + PoolingParameters, + SingleLayerConvParameter, + SingleLayerPoolingParameters, + SingleLayerUnpoolingParameters, + UnpoolingParameters, +) diff --git a/clinicadl/networks/nn/layers/utils/enum.py b/clinicadl/networks/nn/layers/utils/enum.py new file mode 100644 index 000000000..695776551 --- /dev/null +++ b/clinicadl/networks/nn/layers/utils/enum.py @@ -0,0 +1,65 @@ +from clinicadl.utils.enum import CaseInsensitiveEnum + + +class UnpoolingLayer(CaseInsensitiveEnum): + """Supported unpooling layers in ClinicaDL.""" + + CONV_TRANS = "convtranspose" + UPSAMPLE = "upsample" + + +class ActFunction(CaseInsensitiveEnum): + """Supported activation functions in ClinicaDL.""" + + ELU = "elu" + RELU = "relu" + LEAKY_RELU = "leakyrelu" + PRELU = "prelu" + RELU6 = "relu6" + SELU = "selu" + CELU = "celu" + GELU = "gelu" + SIGMOID = "sigmoid" + TANH = "tanh" + SOFTMAX = "softmax" + LOGSOFTMAX = "logsoftmax" + MISH = "mish" + + +class PoolingLayer(CaseInsensitiveEnum): + """Supported pooling layers in ClinicaDL.""" + + MAX = "max" + AVG = "avg" + ADAPT_AVG = "adaptiveavg" + ADAPT_MAX = "adaptivemax" + + +class NormLayer(CaseInsensitiveEnum): + """Supported normalization layers in ClinicaDL.""" + + GROUP = "group" + LAYER = "layer" + SYNCBATCH = "syncbatch" + BATCH = "batch" + INSTANCE = "instance" + + +class ConvNormLayer(CaseInsensitiveEnum): + """Supported normalization layers with convolutions in ClinicaDL.""" + + GROUP = "group" + SYNCBATCH = "syncbatch" + BATCH = "batch" + INSTANCE = "instance" + + +class UnpoolingMode(CaseInsensitiveEnum): + """Supported unpooling mode for AutoEncoders in ClinicaDL.""" + + NEAREST = "nearest" + LINEAR = "linear" + BILINEAR = "bilinear" + BICUBIC = "bicubic" + TRILINEAR = "trilinear" + CONV_TRANS = "convtranspose" diff --git a/clinicadl/networks/nn/layers/utils/types.py b/clinicadl/networks/nn/layers/utils/types.py new file mode 100644 index 000000000..f5ef18847 --- /dev/null +++ b/clinicadl/networks/nn/layers/utils/types.py @@ -0,0 +1,37 @@ +from typing import Any, Dict, List, Tuple, Union + +from .enum import ( + ActFunction, + ConvNormLayer, + NormLayer, + PoolingLayer, + UnpoolingLayer, +) + +SingleLayerConvParameter = Union[int, Tuple[int, ...]] +ConvParameters = Union[SingleLayerConvParameter, List[SingleLayerConvParameter]] + +PoolingType = Union[str, PoolingLayer] +SingleLayerPoolingParameters = Tuple[PoolingType, Dict[str, Any]] +PoolingParameters = Union[ + SingleLayerPoolingParameters, List[SingleLayerPoolingParameters] +] + +UnpoolingType = Union[str, UnpoolingLayer] +SingleLayerUnpoolingParameters = Tuple[UnpoolingType, Dict[str, Any]] +UnpoolingParameters = Union[ + SingleLayerUnpoolingParameters, List[SingleLayerUnpoolingParameters] +] + +NormalizationType = Union[str, NormLayer] +NormalizationParameters = Union[ + NormalizationType, Tuple[NormalizationType, Dict[str, Any]] +] + +ConvNormalizationType = Union[str, ConvNormLayer] +ConvNormalizationParameters = Union[ + ConvNormalizationType, Tuple[ConvNormalizationType, Dict[str, Any]] +] + +ActivationType = Union[str, ActFunction] +ActivationParameters = Union[ActivationType, Tuple[ActivationType, Dict[str, Any]]] diff --git a/clinicadl/networks/nn/layers/vit.py b/clinicadl/networks/nn/layers/vit.py new file mode 100644 index 000000000..e485d6c6b --- /dev/null +++ b/clinicadl/networks/nn/layers/vit.py @@ -0,0 +1,94 @@ +from functools import partial +from typing import Callable, Optional + +import torch +import torch.nn as nn +from torchvision.models.vision_transformer import MLPBlock + + +class EncoderBlock(nn.Module): + """Transformer encoder block.""" + + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ) -> None: + super().__init__() + self.num_heads = num_heads + + # Attention block + self.norm1 = norm_layer(hidden_dim) + self.self_attention = nn.MultiheadAttention( + hidden_dim, num_heads, dropout=attention_dropout, batch_first=True + ) + self.dropout = nn.Dropout(dropout) + + # MLP block + self.norm2 = norm_layer(hidden_dim) + self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + x = self.norm1(x) + x, _ = self.self_attention(x, x, x, need_weights=False) + x = self.dropout(x) + x += residual + + y = self.norm2(x) + y = self.mlp(y) + return x + y + + +class Encoder(nn.Module): + """Encoder with multiple transformer blocks.""" + + def __init__( + self, + seq_length: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + pos_embedding: Optional[nn.Parameter] = None, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ) -> None: + super().__init__() + + if pos_embedding is not None: + self.pos_embedding = pos_embedding + else: + self.pos_embedding = nn.Parameter( + torch.empty(1, seq_length, hidden_dim).normal_(std=0.02) + ) # from BERT + self.dropout = nn.Dropout(dropout) + self.layers = nn.ModuleList( + [ + EncoderBlock( + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, + ) + for _ in range(num_layers) + ] + ) + self.norm = norm_layer(hidden_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.pos_embedding + + x = self.dropout(x) + for layer in self.layers: + x = layer(x) + + return self.norm(x) diff --git a/clinicadl/networks/nn/mlp.py b/clinicadl/networks/nn/mlp.py new file mode 100644 index 000000000..a27b2ad4e --- /dev/null +++ b/clinicadl/networks/nn/mlp.py @@ -0,0 +1,146 @@ +from collections import OrderedDict +from typing import Optional, Sequence + +import torch.nn as nn +from monai.networks.blocks import ADN +from monai.networks.layers.utils import get_act_layer +from monai.networks.nets import FullyConnectedNet as BaseMLP + +from .layers.utils import ( + ActFunction, + ActivationParameters, + NormalizationParameters, + NormLayer, +) +from .utils import check_adn_ordering, check_norm_layer + + +class MLP(BaseMLP): + """Simple full-connected layer neural network (or Multi-Layer Perceptron) with linear, normalization, activation + and dropout layers. + + Parameters + ---------- + in_channels : int + number of input channels (i.e. number of features). + out_channels : int + number of output channels. + hidden_channels : Sequence[int] + number of output channels for each hidden layer. Thus, this parameter also controls the number of hidden layers. + act : Optional[ActivationParameters] (optional, default=ActFunction.PRELU) + the activation function used after a linear layer, and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + output_act : Optional[ActivationParameters] (optional, default=None) + a potential activation layer applied to the output of the network. Should be pass in the same way as `act`. + If None, no last activation will be applied. + norm : Optional[NormalizationParameters] (optional, default=NormLayer.BATCH) + the normalization type used after a linear layer, and optionally the arguments of the normalization + layer. Should be passed as `norm_type` or `(norm_type, parameters)`. If None, no normalization will be + performed.\n + `norm_type` can be any value in {`batch`, `group`, `instance`, `layer`, `syncbatch`}. Please refer to PyTorch's + [normalization layers](https://pytorch.org/docs/stable/nn.html#normalization-layers) to know the mandatory and + optional arguments for each of them.\n + Please note that arguments `num_channels`, `num_features` and `normalized_shape` of the normalization layer + should not be passed, as they are automatically inferred from the output of the previous layer in the network. + dropout : Optional[float] (optional, default=None) + dropout ratio. If None, no dropout. + bias : bool (optional, default=True) + whether to have a bias term in linear layers. + adn_ordering : str (optional, default="NDA") + order of operations `Activation`, `Dropout` and `Normalization` after a linear layer (except the last + one). + For example if "ND" is passed, `Normalization` and then `Dropout` will be performed (without `Activation`).\n + Note: ADN will not be applied after the last linear layer. + + Examples + -------- + >>> MLP(in_channels=12, out_channels=2, hidden_channels=[8, 4], dropout=0.1, act=("elu", {"alpha": 0.5}), + norm=("group", {"num_groups": 2}), bias=True, adn_ordering="ADN", output_act="softmax") + MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (hidden0): Sequential( + (linear): Linear(in_features=12, out_features=8, bias=True) + (adn): ADN( + (A): ELU(alpha=0.5) + (D): Dropout(p=0.1, inplace=False) + (N): GroupNorm(2, 8, eps=1e-05, affine=True) + ) + ) + (hidden1): Sequential( + (linear): Linear(in_features=8, out_features=4, bias=True) + (adn): ADN( + (A): ELU(alpha=0.5) + (D): Dropout(p=0.1, inplace=False) + (N): GroupNorm(2, 4, eps=1e-05, affine=True) + ) + ) + (output): Sequential( + (linear): Linear(in_features=4, out_features=2, bias=True) + (output_act): Softmax(dim=None) + ) + ) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: Sequence[int], + act: Optional[ActivationParameters] = ActFunction.PRELU, + output_act: Optional[ActivationParameters] = None, + norm: Optional[NormalizationParameters] = NormLayer.BATCH, + dropout: Optional[float] = None, + bias: bool = True, + adn_ordering: str = "NDA", + ) -> None: + self.norm = check_norm_layer(norm) + super().__init__( + in_channels, + out_channels, + hidden_channels, + dropout, + act, + bias, + check_adn_ordering(adn_ordering), + ) + self.output = nn.Sequential(OrderedDict([("linear", self.output)])) + self.output.output_act = get_act_layer(output_act) if output_act else None + # renaming + self._modules = OrderedDict( + [ + (key.replace("hidden_", "hidden"), sub_m) + for key, sub_m in self._modules.items() + ] + ) + + def _get_layer(self, in_channels: int, out_channels: int, bias: bool) -> nn.Module: + """ + Gets the parametrized Linear layer + ADN block. + """ + if self.norm == NormLayer.LAYER: + norm = ("layer", {"normalized_shape": out_channels}) + else: + norm = self.norm + seq = nn.Sequential( + OrderedDict( + [ + ("linear", nn.Linear(in_channels, out_channels, bias)), + ( + "adn", + ADN( + ordering=self.adn_ordering, + act=self.act, + norm=norm, + dropout=self.dropout, + dropout_dim=1, + in_channels=out_channels, + ), + ), + ] + ) + ) + return seq diff --git a/clinicadl/networks/nn/resnet.py b/clinicadl/networks/nn/resnet.py new file mode 100644 index 000000000..1ba90b30c --- /dev/null +++ b/clinicadl/networks/nn/resnet.py @@ -0,0 +1,566 @@ +import re +from collections import OrderedDict +from copy import deepcopy +from enum import Enum +from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union + +import torch +import torch.nn as nn +from monai.networks.layers.factories import Conv, Norm, Pool +from monai.networks.layers.utils import get_act_layer +from monai.utils import ensure_tuple_rep +from torch.hub import load_state_dict_from_url +from torchvision.models.resnet import ( + ResNet18_Weights, + ResNet34_Weights, + ResNet50_Weights, + ResNet101_Weights, + ResNet152_Weights, +) + +from .layers.resnet import ResNetBlock, ResNetBottleneck +from .layers.senet import SEResNetBlock, SEResNetBottleneck +from .layers.utils import ActivationParameters + + +class ResNetBlockType(str, Enum): + """Supported ResNet blocks.""" + + BASIC = "basic" + BOTTLENECK = "bottleneck" + + +class GeneralResNet(nn.Module): + """Common base class for ResNet and SEResNet.""" + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_outputs: Optional[int], + block_type: Union[str, ResNetBlockType], + n_res_blocks: Sequence[int], + n_features: Sequence[int], + init_conv_size: Union[Sequence[int], int], + init_conv_stride: Union[Sequence[int], int], + bottleneck_reduction: int, + se_reduction: Optional[int], + act: ActivationParameters, + output_act: ActivationParameters, + ) -> None: + super().__init__() + + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.num_outputs = num_outputs + self.block_type = block_type + self._check_args_consistency(n_res_blocks, n_features) + self.n_res_blocks = n_res_blocks + self.n_features = n_features + self.bottleneck_reduction = bottleneck_reduction + self.se_reduction = se_reduction + self.act = act + self.squeeze_excitation = True if se_reduction else False + + self.init_conv_size = ensure_tuple_rep(init_conv_size, spatial_dims) + self.init_conv_stride = ensure_tuple_rep(init_conv_stride, spatial_dims) + + block, in_planes = self._get_block(block_type) + + conv_type, norm_type, pool_type, avgp_type = self._get_layers() + + block_avgpool = [0, 1, (1, 1), (1, 1, 1)] + + self.in_planes = in_planes[0] + self.n_layers = len(in_planes) + self.bias_downsample = False + + self.conv0 = conv_type( # pylint: disable=not-callable + in_channels, + self.in_planes, + kernel_size=self.init_conv_size, + stride=self.init_conv_stride, + padding=tuple(k // 2 for k in self.init_conv_size), + bias=False, + ) + self.norm0 = norm_type(self.in_planes) # pylint: disable=not-callable + self.act0 = get_act_layer(name=act) + self.pool0 = pool_type(kernel_size=3, stride=2, padding=1) # pylint: disable=not-callable + self.layer1 = self._make_resnet_layer( + block, in_planes[0], n_res_blocks[0], spatial_dims, act + ) + for i, (n_blocks, n_feats) in enumerate( + zip(n_res_blocks[1:], in_planes[1:]), start=2 + ): + self.add_module( + f"layer{i}", + self._make_resnet_layer( + block, + planes=n_feats, + blocks=n_blocks, + spatial_dims=spatial_dims, + stride=2, + act=act, + ), + ) + self.fc = ( + nn.Sequential( + OrderedDict( + [ + ("pool", avgp_type(block_avgpool[spatial_dims])), # pylint: disable=not-callable + ("flatten", nn.Flatten(1)), + ("out", nn.Linear(n_features[-1], num_outputs)), + ] + ) + ) + if num_outputs + else None + ) + if self.fc: + self.fc.output_act = get_act_layer(output_act) if output_act else None + + self._init_module(conv_type, norm_type) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv0(x) + x = self.norm0(x) + x = self.act0(x) + x = self.pool0(x) + + for i in range(1, self.n_layers + 1): + x = self.get_submodule(f"layer{i}")(x) + + if self.fc is not None: + x = self.fc(x) + + return x + + def _get_block(self, block_type: Union[str, ResNetBlockType]) -> nn.Module: + """ + Gets the residual block, depending on the block choice made by the user and depending + on whether squeeze-excitation mode or not. + """ + block_type = ResNetBlockType(block_type) + if block_type == ResNetBlockType.BASIC: + in_planes = self.n_features + if self.squeeze_excitation: + block = SEResNetBlock + block.reduction = self.se_reduction + else: + block = ResNetBlock + elif block_type == ResNetBlockType.BOTTLENECK: + in_planes = self._bottleneck_reduce( + self.n_features, self.bottleneck_reduction + ) + if self.squeeze_excitation: + block = SEResNetBottleneck + block.reduction = self.se_reduction + else: + block = ResNetBottleneck + block.expansion = self.bottleneck_reduction + + return block, in_planes + + def _get_layers(self): + """ + Gets convolution, normalization, pooling and adaptative average pooling layers. + """ + conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[ + Conv.CONV, self.spatial_dims + ] + norm_type: Type[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[ + Norm.BATCH, self.spatial_dims + ] + pool_type: Type[Union[nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]] = Pool[ + Pool.MAX, self.spatial_dims + ] + avgp_type: Type[ + Union[nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d] + ] = Pool[Pool.ADAPTIVEAVG, self.spatial_dims] + + return conv_type, norm_type, pool_type, avgp_type + + def _make_resnet_layer( + self, + block: Type[Union[ResNetBlock, ResNetBottleneck]], + planes: int, + blocks: int, + spatial_dims: int, + act: ActivationParameters, + stride: int = 1, + ) -> nn.Sequential: + """ + Builds a ResNet layer. + """ + conv_type: Callable = Conv[Conv.CONV, spatial_dims] + norm_type: Callable = Norm[Norm.BATCH, spatial_dims] + + downsample = None + if stride != 1 or self.in_planes != planes * block.expansion: + downsample = nn.Sequential( + conv_type( # pylint: disable=not-callable + self.in_planes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=self.bias_downsample, + ), + norm_type(planes * block.expansion), # pylint: disable=not-callable + ) + + layers = [ + block( + in_planes=self.in_planes, + planes=planes, + spatial_dims=spatial_dims, + stride=stride, + downsample=downsample, + act=act, + ) + ] + + self.in_planes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.in_planes, planes, spatial_dims=spatial_dims, act=act) + ) + + return nn.Sequential(*layers) + + def _init_module( + self, conv_type: Type[nn.Module], norm_type: Type[nn.Module] + ) -> None: + """ + Initializes the parameters. + """ + for m in self.modules(): + if isinstance(m, conv_type): + nn.init.kaiming_normal_( + torch.as_tensor(m.weight), mode="fan_out", nonlinearity="relu" + ) + elif isinstance(m, norm_type): + nn.init.constant_(torch.as_tensor(m.weight), 1) + nn.init.constant_(torch.as_tensor(m.bias), 0) + elif isinstance(m, nn.Linear): + nn.init.constant_(torch.as_tensor(m.bias), 0) + + @classmethod + def _bottleneck_reduce( + cls, n_features: Sequence[int], bottleneck_reduction: int + ) -> Sequence[int]: + """ + Finds number of feature maps for the bottleneck layers. + """ + reduced_features = [] + for n in n_features: + if n % bottleneck_reduction != 0: + raise ValueError( + "All elements of n_features must be divisible by bottleneck_reduction. " + f"Got {n} in n_features and bottleneck_reduction={bottleneck_reduction}" + ) + reduced_features.append(n // bottleneck_reduction) + + return reduced_features + + @classmethod + def _check_args_consistency( + cls, n_res_blocks: Sequence[int], n_features: Sequence[int] + ) -> None: + """ + Checks consistency between `n_res_blocks` and `n_features`. + """ + if not isinstance(n_res_blocks, Sequence): + raise ValueError(f"n_res_blocks must be a sequence, got {n_res_blocks}") + if not isinstance(n_features, Sequence): + raise ValueError(f"n_features must be a sequence, got {n_features}") + if len(n_features) != len(n_res_blocks): + raise ValueError( + f"n_features and n_res_blocks must have the same length, got n_features={n_features} " + f"and n_res_blocks={n_res_blocks}" + ) + + +class ResNet(GeneralResNet): + """ + ResNet based on the [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) paper. + Adapted from [MONAI's implementation](https://docs.monai.io/en/stable/networks.html#resnet). + + The user can customize the number of residual blocks, the number of downsampling blocks, the number of channels + in each block, as well as other parameters like the type of residual block used. + + ResNet is a fully convolutional network that can work with input of any size, provided that is it large + enough not to be reduced to a 1-pixel image (before the adaptative average pooling). + + Parameters + ---------- + spatial_dims : int + number of spatial dimensions of the input image. + in_channels : int + number of channels in the input image. + num_outputs : Optional[int] + number of output variables after the last linear layer.\n + If None, the features before the last fully connected layer (including average pooling) will be returned. + block_type : Union[str, ResNetBlockType] (optional, default=ResNetBlockType.BASIC) + type of residual block. Either `basic` or `bottleneck`. Default to `basic`, as in `ResNet-18`. + n_res_blocks : Sequence[int] (optional, default=(2, 2, 2, 2)) + number of residual block in each ResNet layer. A ResNet layer refers here to the set of residual blocks + between two downsamplings. The length of `n_res_blocks` thus determines the number of ResNet layers. + Default to `(2, 2, 2, 2)`, as in `ResNet-18`. + n_features : Sequence[int] (optional, default=(64, 128, 256, 512)) + number of output feature maps for each ResNet layer. The length of `n_features` must be equal to the length + of `n_res_blocks`. Default to `(64, 128, 256, 512)`, as in `ResNet-18`. + init_conv_size : Union[Sequence[int], int] (optional, default=7) + kernel_size for the first convolution. + If tuple, it will be understood as the values for each dimension. + Default to 7, as in the original paper. + init_conv_stride : Union[Sequence[int], int] (optional, default=2) + stride for the first convolution. + If tuple, it will be understood as the values for each dimension. + Default to 2, as in the original paper. + bottleneck_reduction : int (optional, default=4) + if `block_type='bottleneck'`, `bottleneck_reduction` determines the reduction factor for the number + of feature maps in bottleneck layers (1x1 convolutions). Default to 4, as in the original paper. + act : ActivationParameters (optional, default=("relu", {"inplace": True})) + the activation function used in the convolutional part, and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them.\n + Default is "relu", as in the original paper. + output_act : Optional[ActivationParameters] (optional, default=None) + if `num_outputs` is not None, a potential activation layer applied to the outputs of the network. + Should be pass in the same way as `act`. + If None, no last activation will be applied. + + Examples + -------- + >>> ResNet( + spatial_dims=2, + in_channels=1, + num_outputs=2, + block_type="bottleneck", + bottleneck_reduction=4, + n_features=(8, 16), + n_res_blocks=(2, 2), + output_act="softmax", + init_conv_size=5, + ) + ResNet( + (conv0): Conv2d(1, 2, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False) + (norm0): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act0): ReLU(inplace=True) + (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + (layer1): Sequential( + (0): ResNetBottleneck( + (conv1): Conv2d(2, 2, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv2): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (norm2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act2): ReLU(inplace=True) + (conv3): Conv2d(2, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (downsample): Sequential( + (0): Conv2d(2, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) + (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (act3): ReLU(inplace=True) + ) + (1): ResNetBottleneck( + (conv1): Conv2d(8, 2, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv2): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (norm2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act2): ReLU(inplace=True) + (conv3): Conv2d(2, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act3): ReLU(inplace=True) + ) + ) + (layer2): Sequential( + (0): ResNetBottleneck( + (conv1): Conv2d(8, 4, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv2): Conv2d(4, 4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + (norm2): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act2): ReLU(inplace=True) + (conv3): Conv2d(4, 16, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (downsample): Sequential( + (0): Conv2d(8, 16, kernel_size=(1, 1), stride=(2, 2), bias=False) + (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (act3): ReLU(inplace=True) + ) + (1): ResNetBottleneck( + (conv1): Conv2d(16, 4, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv2): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (norm2): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act2): ReLU(inplace=True) + (conv3): Conv2d(4, 16, kernel_size=(1, 1), stride=(1, 1), bias=False) + (norm3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act3): ReLU(inplace=True) + ) + ) + (fc): Sequential( + (pool): AdaptiveAvgPool2d(output_size=(1, 1)) + (flatten): Flatten(start_dim=1, end_dim=-1) + (out): Linear(in_features=16, out_features=2, bias=True) + (output_act): Softmax(dim=None) + ) + ) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_outputs: Optional[int], + block_type: Union[str, ResNetBlockType] = ResNetBlockType.BASIC, + n_res_blocks: Sequence[int] = (2, 2, 2, 2), + n_features: Sequence[int] = (64, 128, 256, 512), + init_conv_size: Union[Sequence[int], int] = 7, + init_conv_stride: Union[Sequence[int], int] = 2, + bottleneck_reduction: int = 4, + act: ActivationParameters = ("relu", {"inplace": True}), + output_act: Optional[ActivationParameters] = None, + ) -> None: + super().__init__( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_outputs=num_outputs, + block_type=block_type, + n_res_blocks=n_res_blocks, + n_features=n_features, + init_conv_size=init_conv_size, + init_conv_stride=init_conv_stride, + bottleneck_reduction=bottleneck_reduction, + se_reduction=None, + act=act, + output_act=output_act, + ) + + +class SOTAResNet(str, Enum): + """Supported ResNet networks.""" + + RESNET_18 = "ResNet-18" + RESNET_34 = "ResNet-34" + RESNET_50 = "ResNet-50" + RESNET_101 = "ResNet-101" + RESNET_152 = "ResNet-152" + + +def get_resnet( + name: Union[str, SOTAResNet], + num_outputs: Optional[int], + output_act: ActivationParameters = None, + pretrained: bool = False, +) -> ResNet: + """ + To get a ResNet implemented in the [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) + paper. + + Only the last fully connected layer will be changed to match `num_outputs`. + + The user can also use the pretrained models from `torchvision`. Note that the last fully connected layer will not + used pretrained weights, as it is task specific. + + .. warning:: `ResNet-18`, `ResNet-34`, `ResNet-50`, `ResNet-101` and `ResNet-152` only works with 2D images with 3 + channels. + + Parameters + ---------- + model : Union[str, SOTAResNet] + The name of the ResNet. Available networks are `ResNet-18`, `ResNet-34`, `ResNet-50`, `ResNet-101` and `ResNet-152`. + num_outputs : Optional[int] + number of output variables after the last linear layer.\n + If None, the features before the last fully connected layer will be returned. + output_act : ActivationParameters (optional, default=None) + if `num_outputs` is not None, a potential activation layer applied to the outputs of the network, + and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + pretrained : bool (optional, default=False) + whether to use pretrained weights. The pretrained weights used are the default ones from [torchvision](https:// + pytorch.org/vision/main/models/resnet.html). + + Returns + ------- + ResNet + The network, with potentially pretrained weights. + """ + name = SOTAResNet(name) + if name == SOTAResNet.RESNET_18: + block_type = ResNetBlockType.BASIC + n_res_blocks = (2, 2, 2, 2) + n_features = (64, 128, 256, 512) + model_url = ResNet18_Weights.DEFAULT.url + elif name == SOTAResNet.RESNET_34: + block_type = ResNetBlockType.BASIC + n_res_blocks = (3, 4, 6, 3) + n_features = (64, 128, 256, 512) + model_url = ResNet34_Weights.DEFAULT.url + elif name == SOTAResNet.RESNET_50: + block_type = ResNetBlockType.BOTTLENECK + n_res_blocks = (3, 4, 6, 3) + n_features = (256, 512, 1024, 2048) + model_url = ResNet50_Weights.DEFAULT.url + elif name == SOTAResNet.RESNET_101: + block_type = ResNetBlockType.BOTTLENECK + n_res_blocks = (3, 4, 23, 3) + n_features = (256, 512, 1024, 2048) + model_url = ResNet101_Weights.DEFAULT.url + elif name == SOTAResNet.RESNET_152: + block_type = ResNetBlockType.BOTTLENECK + n_res_blocks = (3, 8, 36, 3) + n_features = (256, 512, 1024, 2048) + model_url = ResNet152_Weights.DEFAULT.url + + # pylint: disable=possibly-used-before-assignment + resnet = ResNet( + spatial_dims=2, + in_channels=3, + num_outputs=num_outputs, + n_res_blocks=n_res_blocks, + block_type=block_type, + n_features=n_features, + output_act=output_act, + ) + if pretrained: + fc_layers = deepcopy(resnet.fc) + resnet.fc = None + pretrained_dict = load_state_dict_from_url(model_url, progress=True) + resnet.load_state_dict(_state_dict_adapter(pretrained_dict)) + resnet.fc = fc_layers + + return resnet + + +def _state_dict_adapter(state_dict: Mapping[str, Any]) -> Mapping[str, Any]: + """ + A mapping between torchvision's layer names and ours. + """ + state_dict = {k: v for k, v in state_dict.items() if "fc" not in k} + + mappings = [ + (r"(?>> SEResNet( + spatial_dims=2, + in_channels=1, + num_outputs=2, + block_type="basic", + se_reduction=2, + n_features=(8,), + n_res_blocks=(2,), + output_act="softmax", + init_conv_size=5, + ) + SEResNet( + (conv0): Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False) + (norm0): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act0): ReLU(inplace=True) + (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + (layer1): Sequential( + (0): SEResNetBlock( + (conv1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (norm1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (norm2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (se_layer): ChannelSELayer( + (avg_pool): AdaptiveAvgPool2d(output_size=1) + (fc): Sequential( + (0): Linear(in_features=8, out_features=4, bias=True) + (1): ReLU(inplace=True) + (2): Linear(in_features=4, out_features=8, bias=True) + (3): Sigmoid() + ) + ) + (act2): ReLU(inplace=True) + ) + (1): SEResNetBlock( + (conv1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (norm1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (act1): ReLU(inplace=True) + (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (norm2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (se_layer): ChannelSELayer( + (avg_pool): AdaptiveAvgPool2d(output_size=1) + (fc): Sequential( + (0): Linear(in_features=8, out_features=4, bias=True) + (1): ReLU(inplace=True) + (2): Linear(in_features=4, out_features=8, bias=True) + (3): Sigmoid() + ) + ) + (act2): ReLU(inplace=True) + ) + ) + (fc): Sequential( + (pool): AdaptiveAvgPool2d(output_size=(1, 1)) + (flatten): Flatten(start_dim=1, end_dim=-1) + (out): Linear(in_features=8, out_features=2, bias=True) + (output_act): Softmax(dim=None) + ) + ) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_outputs: Optional[int], + se_reduction: int = 16, + **kwargs: Any, + ) -> None: + # get defaults from resnet + _, default_resnet_args = get_args_and_defaults(ResNet.__init__) + for arg, value in default_resnet_args.items(): + if arg not in kwargs: + kwargs[arg] = value + + self._check_se_channels(kwargs["n_features"], se_reduction) + + super().__init__( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_outputs=num_outputs, + se_reduction=se_reduction, + **kwargs, + ) + + @classmethod + def _check_se_channels(cls, n_features: Sequence[int], se_reduction: int) -> None: + """ + Checks that the output of residual blocks always have a number of channels greater + than squeeze-excitation bottleneck reduction factor. + """ + if not isinstance(n_features, Sequence): + raise ValueError(f"n_features must be a sequence. Got {n_features}") + for n in n_features: + if n < se_reduction: + raise ValueError( + f"elements of n_features must be greater or equal to se_reduction. Got {n} in n_features " + f"and se_reduction={se_reduction}" + ) + + +class SOTAResNet(str, Enum): + """Supported SEResNet networks.""" + + SE_RESNET_50 = "SEResNet-50" + SE_RESNET_101 = "SEResNet-101" + SE_RESNET_152 = "SEResNet-152" + + +def get_seresnet( + name: Union[str, SOTAResNet], + num_outputs: Optional[int], + output_act: ActivationParameters = None, + pretrained: bool = False, +) -> SEResNet: + """ + To get a Squeeze-and-Excitation ResNet implemented in the [Squeeze-and-Excitation Networks](https://arxiv.org/pdf/ + 1709.01507) paper. + + Only the last fully connected layer will be changed to match `num_outputs`. + + .. warning:: `SEResNet-50`, `SEResNet-101` and `SEResNet-152` only works with 2D images with 3 channels. + + Note: pretrained weights are not yet available for these networks. + + Parameters + ---------- + model : Union[str, SOTAResNet] + the name of the SEResNet. Available networks are `SEResNet-50`, `SEResNet-101` and `SEResNet-152`. + num_outputs : Optional[int] + number of output variables after the last linear layer.\n + If None, the features before the last fully connected layer will be returned. + output_act : ActivationParameters (optional, default=None) + if `num_outputs` is not None, a potential activation layer applied to the outputs of the network, + and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + pretrained : bool (optional, default=False) + pretrained networks are not yet available for SE-ResNets. Leave this argument to False. + + Returns + ------- + SEResNet + the network. + """ + if pretrained is not False: + raise ValueError( + "Pretrained networks are not yet available for SE-ResNets. Please leave " + "'pretrained' to False." + ) + + name = SOTAResNet(name) + if name == SOTAResNet.SE_RESNET_50: + block_type = ResNetBlockType.BOTTLENECK + n_res_blocks = (3, 4, 6, 3) + n_features = (256, 512, 1024, 2048) + elif name == SOTAResNet.SE_RESNET_101: + block_type = ResNetBlockType.BOTTLENECK + n_res_blocks = (3, 4, 23, 3) + n_features = (256, 512, 1024, 2048) + elif name == SOTAResNet.SE_RESNET_152: + block_type = ResNetBlockType.BOTTLENECK + n_res_blocks = (3, 8, 36, 3) + n_features = (256, 512, 1024, 2048) + + # pylint: disable=possibly-used-before-assignment + resnet = SEResNet( + spatial_dims=2, + in_channels=3, + num_outputs=num_outputs, + n_res_blocks=n_res_blocks, + block_type=block_type, + n_features=n_features, + output_act=output_act, + ) + + return resnet diff --git a/clinicadl/networks/nn/unet.py b/clinicadl/networks/nn/unet.py new file mode 100644 index 000000000..dd1e59141 --- /dev/null +++ b/clinicadl/networks/nn/unet.py @@ -0,0 +1,250 @@ +from abc import ABC, abstractmethod +from typing import Optional, Sequence + +import torch +import torch.nn as nn +from monai.networks.blocks.convolutions import Convolution +from monai.networks.layers.utils import get_act_layer + +from .layers.unet import ConvBlock, DownBlock, UpSample +from .layers.utils import ActFunction, ActivationParameters + + +class BaseUNet(nn.Module, ABC): + """Base class for UNet and AttentionUNet.""" + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int] = (64, 128, 256, 512, 1024), + act: ActivationParameters = ActFunction.RELU, + output_act: Optional[ActivationParameters] = None, + dropout: Optional[float] = None, + ): + super().__init__() + if not isinstance(channels, Sequence) or len(channels) < 2: + raise ValueError( + f"channels should be a sequence, whose length is no less than 2. Got {channels}" + ) + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.act = act + self.dropout = dropout + + self.doubleconv = ConvBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + act=act, + dropout=dropout, + ) + self._build_encoder() + self._build_decoder() + self.reduce_channels = Convolution( + spatial_dims=spatial_dims, + in_channels=channels[0], + out_channels=out_channels, + kernel_size=1, + strides=1, + padding=0, + conv_only=True, + ) + self.output_act = get_act_layer(output_act) if output_act else None + + @abstractmethod + def forward(self, x: torch.Tensor) -> torch.Tensor: + pass + + def _build_encoder(self) -> None: + for i in range(1, len(self.channels)): + self.add_module( + f"down{i}", + DownBlock( + spatial_dims=self.spatial_dims, + in_channels=self.channels[i - 1], + out_channels=self.channels[i], + act=self.act, + dropout=self.dropout, + ), + ) + + @abstractmethod + def _build_decoder(self) -> None: + pass + + +class UNet(BaseUNet): + """ + UNet based on [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/pdf/1505.04597). + + The user can customize the number of encoding blocks, the number of channels in each block, as well as other parameters + like the activation function. + + .. warning:: UNet works only with images whose dimensions are high enough powers of 2. More precisely, if n is the number + of max pooling operation in your UNet (which is equal to `len(channels)-1`), the image must have :math:`2^{k}` + pixels in each dimension, with :math:`k \\geq n` (e.g. shape (:math:`2^{n}`, :math:`2^{n+3}`) for a 2D image). + + Note: the implementation proposed here is not exactly the one described in the original paper. Padding is added to + convolutions so that the feature maps keep a constant size (except when they are passed to `max pool` or `up-sample` + layers), batch normalization is used, and `up-conv` layers are here made with an [Upsample](https://pytorch.org/docs/ + stable/generated/torch.nn.Upsample.html) layer followed by a 3x3 convolution. + + Parameters + ---------- + spatial_dims : int + number of spatial dimensions of the input image. + in_channels : int + number of channels in the input image. + out_channels : int + number of output channels. + channels : Sequence[int] (optional, default=(64, 128, 256, 512, 1024)) + sequence of integers stating the number of channels in each UNet block. Thus, this parameter also controls + the number of UNet blocks. The length `channels` should be nos less than 2.\n + Default to `(64, 128, 256, 512, 1024)`, as in the original paper. + act : ActivationParameters (optional, default=ActFunction.RELU) + the activation function used in the convolutional part, and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them.\n + Default is "relu", as in the original paper. + output_act : Optional[ActivationParameters] (optional, default=None) + a potential activation layer applied to the output of the network. Should be pass in the same way as `act`. + If None, no last activation will be applied. + dropout : Optional[float] (optional, default=None) + dropout ratio. If None, no dropout. + + Examples + -------- + >>> UNet( + spatial_dims=2, + in_channels=1, + out_channels=2, + channels=(4, 8), + act="elu", + output_act=("softmax", {"dim": 1}), + dropout=0.1, + ) + UNet( + (doubleconv): ConvBlock( + (0): Convolution( + (conv): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (1): Convolution( + (conv): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + ) + (down1): DownBlock( + (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + (doubleconv): ConvBlock( + (0): Convolution( + (conv): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (1): Convolution( + (conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + ) + ) + (upsample1): UpSample( + (0): Upsample(scale_factor=2.0, mode='nearest') + (1): Convolution( + (conv): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + ) + (doubleconv1): ConvBlock( + (0): Convolution( + (conv): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + (1): Convolution( + (conv): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (adn): ADN( + (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (D): Dropout(p=0.1, inplace=False) + (A): ELU(alpha=1.0) + ) + ) + ) + (reduce_channels): Convolution( + (conv): Conv2d(4, 2, kernel_size=(1, 1), stride=(1, 1)) + ) + (output_act): Softmax(dim=1) + ) + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_history = [self.doubleconv(x)] + + for i in range(1, len(self.channels)): + x = self.get_submodule(f"down{i}")(x_history[-1]) + x_history.append(x) + + x_history.pop() # the output of bottelneck is not used as a residual + for i in range(len(self.channels) - 1, 0, -1): + up = self.get_submodule(f"upsample{i}")(x) + merged = torch.cat((x_history.pop(), up), dim=1) + x = self.get_submodule(f"doubleconv{i}")(merged) + + out = self.reduce_channels(x) + + if self.output_act is not None: + out = self.output_act(out) + + return out + + def _build_decoder(self): + for i in range(len(self.channels) - 1, 0, -1): + self.add_module( + f"upsample{i}", + UpSample( + spatial_dims=self.spatial_dims, + in_channels=self.channels[i], + out_channels=self.channels[i - 1], + act=self.act, + dropout=self.dropout, + ), + ) + self.add_module( + f"doubleconv{i}", + ConvBlock( + spatial_dims=self.spatial_dims, + in_channels=self.channels[i - 1] * 2, + out_channels=self.channels[i - 1], + act=self.act, + dropout=self.dropout, + ), + ) diff --git a/clinicadl/networks/nn/utils/__init__.py b/clinicadl/networks/nn/utils/__init__.py new file mode 100644 index 000000000..ce603f205 --- /dev/null +++ b/clinicadl/networks/nn/utils/__init__.py @@ -0,0 +1,14 @@ +from .checks import ( + check_adn_ordering, + check_conv_args, + check_mlp_args, + check_norm_layer, + check_pool_indices, + ensure_list_of_tuples, +) +from .shapes import ( + calculate_conv_out_shape, + calculate_convtranspose_out_shape, + calculate_pool_out_shape, + calculate_unpool_out_shape, +) diff --git a/clinicadl/networks/nn/utils/checks.py b/clinicadl/networks/nn/utils/checks.py new file mode 100644 index 000000000..1917a2894 --- /dev/null +++ b/clinicadl/networks/nn/utils/checks.py @@ -0,0 +1,167 @@ +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +from ..layers.utils import ( + ConvParameters, + NormalizationParameters, + NormLayer, + PoolingLayer, +) + +__all__ = [ + "ensure_list_of_tuples", + "check_norm_layer", + "check_conv_args", + "check_mlp_args", + "check_pool_indices", +] + + +def ensure_list_of_tuples( + parameter: ConvParameters, dim: int, n_layers: int, name: str +) -> List[Tuple[int, ...]]: + """ + Checks spatial parameters (e.g. kernel_size) and returns a list of tuples. + Each element of the list corresponds to the parameters of one layer, and + each element of the tuple corresponds to the parameters for one dimension. + """ + parameter = _check_conv_parameter(parameter, dim, n_layers, name) + if isinstance(parameter, tuple): + return [parameter] * n_layers + else: + return parameter + + +def check_norm_layer( + norm: Optional[NormalizationParameters], +) -> Optional[NormalizationParameters]: + """ + Checks that the argument for normalization layers has the right format (i.e. + `norm_type` or (`norm_type`, `norm_layer_parameters`)) and checks potential + mandatory arguments in `norm_layer_parameters`. + """ + if norm is None: + return norm + + if not isinstance(norm, str) and not isinstance(norm, PoolingLayer): + if ( + not isinstance(norm, tuple) + or len(norm) != 2 + or not isinstance(norm[1], dict) + ): + raise ValueError( + "norm must be either the name of the normalization layer or a double with first the name and then the " + f"arguments of the layer in a dict. Got {norm}" + ) + norm_mode = NormLayer(norm[0]) + args = norm[1] + else: + norm_mode = NormLayer(norm) + args = {} + if norm_mode == NormLayer.GROUP and "num_groups" not in args: + raise ValueError( + f"num_groups is a mandatory argument for GroupNorm and must be passed in `norm`. Got `norm`={norm}" + ) + + return norm + + +def check_adn_ordering(adn: str) -> str: + """ + Checks ADN sequence. + """ + if not isinstance(adn, str): + raise ValueError(f"adn_ordering must be a string. Got {adn}") + + for letter in adn: + if letter not in { + "A", + "D", + "N", + }: + raise ValueError( + f"adn_ordering must be composed by 'A', 'D' or/and 'N'. Got {letter}" + ) + if len(adn) != len(set(adn)): + raise ValueError(f"adn_ordering cannot contain duplicated letter. Got {adn}") + + return adn + + +def check_conv_args(conv_args: Dict[str, Any]) -> None: + """ + Checks that `conv_args` is a dict with at least the mandatory argument `channels`. + """ + if not isinstance(conv_args, dict): + raise ValueError( + f"conv_args must be a dict with the arguments for the convolutional part. Got: {conv_args}" + ) + if "channels" not in conv_args: + raise ValueError( + "channels is a mandatory argument for the convolutional part and must therefore be " + f"passed in conv_args. Got conv_args={conv_args}" + ) + + +def check_mlp_args(mlp_args: Optional[Dict[str, Any]]) -> None: + """ + Checks that `mlp_args` is a dict with at least the mandatory argument `hidden_channels`. + """ + if mlp_args is not None: + if not isinstance(mlp_args, dict): + raise ValueError( + f"mlp_args must be a dict with the arguments for the MLP part. Got: {mlp_args}" + ) + if "hidden_channels" not in mlp_args: + raise ValueError( + "hidden_channels is a mandatory argument for the MLP part and must therefore be " + f"passed in mlp_args. Got mlp_args={mlp_args}" + ) + + +def check_pool_indices( + pooling_indices: Optional[Sequence[int]], n_layers: int +) -> Sequence[int]: + """ + Checks that the (un)pooling indices are consistent with the number of layers. + """ + if pooling_indices is not None: + for idx in pooling_indices: + if idx > n_layers - 1: + raise ValueError( + f"indices in (un)pooling_indices must be smaller than len(channels)-1, got (un)pooling_indices={pooling_indices} and len(channels)={n_layers}" + ) + elif idx < -1: + raise ValueError( + f"indices in (un)pooling_indices must be greater or equal to -1, got (un)pooling_indices={pooling_indices}" + ) + return sorted(pooling_indices) + else: + return [] + + +def _check_conv_parameter( + parameter: ConvParameters, dim: int, n_layers: int, name: str +) -> Union[Tuple[int, ...], List[Tuple[int, ...]]]: + """ + Checks spatial parameters (e.g. kernel_size). + """ + if isinstance(parameter, int): + return (parameter,) * dim + elif isinstance(parameter, tuple): + if len(parameter) != dim: + raise ValueError( + f"If a tuple is passed for {name}, its dimension must be {dim}. Got {parameter}" + ) + return parameter + elif isinstance(parameter, list): + if len(parameter) != n_layers: + raise ValueError( + f"If a list is passed, {name} must contain as many elements as there are layers. " + f"There are {n_layers} layers, but got {parameter}" + ) + checked_params = [] + for param in parameter: + checked_params.append(_check_conv_parameter(param, dim, n_layers, name)) + return checked_params + else: + raise ValueError(f"{name} must be an int, a tuple or a list. Got {name}") diff --git a/clinicadl/networks/nn/utils/shapes.py b/clinicadl/networks/nn/utils/shapes.py new file mode 100644 index 000000000..a649af076 --- /dev/null +++ b/clinicadl/networks/nn/utils/shapes.py @@ -0,0 +1,203 @@ +from math import ceil +from typing import Optional, Sequence, Tuple, Union + +import numpy as np + +from ..layers.utils import PoolingLayer, UnpoolingLayer + +__all__ = [ + "calculate_conv_out_shape", + "calculate_convtranspose_out_shape", + "calculate_pool_out_shape", + "calculate_unpool_out_shape", +] + + +def calculate_conv_out_shape( + in_shape: Union[Sequence[int], int], + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int] = 1, + padding: Union[Sequence[int], int] = 0, + dilation: Union[Sequence[int], int] = 1, + **kwargs, # for uniformization +) -> Tuple[int, ...]: + """ + Calculates the output shape of a convolution layer. All arguments can be scalars or multiple + values. Always return a tuple. + """ + in_shape_np = np.atleast_1d(in_shape) + kernel_size_np = np.atleast_1d(kernel_size) + stride_np = np.atleast_1d(stride) + padding_np = np.atleast_1d(padding) + dilation_np = np.atleast_1d(dilation) + + out_shape_np = ( + (in_shape_np + 2 * padding_np - dilation_np * (kernel_size_np - 1) - 1) + / stride_np + ) + 1 + + return tuple(int(s) for s in out_shape_np) + + +def calculate_convtranspose_out_shape( + in_shape: Union[Sequence[int], int], + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int] = 1, + padding: Union[Sequence[int], int] = 0, + output_padding: Union[Sequence[int], int] = 0, + dilation: Union[Sequence[int], int] = 1, + **kwargs, # for uniformization +) -> Tuple[int, ...]: + """ + Calculates the output shape of a transposed convolution layer. All arguments can be scalars or + multiple values. Always return a tuple. + """ + in_shape_np = np.atleast_1d(in_shape) + kernel_size_np = np.atleast_1d(kernel_size) + stride_np = np.atleast_1d(stride) + padding_np = np.atleast_1d(padding) + dilation_np = np.atleast_1d(dilation) + output_padding_np = np.atleast_1d(output_padding) + + out_shape_np = ( + (in_shape_np - 1) * stride_np + - 2 * padding_np + + dilation_np * (kernel_size_np - 1) + + output_padding_np + + 1 + ) + + return tuple(int(s) for s in out_shape_np) + + +def calculate_pool_out_shape( + pool_mode: Union[str, PoolingLayer], + in_shape: Union[Sequence[int], int], + **kwargs, +) -> Tuple[int, ...]: + """ + Calculates the output shape of a pooling layer. The first argument is the type of pooling + performed (`max` or `avg`). All other arguments can be scalars or multiple values, except + `ceil_mode`. + Always return a tuple. + """ + pool_mode = PoolingLayer(pool_mode) + if pool_mode == PoolingLayer.MAX: + return _calculate_maxpool_out_shape(in_shape, **kwargs) + elif pool_mode == PoolingLayer.AVG: + return _calculate_avgpool_out_shape(in_shape, **kwargs) + elif pool_mode == PoolingLayer.ADAPT_MAX or pool_mode == PoolingLayer.ADAPT_AVG: + return _calculate_adaptivepool_out_shape(in_shape, **kwargs) + + +def calculate_unpool_out_shape( + unpool_mode: Union[str, UnpoolingLayer], + in_shape: Union[Sequence[int], int], + **kwargs, +) -> Tuple[int, ...]: + """ + Calculates the output shape of an unpooling layer. The first argument is the type of unpooling + performed (`upsample` or `convtranspose`). + Always return a tuple. + """ + unpool_mode = UnpoolingLayer(unpool_mode) + if unpool_mode == UnpoolingLayer.UPSAMPLE: + return _calculate_upsample_out_shape(in_shape, **kwargs) + elif unpool_mode == UnpoolingLayer.CONV_TRANS: + return calculate_convtranspose_out_shape(in_shape, **kwargs) + + +def _calculate_maxpool_out_shape( + in_shape: Union[Sequence[int], int], + kernel_size: Union[Sequence[int], int], + stride: Optional[Union[Sequence[int], int]] = None, + padding: Union[Sequence[int], int] = 0, + dilation: Union[Sequence[int], int] = 1, + ceil_mode: bool = False, + **kwargs, # for uniformization +) -> Tuple[int, ...]: + """ + Calculates the output shape of a MaxPool layer. + """ + if stride is None: + stride = kernel_size + + in_shape_np = np.atleast_1d(in_shape) + kernel_size_np = np.atleast_1d(kernel_size) + stride_np = np.atleast_1d(stride) + padding_np = np.atleast_1d(padding) + dilation_np = np.atleast_1d(dilation) + + out_shape_np = ( + (in_shape_np + 2 * padding_np - dilation_np * (kernel_size_np - 1) - 1) + / stride_np + ) + 1 + if ceil_mode: + out_shape = tuple(ceil(s) for s in out_shape_np) + else: + out_shape = tuple(int(s) for s in out_shape_np) + + return out_shape + + +def _calculate_avgpool_out_shape( + in_shape: Union[Sequence[int], int], + kernel_size: Union[Sequence[int], int], + stride: Optional[Union[Sequence[int], int]] = None, + padding: Union[Sequence[int], int] = 0, + ceil_mode: bool = False, + **kwargs, # for uniformization +) -> Tuple[int, ...]: + """ + Calculates the output shape of an AvgPool layer. + """ + if stride is None: + stride = kernel_size + + in_shape_np = np.atleast_1d(in_shape) + kernel_size_np = np.atleast_1d(kernel_size) + stride_np = np.atleast_1d(stride) + padding_np = np.atleast_1d(padding) + + out_shape_np = ((in_shape_np + 2 * padding_np - kernel_size_np) / stride_np) + 1 + if ceil_mode: + out_shape_np = np.ceil(out_shape_np) + out_shape_np[(out_shape_np - 1) * stride_np >= in_shape_np + padding_np] -= 1 + + return tuple(int(s) for s in out_shape_np) + + +def _calculate_adaptivepool_out_shape( + in_shape: Union[Sequence[int], int], + output_size: Union[Sequence[int], int], + **kwargs, # for uniformization +) -> Tuple[int, ...]: + """ + Calculates the output shape of an AdaptiveMaxPool or AdaptiveAvgPool layer. + """ + in_shape_np = np.atleast_1d(in_shape) + out_shape_np = np.ones_like(in_shape_np) * np.atleast_1d(output_size) + + return tuple(int(s) for s in out_shape_np) + + +def _calculate_upsample_out_shape( + in_shape: Union[Sequence[int], int], + scale_factor: Optional[Union[Sequence[int], int]] = None, + size: Optional[Union[Sequence[int], int]] = None, + **kwargs, # for uniformization +) -> Tuple[int, ...]: + """ + Calculates the output shape of an Upsample layer. + """ + in_shape_np = np.atleast_1d(in_shape) + if size and scale_factor: + raise ValueError("Pass either size or scale_factor, not both.") + elif size: + out_shape_np = np.ones_like(in_shape_np) * np.atleast_1d(size) + elif scale_factor: + out_shape_np = in_shape_np * scale_factor + else: + raise ValueError("Pass one of size or scale_factor.") + + return tuple(int(s) for s in out_shape_np) diff --git a/clinicadl/networks/nn/vae.py b/clinicadl/networks/nn/vae.py new file mode 100644 index 000000000..9dac6b43b --- /dev/null +++ b/clinicadl/networks/nn/vae.py @@ -0,0 +1,200 @@ +from copy import deepcopy +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn + +from .autoencoder import AutoEncoder +from .layers.utils import ActivationParameters, UnpoolingMode + + +class VAE(nn.Module): + """ + A Variational AutoEncoder with convolutional and fully connected layers. + + The user must pass the arguments to build an encoder, from its convolutional and + fully connected parts, and the decoder will be automatically built by taking the + symmetrical network. + + More precisely, to build the decoder, the order of the encoding layers is reverted, convolutions are + replaced by transposed convolutions and pooling layers are replaced by either upsampling or transposed + convolution layers. + Please note that the order of `Activation`, `Dropout` and `Normalization`, defined with the + argument `adn_ordering` in `conv_args`, is the same for the encoder and the decoder. + + Note that an `AutoEncoder` is an aggregation of a `CNN` (:py:class:`clinicadl.monai_networks.nn. + cnn.CNN`), whose last linear layer is duplicated to infer both the mean and the log variance, + and a `Generator` (:py:class:`clinicadl.monai_networks.nn.generator.Generator`). + + Parameters + ---------- + in_shape : Sequence[int] + sequence of integers stating the dimension of the input tensor (minus batch dimension). + latent_size : int + size of the latent vector. + conv_args : Dict[str, Any] + the arguments for the convolutional part of the encoder. The arguments are those accepted + by :py:class:`clinicadl.monai_networks.nn.conv_encoder.ConvEncoder`, except `in_shape` that + is specified here. So, the only mandatory argument is `channels`. + mlp_args : Optional[Dict[str, Any]] (optional, default=None) + the arguments for the MLP part of the encoder . The arguments are those accepted by + :py:class:`clinicadl.monai_networks.nn.mlp.MLP`, except `in_channels` that is inferred + from the output of the convolutional part, and `out_channels` that is set to `latent_size`. + So, the only mandatory argument is `hidden_channels`.\n + If None, the MLP part will be reduced to a single linear layer.\n + The last linear layer will be duplicated to infer both the mean and the log variance. + out_channels : Optional[int] (optional, default=None) + number of output channels. If None, the output will have the same number of channels as the + input. + output_act : Optional[ActivationParameters] (optional, default=None) + a potential activation layer applied to the output of the network, and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + unpooling_mode : Union[str, UnpoolingMode] (optional, default=UnpoolingMode.NEAREST) + type of unpooling. Can be either `"nearest"`, `"linear"`, `"bilinear"`, `"bicubic"`, `"trilinear"` or + `"convtranspose"`.\n + - `nearest`: unpooling is performed by upsampling with the :italic:`nearest` algorithm (see [PyTorch's Upsample layer] + (https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html)). + - `linear`: unpooling is performed by upsampling with the :italic:`linear` algorithm. Only works with 1D images (excluding the + channel dimension). + - `bilinear`: unpooling is performed by upsampling with the :italic:`bilinear` algorithm. Only works with 2D images. + - `bicubic`: unpooling is performed by upsampling with the :italic:`bicubic` algorithm. Only works with 2D images. + - `trilinear`: unpooling is performed by upsampling with the :italic:`trilinear` algorithm. Only works with 3D images. + - `convtranspose`: unpooling is performed with a transposed convolution, whose parameters (kernel size, stride, etc.) are + computed to reverse the pooling operation. + + Examples + -------- + >>> VAE( + in_shape=(1, 16, 16), + latent_size=4, + conv_args={"channels": [2]}, + mlp_args={"hidden_channels": [16], "output_act": "relu"}, + out_channels=2, + output_act="sigmoid", + unpooling_mode="bilinear", + ) + VAE( + (encoder): CNN( + (convolutions): ConvEncoder( + (layer0): Convolution( + (conv): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)) + ) + ) + (mlp): MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (hidden0): Sequential( + (linear): Linear(in_features=392, out_features=16, bias=True) + (adn): ADN( + (N): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (A): PReLU(num_parameters=1) + ) + ) + (output): Identity() + ) + ) + (mu): Sequential( + (linear): Linear(in_features=16, out_features=4, bias=True) + (output_act): ReLU() + ) + (log_var): Sequential( + (linear): Linear(in_features=16, out_features=4, bias=True) + (output_act): ReLU() + ) + (decoder): Generator( + (mlp): MLP( + (flatten): Flatten(start_dim=1, end_dim=-1) + (hidden0): Sequential( + (linear): Linear(in_features=4, out_features=16, bias=True) + (adn): ADN( + (N): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (A): PReLU(num_parameters=1) + ) + ) + (output): Sequential( + (linear): Linear(in_features=16, out_features=392, bias=True) + (output_act): ReLU() + ) + ) + (reshape): Reshape() + (convolutions): ConvDecoder( + (layer0): Convolution( + (conv): ConvTranspose2d(2, 2, kernel_size=(3, 3), stride=(1, 1)) + ) + (output_act): Sigmoid() + ) + ) + ) + """ + + def __init__( + self, + in_shape: Sequence[int], + latent_size: int, + conv_args: Dict[str, Any], + mlp_args: Optional[Dict[str, Any]] = None, + out_channels: Optional[int] = None, + output_act: Optional[ActivationParameters] = None, + unpooling_mode: Union[str, UnpoolingMode] = UnpoolingMode.NEAREST, + ) -> None: + super().__init__() + ae = AutoEncoder( + in_shape, + latent_size, + conv_args, + mlp_args, + out_channels, + output_act, + unpooling_mode, + ) + + # replace last mlp layer by two parallel layers + mu_layers = deepcopy(ae.encoder.mlp.output) + log_var_layers = deepcopy(ae.encoder.mlp.output) + self._reset_weights( + log_var_layers + ) # to have different initialization for the two layers + ae.encoder.mlp.output = nn.Identity() + + self.encoder = ae.encoder + self.mu = mu_layers + self.log_var = log_var_layers + self.decoder = ae.decoder + + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Encoding, sampling and decoding. + """ + feature = self.encoder(x) + mu = self.mu(feature) + log_var = self.log_var(feature) + z = self.reparameterize(mu, log_var) + + return self.decoder(z), mu, log_var + + def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: + """ + Samples a random vector from a gaussian distribution, given the mean and log-variance + of this distribution. + """ + std = torch.exp(0.5 * log_var) + + if self.training: # multiply random noise with std only during training + std = torch.randn_like(std).mul(std) + + return std.add_(mu) + + @classmethod + def _reset_weights(cls, layer: Union[nn.Sequential, nn.Linear]) -> None: + """ + Resets the output layer(s) of an MLP. + """ + if isinstance(layer, nn.Linear): + layer.reset_parameters() + else: + layer.linear.reset_parameters() diff --git a/clinicadl/networks/nn/vit.py b/clinicadl/networks/nn/vit.py new file mode 100644 index 000000000..372e1728a --- /dev/null +++ b/clinicadl/networks/nn/vit.py @@ -0,0 +1,420 @@ +import math +import re +from collections import OrderedDict +from copy import deepcopy +from enum import Enum +from typing import Any, Mapping, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding +from monai.networks.layers import Conv +from monai.networks.layers.utils import get_act_layer +from monai.utils import ensure_tuple_rep +from torch.hub import load_state_dict_from_url +from torchvision.models.vision_transformer import ( + ViT_B_16_Weights, + ViT_B_32_Weights, + ViT_L_16_Weights, + ViT_L_32_Weights, +) + +from .layers.utils import ActFunction, ActivationParameters +from .layers.vit import Encoder + + +class PosEmbedType(str, Enum): + """Available position embedding types for ViT.""" + + LEARN = "learnable" + SINCOS = "sincos" + + +class ViT(nn.Module): + """ + Vision Transformer based on the [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale] + (https://arxiv.org/pdf/2010.11929) paper. + Adapted from [torchvision's implementation](https://pytorch.org/vision/main/models/vision_transformer.html). + + The user can customize the patch size, the embedding dimension, the number of transformer blocks, the number of + attention heads, as well as other parameters like the type of position embedding. + + Parameters + ---------- + in_shape : Sequence[int] + sequence of integers stating the dimension of the input tensor (minus batch dimension). + patch_size : Union[Sequence[int], int] + sequence of integers stating the patch size (minus batch and channel dimensions). If int, the same + patch size will be used for all dimensions. + Patch size must divide image size in all dimensions. + num_outputs : Optional[int] + number of output variables after the last linear layer.\n + If None, the patch embeddings after the last transformer block will be returned. + embedding_dim : int (optional, default=768) + size of the embedding vectors. Must be divisible by `num_heads` as each head will be responsible for + a part of the embedding vectors. Default to 768, as for 'ViT-Base' in the original paper. + num_layers : int (optional, default=12) + number of consecutive transformer blocks. Default to 12, as for 'ViT-Base' in the original paper. + num_heads : int (optional, default=12) + number of heads in the self-attention block. Must divide `embedding_size`. + Default to 12, as for 'ViT-Base' in the original paper. + mlp_dim : int (optional, default=3072) + size of the hidden layer in the MLP part of the transformer block. Default to 3072, as for 'ViT-Base' + in the original paper. + pos_embed_type : Optional[Union[str, PosEmbedType]] (optional, default="learnable") + type of position embedding. Can be either `"learnable"`, `"sincos"` or `None`.\n + - `learnable`: the position embeddings are parameters that will be learned during the training + process. + - `sincos`: the position embeddings are fixed and determined with sinus and cosinus formulas (based on Dosovitskiy et al., + 'Attention Is All You Need, https://arxiv.org/pdf/1706.03762). Only implemented for 2D and 3D images. With `sincos` + position embedding, `embedding_dim` must be divisible by 4 for 2D images and by 6 for 3D images. + - `None`: no position embeddings are used.\n + Default to `"learnable"`, as in the original paper. + output_act : Optional[ActivationParameters] (optional, default=ActFunction.TANH) + if `num_outputs` is not None, a potential activation layer applied to the outputs of the network, + and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them.\n + Default to `"tanh"`, as in the original paper. + dropout : Optional[float] (optional, default=None) + dropout ratio. If None, no dropout. + + Examples + -------- + >>> ViT( + in_shape=(3, 60, 64), + patch_size=4, + num_outputs=2, + embedding_dim=32, + num_layers=2, + num_heads=4, + mlp_dim=128, + output_act="softmax", + ) + ViT( + (conv_proj): Conv2d(3, 32, kernel_size=(4, 4), stride=(4, 4)) + (encoder): Encoder( + (dropout): Dropout(p=0.0, inplace=False) + (layers): ModuleList( + (0-1): 2 x EncoderBlock( + (norm1): LayerNorm((32,), eps=1e-06, elementwise_affine=True) + (self_attention): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True) + ) + (dropout): Dropout(p=0.0, inplace=False) + (norm2): LayerNorm((32,), eps=1e-06, elementwise_affine=True) + (mlp): MLPBlock( + (0): Linear(in_features=32, out_features=128, bias=True) + (1): GELU(approximate='none') + (2): Dropout(p=0.0, inplace=False) + (3): Linear(in_features=128, out_features=32, bias=True) + (4): Dropout(p=0.0, inplace=False) + ) + ) + ) + (norm): LayerNorm((32,), eps=1e-06, elementwise_affine=True) + ) + (fc): Sequential( + (out): Linear(in_features=32, out_features=2, bias=True) + (output_act): Softmax(dim=None) + ) + ) + """ + + def __init__( + self, + in_shape: Sequence[int], + patch_size: Union[Sequence[int], int], + num_outputs: Optional[int], + embedding_dim: int = 768, + num_layers: int = 12, + num_heads: int = 12, + mlp_dim: int = 3072, + pos_embed_type: Optional[Union[str, PosEmbedType]] = PosEmbedType.LEARN, + output_act: Optional[ActivationParameters] = ActFunction.TANH, + dropout: Optional[float] = None, + ) -> None: + super().__init__() + + self.in_channels, *self.img_size = in_shape + self.spatial_dims = len(self.img_size) + self.patch_size = ensure_tuple_rep(patch_size, self.spatial_dims) + + self._check_embedding_dim(embedding_dim, num_heads) + self._check_patch_size(self.img_size, self.patch_size) + self.embedding_dim = embedding_dim + self.classification = True if num_outputs else False + dropout = dropout if dropout else 0.0 + + self.conv_proj = Conv[Conv.CONV, self.spatial_dims]( # pylint: disable=not-callable + in_channels=self.in_channels, + out_channels=self.embedding_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + ) + self.seq_length = int( + np.prod(np.array(self.img_size) // np.array(self.patch_size)) + ) + + # Add a class token + if self.classification: + self.class_token = nn.Parameter(torch.zeros(1, 1, self.embedding_dim)) + self.seq_length += 1 + + pos_embedding = self._get_pos_embedding(pos_embed_type) + self.encoder = Encoder( + self.seq_length, + num_layers, + num_heads, + self.embedding_dim, + mlp_dim, + dropout=dropout, + attention_dropout=dropout, + pos_embedding=pos_embedding, + ) + + if self.classification: + self.class_token = nn.Parameter(torch.zeros(1, 1, embedding_dim)) + self.fc = nn.Sequential( + OrderedDict([("out", nn.Linear(embedding_dim, num_outputs))]) + ) + self.fc.output_act = get_act_layer(output_act) if output_act else None + else: + self.fc = None + + self._init_layers() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv_proj(x) + # (n, hidden_dim, n_h, n_w) -> (n, (h * w * d), hidden_dim) + x = x.flatten(2).transpose(-1, -2) + n = x.shape[0] + + # Expand the class token to the full batch + if self.fc: + batch_class_token = self.class_token.expand(n, -1, -1) + x = torch.cat([batch_class_token, x], dim=1) + + x = self.encoder(x) + + # Classifier "token" as used by standard language architectures + if self.fc: + x = x[:, 0] + x = self.fc(x) + + return x + + def _get_pos_embedding( + self, pos_embed_type: Optional[Union[str, PosEmbedType]] + ) -> Optional[nn.Parameter]: + """ + Gets position embeddings. If `pos_embed_type` is "learnable", will return None as it will be handled + by the encoder module. + """ + if pos_embed_type is None: + pos_embed = nn.Parameter( + torch.zeros(1, self.seq_length, self.embedding_dim) + ) + pos_embed.requires_grad = False + return pos_embed + + pos_embed_type = PosEmbedType(pos_embed_type) + + if pos_embed_type == PosEmbedType.LEARN: + return None # will be initialized inside the Encoder + + elif pos_embed_type == PosEmbedType.SINCOS: + if self.spatial_dims != 2 and self.spatial_dims != 3: + raise ValueError( + f"{self.spatial_dims}D sincos position embedding not implemented" + ) + elif self.spatial_dims == 2 and self.embedding_dim % 4: + raise ValueError( + f"embedding_dim must be divisible by 4 for 2D sincos position embedding. Got embedding_dim={self.embedding_dim}" + ) + elif self.spatial_dims == 3 and self.embedding_dim % 6: + raise ValueError( + f"embedding_dim must be divisible by 6 for 3D sincos position embedding. Got embedding_dim={self.embedding_dim}" + ) + grid_size = [] + for in_size, pa_size in zip(self.img_size, self.patch_size): + grid_size.append(in_size // pa_size) + pos_embed = build_sincos_position_embedding( + grid_size, self.embedding_dim, self.spatial_dims + ) + if self.classification: + pos_embed = torch.nn.Parameter( + torch.cat([torch.zeros(1, 1, self.embedding_dim), pos_embed], dim=1) + ) # add 0 for class token pos embedding + pos_embed.requires_grad = False + return pos_embed + + def _init_layers(self): + """ + Initializes some layers, based on torchvision's implementation: https://pytorch.org/vision/main/ + _modules/torchvision/models/vision_transformer.html + """ + fan_in = self.conv_proj.in_channels * np.prod(self.conv_proj.kernel_size) + nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) + nn.init.zeros_(self.conv_proj.bias) + + @classmethod + def _check_embedding_dim(cls, embedding_dim: int, num_heads: int) -> None: + """ + Checks consistency between embedding dimension and number of heads. + """ + if embedding_dim % num_heads != 0: + raise ValueError( + f"embedding_dim should be divisible by num_heads. Got embedding_dim={embedding_dim} " + f" and num_heads={num_heads}" + ) + + @classmethod + def _check_patch_size( + cls, img_size: Tuple[int, ...], patch_size: Tuple[int, ...] + ) -> None: + """ + Checks consistency between image size and patch size. + """ + for i, p in zip(img_size, patch_size): + if i % p != 0: + raise ValueError( + f"img_size should be divisible by patch_size. Got img_size={img_size} " + f" and patch_size={patch_size}" + ) + + +class SOTAViT(str, Enum): + """Supported ViT networks.""" + + B_16 = "ViT-B/16" + B_32 = "ViT-B/32" + L_16 = "ViT-L/16" + L_32 = "ViT-L/32" + + +def get_vit( + name: Union[str, SOTAViT], + num_outputs: Optional[int], + output_act: ActivationParameters = None, + pretrained: bool = False, +) -> ViT: + """ + To get a Vision Transformer implemented in the [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale] + (https://arxiv.org/pdf/2010.11929) paper. + + Only the last fully connected layer will be changed to match `num_outputs`. + + The user can also use the pretrained models from `torchvision`. Note that the last fully connected layer will not + used pretrained weights, as it is task specific. + + .. warning:: `ViT-B/16`, `ViT-B/32`, `ViT-L/16` and `ViT-L/32` work with 2D images of size (224, 224), with 3 channels. + + Parameters + ---------- + model : Union[str, SOTAViT] + The name of the Vision Transformer. Available networks are `ViT-B/16`, `ViT-B/32`, `ViT-L/16` and `ViT-L/32`. + num_outputs : Optional[int] + number of output variables after the last linear layer.\n + If None, the features before the last fully connected layer will be returned. + output_act : ActivationParameters (optional, default=None) + if `num_outputs` is not None, a potential activation layer applied to the outputs of the network, + and optionally its arguments. + Should be passed as `activation_name` or `(activation_name, arguments)`. If None, no activation will be used.\n + `activation_name` can be any value in {`celu`, `elu`, `gelu`, `leakyrelu`, `logsoftmax`, `mish`, `prelu`, + `relu`, `relu6`, `selu`, `sigmoid`, `softmax`, `tanh`}. Please refer to PyTorch's [activationfunctions] + (https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) to know the optional + arguments for each of them. + pretrained : bool (optional, default=False) + whether to use pretrained weights. The pretrained weights used are the default ones from [torchvision](https:// + pytorch.org/vision/main/models/vision_transformer.html). + + Returns + ------- + ViT + The network, with potentially pretrained weights. + """ + name = SOTAViT(name) + if name == SOTAViT.B_16: + in_shape = (3, 224, 224) + patch_size = 16 + embedding_dim = 768 + mlp_dim = 3072 + num_layers = 12 + num_heads = 12 + model_url = ViT_B_16_Weights.DEFAULT.url + elif name == SOTAViT.B_32: + in_shape = (3, 224, 224) + patch_size = 32 + embedding_dim = 768 + mlp_dim = 3072 + num_layers = 12 + num_heads = 12 + model_url = ViT_B_32_Weights.DEFAULT.url + elif name == SOTAViT.L_16: + in_shape = (3, 224, 224) + patch_size = 16 + embedding_dim = 1024 + mlp_dim = 4096 + num_layers = 24 + num_heads = 16 + model_url = ViT_L_16_Weights.DEFAULT.url + elif name == SOTAViT.L_32: + in_shape = (3, 224, 224) + patch_size = 32 + embedding_dim = 1024 + mlp_dim = 4096 + num_layers = 24 + num_heads = 16 + model_url = ViT_L_32_Weights.DEFAULT.url + + # pylint: disable=possibly-used-before-assignment + vit = ViT( + in_shape=in_shape, + patch_size=patch_size, + num_outputs=num_outputs, + embedding_dim=embedding_dim, + mlp_dim=mlp_dim, + num_heads=num_heads, + num_layers=num_layers, + output_act=output_act, + ) + + if pretrained: + pretrained_dict = load_state_dict_from_url(model_url, progress=True) + if num_outputs is None: + del pretrained_dict["class_token"] + pretrained_dict["encoder.pos_embedding"] = pretrained_dict[ + "encoder.pos_embedding" + ][:, 1:] # remove class token position embedding + fc_layers = deepcopy(vit.fc) + vit.fc = None + vit.load_state_dict(_state_dict_adapter(pretrained_dict)) + vit.fc = fc_layers + + return vit + + +def _state_dict_adapter(state_dict: Mapping[str, Any]) -> Mapping[str, Any]: + """ + A mapping between torchvision's layer names and ours. + """ + state_dict = {k: v for k, v in state_dict.items() if "heads" not in k} + + mappings = [ + ("ln_", "norm"), + ("ln", "norm"), + (r"encoder_layer_(\d+)", r"\1"), + ] + + for key in list(state_dict.keys()): + new_key = key + for transform in mappings: + new_key = re.sub(transform[0], transform[1], new_key) + state_dict[new_key] = state_dict.pop(key) + + return state_dict diff --git a/clinicadl/network/__init__.py b/clinicadl/networks/old_network/__init__.py similarity index 100% rename from clinicadl/network/__init__.py rename to clinicadl/networks/old_network/__init__.py diff --git a/clinicadl/nn/__init__.py b/clinicadl/networks/old_network/autoencoder/__init__.py similarity index 100% rename from clinicadl/nn/__init__.py rename to clinicadl/networks/old_network/autoencoder/__init__.py diff --git a/clinicadl/network/autoencoder/cnn_transformer.py b/clinicadl/networks/old_network/autoencoder/cnn_transformer.py similarity index 98% rename from clinicadl/network/autoencoder/cnn_transformer.py rename to clinicadl/networks/old_network/autoencoder/cnn_transformer.py index 39270d2a1..d60789879 100644 --- a/clinicadl/network/autoencoder/cnn_transformer.py +++ b/clinicadl/networks/old_network/autoencoder/cnn_transformer.py @@ -2,7 +2,7 @@ from torch import nn -from clinicadl.network.network_utils import ( +from clinicadl.networks.old_network.network_utils import ( CropMaxUnpool2d, CropMaxUnpool3d, PadMaxPool2d, diff --git a/clinicadl/network/autoencoder/models.py b/clinicadl/networks/old_network/autoencoder/models.py similarity index 89% rename from clinicadl/network/autoencoder/models.py rename to clinicadl/networks/old_network/autoencoder/models.py index 8ac40191e..ba0e3928a 100644 --- a/clinicadl/network/autoencoder/models.py +++ b/clinicadl/networks/old_network/autoencoder/models.py @@ -1,9 +1,9 @@ from torch import nn -from clinicadl.network.autoencoder.cnn_transformer import CNN_Transformer -from clinicadl.network.cnn.models import Conv4_FC3, Conv5_FC3 -from clinicadl.network.sub_network import AutoEncoder -from clinicadl.network.vae.vae_layers import ( +from clinicadl.networks.old_network.autoencoder.cnn_transformer import CNN_Transformer +from clinicadl.networks.old_network.cnn.models import Conv4_FC3, Conv5_FC3 +from clinicadl.networks.old_network.sub_network import AutoEncoder +from clinicadl.networks.old_network.vae.vae_layers import ( DecoderLayer3D, EncoderLayer3D, Flatten, diff --git a/clinicadl/network/cnn/SECNN.py b/clinicadl/networks/old_network/cnn/SECNN.py similarity index 100% rename from clinicadl/network/cnn/SECNN.py rename to clinicadl/networks/old_network/cnn/SECNN.py diff --git a/clinicadl/predict/__init__.py b/clinicadl/networks/old_network/cnn/__init__.py similarity index 100% rename from clinicadl/predict/__init__.py rename to clinicadl/networks/old_network/cnn/__init__.py diff --git a/clinicadl/network/cnn/models.py b/clinicadl/networks/old_network/cnn/models.py similarity index 97% rename from clinicadl/network/cnn/models.py rename to clinicadl/networks/old_network/cnn/models.py index 87d5e3ce5..af03969e4 100644 --- a/clinicadl/network/cnn/models.py +++ b/clinicadl/networks/old_network/cnn/models.py @@ -4,11 +4,11 @@ from torch import nn from torchvision.models.resnet import BasicBlock -from clinicadl.network.cnn.resnet import ResNetDesigner, model_urls -from clinicadl.network.cnn.resnet3D import ResNetDesigner3D -from clinicadl.network.cnn.SECNN import SECNNDesigner3D -from clinicadl.network.network_utils import PadMaxPool2d, PadMaxPool3d -from clinicadl.network.sub_network import CNN, CNN_SSDA +from clinicadl.networks.old_network.cnn.resnet import ResNetDesigner, model_urls +from clinicadl.networks.old_network.cnn.resnet3D import ResNetDesigner3D +from clinicadl.networks.old_network.cnn.SECNN import SECNNDesigner3D +from clinicadl.networks.old_network.network_utils import PadMaxPool2d, PadMaxPool3d +from clinicadl.networks.old_network.sub_network import CNN, CNN_SSDA def get_layers_fn(input_size): diff --git a/clinicadl/network/cnn/random.py b/clinicadl/networks/old_network/cnn/random.py similarity index 98% rename from clinicadl/network/cnn/random.py rename to clinicadl/networks/old_network/cnn/random.py index 897a014d1..221fee3f5 100644 --- a/clinicadl/network/cnn/random.py +++ b/clinicadl/networks/old_network/cnn/random.py @@ -1,7 +1,7 @@ import numpy as np -from clinicadl.network.network_utils import * -from clinicadl.network.sub_network import CNN +from clinicadl.networks.old_network.network_utils import * +from clinicadl.networks.old_network.sub_network import CNN from clinicadl.utils.exceptions import ClinicaDLNetworksError diff --git a/clinicadl/network/cnn/resnet.py b/clinicadl/networks/old_network/cnn/resnet.py similarity index 100% rename from clinicadl/network/cnn/resnet.py rename to clinicadl/networks/old_network/cnn/resnet.py diff --git a/clinicadl/network/cnn/resnet3D.py b/clinicadl/networks/old_network/cnn/resnet3D.py similarity index 100% rename from clinicadl/network/cnn/resnet3D.py rename to clinicadl/networks/old_network/cnn/resnet3D.py diff --git a/clinicadl/network/config.py b/clinicadl/networks/old_network/config.py similarity index 100% rename from clinicadl/network/config.py rename to clinicadl/networks/old_network/config.py diff --git a/clinicadl/network/network.py b/clinicadl/networks/old_network/network.py similarity index 100% rename from clinicadl/network/network.py rename to clinicadl/networks/old_network/network.py diff --git a/clinicadl/network/network_utils.py b/clinicadl/networks/old_network/network_utils.py similarity index 100% rename from clinicadl/network/network_utils.py rename to clinicadl/networks/old_network/network_utils.py diff --git a/clinicadl/prepare_data/__init__.py b/clinicadl/networks/old_network/nn/__init__.py similarity index 100% rename from clinicadl/prepare_data/__init__.py rename to clinicadl/networks/old_network/nn/__init__.py diff --git a/clinicadl/nn/blocks/__init__.py b/clinicadl/networks/old_network/nn/blocks/__init__.py similarity index 100% rename from clinicadl/nn/blocks/__init__.py rename to clinicadl/networks/old_network/nn/blocks/__init__.py diff --git a/clinicadl/nn/blocks/decoder.py b/clinicadl/networks/old_network/nn/blocks/decoder.py similarity index 98% rename from clinicadl/nn/blocks/decoder.py rename to clinicadl/networks/old_network/nn/blocks/decoder.py index 27938c8d7..06db04937 100644 --- a/clinicadl/nn/blocks/decoder.py +++ b/clinicadl/networks/old_network/nn/blocks/decoder.py @@ -1,7 +1,7 @@ import torch.nn as nn import torch.nn.functional as F -from clinicadl.nn.layers import Unflatten2D, get_norm_layer +from clinicadl.networks.old_network.nn.layers import Unflatten2D, get_norm_layer __all__ = [ "Decoder2D", diff --git a/clinicadl/nn/blocks/encoder.py b/clinicadl/networks/old_network/nn/blocks/encoder.py similarity index 98% rename from clinicadl/nn/blocks/encoder.py rename to clinicadl/networks/old_network/nn/blocks/encoder.py index fde13b956..290855dae 100644 --- a/clinicadl/nn/blocks/encoder.py +++ b/clinicadl/networks/old_network/nn/blocks/encoder.py @@ -1,7 +1,7 @@ import torch.nn as nn import torch.nn.functional as F -from clinicadl.nn.layers import get_norm_layer +from clinicadl.networks.old_network.nn.layers import get_norm_layer __all__ = [ "Encoder2D", diff --git a/clinicadl/nn/blocks/residual.py b/clinicadl/networks/old_network/nn/blocks/residual.py similarity index 100% rename from clinicadl/nn/blocks/residual.py rename to clinicadl/networks/old_network/nn/blocks/residual.py diff --git a/clinicadl/nn/blocks/se.py b/clinicadl/networks/old_network/nn/blocks/se.py similarity index 100% rename from clinicadl/nn/blocks/se.py rename to clinicadl/networks/old_network/nn/blocks/se.py diff --git a/clinicadl/nn/blocks/unet.py b/clinicadl/networks/old_network/nn/blocks/unet.py similarity index 100% rename from clinicadl/nn/blocks/unet.py rename to clinicadl/networks/old_network/nn/blocks/unet.py diff --git a/clinicadl/nn/layers/__init__.py b/clinicadl/networks/old_network/nn/layers/__init__.py similarity index 100% rename from clinicadl/nn/layers/__init__.py rename to clinicadl/networks/old_network/nn/layers/__init__.py diff --git a/clinicadl/nn/layers/factory/__init__.py b/clinicadl/networks/old_network/nn/layers/factory/__init__.py similarity index 100% rename from clinicadl/nn/layers/factory/__init__.py rename to clinicadl/networks/old_network/nn/layers/factory/__init__.py diff --git a/clinicadl/nn/layers/factory/conv.py b/clinicadl/networks/old_network/nn/layers/factory/conv.py similarity index 100% rename from clinicadl/nn/layers/factory/conv.py rename to clinicadl/networks/old_network/nn/layers/factory/conv.py diff --git a/clinicadl/nn/layers/factory/norm.py b/clinicadl/networks/old_network/nn/layers/factory/norm.py similarity index 100% rename from clinicadl/nn/layers/factory/norm.py rename to clinicadl/networks/old_network/nn/layers/factory/norm.py diff --git a/clinicadl/nn/layers/factory/pool.py b/clinicadl/networks/old_network/nn/layers/factory/pool.py similarity index 100% rename from clinicadl/nn/layers/factory/pool.py rename to clinicadl/networks/old_network/nn/layers/factory/pool.py diff --git a/clinicadl/nn/layers/pool.py b/clinicadl/networks/old_network/nn/layers/pool.py similarity index 100% rename from clinicadl/nn/layers/pool.py rename to clinicadl/networks/old_network/nn/layers/pool.py diff --git a/clinicadl/nn/layers/reverse.py b/clinicadl/networks/old_network/nn/layers/reverse.py similarity index 100% rename from clinicadl/nn/layers/reverse.py rename to clinicadl/networks/old_network/nn/layers/reverse.py diff --git a/clinicadl/nn/layers/unflatten.py b/clinicadl/networks/old_network/nn/layers/unflatten.py similarity index 100% rename from clinicadl/nn/layers/unflatten.py rename to clinicadl/networks/old_network/nn/layers/unflatten.py diff --git a/clinicadl/nn/layers/unpool.py b/clinicadl/networks/old_network/nn/layers/unpool.py similarity index 100% rename from clinicadl/nn/layers/unpool.py rename to clinicadl/networks/old_network/nn/layers/unpool.py diff --git a/clinicadl/nn/networks/__init__.py b/clinicadl/networks/old_network/nn/networks/__init__.py similarity index 92% rename from clinicadl/nn/networks/__init__.py rename to clinicadl/networks/old_network/nn/networks/__init__.py index c77097e60..3b88830fb 100644 --- a/clinicadl/nn/networks/__init__.py +++ b/clinicadl/networks/old_network/nn/networks/__init__.py @@ -8,7 +8,6 @@ resnet18, ) from .random import RandomArchitecture -from .ssda import Conv5_FC3_SSDA from .unet import UNet from .vae import ( CVAE_3D, diff --git a/clinicadl/nn/networks/ae.py b/clinicadl/networks/old_network/nn/networks/ae.py similarity index 92% rename from clinicadl/nn/networks/ae.py rename to clinicadl/networks/old_network/nn/networks/ae.py index 1a8ed283f..aabe9b15a 100644 --- a/clinicadl/nn/networks/ae.py +++ b/clinicadl/networks/old_network/nn/networks/ae.py @@ -1,17 +1,17 @@ import numpy as np from torch import nn -from clinicadl.nn.blocks import Decoder3D, Encoder3D -from clinicadl.nn.layers import ( +from clinicadl.networks.old_network.nn.blocks import Decoder3D, Encoder3D +from clinicadl.networks.old_network.nn.layers import ( CropMaxUnpool2d, CropMaxUnpool3d, PadMaxPool2d, PadMaxPool3d, Unflatten3D, ) -from clinicadl.nn.networks.cnn import Conv4_FC3, Conv5_FC3 -from clinicadl.nn.networks.factory import autoencoder_from_cnn -from clinicadl.nn.utils import compute_output_size +from clinicadl.networks.old_network.nn.networks.cnn import Conv4_FC3, Conv5_FC3 +from clinicadl.networks.old_network.nn.networks.factory import autoencoder_from_cnn +from clinicadl.networks.old_network.nn.utils import compute_output_size from clinicadl.utils.enum import BaseEnum diff --git a/clinicadl/nn/networks/cnn.py b/clinicadl/networks/old_network/nn/networks/cnn.py similarity index 98% rename from clinicadl/nn/networks/cnn.py rename to clinicadl/networks/old_network/nn/networks/cnn.py index eb2104b1e..cfdf610d7 100644 --- a/clinicadl/nn/networks/cnn.py +++ b/clinicadl/networks/old_network/nn/networks/cnn.py @@ -4,7 +4,7 @@ from torch import nn from torchvision.models.resnet import BasicBlock -from clinicadl.nn.layers.factory import ( +from clinicadl.networks.old_network.nn.layers.factory import ( get_conv_layer, get_norm_layer, get_pool_layer, @@ -63,8 +63,6 @@ def __init__(self, convolution_layers: nn.Module, fc_layers: nn.Module) -> None: def forward(self, x): inter = self.convolutions(x) - print(self.convolutions) - print(inter.shape) return self.fc(inter) diff --git a/clinicadl/nn/networks/factory/__init__.py b/clinicadl/networks/old_network/nn/networks/factory/__init__.py similarity index 100% rename from clinicadl/nn/networks/factory/__init__.py rename to clinicadl/networks/old_network/nn/networks/factory/__init__.py diff --git a/clinicadl/nn/networks/factory/ae.py b/clinicadl/networks/old_network/nn/networks/factory/ae.py similarity index 97% rename from clinicadl/nn/networks/factory/ae.py rename to clinicadl/networks/old_network/nn/networks/factory/ae.py index fccb14484..99dcd162e 100644 --- a/clinicadl/nn/networks/factory/ae.py +++ b/clinicadl/networks/old_network/nn/networks/factory/ae.py @@ -5,7 +5,7 @@ from torch import nn -from clinicadl.nn.layers import ( +from clinicadl.networks.old_network.nn.layers import ( CropMaxUnpool2d, CropMaxUnpool3d, PadMaxPool2d, @@ -13,7 +13,7 @@ ) if TYPE_CHECKING: - from clinicadl.nn.networks.cnn import CNN + from clinicadl.networks.old_network.nn.networks.cnn import CNN def autoencoder_from_cnn(model: CNN) -> Tuple[nn.Module, nn.Module]: diff --git a/clinicadl/nn/networks/factory/resnet.py b/clinicadl/networks/old_network/nn/networks/factory/resnet.py similarity index 98% rename from clinicadl/nn/networks/factory/resnet.py rename to clinicadl/networks/old_network/nn/networks/factory/resnet.py index 251199c92..0500c9ece 100644 --- a/clinicadl/nn/networks/factory/resnet.py +++ b/clinicadl/networks/old_network/nn/networks/factory/resnet.py @@ -3,7 +3,7 @@ import torch from torch import nn -from clinicadl.nn.blocks import ResBlock +from clinicadl.networks.old_network.nn.blocks import ResBlock model_urls = {"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth"} diff --git a/clinicadl/nn/networks/factory/secnn.py b/clinicadl/networks/old_network/nn/networks/factory/secnn.py similarity index 96% rename from clinicadl/nn/networks/factory/secnn.py rename to clinicadl/networks/old_network/nn/networks/factory/secnn.py index 270f0a357..f12e6de1a 100644 --- a/clinicadl/nn/networks/factory/secnn.py +++ b/clinicadl/networks/old_network/nn/networks/factory/secnn.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from clinicadl.nn.blocks import ResBlock_SE +from clinicadl.networks.old_network.nn.blocks import ResBlock_SE class SECNNDesigner3D(nn.Module): diff --git a/clinicadl/nn/networks/random.py b/clinicadl/networks/old_network/nn/networks/random.py similarity index 98% rename from clinicadl/nn/networks/random.py rename to clinicadl/networks/old_network/nn/networks/random.py index 50b18dd60..4122cdf0b 100644 --- a/clinicadl/nn/networks/random.py +++ b/clinicadl/networks/old_network/nn/networks/random.py @@ -1,8 +1,8 @@ import numpy as np import torch.nn as nn -from clinicadl.nn.layers import PadMaxPool2d, PadMaxPool3d -from clinicadl.nn.networks.cnn import CNN +from clinicadl.networks.old_network.nn.layers import PadMaxPool2d, PadMaxPool3d +from clinicadl.networks.old_network.nn.networks.cnn import CNN from clinicadl.utils.exceptions import ClinicaDLNetworksError diff --git a/clinicadl/nn/networks/ssda.py b/clinicadl/networks/old_network/nn/networks/ssda.py similarity index 98% rename from clinicadl/nn/networks/ssda.py rename to clinicadl/networks/old_network/nn/networks/ssda.py index a87cb33b5..a774f5c86 100644 --- a/clinicadl/nn/networks/ssda.py +++ b/clinicadl/networks/old_network/nn/networks/ssda.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -from clinicadl.nn.layers import ( +from clinicadl.networks.old_network.nn.layers import ( GradientReversal, get_conv_layer, get_norm_layer, diff --git a/clinicadl/nn/networks/unet.py b/clinicadl/networks/old_network/nn/networks/unet.py similarity index 90% rename from clinicadl/nn/networks/unet.py rename to clinicadl/networks/old_network/nn/networks/unet.py index 45850de29..36450b01f 100644 --- a/clinicadl/nn/networks/unet.py +++ b/clinicadl/networks/old_network/nn/networks/unet.py @@ -1,6 +1,6 @@ from torch import nn -from clinicadl.nn.blocks import UNetDown, UNetFinalLayer, UNetUp +from clinicadl.networks.old_network.nn.blocks import UNetDown, UNetFinalLayer, UNetUp class UNet(nn.Module): diff --git a/clinicadl/nn/networks/vae.py b/clinicadl/networks/old_network/nn/networks/vae.py similarity index 99% rename from clinicadl/nn/networks/vae.py rename to clinicadl/networks/old_network/nn/networks/vae.py index 9e9b3e72f..fe7564ef9 100644 --- a/clinicadl/nn/networks/vae.py +++ b/clinicadl/networks/old_network/nn/networks/vae.py @@ -1,14 +1,14 @@ import torch import torch.nn as nn -from clinicadl.nn.blocks import ( +from clinicadl.networks.old_network.nn.blocks import ( Decoder3D, Encoder3D, VAE_Decoder2D, VAE_Encoder2D, ) -from clinicadl.nn.layers import Unflatten3D -from clinicadl.nn.utils import multiply_list +from clinicadl.networks.old_network.nn.layers import Unflatten3D +from clinicadl.networks.old_network.nn.utils import multiply_list from clinicadl.utils.enum import BaseEnum diff --git a/clinicadl/nn/utils.py b/clinicadl/networks/old_network/nn/utils.py similarity index 98% rename from clinicadl/nn/utils.py rename to clinicadl/networks/old_network/nn/utils.py index dc3afd71c..263afc407 100644 --- a/clinicadl/nn/utils.py +++ b/clinicadl/networks/old_network/nn/utils.py @@ -64,7 +64,6 @@ def compute_output_size( input_ = torch.randn(input_size).unsqueeze(0) if isinstance(layer, nn.MaxUnpool3d) or isinstance(layer, nn.MaxUnpool2d): indices = torch.zeros_like(input_, dtype=int) - print(indices) output = layer(input_, indices) else: output = layer(input_) diff --git a/clinicadl/network/sub_network.py b/clinicadl/networks/old_network/sub_network.py similarity index 98% rename from clinicadl/network/sub_network.py rename to clinicadl/networks/old_network/sub_network.py index 9d17e8600..e3feb1347 100644 --- a/clinicadl/network/sub_network.py +++ b/clinicadl/networks/old_network/sub_network.py @@ -4,8 +4,8 @@ import torch from torch import nn -from clinicadl.network.network import Network -from clinicadl.network.network_utils import ( +from clinicadl.networks.old_network.network import Network +from clinicadl.networks.old_network.network_utils import ( CropMaxUnpool2d, CropMaxUnpool3d, PadMaxPool2d, diff --git a/clinicadl/networks/old_network/unet/__init__.py b/clinicadl/networks/old_network/unet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/clinicadl/network/unet/unet.py b/clinicadl/networks/old_network/unet/unet.py similarity index 98% rename from clinicadl/network/unet/unet.py rename to clinicadl/networks/old_network/unet/unet.py index 3743f13d8..f23729def 100644 --- a/clinicadl/network/unet/unet.py +++ b/clinicadl/networks/old_network/unet/unet.py @@ -1,7 +1,7 @@ import torch from torch import nn -from clinicadl.network.network import Network +from clinicadl.networks.old_network.network import Network class UNetDown(nn.Module): diff --git a/clinicadl/networks/old_network/vae/__init__.py b/clinicadl/networks/old_network/vae/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/clinicadl/network/vae/advanced_CVAE.py b/clinicadl/networks/old_network/vae/advanced_CVAE.py similarity index 97% rename from clinicadl/network/vae/advanced_CVAE.py rename to clinicadl/networks/old_network/vae/advanced_CVAE.py index 2da43916e..d174df157 100644 --- a/clinicadl/network/vae/advanced_CVAE.py +++ b/clinicadl/networks/old_network/vae/advanced_CVAE.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch.nn.functional as F -from clinicadl.network.network import Network -from clinicadl.network.vae.vae_utils import multiply_list +from clinicadl.networks.old_network.network import Network +from clinicadl.networks.old_network.vae.vae_utils import multiply_list class CVAE_3D_final_conv(Network): diff --git a/clinicadl/network/vae/base_vae.py b/clinicadl/networks/old_network/vae/base_vae.py similarity index 97% rename from clinicadl/network/vae/base_vae.py rename to clinicadl/networks/old_network/vae/base_vae.py index b9ccb4808..c8d2cbef2 100644 --- a/clinicadl/network/vae/base_vae.py +++ b/clinicadl/networks/old_network/vae/base_vae.py @@ -1,6 +1,6 @@ import torch -from clinicadl.network.network import Network +from clinicadl.networks.old_network.network import Network class BaseVAE(Network): diff --git a/clinicadl/network/vae/convolutional_VAE.py b/clinicadl/networks/old_network/vae/convolutional_VAE.py similarity index 98% rename from clinicadl/network/vae/convolutional_VAE.py rename to clinicadl/networks/old_network/vae/convolutional_VAE.py index ab29c842e..5021b826d 100644 --- a/clinicadl/network/vae/convolutional_VAE.py +++ b/clinicadl/networks/old_network/vae/convolutional_VAE.py @@ -4,8 +4,8 @@ import torch.nn as nn import torch.nn.functional as F -from clinicadl.network.network import Network -from clinicadl.network.vae.vae_utils import multiply_list +from clinicadl.networks.old_network.network import Network +from clinicadl.networks.old_network.vae.vae_utils import multiply_list class CVAE_3D(Network): diff --git a/clinicadl/network/vae/vae_layers.py b/clinicadl/networks/old_network/vae/vae_layers.py similarity index 99% rename from clinicadl/network/vae/vae_layers.py rename to clinicadl/networks/old_network/vae/vae_layers.py index dfa9f0e15..a84067b99 100644 --- a/clinicadl/network/vae/vae_layers.py +++ b/clinicadl/networks/old_network/vae/vae_layers.py @@ -1,7 +1,7 @@ import torch.nn.functional as F from torch import nn -from clinicadl.network.vae.vae_utils import get_norm2d, get_norm3d +from clinicadl.networks.old_network.vae.vae_utils import get_norm2d, get_norm3d class EncoderLayer2D(nn.Module): diff --git a/clinicadl/network/vae/vae_utils.py b/clinicadl/networks/old_network/vae/vae_utils.py similarity index 100% rename from clinicadl/network/vae/vae_utils.py rename to clinicadl/networks/old_network/vae/vae_utils.py diff --git a/clinicadl/network/vae/vanilla_vae.py b/clinicadl/networks/old_network/vae/vanilla_vae.py similarity index 99% rename from clinicadl/network/vae/vanilla_vae.py rename to clinicadl/networks/old_network/vae/vanilla_vae.py index 200db6cc1..a7494f385 100644 --- a/clinicadl/network/vae/vanilla_vae.py +++ b/clinicadl/networks/old_network/vae/vanilla_vae.py @@ -1,8 +1,8 @@ import torch from torch import nn -from clinicadl.network.vae.base_vae import BaseVAE -from clinicadl.network.vae.vae_layers import ( +from clinicadl.networks.old_network.vae.base_vae import BaseVAE +from clinicadl.networks.old_network.vae.vae_layers import ( DecoderLayer3D, EncoderLayer3D, Flatten, diff --git a/clinicadl/optim/__init__.py b/clinicadl/optimization/__init__.py similarity index 100% rename from clinicadl/optim/__init__.py rename to clinicadl/optimization/__init__.py diff --git a/clinicadl/optim/config.py b/clinicadl/optimization/config.py similarity index 100% rename from clinicadl/optim/config.py rename to clinicadl/optimization/config.py diff --git a/clinicadl/optim/early_stopping/__init__.py b/clinicadl/optimization/early_stopping/__init__.py similarity index 100% rename from clinicadl/optim/early_stopping/__init__.py rename to clinicadl/optimization/early_stopping/__init__.py diff --git a/clinicadl/optim/early_stopping/config.py b/clinicadl/optimization/early_stopping/config.py similarity index 100% rename from clinicadl/optim/early_stopping/config.py rename to clinicadl/optimization/early_stopping/config.py diff --git a/clinicadl/optim/early_stopping/early_stopper.py b/clinicadl/optimization/early_stopping/early_stopper.py similarity index 100% rename from clinicadl/optim/early_stopping/early_stopper.py rename to clinicadl/optimization/early_stopping/early_stopper.py diff --git a/clinicadl/optim/lr_scheduler/__init__.py b/clinicadl/optimization/lr_scheduler/__init__.py similarity index 100% rename from clinicadl/optim/lr_scheduler/__init__.py rename to clinicadl/optimization/lr_scheduler/__init__.py diff --git a/clinicadl/optim/lr_scheduler/config.py b/clinicadl/optimization/lr_scheduler/config.py similarity index 100% rename from clinicadl/optim/lr_scheduler/config.py rename to clinicadl/optimization/lr_scheduler/config.py diff --git a/clinicadl/optim/lr_scheduler/enum.py b/clinicadl/optimization/lr_scheduler/enum.py similarity index 100% rename from clinicadl/optim/lr_scheduler/enum.py rename to clinicadl/optimization/lr_scheduler/enum.py diff --git a/clinicadl/optim/lr_scheduler/factory.py b/clinicadl/optimization/lr_scheduler/factory.py similarity index 100% rename from clinicadl/optim/lr_scheduler/factory.py rename to clinicadl/optimization/lr_scheduler/factory.py diff --git a/clinicadl/optim/optimizer/__init__.py b/clinicadl/optimization/optimizer/__init__.py similarity index 100% rename from clinicadl/optim/optimizer/__init__.py rename to clinicadl/optimization/optimizer/__init__.py diff --git a/clinicadl/optim/optimizer/config.py b/clinicadl/optimization/optimizer/config.py similarity index 100% rename from clinicadl/optim/optimizer/config.py rename to clinicadl/optimization/optimizer/config.py diff --git a/clinicadl/optim/optimizer/enum.py b/clinicadl/optimization/optimizer/enum.py similarity index 100% rename from clinicadl/optim/optimizer/enum.py rename to clinicadl/optimization/optimizer/enum.py diff --git a/clinicadl/optim/optimizer/factory.py b/clinicadl/optimization/optimizer/factory.py similarity index 100% rename from clinicadl/optim/optimizer/factory.py rename to clinicadl/optimization/optimizer/factory.py diff --git a/clinicadl/optim/optimizer/utils.py b/clinicadl/optimization/optimizer/utils.py similarity index 100% rename from clinicadl/optim/optimizer/utils.py rename to clinicadl/optimization/optimizer/utils.py diff --git a/clinicadl/optimizer/optimization.py b/clinicadl/optimizer/optimization.py deleted file mode 100644 index eba352f2e..000000000 --- a/clinicadl/optimizer/optimization.py +++ /dev/null @@ -1,16 +0,0 @@ -from logging import getLogger - -from pydantic import BaseModel, ConfigDict -from pydantic.types import PositiveInt - -logger = getLogger("clinicadl.optimization_config") - - -class OptimizationConfig(BaseModel): - """Config class to configure the optimization process.""" - - accumulation_steps: PositiveInt = 1 - epochs: PositiveInt = 20 - profiler: bool = False - # pydantic config - model_config = ConfigDict(validate_assignment=True) diff --git a/clinicadl/optimizer/optimizer.py b/clinicadl/optimizer/optimizer.py deleted file mode 100644 index 2beb9b913..000000000 --- a/clinicadl/optimizer/optimizer.py +++ /dev/null @@ -1,18 +0,0 @@ -from logging import getLogger - -from pydantic import BaseModel, ConfigDict -from pydantic.types import NonNegativeFloat, PositiveFloat - -from clinicadl.utils.enum import Optimizer - -logger = getLogger("clinicadl.optimizer_config") - - -class OptimizerConfig(BaseModel): - """Config class to configure the optimizer.""" - - learning_rate: PositiveFloat = 1e-4 - optimizer: Optimizer = Optimizer.ADAM - weight_decay: NonNegativeFloat = 1e-4 - # pydantic config - model_config = ConfigDict(validate_assignment=True) diff --git a/clinicadl/predict/config.py b/clinicadl/predict/config.py deleted file mode 100644 index a96b4b104..000000000 --- a/clinicadl/predict/config.py +++ /dev/null @@ -1,40 +0,0 @@ -from logging import getLogger - -from clinicadl.caps_dataset.data_config import DataConfig as DataBaseConfig -from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig -from clinicadl.maps_manager.config import ( - MapsManagerConfig as MapsManagerBaseConfig, -) -from clinicadl.splitter.config import SplitConfig -from clinicadl.splitter.validation import ValidationConfig -from clinicadl.utils.computational.computational import ComputationalConfig -from clinicadl.utils.exceptions import ClinicaDLArgumentError # type: ignore - -logger = getLogger("clinicadl.predict_config") - - -class MapsManagerConfig(MapsManagerBaseConfig): - save_tensor: bool = False - save_latent_tensor: bool = False - - def check_output_saving_tensor(self, network_task: str) -> None: - # Check if task is reconstruction for "save_tensor" and "save_nifti" - if self.save_tensor and network_task != "reconstruction": - raise ClinicaDLArgumentError( - "Cannot save tensors if the network task is not reconstruction. Please remove --save_tensor option." - ) - - -class DataConfig(DataBaseConfig): - use_labels: bool = True - - -class PredictConfig( - MapsManagerConfig, - DataConfig, - ValidationConfig, - ComputationalConfig, - DataLoaderConfig, - SplitConfig, -): - """Config class to perform Transfer Learning.""" diff --git a/clinicadl/predictor/__init__.py b/clinicadl/predictor/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/clinicadl/predictor/config.py b/clinicadl/predictor/config.py new file mode 100644 index 000000000..6075890aa --- /dev/null +++ b/clinicadl/predictor/config.py @@ -0,0 +1,105 @@ +from logging import getLogger +from typing import Any, Dict + +from pydantic import BaseModel, ConfigDict, computed_field + +from clinicadl.dataset.data_config import DataConfig as DataBaseConfig +from clinicadl.dataset.dataloader_config import DataLoaderConfig +from clinicadl.experiment_manager.config import ( + MapsManagerConfig as MapsManagerBaseConfig, +) +from clinicadl.experiment_manager.maps_manager import MapsManager +from clinicadl.predictor.validation import ValidationConfig +from clinicadl.splitter.config import SplitConfig +from clinicadl.transforms.config import TransformsConfig +from clinicadl.utils.computational.computational import ComputationalConfig +from clinicadl.utils.enum import Task +from clinicadl.utils.exceptions import ClinicaDLArgumentError # type: ignore + +logger = getLogger("clinicadl.predict_config") + + +class MapsManagerConfig(MapsManagerBaseConfig): + save_tensor: bool = False + save_latent_tensor: bool = False + + def check_output_saving_tensor(self, network_task: str) -> None: + # Check if task is reconstruction for "save_tensor" and "save_nifti" + if self.save_tensor and network_task != "reconstruction": + raise ClinicaDLArgumentError( + "Cannot save tensors if the network task is not reconstruction. Please remove --save_tensor option." + ) + + +class DataConfig(DataBaseConfig): + use_labels: bool = True + + +class PredictConfig(BaseModel): + """Config class to perform Transfer Learning.""" + + maps_manager: MapsManagerConfig + data: DataConfig + validation: ValidationConfig + computational: ComputationalConfig + dataloader: DataLoaderConfig + split: SplitConfig + transforms: TransformsConfig + + model_config = ConfigDict(validate_assignment=True) + + def __init__(self, **kwargs): + super().__init__( + maps_manager=kwargs, + computational=kwargs, + dataloader=kwargs, + data=kwargs, + split=kwargs, + validation=kwargs, + transforms=kwargs, + ) + + def _update(self, config_dict: Dict[str, Any]) -> None: + """Updates the configs with a dict given by the user.""" + self.data.__dict__.update(config_dict) + self.split.__dict__.update(config_dict) + self.validation.__dict__.update(config_dict) + self.maps_manager.__dict__.update(config_dict) + self.split.__dict__.update(config_dict) + self.computational.__dict__.update(config_dict) + self.dataloader.__dict__.update(config_dict) + self.transforms.__dict__.update(config_dict) + + def adapt_with_maps_manager_info(self, maps_manager: MapsManager): + self.maps_manager.check_output_saving_nifti(maps_manager.network_task) + self.data.diagnoses = ( + maps_manager.diagnoses + if self.data.diagnoses is None or len(self.data.diagnoses) == 0 + else self.data.diagnoses + ) + + self.dataloader.batch_size = ( + maps_manager.batch_size + if not self.dataloader.batch_size + else self.dataloader.batch_size + ) + self.dataloader.n_proc = ( + maps_manager.n_proc + if not self.dataloader.n_proc + else self.dataloader.n_proc + ) + + self.split.adapt_cross_val_with_maps_manager_info(maps_manager) + self.maps_manager.check_output_saving_tensor(maps_manager.network_task) + + self.transforms = TransformsConfig( + normalize=maps_manager.normalize, + data_augmentation=maps_manager.data_augmentation, + size_reduction=maps_manager.size_reduction, + size_reduction_factor=maps_manager.size_reduction_factor, + ) + + if self.split.split is None and self.split.n_splits == 0: + from clinicadl.splitter.split_utils import find_splits + + self.split.split = find_splits(self.maps_manager.maps_dir) diff --git a/clinicadl/predict/predict_manager.py b/clinicadl/predictor/old_predictor.py similarity index 54% rename from clinicadl/predict/predict_manager.py rename to clinicadl/predictor/old_predictor.py index 55515dc8e..e4d727190 100644 --- a/clinicadl/predict/predict_manager.py +++ b/clinicadl/predictor/old_predictor.py @@ -2,328 +2,138 @@ import shutil from logging import getLogger from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import pandas as pd import torch import torch.distributed as dist from torch.amp import autocast +from torch.nn.modules.loss import _Loss from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from clinicadl.caps_dataset.data import ( +from clinicadl.dataset.caps_dataset import ( return_dataset, ) +from clinicadl.experiment_manager.maps_manager import MapsManager from clinicadl.interpret.config import InterpretConfig -from clinicadl.maps_manager.maps_manager import MapsManager -from clinicadl.metrics.utils import ( + +from clinicadl.metrics.old_metrics.metric_module import MetricModule +from clinicadl.metrics.old_metrics.utils import ( check_selection_metric, find_selection_metrics, ) -from clinicadl.predict.config import PredictConfig -from clinicadl.trainer.tasks_utils import generate_label_code, get_criterion +from clinicadl.networks.old_network.network import Network + +from clinicadl.predictor.config import PredictConfig +from clinicadl.trainer.tasks_utils import ( + columns, + compute_metrics, + generate_label_code, + generate_test_row, + get_criterion, +) from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.computational.ddp import DDP, cluster +from clinicadl.utils.enum import Task from clinicadl.utils.exceptions import ( ClinicaDLArgumentError, ClinicaDLDataLeakageError, MAPSError, ) -from clinicadl.validator.validator import Validator logger = getLogger("clinicadl.predict_manager") level_list: List[str] = ["warning", "info", "debug"] -class PredictManager: +class Predictor: def __init__(self, _config: Union[PredictConfig, InterpretConfig]) -> None: - self.maps_manager = MapsManager(_config.maps_dir) self._config = _config - self.validator = Validator() + + + self.maps_manager = MapsManager(_config.maps_manager.maps_dir) + self._config.adapt_with_maps_manager_info(self.maps_manager) + tmp = self._config.data.model_dump( + exclude=set(["preprocessing_dict", "mode", "caps_dict"]) + ) + tmp.update(self._config.split.model_dump()) + tmp.update(self._config.validation.model_dump()) + self.splitter = Splitter(SplitterConfig(**tmp)) def predict( self, label_code: Union[str, dict[str, int]] = "default", ): - """Performs the prediction task on a subset of caps_directory defined in a TSV file. - Parameters - ---------- - data_group : str - name of the data group tested. - caps_directory : Path (optional, default=None) - path to the CAPS folder. For more information please refer to - [clinica documentation](https://aramislab.paris.inria.fr/clinica/docs/public/latest/CAPS/Introduction/). - Default will load the value of an existing data group - tsv_path : Path (optional, default=None) - path to a TSV file containing the list of participants and sessions to test. - Default will load the DataFrame of an existing data group - split_list : List[int] (optional, default=None) - list of splits to test. Default perform prediction on all splits available. - selection_metrics : List[str] (optional, default=None) - list of selection metrics to test. - Default performs the prediction on all selection metrics available. - multi_cohort : bool (optional, default=False) - If True considers that tsv_path is the path to a multi-cohort TSV. - diagnoses : List[str] (optional, default=()) - List of diagnoses to load if tsv_path is a split_directory. - Default uses the same as in training step. - use_labels : bool (optional, default=True) - If True, the labels must exist in test meta-data and metrics are computed. - batch_size : int (optional, default=None) - If given, sets the value of batch_size, else use the same as in training step. - n_proc : int (optional, default=None) - If given, sets the value of num_workers, else use the same as in training step. - gpu : bool (optional, default=None) - If given, a new value for the device of the model will be computed. - amp : bool (optional, default=False) - If enabled, uses Automatic Mixed Precision (requires GPU usage). - overwrite : bool (optional, default=False) - If True erase the occurrences of data_group. - label : str (optional, default=None) - Target label used for training (if network_task in [`regression`, `classification`]). - label_code : Optional[Dict[str, int]] (optional, default="default") - dictionary linking the target values to a node number. - save_tensor : bool (optional, default=False) - If true, save the tensor predicted for reconstruction task - save_nifti : bool (optional, default=False) - If true, save the nifti associated to the prediction for reconstruction task. - save_latent_tensor : bool (optional, default=False) - If true, save the tensor from the latent space for reconstruction task. - skip_leak_check : bool (optional, default=False) - If true, skip the leak check (not recommended). - Examples - -------- - >>> _input_ - _output_ - """ + """Performs the prediction task on a subset of caps_directory defined in a TSV file.""" - assert isinstance(self._config, PredictConfig) - - self._config.check_output_saving_nifti(self.maps_manager.network_task) - self._config.diagnoses = ( - self.maps_manager.diagnoses - if self._config.diagnoses is None or len(self._config.diagnoses) == 0 - else self._config.diagnoses - ) - - self._config.batch_size = ( - self.maps_manager.batch_size - if not self._config.batch_size - else self._config.batch_size - ) - self._config.n_proc = ( - self.maps_manager.n_proc if not self._config.n_proc else self._config.n_proc - ) - - self._config.adapt_cross_val_with_maps_manager_info(self.maps_manager) - self._config.check_output_saving_tensor(self.maps_manager.network_task) - - transforms = TransformsConfig( - normalize=self.maps_manager.normalize, - data_augmentation=self.maps_manager.data_augmentation, - size_reduction=self.maps_manager.size_reduction, - size_reduction_factor=self.maps_manager.size_reduction_factor, - ) - group_df = self._config.create_groupe_df() + group_df = self._config.data.create_groupe_df() self._check_data_group(group_df) criterion = get_criterion( self.maps_manager.network_task, self.maps_manager.loss ) - self._check_data_group(df=group_df) - assert self._config.split # don't know if needed ? try to raise an exception ? - # assert self._config.label - - for split in self._config.split: + for split in self.splitter.split_iterator(): logger.info(f"Prediction of split {split}") group_df, group_parameters = self.get_group_info( - self._config.data_group, split + self._config.maps_manager.data_group, split ) # Find label code if not given - if self._config.is_given_label_code(self.maps_manager.label, label_code): + if self._config.data.is_given_label_code( + self.maps_manager.label, label_code + ): generate_label_code( - self.maps_manager.network_task, group_df, self._config.label + self.maps_manager.network_task, group_df, self._config.data.label ) # Erase previous TSV files on master process - if not self._config.selection_metrics: + if not self._config.validation.selection_metrics: split_selection_metrics = find_selection_metrics( self.maps_manager.maps_path, split, ) else: - split_selection_metrics = self._config.selection_metrics + split_selection_metrics = self._config.validation.selection_metrics for selection in split_selection_metrics: tsv_dir = ( self.maps_manager.maps_path / f"split-{split}" / f"best-{selection}" - / self._config.data_group + / self._config.maps_manager.data_group ) - tsv_pattern = f"{self._config.data_group}*.tsv" + tsv_pattern = f"{self._config.maps_manager.data_group}*.tsv" for tsv_file in tsv_dir.glob(tsv_pattern): tsv_file.unlink() - self._config.check_label(self.maps_manager.label) + + self._config.data.check_label(self.maps_manager.label) if self.maps_manager.multi_network: - self._predict_multi( - group_parameters, - group_df, - transforms, - label_code, - criterion, - split, - split_selection_metrics, - ) + for network in range(self.maps_manager.num_networks): + self._predict_single( + group_parameters, + group_df, + self._config.transforms, + label_code, + criterion, + split, + split_selection_metrics, + network, + ) else: self._predict_single( group_parameters, group_df, - transforms, + self._config.transforms, label_code, criterion, split, split_selection_metrics, ) if cluster.master: - self.validator._ensemble_prediction( + self._ensemble_prediction( self.maps_manager, - self._config.data_group, + self._config.maps_manager.data_group, split, - self._config.selection_metrics, - self._config.use_labels, - self._config.skip_leak_check, - ) - - def _predict_multi( - self, - group_parameters, - group_df, - transforms, - label_code, - criterion, - split, - split_selection_metrics, - ): - """_summary_ - Parameters - ---------- - group_parameters : _type_ - _description_ - group_df : _type_ - _description_ - all_transforms : _type_ - _description_ - use_labels : _type_ - _description_ - label : _type_ - _description_ - label_code : _type_ - _description_ - batch_size : _type_ - _description_ - n_proc : _type_ - _description_ - criterion : _type_ - _description_ - data_group : _type_ - _description_ - split : _type_ - _description_ - split_selection_metrics : _type_ - _description_ - gpu : _type_ - _description_ - amp : _type_ - _description_ - save_tensor : _type_ - _description_ - save_latent_tensor : _type_ - _description_ - save_nifti : _type_ - _description_ - selection_metrics : _type_ - _description_ - Examples - -------- - >>> _input_ - _output_ - Notes - ----- - _notes_ - See Also - -------- - - _related_ - """ - assert isinstance(self._config, PredictConfig) - # assert self._config.label - - for network in range(self.maps_manager.num_networks): - data_test = return_dataset( - group_parameters["caps_directory"], - group_df, - self.maps_manager.preprocessing_dict, - transforms_config=transforms, - multi_cohort=group_parameters["multi_cohort"], - label_presence=self._config.use_labels, - label=self._config.label, - label_code=( - self.maps_manager.label_code - if label_code == "default" - else label_code - ), - cnn_index=network, - ) - test_loader = DataLoader( - data_test, - batch_size=( - self._config.batch_size - if self._config.batch_size is not None - else self.maps_manager.batch_size - ), - shuffle=False, - sampler=DistributedSampler( - data_test, - num_replicas=cluster.world_size, - rank=cluster.rank, - shuffle=False, - ), - num_workers=self._config.n_proc - if self._config.n_proc is not None - else self.maps_manager.n_proc, - ) - self.validator._test_loader( - maps_manager=self.maps_manager, - dataloader=test_loader, - criterion=criterion, - data_group=self._config.data_group, - split=split, - selection_metrics=split_selection_metrics, - use_labels=self._config.use_labels, - gpu=self._config.gpu, - amp=self._config.amp, - network=network, - ) - if self._config.save_tensor: - logger.debug("Saving tensors") - self.validator._compute_output_tensors( - self.maps_manager, - data_test, - self._config.data_group, - split, - self._config.selection_metrics, - gpu=self._config.gpu, - network=network, - ) - if self._config.save_nifti: - self._compute_output_nifti( - data_test, - split, - network=network, - ) - if self._config.save_latent_tensor: - self._compute_latent_tensors( - dataset=data_test, - split=split, - network=network, + self._config.validation.selection_metrics, + self._config.data.use_labels, + self._config.validation.skip_leak_check, ) def _predict_single( @@ -335,78 +145,31 @@ def _predict_single( criterion, split, split_selection_metrics, + network: Optional[int] = None, ): - """_summary_ - Parameters - ---------- - group_parameters : _type_ - _description_ - group_df : _type_ - _description_ - all_transforms : _type_ - _description_ - use_labels : _type_ - _description_ - label : _type_ - _description_ - label_code : _type_ - _description_ - batch_size : _type_ - _description_ - n_proc : _type_ - _description_ - criterion : _type_ - _description_ - data_group : _type_ - _description_ - split : _type_ - _description_ - split_selection_metrics : _type_ - _description_ - gpu : _type_ - _description_ - amp : _type_ - _description_ - save_tensor : _type_ - _description_ - save_latent_tensor : _type_ - _description_ - save_nifti : _type_ - _description_ - selection_metrics : _type_ - _description_ - Examples - -------- - >>> _input_ - _output_ - Notes - ----- - _notes_ - See Also - -------- - - _related_ - """ + """_summary_""" assert isinstance(self._config, PredictConfig) - # assert self._config.label + # assert self._config.data.label data_test = return_dataset( group_parameters["caps_directory"], group_df, self.maps_manager.preprocessing_dict, - transforms_config=transforms, + transforms_config=self._config.transforms, multi_cohort=group_parameters["multi_cohort"], - label_presence=self._config.use_labels, - label=self._config.label, + label_presence=self._config.data.use_labels, + label=self._config.data.label, label_code=( self.maps_manager.label_code if label_code == "default" else label_code ), + cnn_index=network, ) test_loader = DataLoader( data_test, batch_size=( - self._config.batch_size - if self._config.batch_size is not None + self._config.dataloader.batch_size + if self._config.dataloader.batch_size is not None else self.maps_manager.batch_size ), shuffle=False, @@ -416,40 +179,44 @@ def _predict_single( rank=cluster.rank, shuffle=False, ), - num_workers=self._config.n_proc - if self._config.n_proc is not None + num_workers=self._config.dataloader.n_proc + if self._config.dataloader.n_proc is not None else self.maps_manager.n_proc, ) - self.validator._test_loader( - self.maps_manager, - test_loader, - criterion, - self._config.data_group, - split, - split_selection_metrics, - use_labels=self._config.use_labels, - gpu=self._config.gpu, - amp=self._config.amp, + self._test_loader( + maps_manager=self.maps_manager, + dataloader=test_loader, + criterion=criterion, + data_group=self._config.maps_manager.data_group, + split=split, + selection_metrics=split_selection_metrics, + use_labels=self._config.data.use_labels, + gpu=self._config.computational.gpu, + amp=self._config.computational.amp, + network=network, ) - if self._config.save_tensor: + if self._config.maps_manager.save_tensor: logger.debug("Saving tensors") - self.validator._compute_output_tensors( - self.maps_manager, - data_test, - self._config.data_group, - split, - self._config.selection_metrics, - gpu=self._config.gpu, + self._compute_output_tensors( + maps_manager=self.maps_manager, + dataset=data_test, + data_group=self._config.maps_manager.data_group, + split=split, + selection_metrics=self._config.validation.selection_metrics, + gpu=self._config.computational.gpu, + network=network, ) - if self._config.save_nifti: + if self._config.maps_manager.save_nifti: self._compute_output_nifti( - data_test, - split, + dataset=data_test, + split=split, + network=network, ) - if self._config.save_latent_tensor: + if self._config.maps_manager.save_latent_tensor: self._compute_latent_tensors( dataset=data_test, split=split, + network=network, ) def _compute_latent_tensors( @@ -478,13 +245,13 @@ def _compute_latent_tensors( network : _type_ (optional, default=None) Index of the network tested (only used in multi-network setting). """ - for selection_metric in self._config.selection_metrics: + for selection_metric in self._config.validation.selection_metrics: # load the best trained model during the training model, _ = self.maps_manager._init_model( transfer_path=self.maps_manager.maps_path, split=split, transfer_selection=selection_metric, - gpu=self._config.gpu, + gpu=self._config.computational.gpu, network=network, nb_unfrozen_layer=self.maps_manager.nb_unfrozen_layer, ) @@ -498,7 +265,7 @@ def _compute_latent_tensors( self.maps_manager.maps_path / f"split-{split}" / f"best-{selection_metric}" - / self._config.data_group + / self._config.maps_manager.data_group / "latent_tensors" ) if cluster.master: @@ -555,13 +322,13 @@ def _compute_output_nifti( import nibabel as nib from numpy import eye - for selection_metric in self._config.selection_metrics: + for selection_metric in self._config.validation.selection_metrics: # load the best trained model during the training model, _ = self.maps_manager._init_model( transfer_path=self.maps_manager.maps_path, split=split, transfer_selection=selection_metric, - gpu=self._config.gpu, + gpu=self._config.computational.gpu, network=network, nb_unfrozen_layer=self.maps_manager.nb_unfrozen_layer, ) @@ -575,7 +342,7 @@ def _compute_output_nifti( self.maps_manager.maps_path / f"split-{split}" / f"best-{selection_metric}" - / self._config.data_group + / self._config.maps_manager.data_group / "nifti_images" ) if cluster.master: @@ -608,77 +375,10 @@ def _compute_output_nifti( def interpret(self): """Performs the interpretation task on a subset of caps_directory defined in a TSV file. The mean interpretation is always saved, to save the individual interpretations set save_individual to True. - Parameters - ---------- - data_group : str - Name of the data group interpreted. - name : str - Name of the interpretation procedure. - method : str - Method used for extraction (ex: gradients, grad-cam...). - caps_directory : Path (optional, default=None) - Path to the CAPS folder. For more information please refer to - [clinica documentation](https://aramislab.paris.inria.fr/clinica/docs/public/latest/CAPS/Introduction/). - Default will load the value of an existing data group. - tsv_path : Path (optional, default=None) - Path to a TSV file containing the list of participants and sessions to test. - Default will load the DataFrame of an existing data group. - split_list : list[int] (optional, default=None) - List of splits to interpret. Default perform interpretation on all splits available. - selection_metrics : list[str] (optional, default=None) - List of selection metrics to interpret. - Default performs the interpretation on all selection metrics available. - multi_cohort : bool (optional, default=False) - If True considers that tsv_path is the path to a multi-cohort TSV. - diagnoses : list[str] (optional, default=()) - List of diagnoses to load if tsv_path is a split_directory. - Default uses the same as in training step. - target_node : int (optional, default=0) - Node from which the interpretation is computed. - save_individual : bool (optional, default=False) - If True saves the individual map of each participant / session couple. - batch_size : int (optional, default=None) - If given, sets the value of batch_size, else use the same as in training step. - n_proc : int (optional, default=None) - If given, sets the value of num_workers, else use the same as in training step. - gpu : bool (optional, default=None) - If given, a new value for the device of the model will be computed. - amp : bool (optional, default=False) - If enabled, uses Automatic Mixed Precision (requires GPU usage). - overwrite : bool (optional, default=False) - If True erase the occurrences of data_group. - overwrite_name : bool (optional, default=False) - If True erase the occurrences of name. - level : int (optional, default=None) - Layer number in the convolutional part after which the feature map is chosen. - save_nifti : bool (optional, default=False) - If True, save the interpretation map in nifti format. - Raises - ------ - NotImplementedError - If the method is not implemented - NotImplementedError - If the interpretaion of multi network is asked - MAPSError - If the interpretation has already been determined. """ assert isinstance(self._config, InterpretConfig) - self._config.diagnoses = ( - self.maps_manager.diagnoses - if self._config.diagnoses is None or len(self._config.diagnoses) == 0 - else self._config.diagnoses - ) - self._config.batch_size = ( - self.maps_manager.batch_size - if not self._config.batch_size - else self._config.batch_size - ) - self._config.n_proc = ( - self.maps_manager.n_proc if not self._config.n_proc else self._config.n_proc - ) - - self._config.adapt_cross_val_with_maps_manager_info(self.maps_manager) + self._config.adapt_with_maps_manager_info(self.maps_manager) if self.maps_manager.multi_network: raise NotImplementedError( @@ -690,14 +390,13 @@ def interpret(self): size_reduction=self.maps_manager.size_reduction, size_reduction_factor=self.maps_manager.size_reduction_factor, ) - group_df = self._config.create_groupe_df() + group_df = self._config.data.create_groupe_df() self._check_data_group(group_df) - assert self._config.split - for split in self._config.split: + for split in self.splitter.split_iterator(): logger.info(f"Interpretation of split {split}") df_group, parameters_group = self.get_group_info( - self._config.data_group, split + self._config.maps_manager.data_group, split ) data_test = return_dataset( parameters_group["caps_directory"], @@ -711,30 +410,30 @@ def interpret(self): ) test_loader = DataLoader( data_test, - batch_size=self._config.batch_size, + batch_size=self._config.dataloader.batch_size, shuffle=False, - num_workers=self._config.n_proc, + num_workers=self._config.dataloader.n_proc, ) - if not self._config.selection_metrics: - self._config.selection_metrics = find_selection_metrics( + if not self._config.validation.selection_metrics: + self._config.validation.selection_metrics = find_selection_metrics( self.maps_manager.maps_path, split, ) - for selection_metric in self._config.selection_metrics: + for selection_metric in self._config.validation.selection_metrics: logger.info(f"Interpretation of metric {selection_metric}") results_path = ( self.maps_manager.maps_path / f"split-{split}" / f"best-{selection_metric}" - / self._config.data_group - / f"interpret-{self._config.name}" + / self._config.maps_manager.data_group + / f"interpret-{self._config.interpret.name}" ) if (results_path).is_dir(): - if self._config.overwrite_name: + if self._config.interpret.overwrite_name: shutil.rmtree(results_path) else: raise MAPSError( - f"Interpretation name {self._config.name} is already written. " + f"Interpretation name {self._config.interpret.name} is already written. " f"Please choose another name or set overwrite_name to True." ) results_path.mkdir(parents=True) @@ -742,28 +441,28 @@ def interpret(self): transfer_path=self.maps_manager.maps_path, split=split, transfer_selection=selection_metric, - gpu=self._config.gpu, + gpu=self._config.computational.gpu, ) - interpreter = self._config.get_method()(model) + interpreter = self._config.interpret.get_method()(model) cum_maps = [0] * data_test.elem_per_image for data in test_loader: images = data["image"].to(model.device) map_pt = interpreter.generate_gradients( images, - self._config.target_node, - level=self._config.level, - amp=self._config.amp, + self._config.interpret.target_node, + level=self._config.interpret.level, + amp=self._config.computational.amp, ) for i in range(len(data["participant_id"])): mode_id = data[f"{self.maps_manager.mode}_id"][i] cum_maps[mode_id] += map_pt[i] - if self._config.save_individual: + if self._config.interpret.save_individual: single_path = ( results_path / f"{data['participant_id'][i]}_{data['session_id'][i]}_{self.maps_manager.mode}-{data[f'{self.maps_manager.mode}_id'][i]}_map.pt" ) torch.save(map_pt[i], single_path) - if self._config.save_nifti: + if self._config.maps_manager.save_nifti: import nibabel as nib from numpy import eye @@ -781,7 +480,7 @@ def interpret(self): mode_map, results_path / f"mean_{self.maps_manager.mode}-{i}_map.pt", ) - if self._config.save_nifti: + if self._config.maps_manager.save_nifti: import nibabel as nib from numpy import eye @@ -801,22 +500,6 @@ def _check_data_group( Parameters ---------- - data_group : str - name of the data group - caps_directory : str (optional, default=None) - input CAPS directory - df : pd.DataFrame (optional, default=None) - Table of participant_id / session_id of the data group - multi_cohort : bool (optional, default=False) - indicates if the input data comes from several CAPS - overwrite : bool (optional, default=False) - If True former definition of data group is erased - label : str (optional, default=None) - label name if applicable - split_list : list[int] (optional, default=None) - _description_ - skip_leak_check : bool (optional, default=False) - _description_ Raises ------ @@ -828,17 +511,21 @@ def _check_data_group( when caps_directory or df are not given and data group does not exist """ - group_dir = self.maps_manager.maps_path / "groups" / self._config.data_group + group_dir = ( + self.maps_manager.maps_path + / "groups" + / self._config.maps_manager.data_group + ) logger.debug(f"Group path {group_dir}") if group_dir.is_dir(): # Data group already exists - if self._config.overwrite: - if self._config.data_group in ["train", "validation"]: + if self._config.maps_manager.overwrite: + if self._config.maps_manager.data_group in ["train", "validation"]: raise MAPSError("Cannot overwrite train or validation data group.") else: - # if not split_list: - # split_list = self.maps_manager.find_splits() + if not self._config.split.split: + self._config.split.split = self.maps_manager.find_splits() assert self._config.split - for split in self._config.split: + for split in self._config.split.split: selection_metrics = find_selection_metrics( self.maps_manager.maps_path, split, @@ -848,40 +535,40 @@ def _check_data_group( self.maps_manager.maps_path / f"split-{split}" / f"best-{selection}" - / self._config.data_group + / self._config.maps_manager.data_group ) if results_path.is_dir(): shutil.rmtree(results_path) elif df is not None or ( - self._config.caps_directory is not None - and self._config.caps_directory != Path("") + self._config.data.caps_directory is not None + and self._config.data.caps_directory != Path("") ): raise ClinicaDLArgumentError( - f"Data group {self._config.data_group} is already defined. " + f"Data group {self._config.maps_manager.data_group} is already defined. " f"Please do not give any caps_directory, tsv_path or multi_cohort to use it. " - f"To erase {self._config.data_group} please set overwrite to True." + f"To erase {self._config.maps_manager.data_group} please set overwrite to True." ) elif not group_dir.is_dir() and ( - self._config.caps_directory is None or df is None + self._config.data.caps_directory is None or df is None ): # Data group does not exist yet / was overwritten + missing data raise ClinicaDLArgumentError( - f"The data group {self._config.data_group} does not already exist. " + f"The data group {self._config.maps_manager.data_group} does not already exist. " f"Please specify a caps_directory and a tsv_path to create this data group." ) elif ( not group_dir.is_dir() ): # Data group does not exist yet / was overwritten + all data is provided - if self._config.skip_leak_check: + if self._config.validation.skip_leak_check: logger.info("Skipping data leakage check") else: - self._check_leakage(self._config.data_group, df) + self._check_leakage(self._config.maps_manager.data_group, df) self._write_data_group( - self._config.data_group, + self._config.maps_manager.data_group, df, - self._config.caps_directory, - self._config.multi_cohort, - label=self._config.label, + self._config.data.caps_directory, + self._config.data.multi_cohort, + label=self._config.data.label, ) def get_group_info( @@ -997,8 +684,8 @@ def _write_data_group( group_path.mkdir(parents=True) columns = ["participant_id", "session_id", "cohort"] - if self._config.label in df.columns.values: - columns += [self._config.label] + if self._config.data.label in df.columns.values: + columns += [self._config.data.label] if label is not None and label in df.columns.values: columns += [label] @@ -1088,3 +775,387 @@ def get_interpretation( weights_only=True, ) return map_pt + + def test( + self, + mode: str, + metrics_module: MetricModule, + n_classes: int, + network_task, + model: Network, + dataloader: DataLoader, + criterion: _Loss, + use_labels: bool = True, + amp: bool = False, + report_ci=False, + ) -> Tuple[pd.DataFrame, Dict[str, float]]: + """ + Computes the predictions and evaluation metrics. + + Parameters + ---------- + model: Network + The model trained. + dataloader: DataLoader + Wrapper of a CapsDataset. + criterion: _Loss + Function to calculate the loss. + use_labels: bool + If True the true_label will be written in output DataFrame + and metrics dict will be created. + amp: bool + If True, enables Pytorch's automatic mixed precision. + + Returns + ------- + the results and metrics on the image level. + """ + model.eval() + dataloader.dataset.eval() + + results_df = pd.DataFrame(columns=columns(network_task, mode, n_classes)) + total_loss = {} + with torch.no_grad(): + for i, data in enumerate(dataloader): + # initialize the loss list to save the loss components + with autocast("cuda", enabled=amp): + outputs, loss_dict = model(data, criterion, use_labels=use_labels) + + if i == 0: + for loss_component in loss_dict.keys(): + total_loss[loss_component] = 0 + for loss_component in total_loss.keys(): + total_loss[loss_component] += loss_dict[loss_component].float() + + # Generate detailed DataFrame + for idx in range(len(data["participant_id"])): + row = generate_test_row( + network_task, + mode, + metrics_module, + n_classes, + idx, + data, + outputs.float(), + ) + row_df = pd.DataFrame( + row, columns=columns(network_task, mode, n_classes) + ) + results_df = pd.concat([results_df, row_df]) + + del outputs, loss_dict + dataframes = [None] * dist.get_world_size() + dist.gather_object( + results_df, dataframes if dist.get_rank() == 0 else None, dst=0 + ) + if dist.get_rank() == 0: + results_df = pd.concat(dataframes) + del dataframes + results_df.reset_index(inplace=True, drop=True) + + if not use_labels: + metrics_dict = None + else: + metrics_dict = compute_metrics( + network_task, results_df, metrics_module, report_ci=report_ci + ) + for loss_component in total_loss.keys(): + dist.reduce(total_loss[loss_component], dst=0) + loss_value = total_loss[loss_component].item() / cluster.world_size + + if report_ci: + metrics_dict["Metric_names"].append(loss_component) + metrics_dict["Metric_values"].append(loss_value) + metrics_dict["Lower_CI"].append("N/A") + metrics_dict["Upper_CI"].append("N/A") + metrics_dict["SE"].append("N/A") + + else: + metrics_dict[loss_component] = loss_value + + torch.cuda.empty_cache() + + return results_df, metrics_dict + + def test_da( + self, + mode: str, + metrics_module: MetricModule, + n_classes: int, + network_task: Union[str, Task], + model: Network, + dataloader: DataLoader, + criterion: _Loss, + alpha: float = 0, + use_labels: bool = True, + target: bool = True, + report_ci=False, + ) -> Tuple[pd.DataFrame, Dict[str, float]]: + """ + Computes the predictions and evaluation metrics. + + Args: + model: the model trained. + dataloader: wrapper of a CapsDataset. + criterion: function to calculate the loss. + use_labels: If True the true_label will be written in output DataFrame + and metrics dict will be created. + Returns: + the results and metrics on the image level. + """ + model.eval() + dataloader.dataset.eval() + results_df = pd.DataFrame(columns=columns(network_task, mode, n_classes)) + total_loss = 0 + with torch.no_grad(): + for i, data in enumerate(dataloader): + outputs, loss_dict = model.compute_outputs_and_loss_test( + data, criterion, alpha, target + ) + total_loss += loss_dict["loss"].item() + + # Generate detailed DataFrame + for idx in range(len(data["participant_id"])): + row = generate_test_row( + network_task, + mode, + metrics_module, + n_classes, + idx, + data, + outputs, + ) + row_df = pd.DataFrame( + row, columns=columns(network_task, mode, n_classes) + ) + results_df = pd.concat([results_df, row_df]) + + del outputs, loss_dict + results_df.reset_index(inplace=True, drop=True) + + if not use_labels: + metrics_dict = None + else: + metrics_dict = compute_metrics( + network_task, results_df, metrics_module, report_ci=report_ci + ) + if report_ci: + metrics_dict["Metric_names"].append("loss") + metrics_dict["Metric_values"].append(total_loss) + metrics_dict["Lower_CI"].append("N/A") + metrics_dict["Upper_CI"].append("N/A") + metrics_dict["SE"].append("N/A") + + else: + metrics_dict["loss"] = total_loss + + torch.cuda.empty_cache() + + return results_df, metrics_dict + + def _test_loader( + self, + maps_manager: MapsManager, + dataloader, + criterion, + data_group: str, + split: int, + selection_metrics, + use_labels=True, + gpu=None, + amp=False, + network=None, + report_ci=True, + ): + """ + Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files. + + Args: + dataloader (torch.utils.data.DataLoader): DataLoader wrapping the test CapsDataset. + criterion (torch.nn.modules.loss._Loss): optimization criterion used during training. + data_group (str): name of the data group used for the testing task. + split (int): Index of the split used to train the model tested. + selection_metrics (list[str]): List of metrics used to select the best models which are tested. + use_labels (bool): If True, the labels must exist in test meta-data and metrics are computed. + gpu (bool): If given, a new value for the device of the model will be computed. + amp (bool): If enabled, uses Automatic Mixed Precision (requires GPU usage). + network (int): Index of the network tested (only used in multi-network setting). + """ + for selection_metric in selection_metrics: + if cluster.master: + log_dir = ( + maps_manager.maps_path + / f"split-{split}" + / f"best-{selection_metric}" + / data_group + ) + maps_manager.write_description_log( + log_dir, + data_group, + dataloader.dataset.config.data.caps_dict, + dataloader.dataset.config.data.data_df, + ) + + # load the best trained model during the training + model, _ = maps_manager._init_model( + transfer_path=maps_manager.maps_path, + split=split, + transfer_selection=selection_metric, + gpu=gpu, + network=network, + ) + model = DDP( + model, + fsdp=maps_manager.fully_sharded_data_parallel, + amp=maps_manager.amp, + ) + + prediction_df, metrics = self.test( + mode=maps_manager.mode, + metrics_module=maps_manager.metrics_module, + n_classes=maps_manager.n_classes, + network_task=maps_manager.network_task, + model=model, + dataloader=dataloader, + criterion=criterion, + use_labels=use_labels, + amp=amp, + report_ci=report_ci, + ) + if use_labels: + if network is not None: + metrics[f"{maps_manager.mode}_id"] = network + + loss_to_log = ( + metrics["Metric_values"][-1] if report_ci else metrics["loss"] + ) + + logger.info( + f"{maps_manager.mode} level {data_group} loss is {loss_to_log} for model selected on {selection_metric}" + ) + + if cluster.master: + # Replace here + maps_manager._mode_level_to_tsv( + prediction_df, + metrics, + split, + selection_metric, + data_group=data_group, + ) + + @torch.no_grad() + def _compute_output_tensors( + self, + maps_manager: MapsManager, + dataset, + data_group, + split, + selection_metrics, + nb_images=None, + gpu=None, + network=None, + ): + """ + Compute the output tensors and saves them in the MAPS. + + Args: +<<<<<<< HEAD +<<<<<<<< HEAD:clinicadl/predictor/old_predictor.py + dataset (clinicadl.dataset.caps_dataset.CapsDataset): wrapper of the data set. +======== + dataset (clinicadl.caps_dataset.data.CapsDataset): wrapper of the data set. +>>>>>>>> 1ae72275 (Cb extract validator (#666)):clinicadl/predictor/predictor.py +======= + dataset (clinicadl.dataset.caps_dataset.CapsDataset): wrapper of the data set. +>>>>>>> 109ee64a (Base for v2 (#676)) + data_group (str): name of the data group used for the task. + split (int): split number. + selection_metrics (list[str]): metrics used for model selection. + nb_images (int): number of full images to write. Default computes the outputs of the whole data set. + gpu (bool): If given, a new value for the device of the model will be computed. + network (int): Index of the network tested (only used in multi-network setting). + """ + for selection_metric in selection_metrics: + # load the best trained model during the training + model, _ = maps_manager._init_model( + transfer_path=maps_manager.maps_path, + split=split, + transfer_selection=selection_metric, + gpu=gpu, + network=network, + nb_unfrozen_layer=maps_manager.nb_unfrozen_layer, + ) + model = DDP( + model, + fsdp=maps_manager.fully_sharded_data_parallel, + amp=maps_manager.amp, + ) + model.eval() + + tensor_path = ( + maps_manager.maps_path + / f"split-{split}" + / f"best-{selection_metric}" + / data_group + / "tensors" + ) + if cluster.master: + tensor_path.mkdir(parents=True, exist_ok=True) + dist.barrier() + + if nb_images is None: # Compute outputs for the whole data set + nb_modes = len(dataset) + else: + nb_modes = nb_images * dataset.elem_per_image + + for i in [ + *range(cluster.rank, nb_modes, cluster.world_size), + *range(int(nb_modes % cluster.world_size <= cluster.rank)), + ]: + data = dataset[i] + image = data["image"] + x = image.unsqueeze(0).to(model.device) + with autocast("cuda", enabled=maps_manager.std_amp): + output = model(x) + output = output.squeeze(0).cpu().float() + participant_id = data["participant_id"] + session_id = data["session_id"] + mode_id = data[f"{maps_manager.mode}_id"] + input_filename = f"{participant_id}_{session_id}_{maps_manager.mode}-{mode_id}_input.pt" + output_filename = f"{participant_id}_{session_id}_{maps_manager.mode}-{mode_id}_output.pt" + torch.save(image, tensor_path / input_filename) + torch.save(output, tensor_path / output_filename) + logger.debug(f"File saved at {[input_filename, output_filename]}") + + def _ensemble_prediction( + self, + maps_manager: MapsManager, + data_group, + split, + selection_metrics, + use_labels=True, + skip_leak_check=False, + ): + """Computes the results on the image-level.""" + + if not selection_metrics: + selection_metrics = find_selection_metrics(maps_manager.maps_path, split) + + for selection_metric in selection_metrics: + ##################### + # Soft voting + if maps_manager.num_networks > 1 and not skip_leak_check: + maps_manager._ensemble_to_tsv( + split, + selection=selection_metric, + data_group=data_group, + use_labels=use_labels, + ) + elif maps_manager.mode != "image" and not skip_leak_check: + maps_manager._mode_to_image_tsv( + split, + selection=selection_metric, + data_group=data_group, + use_labels=use_labels, + ) diff --git a/clinicadl/predictor/predictor.py b/clinicadl/predictor/predictor.py new file mode 100644 index 000000000..157d49c8e --- /dev/null +++ b/clinicadl/predictor/predictor.py @@ -0,0 +1,13 @@ + +from clinicadl.dataset.caps_dataset import CapsDataset +from clinicadl.experiment_manager.experiment_manager import ExperimentManager + + +class Predictor: + def __init__(self, manager: ExperimentManager): + """TO COMPLETE""" + pass + + def predict(self, dataset_test: CapsDataset, split: int): + """TO COMPLETE""" + pass diff --git a/clinicadl/predict/utils.py b/clinicadl/predictor/utils.py similarity index 95% rename from clinicadl/predict/utils.py rename to clinicadl/predictor/utils.py index c66372764..6aea27e65 100644 --- a/clinicadl/predict/utils.py +++ b/clinicadl/predictor/utils.py @@ -3,7 +3,7 @@ import pandas as pd -from clinicadl.metrics.utils import check_selection_metric +from clinicadl.metrics.old_metrics.utils import check_selection_metric from clinicadl.splitter.split_utils import print_description_log from clinicadl.utils.exceptions import MAPSError diff --git a/clinicadl/splitter/validation.py b/clinicadl/predictor/validation.py similarity index 100% rename from clinicadl/splitter/validation.py rename to clinicadl/predictor/validation.py diff --git a/clinicadl/quality_check/pet_linear/quality_check.py b/clinicadl/quality_check/pet_linear/quality_check.py index 7c355b09c..d54eabac2 100644 --- a/clinicadl/quality_check/pet_linear/quality_check.py +++ b/clinicadl/quality_check/pet_linear/quality_check.py @@ -12,8 +12,8 @@ import pandas as pd from joblib import Parallel, delayed -from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.caps_dataset.preprocessing.utils import pet_linear_nii +from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.dataset.utils import pet_linear_nii from clinicadl.utils.iotools.clinica_utils import ( RemoteFileStructure, clinicadl_file_reader, diff --git a/clinicadl/quality_check/pet_linear/utils.py b/clinicadl/quality_check/pet_linear/utils.py index 1edba9e15..e27c60d2b 100644 --- a/clinicadl/quality_check/pet_linear/utils.py +++ b/clinicadl/quality_check/pet_linear/utils.py @@ -6,7 +6,7 @@ import numpy as np -from clinicadl.transforms.transforms import MinMaxNormalization +from clinicadl.transforms.factory import MinMaxNormalization def get_metric(contour_np, image_np, inside): diff --git a/clinicadl/quality_check/t1_linear/quality_check.py b/clinicadl/quality_check/t1_linear/quality_check.py index 7063c0c68..373f5228c 100755 --- a/clinicadl/quality_check/t1_linear/quality_check.py +++ b/clinicadl/quality_check/t1_linear/quality_check.py @@ -11,7 +11,7 @@ from torch.amp import autocast from torch.utils.data import DataLoader -from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig from clinicadl.generate.generate_utils import load_and_check_tsv from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.exceptions import ClinicaDLArgumentError diff --git a/clinicadl/quality_check/t1_linear/utils.py b/clinicadl/quality_check/t1_linear/utils.py index 20d4d5462..0ac67736c 100755 --- a/clinicadl/quality_check/t1_linear/utils.py +++ b/clinicadl/quality_check/t1_linear/utils.py @@ -8,9 +8,9 @@ import torch from torch.utils.data import Dataset -from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.caps_dataset.caps_dataset_utils import compute_folder_and_file_type -from clinicadl.caps_dataset.preprocessing.utils import linear_nii +from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.dataset.caps_dataset_utils import compute_folder_and_file_type +from clinicadl.dataset.utils import linear_nii from clinicadl.utils.enum import Preprocessing from clinicadl.utils.exceptions import ClinicaDLException from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader @@ -30,7 +30,7 @@ def __init__( data_df (DataFrame): Subject and session list. """ - from clinicadl.transforms.transforms import MinMaxNormalization + from clinicadl.transforms.factory import MinMaxNormalization self.img_dir = config.data.caps_directory self.df = config.data.data_df diff --git a/clinicadl/random_search/random_search.py b/clinicadl/random_search/random_search.py index 7929e9382..f38f248d2 100755 --- a/clinicadl/random_search/random_search.py +++ b/clinicadl/random_search/random_search.py @@ -4,7 +4,7 @@ from pathlib import Path -from clinicadl.trainer.trainer import Trainer +from clinicadl.trainer.old_trainer import Trainer from .random_search_config import RandomSearchConfig, create_training_config from .random_search_utils import get_space_dict, random_sampling diff --git a/clinicadl/random_search/random_search_config.py b/clinicadl/random_search/random_search_config.py index 2e1d728a9..3d6a65d2d 100644 --- a/clinicadl/random_search/random_search_config.py +++ b/clinicadl/random_search/random_search_config.py @@ -15,7 +15,7 @@ from clinicadl.utils.enum import Normalization, Pooling, Task if TYPE_CHECKING: - from clinicadl.trainer.trainer import TrainConfig + from clinicadl.trainer.old_trainer import TrainConfig class RandomSearchConfig( diff --git a/clinicadl/random_search/random_search_utils.py b/clinicadl/random_search/random_search_utils.py index ed164ea0c..f8f3bca9a 100644 --- a/clinicadl/random_search/random_search_utils.py +++ b/clinicadl/random_search/random_search_utils.py @@ -124,7 +124,6 @@ def random_sampling(rs_options: Dict[str, Any]) -> Dict[str, Any]: "mode": "fixed", "multi_cohort": "fixed", "multi_network": "choice", - "ssda_netork": "fixed", "n_fcblocks": "randint", "n_splits": "fixed", "n_proc": "fixed", diff --git a/clinicadl/resources/config/train_config.toml b/clinicadl/resources/config/train_config.toml index f4f2afe30..9e5f54657 100644 --- a/clinicadl/resources/config/train_config.toml +++ b/clinicadl/resources/config/train_config.toml @@ -4,7 +4,6 @@ [Model] architecture = "default" # ex : Conv5_FC3 multi_network = false -ssda_network = false [Architecture] # CNN diff --git a/clinicadl/splitter/config.py b/clinicadl/splitter/config.py index 53413fdda..da4e32707 100644 --- a/clinicadl/splitter/config.py +++ b/clinicadl/splitter/config.py @@ -6,9 +6,9 @@ from pydantic import BaseModel, ConfigDict, field_validator from pydantic.types import NonNegativeInt -from clinicadl.caps_dataset.data_config import DataConfig +from clinicadl.dataset.data_config import DataConfig +from clinicadl.predictor.validation import ValidationConfig from clinicadl.splitter.split_utils import find_splits -from clinicadl.splitter.validation import ValidationConfig logger = getLogger("clinicadl.split_config") diff --git a/clinicadl/splitter/kfold.py b/clinicadl/splitter/kfold.py new file mode 100644 index 000000000..37805dfba --- /dev/null +++ b/clinicadl/splitter/kfold.py @@ -0,0 +1,24 @@ +from typing import Optional + +from clinicadl.dataset.caps_dataset import CapsDataset +from clinicadl.experiment_manager.experiment_manager import ExperimentManager + + +class Split: + def __init__( + self, + ): + """TO COMPLETE""" + pass + + +class KFolder: + def __init__( + self, n_splits: int, caps_dataset: CapsDataset, manager: ExperimentManager + ) -> None: + """TO COMPLETE""" + + def split_iterator(self, split_list: Optional[list] = None) -> list[Split]: + """TO COMPLETE""" + + return list[Split()] diff --git a/clinicadl/splitter/splitter.py b/clinicadl/splitter/old_splitter.py similarity index 86% rename from clinicadl/splitter/splitter.py rename to clinicadl/splitter/old_splitter.py index 3bbdde461..d39b14a5b 100644 --- a/clinicadl/splitter/splitter.py +++ b/clinicadl/splitter/old_splitter.py @@ -1,4 +1,5 @@ import abc +import shutil from logging import getLogger from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -6,6 +7,8 @@ import pandas as pd from clinicadl.splitter.config import SplitterConfig +from clinicadl.utils import cluster +from clinicadl.utils.exceptions import MAPSError logger = getLogger("clinicadl.split_manager") @@ -14,7 +17,7 @@ class Splitter: def __init__( self, config: SplitterConfig, - split_list: Optional[List[int]] = None, + # split_list: Optional[List[int]] = None, ): """_summary_ @@ -29,19 +32,19 @@ def __init__( """ self.config = config - self.split_list = split_list + # self.config.split.split = split_list - self.caps_dict = self.config.data.caps_dict # TODO : check if useful ? + # self.caps_dict = self.config.data.caps_dict # TODO : check if useful ? def max_length(self) -> int: """Maximum number of splits""" return self.config.split.n_splits def __len__(self): - if not self.split_list: + if not self.config.split.split: return self.config.split.n_splits else: - return len(self.split_list) + return len(self.config.split.split) @property def allowed_splits_list(self): @@ -203,13 +206,32 @@ def _get_tsv_paths(self, cohort_path, *args) -> Tuple[Path, Path]: def split_iterator(self): """Returns an iterable to iterate on all splits wanted.""" - if not self.split_list: + + if not self.config.split.split: return range(self.config.split.n_splits) else: - return self.split_list + return self.config.split.split def _check_item(self, item): if item not in self.allowed_splits_list: raise IndexError( f"Split index {item} out of allowed splits {self.allowed_splits_list}." ) + + def check_split_list(self, maps_path, overwrite): + existing_splits = [] + for split in self.split_iterator(): + split_path = maps_path / f"split-{split}" + if split_path.is_dir(): + if overwrite: + if cluster.master: + shutil.rmtree(split_path) + else: + existing_splits.append(split) + + if len(existing_splits) > 0: + raise MAPSError( + f"Splits {existing_splits} already exist. Please " + f"specify a list of splits not intersecting the previous list, " + f"or use overwrite to erase previously trained splits." + ) diff --git a/clinicadl/splitter/split.py b/clinicadl/splitter/split.py new file mode 100644 index 000000000..72c4f9d82 --- /dev/null +++ b/clinicadl/splitter/split.py @@ -0,0 +1,18 @@ +from pathlib import Path + +from clinicadl.dataset.caps_dataset import CapsDataset +from clinicadl.experiment_manager.experiment_manager import ExperimentManager +from clinicadl.splitter.kfold import Split + + +def split_tsv(sub_ses_tsv: Path) -> Path: + """TO COMPLETE""" + + split_dir = Path("") + return split_dir + + +def get_single_split( + n_subject_validation: int, caps_dataset: CapsDataset, manager: ExperimentManager +) -> Split: + pass diff --git a/clinicadl/maps_manager/tmp_config.py b/clinicadl/tmp_config.py similarity index 92% rename from clinicadl/maps_manager/tmp_config.py rename to clinicadl/tmp_config.py index a31af7edb..54a791b1e 100644 --- a/clinicadl/maps_manager/tmp_config.py +++ b/clinicadl/tmp_config.py @@ -19,8 +19,8 @@ ) from typing_extensions import Self -from clinicadl.caps_dataset.data import return_dataset -from clinicadl.metrics.metric_module import MetricModule +from clinicadl.dataset.caps_dataset import return_dataset +from clinicadl.metrics.old_metrics.metric_module import MetricModule from clinicadl.splitter.split_utils import find_splits from clinicadl.trainer.tasks_utils import ( evaluation_metrics, @@ -28,7 +28,7 @@ get_default_network, output_size, ) -from clinicadl.transforms import transforms +from clinicadl.transforms import factory from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.enum import ( Compensation, @@ -58,6 +58,7 @@ class TmpConfig(BaseModel): arguments needed : caps_directory, maps_path, loss """ + # ??? output_size: Optional[int] = None n_classes: Optional[int] = None network_task: Optional[str] = None @@ -70,18 +71,21 @@ class TmpConfig(BaseModel): std_amp: Optional[bool] = None preprocessing_dict: Optional[dict] = None + # CALLBACKS emissions_calculator: bool = False track_exp: Optional[ExperimentTracking] = None + # COMPUTATIONAL amp: bool = False fully_sharded_data_parallel: bool = False gpu: bool = True + # SPLIT n_splits: NonNegativeInt = 0 split: Optional[Tuple[NonNegativeInt, ...]] = None tsv_path: Optional[Path] = None # not needed in predict ? - # DataConfig + # DATA caps_directory: Path baseline: bool = False diagnoses: Tuple[str, ...] = ("AD", "CN") @@ -94,55 +98,68 @@ class TmpConfig(BaseModel): data_tsv: Optional[Path] = None n_subjects: int = 300 + # DATALOADER batch_size: PositiveInt = 8 n_proc: PositiveInt = 2 sampler: Sampler = Sampler.RANDOM + # EARLY STOPPING patience: NonNegativeInt = 0 tolerance: NonNegativeFloat = 0.0 + patience_epochs: NonNegativeInt = 0 + # LEARNING RATE adaptive_learning_rate: bool = False + # MAPS MANAGER maps_path: Path data_group: Optional[str] = None overwrite: bool = False save_nifti: bool = False + # NETWORK architecture: str = "default" dropout: NonNegativeFloat = 0.0 loss: str multi_network: bool = False + # OPTIMIZATION accumulation_steps: PositiveInt = 1 epochs: PositiveInt = 20 profiler: bool = False + # OPTIMIZER learning_rate: PositiveFloat = 1e-4 optimizer: Optimizer = Optimizer.ADAM weight_decay: NonNegativeFloat = 1e-4 + # REPRODUCIBILITY compensation: Compensation = Compensation.MEMORY deterministic: bool = False save_all_models: bool = False seed: int = 0 config_file: Optional[Path] = None + # SSDA caps_target: Path = Path("") preprocessing_json_target: Path = Path("") ssda_network: bool = False tsv_target_lab: Path = Path("") tsv_target_unlab: Path = Path("") + # TRANSFER LEARNING nb_unfrozen_layer: NonNegativeInt = 0 transfer_path: Optional[Path] = None transfer_selection_metric: str = "loss" + # TRANSFORMS data_augmentation: Tuple[Transform, ...] = () train_transformations: Optional[Tuple[Transform, ...]] = None normalize: bool = True size_reduction: bool = False size_reduction_factor: SizeReductionFactor = SizeReductionFactor.TWO + # VALIDATION evaluation_steps: NonNegativeInt = 0 selection_metrics: Tuple[str, ...] = () valid_longitudinal: bool = False @@ -282,7 +299,7 @@ def adapt_cross_val_with_maps_manager_info( ): # maps_manager is of type MapsManager but need to be in a MapsConfig type in the future # TEMPORARY if not self.split: - self.split = find_splits(maps_manager.maps_path, maps_manager.split_name) + self.split = find_splits(maps_manager.maps_path) logger.debug(f"List of splits {self.split}") def create_groupe_df(self): @@ -363,7 +380,7 @@ def check_preprocessing_dict(self) -> Self: ValueError In case of multi-cohort dataset, if no preprocessing file is found in any CAPS. """ - from clinicadl.caps_dataset.data import CapsDataset + from clinicadl.dataset.caps_dataset import CapsDataset if self.preprocessing_dict is None: if self.preprocessing_json is not None: @@ -467,16 +484,16 @@ def get_transforms( transforms to apply in train and evaluation mode / transforms to apply in evaluation mode only. """ augmentation_dict = { - "Noise": transforms.RandomNoising(sigma=0.1), + "Noise": factory.RandomNoising(sigma=0.1), "Erasing": torch_transforms.RandomErasing(), - "CropPad": transforms.RandomCropPad(10), - "Smoothing": transforms.RandomSmoothing(), - "Motion": transforms.RandomMotion((2, 4), (2, 4), 2), - "Ghosting": transforms.RandomGhosting((4, 10)), - "Spike": transforms.RandomSpike(1, (1, 3)), - "BiasField": transforms.RandomBiasField(0.5), - "RandomBlur": transforms.RandomBlur((0, 2)), - "RandomSwap": transforms.RandomSwap(15, 100), + "CropPad": factory.RandomCropPad(10), + "Smoothing": factory.RandomSmoothing(), + "Motion": factory.RandomMotion((2, 4), (2, 4), 2), + "Ghosting": factory.RandomGhosting((4, 10)), + "Spike": factory.RandomSpike(1, (1, 3)), + "BiasField": factory.RandomBiasField(0.5), + "RandomBlur": factory.RandomBlur((0, 2)), + "RandomSwap": factory.RandomSwap(15, 100), "None": None, } @@ -491,12 +508,12 @@ def get_transforms( ] ) - transformations_list.append(transforms.NanRemoval()) + transformations_list.append(factory.NanRemoval()) if self.normalize: - transformations_list.append(transforms.MinMaxNormalization()) + transformations_list.append(factory.MinMaxNormalization()) if self.size_reduction: transformations_list.append( - transforms.SizeReduction(self.size_reduction_factor) + factory.SizeReduction(self.size_reduction_factor) ) all_transformations = torch_transforms.Compose(transformations_list) diff --git a/clinicadl/trainer/config/classification.py b/clinicadl/trainer/config/classification.py index 5e71d032e..25a8d7f6b 100644 --- a/clinicadl/trainer/config/classification.py +++ b/clinicadl/trainer/config/classification.py @@ -3,9 +3,9 @@ from pydantic import computed_field, field_validator -from clinicadl.caps_dataset.data_config import DataConfig as BaseDataConfig -from clinicadl.network.config import NetworkConfig as BaseNetworkConfig -from clinicadl.splitter.validation import ValidationConfig as BaseValidationConfig +from clinicadl.dataset.data_config import DataConfig as BaseDataConfig +from clinicadl.networks.old_network.config import NetworkConfig as BaseNetworkConfig +from clinicadl.predictor.validation import ValidationConfig as BaseValidationConfig from clinicadl.trainer.config.train import TrainConfig from clinicadl.utils.enum import ClassificationLoss, ClassificationMetric, Task diff --git a/clinicadl/trainer/config/reconstruction.py b/clinicadl/trainer/config/reconstruction.py index bf39886d4..8a1dd825e 100644 --- a/clinicadl/trainer/config/reconstruction.py +++ b/clinicadl/trainer/config/reconstruction.py @@ -3,8 +3,8 @@ from pydantic import PositiveFloat, PositiveInt, computed_field, field_validator -from clinicadl.network.config import NetworkConfig as BaseNetworkConfig -from clinicadl.splitter.validation import ValidationConfig as BaseValidationConfig +from clinicadl.networks.old_network.config import NetworkConfig as BaseNetworkConfig +from clinicadl.predictor.validation import ValidationConfig as BaseValidationConfig from clinicadl.trainer.config.train import TrainConfig from clinicadl.utils.enum import ( Normalization, diff --git a/clinicadl/trainer/config/regression.py b/clinicadl/trainer/config/regression.py index 37e690f01..7504138d8 100644 --- a/clinicadl/trainer/config/regression.py +++ b/clinicadl/trainer/config/regression.py @@ -3,9 +3,9 @@ from pydantic import computed_field, field_validator -from clinicadl.caps_dataset.data_config import DataConfig as BaseDataConfig -from clinicadl.network.config import NetworkConfig as BaseNetworkConfig -from clinicadl.splitter.validation import ValidationConfig as BaseValidationConfig +from clinicadl.dataset.data_config import DataConfig as BaseDataConfig +from clinicadl.networks.old_network.config import NetworkConfig as BaseNetworkConfig +from clinicadl.predictor.validation import ValidationConfig as BaseValidationConfig from clinicadl.trainer.config.train import TrainConfig from clinicadl.utils.enum import RegressionLoss, RegressionMetric, Task diff --git a/clinicadl/trainer/config/train.py b/clinicadl/trainer/config/train.py index c44febe6b..96e1b2081 100644 --- a/clinicadl/trainer/config/train.py +++ b/clinicadl/trainer/config/train.py @@ -10,17 +10,16 @@ ) from clinicadl.callbacks.config import CallbacksConfig -from clinicadl.caps_dataset.data_config import DataConfig -from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig from clinicadl.config.config.lr_scheduler import LRschedulerConfig from clinicadl.config.config.reproducibility import ReproducibilityConfig -from clinicadl.config.config.ssda import SSDAConfig -from clinicadl.maps_manager.config import MapsManagerConfig -from clinicadl.network.config import NetworkConfig -from clinicadl.optimizer.optimization import OptimizationConfig -from clinicadl.optimizer.optimizer import OptimizerConfig +from clinicadl.dataset.data_config import DataConfig +from clinicadl.dataset.dataloader_config import DataLoaderConfig +from clinicadl.experiment_manager.config import MapsManagerConfig +from clinicadl.networks.old_network.config import NetworkConfig +from clinicadl.optimization.config import OptimizationConfig +from clinicadl.optimization.optimizer.config import OptimizerConfig +from clinicadl.predictor.validation import ValidationConfig from clinicadl.splitter.config import SplitConfig -from clinicadl.splitter.validation import ValidationConfig from clinicadl.trainer.transfer_learning import TransferLearningConfig from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.computational.computational import ComputationalConfig @@ -47,10 +46,8 @@ class TrainConfig(BaseModel, ABC): maps_manager: MapsManagerConfig model: NetworkConfig optimization: OptimizationConfig - optimizer: OptimizerConfig reproducibility: ReproducibilityConfig split: SplitConfig - ssda: SSDAConfig transfer_learning: TransferLearningConfig transforms: TransformsConfig validation: ValidationConfig @@ -74,10 +71,8 @@ def __init__(self, **kwargs): maps_manager=kwargs, model=kwargs, optimization=kwargs, - optimizer=kwargs, reproducibility=kwargs, split=kwargs, - ssda=kwargs, transfer_learning=kwargs, transforms=kwargs, validation=kwargs, @@ -94,10 +89,8 @@ def _update(self, config_dict: Dict[str, Any]) -> None: self.maps_manager.__dict__.update(config_dict) self.model.__dict__.update(config_dict) self.optimization.__dict__.update(config_dict) - self.optimizer.__dict__.update(config_dict) self.reproducibility.__dict__.update(config_dict) self.split.__dict__.update(config_dict) - self.ssda.__dict__.update(config_dict) self.transfer_learning.__dict__.update(config_dict) self.transforms.__dict__.update(config_dict) self.validation.__dict__.update(config_dict) diff --git a/clinicadl/trainer/old_trainer.py b/clinicadl/trainer/old_trainer.py new file mode 100644 index 000000000..fb798d788 --- /dev/null +++ b/clinicadl/trainer/old_trainer.py @@ -0,0 +1,901 @@ +from __future__ import annotations # noqa: I001 + + +from contextlib import nullcontext +from datetime import datetime +from logging import getLogger +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Callable + +import pandas as pd +import torch +import torch.distributed as dist +from torch.amp.grad_scaler import GradScaler +from torch.amp.autocast_mode import autocast +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from clinicadl.splitter.split_utils import find_finished_splits, find_stopped_splits +from clinicadl.dataset.caps_dataset import return_dataset +from clinicadl.utils.early_stopping.early_stopping import EarlyStopping +from clinicadl.utils.exceptions import MAPSError +from clinicadl.utils.computational.ddp import DDP +from clinicadl.utils import cluster +from clinicadl.utils.logwriter import LogWriter +from clinicadl.dataset.caps_dataset_utils import read_json +from clinicadl.metrics.old_metrics.metric_module import RetainBest +from clinicadl.utils.seed import pl_worker_init_function, seed_everything +from clinicadl.experiment_manager.maps_manager import MapsManager +from clinicadl.utils.seed import get_seed +from clinicadl.utils.enum import Task +from clinicadl.utils.iotools.trainer_utils import ( + create_parameters_dict, + patch_to_read_json, +) +from clinicadl.trainer.tasks_utils import create_training_config +from clinicadl.predictor.old_predictor import Predictor +from clinicadl.predictor.config import PredictConfig +from clinicadl.splitter.old_splitter import Splitter +from clinicadl.splitter.config import SplitterConfig +from clinicadl.transforms.config import TransformsConfig + +if TYPE_CHECKING: + from clinicadl.callbacks.callbacks import Callback + from clinicadl.trainer.config.train import TrainConfig + +from clinicadl.trainer.tasks_utils import ( + evaluation_metrics, + generate_sampler, + get_criterion, + save_outputs, +) + +logger = getLogger("clinicadl.trainer") + + +class Trainer: + """Temporary Trainer extracted from the MAPSManager.""" + + def __init__( + self, + config: TrainConfig, + ) -> None: + """ + Parameters + ---------- + config : TrainConfig + """ + self.config = config + + self.maps_manager = self._init_maps_manager(config) + predict_config = PredictConfig(**config.get_dict()) + self.validator = Predictor(predict_config) + + # test + splitter_config = SplitterConfig(**self.config.get_dict()) + self.splitter = Splitter(splitter_config) + self._check_args() + + def _init_maps_manager(self, config) -> MapsManager: + # temporary: to match CLI data. TODO : change CLI data + + parameters, maps_path = create_parameters_dict(config) + + if maps_path.is_dir(): + return MapsManager( + maps_path, verbose=None + ) # TODO : precise which parameters in config are useful + else: + # parameters["maps_path"] = maps_path + return MapsManager( + maps_path, parameters, verbose=None + ) # TODO : precise which parameters in config are useful + + @classmethod + def from_json( + cls, + config_file: str | Path, + maps_path: str | Path, + split: Optional[list[int]] = None, + ) -> Trainer: + """ + Creates a Trainer from a json configuration file. + + Parameters + ---------- + config_file : str | Path + The parameters, stored in a json files. + maps_path : str | Path + The folder where the results of a futur training will be stored. + + Returns + ------- + Trainer + The Trainer object, instantiated with parameters found in config_file. + + Raises + ------ + FileNotFoundError + If config_file doesn't exist. + """ + config_file = Path(config_file) + + if not (config_file).is_file(): + raise FileNotFoundError(f"No file found at {str(config_file)}.") + config_dict = patch_to_read_json(read_json(config_file)) # TODO : remove patch + config_dict["maps_dir"] = maps_path + config_dict["split"] = split if split else () + config_object = create_training_config(config_dict["network_task"])( + **config_dict + ) + return cls(config_object) + + @classmethod + def from_maps(cls, maps_path: str | Path) -> Trainer: + """ + Creates a Trainer from a json configuration file. + + Parameters + ---------- + maps_path : str | Path + The path of the MAPS folder. + + Returns + ------- + Trainer + The Trainer object, instantiated with parameters found in maps_path. + + Raises + ------ + MAPSError + If maps_path folder doesn't exist or there is no maps.json file in it. + """ + maps_path = Path(maps_path) + + if not (maps_path / "maps.json").is_file(): + raise MAPSError( + f"MAPS was not found at {str(maps_path)}." + f"To initiate a new MAPS please give a train_dict." + ) + return cls.from_json(maps_path / "maps.json", maps_path) + + def resume(self) -> None: + """ + Resume a prematurely stopped training. + + Parameters + ---------- + splits : List[int] + The splits that must be resumed. + """ + stopped_splits = set(find_stopped_splits(self.config.maps_manager.maps_dir)) + finished_splits = set(find_finished_splits(self.config.maps_manager.maps_dir)) + # TODO : check these two lines. Why do we need a self.splitter? + + splitter_config = SplitterConfig(**self.config.get_dict()) + self.splitter = Splitter(splitter_config) + + split_iterator = self.splitter.split_iterator() + ### + absent_splits = set(split_iterator) - stopped_splits - finished_splits + + logger.info( + f"Finished splits {finished_splits}\n" + f"Stopped splits {stopped_splits}\n" + f"Absent splits {absent_splits}" + ) + + if len(stopped_splits) == 0 and len(absent_splits) == 0: + raise ValueError( + "Training has been completed on all the splits you passed." + ) + if len(stopped_splits) > 0: + self._resume(list(stopped_splits)) + if len(absent_splits) > 0: + self.train(list(absent_splits), overwrite=True) + + def _check_args(self): + self.config.reproducibility.seed = get_seed(self.config.reproducibility.seed) + # if len(self.config.data.label_code) == 0: + # self.config.data.label_code = self.maps_manager.label_code + # TODO: deal with label_code and replace self.maps_manager.label_code + from clinicadl.trainer.tasks_utils import generate_label_code + + if ( + "label_code" not in self.config.data.model_dump() + or len(self.config.data.label_code) == 0 + or self.config.data.label_code is None + ): # Allows to set custom label code in TOML + train_df = self.splitter[0]["train"] + self.config.data.label_code = generate_label_code( + self.config.network_task, train_df, self.config.data.label + ) + + def train( + self, + split_list: Optional[List[int]] = None, + overwrite: bool = False, + ) -> None: + """ + Performs the training task for a defined list of splits. + + Parameters + ---------- + split_list : Optional[List[int]] (optional, default=None) + List of splits on which the training task is performed. + Default trains all splits of the cross-validation. + overwrite : bool (optional, default=False) + If True, previously trained splits that are going to be trained + are erased. + + Raises + ------ + MAPSError + If splits specified in input already exist and overwrite is False. + """ + + # splitter_config = SplitterConfig(**self.config.get_dict()) + # self.splitter = Splitter(splitter_config) + # self.splitter.check_split_list(self.config.maps_manager.maps_dir, self.config.maps_manager.overwrite) + self.splitter.check_split_list( + self.config.maps_manager.maps_dir, + overwrite, # overwrite change so careful it is not the maps manager overwrite parameters here + ) + for split in self.splitter.split_iterator(): + logger.info(f"Training split {split}") + seed_everything( + self.config.reproducibility.seed, + self.config.reproducibility.deterministic, + self.config.reproducibility.compensation, + ) + + split_df_dict = self.splitter[split] + + if self.config.model.multi_network: + resume, first_network = self.init_first_network(False, split) + for network in range(first_network, self.maps_manager.num_networks): + self._train_single( + split, split_df_dict, network=network, resume=resume + ) + else: + self._train_single(split, split_df_dict, resume=False) + + # def check_split_list(self, split_list, overwrite): + # existing_splits = [] + # splitter_config = SplitterConfig(**self.config.get_dict()) + # self.splitter = Splitter(splitter_config) + # for split in self.splitter.split_iterator(): + # split_path = self.maps_manager.maps_path / f"split-{split}" + # if split_path.is_dir(): + # if overwrite: + # if cluster.master: + # shutil.rmtree(split_path) + # else: + # existing_splits.append(split) + + # if len(existing_splits) > 0: + # raise MAPSError( + # f"Splits {existing_splits} already exist. Please " + # f"specify a list of splits not intersecting the previous list, " + # f"or use overwrite to erase previously trained splits." + # ) + + def _resume( + self, + split_list: Optional[List[int]] = None, + ) -> None: + """ + Resumes the training task for a defined list of splits. + + Parameters + ---------- + split_list : Optional[List[int]] (optional, default=None) + List of splits on which the training task is performed. + If None, the training task is performed on all splits. + + Raises + ------ + MAPSError + If splits specified in input do not exist. + """ + missing_splits = [] + splitter_config = SplitterConfig(**self.config.get_dict()) + self.splitter = Splitter(splitter_config) + for split in self.splitter.split_iterator(): + if not (self.maps_manager.maps_path / f"split-{split}" / "tmp").is_dir(): + missing_splits.append(split) + + if len(missing_splits) > 0: + raise MAPSError( + f"Splits {missing_splits} were not initialized. " + f"Please try train command on these splits and resume only others." + ) + + for split in self.splitter.split_iterator(): + logger.info(f"Training split {split}") + seed_everything( + self.config.reproducibility.seed, + self.config.reproducibility.deterministic, + self.config.reproducibility.compensation, + ) + + split_df_dict = self.splitter[split] + if self.config.model.multi_network: + resume, first_network = self.init_first_network(True, split) + for network in range(first_network, self.maps_manager.num_networks): + self._train_single( + split, split_df_dict, network=network, resume=resume + ) + else: + self._train_single(split, split_df_dict, resume=True) + + def init_first_network(self, resume: bool, split: int): + first_network = 0 + if resume: + training_logs = [ + int(str(network_folder).split("-")[1]) + for network_folder in list( + ( + self.maps_manager.maps_path / f"split-{split}" / "training_logs" + ).iterdir() + ) + ] + first_network = max(training_logs) + if not (self.maps_manager.maps_path / "tmp").is_dir(): + first_network += 1 + resume = False + return resume, first_network + + def get_dataloader( + self, + data_df: pd.DataFrame, + cnn_index: Optional[int] = None, + sampler_option: str = "random", + dp_degree: Optional[int] = None, + rank: Optional[int] = None, + worker_init_fn: Optional[Callable[[int], None]] = None, + shuffle: Optional[bool] = None, + num_replicas: Optional[int] = None, + homemade_sampler: bool = False, + ): + dataset = return_dataset( + input_dir=self.config.data.caps_directory, + data_df=data_df, + preprocessing_dict=self.config.data.preprocessing_dict, + transforms_config=self.config.transforms, + multi_cohort=self.config.data.multi_cohort, + label=self.config.data.label, + label_code=self.config.data.label_code, + cnn_index=cnn_index, + ) + if homemade_sampler: + sampler = generate_sampler( + network_task=self.maps_manager.network_task, + dataset=dataset, + sampler_option=sampler_option, + label_code=self.config.data.label_code, + dp_degree=dp_degree, + rank=rank, + ) + else: + sampler = DistributedSampler( + dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + ) + + train_loader = DataLoader( + dataset=dataset, + batch_size=self.config.dataloader.batch_size, + sampler=sampler, + num_workers=self.config.dataloader.n_proc, + worker_init_fn=worker_init_fn, + shuffle=shuffle, + ) + logger.debug(f"Train loader size is {len(train_loader)}") + + return train_loader + + def _train_single( + self, + split, + split_df_dict: Dict, + network: Optional[int] = None, + resume: bool = False, + ) -> None: + """ + Trains a single CNN for all inputs. + + Parameters + ---------- + split_list : Optional[List[int]] (optional, default=None) + List of splits on which the training task is performed. + If None, performs training on all splits of the cross-validation. + resume : bool (optional, default=False) + If True, the job is resumed from checkpoint. + """ + + logger.debug("Loading training data...") + + train_loader = self.get_dataloader( + data_df=split_df_dict["train"], + cnn_index=network, + sampler_option=self.config.dataloader.sampler, + dp_degree=cluster.world_size, # type: ignore + rank=cluster.rank, # type: ignore + worker_init_fn=pl_worker_init_function, + homemade_sampler=True, + ) + + logger.debug(f"Train loader size is {len(train_loader)}") + logger.debug("Loading validation data...") + + valid_loader = self.get_dataloader( + data_df=split_df_dict["validation"], + cnn_index=network, + num_replicas=cluster.world_size, # type: ignore + rank=cluster.rank, # type: ignore + shuffle=False, + homemade_sampler=False, + ) + + logger.debug(f"Validation loader size is {len(valid_loader)}") + from clinicadl.callbacks.callbacks import CodeCarbonTracker + + self._train( + train_loader, + valid_loader, + split, + resume=resume, + callbacks=[CodeCarbonTracker], + network=network, + ) + + if network is not None: + resume = False + + if cluster.master: + self.validator._ensemble_prediction( + self.maps_manager, + "train", + split, + self.config.validation.selection_metrics, + ) + self.validator._ensemble_prediction( + self.maps_manager, + "validation", + split, + self.config.validation.selection_metrics, + ) + + self.maps_manager._erase_tmp(split) + + def _train( + self, + train_loader: DataLoader, + valid_loader: DataLoader, + split: int, + network: Optional[int] = None, + resume: bool = False, + callbacks: list[Callback] = [], + ): + """ + Core function shared by train and resume. + + Parameters + ---------- + train_loader : torch.utils.data.DataLoader + DataLoader wrapping the training set. + valid_loader : torch.utils.data.DataLoader + DataLoader wrapping the validation set. + split : int + Index of the split trained. + network : int (optional, default=None) + Index of the network trained (used in multi-network setting only). + resume : bool (optional, default=False) + If True the job is resumed from the checkpoint. + callbacks : List[Callback] (optional, default=[]) + List of callbacks to call during training. + + Raises + ------ + Exception + _description_ + """ + self._init_callbacks() + model, beginning_epoch = self.maps_manager._init_model( + split=split, + resume=resume, + transfer_path=self.config.transfer_learning.transfer_path, + transfer_selection=self.config.transfer_learning.transfer_selection_metric, + nb_unfrozen_layer=self.config.transfer_learning.nb_unfrozen_layer, + ) + model = DDP( + model, + fsdp=self.config.computational.fully_sharded_data_parallel, + amp=self.config.computational.amp, + ) + criterion = get_criterion( + self.maps_manager.network_task, self.config.model.loss + ) + + optimizer = self._init_optimizer(model, split=split, resume=resume) + self.callback_handler.on_train_begin( + self.maps_manager.parameters, + criterion=criterion, + optimizer=optimizer, + split=split, + maps_path=self.maps_manager.maps_path, + ) + + model.train() + train_loader.dataset.train() + + early_stopping = EarlyStopping( + "min", + min_delta=self.config.early_stopping.tolerance, + patience=self.config.early_stopping.patience, + ) + metrics_valid = {"loss": None} + + if cluster.master: + log_writer = LogWriter( + self.maps_manager.maps_path, + evaluation_metrics(self.maps_manager.network_task) + ["loss"], + split, + resume=resume, + beginning_epoch=beginning_epoch, + network=network, + ) + # retain_best = RetainBest( + # selection_metrics=list(self.config.validation.selection_metrics) + # ) ??? + + epoch = beginning_epoch + + retain_best = RetainBest( + selection_metrics=list(self.config.validation.selection_metrics) + ) + + scaler = GradScaler("cuda", enabled=self.config.computational.amp) + profiler = self._init_profiler() + + if self.config.callbacks.track_exp == "wandb": + from clinicadl.callbacks.tracking_exp import WandB_handler + + if self.config.lr_scheduler.adaptive_learning_rate: + from torch.optim.lr_scheduler import ReduceLROnPlateau + + # Initialize the ReduceLROnPlateau scheduler + scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.1) + + while epoch < self.config.optimization.epochs and not early_stopping.step( + metrics_valid["loss"] + ): + # self.callback_handler.on_epoch_begin(self.parameters, epoch = epoch) + + if isinstance(train_loader.sampler, DistributedSampler): + # It should always be true for a random sampler. But just in case + # we get a WeightedRandomSampler or a forgotten RandomSampler, + # we do not want to execute this line. + train_loader.sampler.set_epoch(epoch) + + model.zero_grad(set_to_none=True) + evaluation_flag, step_flag = True, True + + with profiler: + for i, data in enumerate(train_loader): + update: bool = ( + i + 1 + ) % self.config.optimization.accumulation_steps == 0 + sync = nullcontext() if update else model.no_sync() + with sync: + with autocast("cuda", enabled=self.maps_manager.std_amp): + _, loss_dict = model(data, criterion) + logger.debug(f"Train loss dictionary {loss_dict}") + loss = loss_dict["loss"] + scaler.scale(loss).backward() + + if update: + step_flag = False + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + + del loss + + # Evaluate the model only when no gradients are accumulated + if ( + self.config.validation.evaluation_steps != 0 + and (i + 1) % self.config.validation.evaluation_steps == 0 + ): + evaluation_flag = False + + _, metrics_train = self.validator.test( + mode=self.maps_manager.mode, + metrics_module=self.maps_manager.metrics_module, + n_classes=self.maps_manager.n_classes, + network_task=self.maps_manager.network_task, + model=model, + dataloader=train_loader, + criterion=criterion, + amp=self.maps_manager.std_amp, + ) + _, metrics_valid = self.validator.test( + mode=self.maps_manager.mode, + metrics_module=self.maps_manager.metrics_module, + n_classes=self.maps_manager.n_classes, + network_task=self.maps_manager.network_task, + model=model, + dataloader=valid_loader, + criterion=criterion, + amp=self.maps_manager.std_amp, + ) + + model.train() + train_loader.dataset.train() + + if cluster.master: + log_writer.step( + epoch, + i, + metrics_train, + metrics_valid, + len(train_loader), + ) + logger.info( + f"{self.config.data.mode} level training loss is {metrics_train['loss']} " + f"at the end of iteration {i}" + ) + logger.info( + f"{self.config.data.mode} level validation loss is {metrics_valid['loss']} " + f"at the end of iteration {i}" + ) + + profiler.step() + + # If no step has been performed, raise Exception + if step_flag: + raise Exception( + "The model has not been updated once in the epoch. The accumulation step may be too large." + ) + + # If no evaluation has been performed, warn the user + elif evaluation_flag and self.config.validation.evaluation_steps != 0: + logger.warning( + f"Your evaluation steps {self.config.validation.evaluation_steps} are too big " + f"compared to the size of the dataset. " + f"The model is evaluated only once at the end epochs." + ) + + # Update weights one last time if gradients were computed without update + if (i + 1) % self.config.optimization.accumulation_steps != 0: + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + + # Always test the results and save them once at the end of the epoch + model.zero_grad(set_to_none=True) + logger.debug(f"Last checkpoint at the end of the epoch {epoch}") + + _, metrics_train = self.validator.test( + mode=self.maps_manager.mode, + metrics_module=self.maps_manager.metrics_module, + n_classes=self.maps_manager.n_classes, + network_task=self.maps_manager.network_task, + model=model, + dataloader=train_loader, + criterion=criterion, + amp=self.maps_manager.std_amp, + ) + _, metrics_valid = self.validator.test( + mode=self.maps_manager.mode, + metrics_module=self.maps_manager.metrics_module, + n_classes=self.maps_manager.n_classes, + network_task=self.maps_manager.network_task, + model=model, + dataloader=valid_loader, + criterion=criterion, + amp=self.maps_manager.std_amp, + ) + + model.train() + train_loader.dataset.train() + + self.callback_handler.on_epoch_end( + self.maps_manager.parameters, + metrics_train=metrics_train, + metrics_valid=metrics_valid, + mode=self.config.data.mode, + i=i, + ) + + model_weights = { + "model": model.state_dict(), + "epoch": epoch, + "name": self.config.model.architecture, + } + optimizer_weights = { + "optimizer": model.optim_state_dict(optimizer), + "epoch": epoch, + "name": self.config.model.architecture, + } + + if cluster.master: + # Save checkpoints and best models + best_dict = retain_best.step(metrics_valid) + self.maps_manager._write_weights( + model_weights, + best_dict, + split, + network=network, + save_all_models=self.config.reproducibility.save_all_models, + ) + self.maps_manager._write_weights( + optimizer_weights, + None, + split, + filename="optimizer.pth.tar", + save_all_models=self.config.reproducibility.save_all_models, + ) + dist.barrier() + + if self.config.lr_scheduler.adaptive_learning_rate: + scheduler.step( + metrics_valid["loss"] + ) # Update learning rate based on validation loss + + epoch += 1 + + del model + self.validator._test_loader( + self.maps_manager, + train_loader, + criterion, + "train", + split, + self.config.validation.selection_metrics, + amp=self.maps_manager.std_amp, + network=network, + ) + self.validator._test_loader( + self.maps_manager, + valid_loader, + criterion, + "validation", + split, + self.config.validation.selection_metrics, + amp=self.maps_manager.std_amp, + network=network, + ) + + if save_outputs(self.maps_manager.network_task): + self.validator._compute_output_tensors( + self.maps_manager, + train_loader.dataset, + "train", + split, + self.config.validation.selection_metrics, + nb_images=1, + network=network, + ) + self.validator._compute_output_tensors( + self.maps_manager, + valid_loader.dataset, + "validation", + split, + self.config.validation.selection_metrics, + nb_images=1, + network=network, + ) + + self.callback_handler.on_train_end(parameters=self.maps_manager.parameters) + + def _init_callbacks(self) -> None: + """ + Initializes training callbacks. + """ + from clinicadl.callbacks.callbacks import CallbacksHandler, LoggerCallback + + # if self.callbacks is None: + # self.callbacks = [Callback()] + + self.callback_handler = CallbacksHandler() # callbacks=self.callbacks) + + if self.config.callbacks.emissions_calculator: + from clinicadl.callbacks.callbacks import CodeCarbonTracker + + self.callback_handler.add_callback(CodeCarbonTracker()) + + if self.config.callbacks.track_exp: + from clinicadl.callbacks.callbacks import Tracker + + self.callback_handler.add_callback(Tracker) + + self.callback_handler.add_callback(LoggerCallback()) + # self.callback_handler.add_callback(MetricConsolePrinterCallback()) + + def _init_optimizer( + self, + model: DDP, + split: Optional[int] = None, + resume: bool = False, + ) -> torch.optim.Optimizer: + """ + Initializes the optimizer. + + Parameters + ---------- + model : clinicadl.utils.maps_manager.ddp.DDP + The parallelizer. + split : int (optional, default=None) + The split considered. Should not be None if resume is True, but is + useless when resume is False. + resume : bool (optional, default=False) + If True, uses checkpoint to recover optimizer's old state. + + Returns + ------- + torch.optim.Optimizer + The optimizer. + """ + + optimizer_cls = getattr(torch.optim, self.config.optimizer.optimizer) + parameters = filter(lambda x: x.requires_grad, model.parameters()) + optimizer_kwargs = dict( + lr=self.config.optimizer.learning_rate, + weight_decay=self.config.optimizer.weight_decay, + ) + + optimizer = optimizer_cls(parameters, **optimizer_kwargs) + + if resume: + checkpoint_path = ( + self.maps_manager.maps_path + / f"split-{split}" + / "tmp" + / "optimizer.pth.tar" + ) + checkpoint_state = torch.load( + checkpoint_path, map_location=model.device, weights_only=True + ) + model.load_optim_state_dict(optimizer, checkpoint_state["optimizer"]) + + return optimizer + + def _init_profiler(self) -> torch.profiler.profile: + """ + Initializes the profiler. + + Returns + ------- + torch.profiler.profile + Profiler context manager. + """ + if self.config.optimization.profiler: + # TODO: no more profiler ???? + from clinicadl.utils.cluster.profiler import ( + ProfilerActivity, + profile, + schedule, + tensorboard_trace_handler, + ) + + time = datetime.now().strftime("%H:%M:%S") + filename = [self.maps_manager.maps_path / "profiler" / f"clinicadl_{time}"] + dist.broadcast_object_list(filename, src=0) + profiler = profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule(wait=2, warmup=2, active=30, repeat=1), + on_trace_ready=tensorboard_trace_handler(filename[0]), + profile_memory=True, + record_shapes=False, + with_stack=False, + with_flops=False, + ) + else: + profiler = nullcontext() + profiler.step = lambda *args, **kwargs: None + + return profiler diff --git a/clinicadl/trainer/tasks_utils.py b/clinicadl/trainer/tasks_utils.py index dc28d0acd..a14bfa4a9 100644 --- a/clinicadl/trainer/tasks_utils.py +++ b/clinicadl/trainer/tasks_utils.py @@ -18,9 +18,9 @@ from torch.utils.data import DataLoader, Sampler, sampler from torch.utils.data.distributed import DistributedSampler -from clinicadl.caps_dataset.data import CapsDataset -from clinicadl.metrics.metric_module import MetricModule -from clinicadl.network.network import Network +from clinicadl.dataset.caps_dataset import CapsDataset +from clinicadl.metrics.old_metrics.metric_module import MetricModule +from clinicadl.networks.old_network.network import Network from clinicadl.trainer.config.train import TrainConfig from clinicadl.utils import cluster from clinicadl.utils.enum import ( @@ -125,7 +125,7 @@ def validate_criterion(criterion_name: str, compatible_losses: List[str]): } if criterion in reconstruction_losses: - from clinicadl.network.vae.vae_utils import ( + from clinicadl.networks.old_network.vae.vae_utils import ( VAEBernoulliLoss, VAEContinuousBernoulliLoss, VAEGaussianLoss, @@ -603,6 +603,7 @@ def generate_sampler( network_task: Union[str, Task], dataset: CapsDataset, sampler_option: str = "random", + label_code: Optional[dict] = None, n_bins: int = 5, dp_degree: Optional[int] = None, rank: Optional[int] = None, @@ -622,7 +623,7 @@ def generate_sampler( def calculate_weights_classification(df): labels = df[dataset.config.data.label].unique() - codes = {dataset.config.data.label_code[label] for label in labels} + codes = {label_code[label] for label in labels} count = np.zeros(len(codes)) for idx in df.index: diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index 66ceb0dd1..386a7d148 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -1,1500 +1,20 @@ -from __future__ import annotations # noqa: I001 - -import shutil -from contextlib import nullcontext -from datetime import datetime -from logging import getLogger from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Callable - -import pandas as pd -import torch -import torch.distributed as dist -from torch.amp.grad_scaler import GradScaler -from torch.amp.autocast_mode import autocast -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from clinicadl.splitter.split_utils import find_finished_splits, find_stopped_splits -from clinicadl.caps_dataset.data import return_dataset -from clinicadl.utils.early_stopping.early_stopping import EarlyStopping -from clinicadl.utils.exceptions import MAPSError -from clinicadl.utils.computational.ddp import DDP -from clinicadl.utils import cluster -from clinicadl.utils.logwriter import LogWriter -from clinicadl.caps_dataset.caps_dataset_utils import read_json -from clinicadl.metrics.metric_module import RetainBest -from clinicadl.utils.seed import pl_worker_init_function, seed_everything -from clinicadl.maps_manager.maps_manager import MapsManager -from clinicadl.utils.seed import get_seed -from clinicadl.utils.enum import Task -from clinicadl.utils.iotools.trainer_utils import ( - create_parameters_dict, - patch_to_read_json, -) -from clinicadl.trainer.tasks_utils import create_training_config -from clinicadl.validator.validator import Validator -from clinicadl.splitter.splitter import Splitter -from clinicadl.splitter.config import SplitterConfig -from clinicadl.transforms.config import TransformsConfig -if TYPE_CHECKING: - from clinicadl.callbacks.callbacks import Callback - from clinicadl.trainer.config.train import TrainConfig - -from clinicadl.trainer.tasks_utils import ( - evaluation_metrics, - generate_sampler, - get_criterion, - save_outputs, -) - -logger = getLogger("clinicadl.trainer") +from clinicadl.dataset.caps_dataset import CapsDataset +from clinicadl.experiment_manager.experiment_manager import ExperimentManager +from clinicadl.model.clinicadl_model import ClinicaDLModel +from clinicadl.splitter.kfold import Split class Trainer: - """Temporary Trainer extracted from the MAPSManager.""" - - def __init__( - self, - config: TrainConfig, - ) -> None: - """ - Parameters - ---------- - config : TrainConfig - """ - self.config = config - - self.maps_manager = self._init_maps_manager(config) - self.validator = Validator() - self._check_args() - - def _init_maps_manager(self, config) -> MapsManager: - # temporary: to match CLI data. TODO : change CLI data - - parameters, maps_path = create_parameters_dict(config) - - if maps_path.is_dir(): - return MapsManager( - maps_path, verbose=None - ) # TODO : precise which parameters in config are useful - else: - # parameters["maps_path"] = maps_path - return MapsManager( - maps_path, parameters, verbose=None - ) # TODO : precise which parameters in config are useful + def __init__(self) -> None: + """TO COMPLETE""" @classmethod - def from_json(cls, config_file: str | Path, maps_path: str | Path) -> Trainer: - """ - Creates a Trainer from a json configuration file. - - Parameters - ---------- - config_file : str | Path - The parameters, stored in a json files. - maps_path : str | Path - The folder where the results of a futur training will be stored. - - Returns - ------- - Trainer - The Trainer object, instantiated with parameters found in config_file. - - Raises - ------ - FileNotFoundError - If config_file doesn't exist. - """ - config_file = Path(config_file) - - if not (config_file).is_file(): - raise FileNotFoundError(f"No file found at {str(config_file)}.") - config_dict = patch_to_read_json(read_json(config_file)) # TODO : remove patch - config_dict["maps_dir"] = maps_path - config_object = create_training_config(config_dict["network_task"])( - **config_dict - ) - return cls(config_object) - - @classmethod - def from_maps(cls, maps_path: str | Path) -> Trainer: - """ - Creates a Trainer from a json configuration file. - - Parameters - ---------- - maps_path : str | Path - The path of the MAPS folder. - - Returns - ------- - Trainer - The Trainer object, instantiated with parameters found in maps_path. - - Raises - ------ - MAPSError - If maps_path folder doesn't exist or there is no maps.json file in it. - """ - maps_path = Path(maps_path) - - if not (maps_path / "maps.json").is_file(): - raise MAPSError( - f"MAPS was not found at {str(maps_path)}." - f"To initiate a new MAPS please give a train_dict." - ) - return cls.from_json(maps_path / "maps.json", maps_path) - - def resume(self, splits: List[int]) -> None: - """ - Resume a prematurely stopped training. - - Parameters - ---------- - splits : List[int] - The splits that must be resumed. - """ - stopped_splits = set(find_stopped_splits(self.config.maps_manager.maps_dir)) - finished_splits = set(find_finished_splits(self.maps_manager.maps_path)) - # TODO : check these two lines. Why do we need a split_manager? - - splitter_config = SplitterConfig(**self.config.get_dict()) - split_manager = Splitter(splitter_config, split_list=splits) - - split_iterator = split_manager.split_iterator() - ### - absent_splits = set(split_iterator) - stopped_splits - finished_splits - - logger.info( - f"Finished splits {finished_splits}\n" - f"Stopped splits {stopped_splits}\n" - f"Absent splits {absent_splits}" - ) - - if len(stopped_splits) == 0 and len(absent_splits) == 0: - raise ValueError( - "Training has been completed on all the splits you passed." - ) - if len(stopped_splits) > 0: - self._resume(list(stopped_splits)) - if len(absent_splits) > 0: - self.train(list(absent_splits), overwrite=True) - - def _check_args(self): - self.config.reproducibility.seed = get_seed(self.config.reproducibility.seed) - # if (len(self.config.data.label_code) == 0): - # self.config.data.label_code = self.maps_manager.label_code - # TODO: deal with label_code and replace self.maps_manager.label_code - - def train( - self, - split_list: Optional[List[int]] = None, - overwrite: bool = False, - ) -> None: - """ - Performs the training task for a defined list of splits. - - Parameters - ---------- - split_list : Optional[List[int]] (optional, default=None) - List of splits on which the training task is performed. - Default trains all splits of the cross-validation. - overwrite : bool (optional, default=False) - If True, previously trained splits that are going to be trained - are erased. - - Raises - ------ - MAPSError - If splits specified in input already exist and overwrite is False. - """ - - self.check_split_list(split_list=split_list, overwrite=overwrite) - - if self.config.ssda.ssda_network: - self._train_ssda(split_list, resume=False) - - else: - splitter_config = SplitterConfig(**self.config.get_dict()) - split_manager = Splitter(splitter_config, split_list=split_list) - - for split in split_manager.split_iterator(): - logger.info(f"Training split {split}") - seed_everything( - self.config.reproducibility.seed, - self.config.reproducibility.deterministic, - self.config.reproducibility.compensation, - ) - - split_df_dict = split_manager[split] - - if self.config.model.multi_network: - resume, first_network = self.init_first_network(False, split) - for network in range(first_network, self.maps_manager.num_networks): - self._train_single( - split, split_df_dict, network=network, resume=resume - ) - else: - self._train_single(split, split_df_dict, resume=False) - - def check_split_list(self, split_list, overwrite): - existing_splits = [] - splitter_config = SplitterConfig(**self.config.get_dict()) - split_manager = Splitter(splitter_config, split_list=split_list) - for split in split_manager.split_iterator(): - split_path = self.maps_manager.maps_path / f"split-{split}" - if split_path.is_dir(): - if overwrite: - if cluster.master: - shutil.rmtree(split_path) - else: - existing_splits.append(split) - - if len(existing_splits) > 0: - raise MAPSError( - f"Splits {existing_splits} already exist. Please " - f"specify a list of splits not intersecting the previous list, " - f"or use overwrite to erase previously trained splits." - ) - - def _resume( - self, - split_list: Optional[List[int]] = None, - ) -> None: - """ - Resumes the training task for a defined list of splits. - - Parameters - ---------- - split_list : Optional[List[int]] (optional, default=None) - List of splits on which the training task is performed. - If None, the training task is performed on all splits. - - Raises - ------ - MAPSError - If splits specified in input do not exist. - """ - missing_splits = [] - splitter_config = SplitterConfig(**self.config.get_dict()) - split_manager = Splitter(splitter_config, split_list=split_list) - for split in split_manager.split_iterator(): - if not (self.maps_manager.maps_path / f"split-{split}" / "tmp").is_dir(): - missing_splits.append(split) - - if len(missing_splits) > 0: - raise MAPSError( - f"Splits {missing_splits} were not initialized. " - f"Please try train command on these splits and resume only others." - ) - - if self.config.ssda.ssda_network: - self._train_ssda(split_list, resume=True) - else: - for split in split_manager.split_iterator(): - logger.info(f"Training split {split}") - seed_everything( - self.config.reproducibility.seed, - self.config.reproducibility.deterministic, - self.config.reproducibility.compensation, - ) - - split_df_dict = split_manager[split] - if self.config.model.multi_network: - resume, first_network = self.init_first_network(True, split) - for network in range(first_network, self.maps_manager.num_networks): - self._train_single( - split, split_df_dict, network=network, resume=resume - ) - else: - self._train_single(split, split_df_dict, resume=True) - - def init_first_network(self, resume: bool, split: int): - first_network = 0 - if resume: - training_logs = [ - int(str(network_folder).split("-")[1]) - for network_folder in list( - ( - self.maps_manager.maps_path / f"split-{split}" / "training_logs" - ).iterdir() - ) - ] - first_network = max(training_logs) - if not (self.maps_manager.maps_path / "tmp").is_dir(): - first_network += 1 - resume = False - return resume, first_network - - def get_dataloader( - self, - data_df: pd.DataFrame, - cnn_index: Optional[int] = None, - sampler_option: str = "random", - dp_degree: Optional[int] = None, - rank: Optional[int] = None, - worker_init_fn: Optional[Callable[[int], None]] = None, - shuffle: Optional[bool] = None, - num_replicas: Optional[int] = None, - homemade_sampler: bool = False, - ): - dataset = return_dataset( - input_dir=self.config.data.caps_directory, - data_df=data_df, - preprocessing_dict=self.config.data.preprocessing_dict, - transforms_config=self.config.transforms, - multi_cohort=self.config.data.multi_cohort, - label=self.config.data.label, - label_code=self.maps_manager.label_code, - cnn_index=cnn_index, - ) - if homemade_sampler: - sampler = generate_sampler( - network_task=self.maps_manager.network_task, - dataset=dataset, - sampler_option=sampler_option, - dp_degree=dp_degree, - rank=rank, - ) - else: - sampler = DistributedSampler( - dataset, - num_replicas=num_replicas, - rank=rank, - shuffle=shuffle, - ) - - train_loader = DataLoader( - dataset=dataset, - batch_size=self.config.dataloader.batch_size, - sampler=sampler, - num_workers=self.config.dataloader.n_proc, - worker_init_fn=worker_init_fn, - shuffle=shuffle, - ) - logger.debug(f"Train loader size is {len(train_loader)}") - - return train_loader - - def _train_single( - self, - split, - split_df_dict: Dict, - network: Optional[int] = None, - resume: bool = False, - ) -> None: - """ - Trains a single CNN for all inputs. - - Parameters - ---------- - split_list : Optional[List[int]] (optional, default=None) - List of splits on which the training task is performed. - If None, performs training on all splits of the cross-validation. - resume : bool (optional, default=False) - If True, the job is resumed from checkpoint. - """ - - logger.debug("Loading training data...") - - train_loader = self.get_dataloader( - data_df=split_df_dict["train"], - cnn_index=network, - sampler_option=self.config.dataloader.sampler, - dp_degree=cluster.world_size, # type: ignore - rank=cluster.rank, # type: ignore - worker_init_fn=pl_worker_init_function, - homemade_sampler=True, - ) - - logger.debug(f"Train loader size is {len(train_loader)}") - logger.debug("Loading validation data...") - - valid_loader = self.get_dataloader( - data_df=split_df_dict["validation"], - cnn_index=network, - num_replicas=cluster.world_size, # type: ignore - rank=cluster.rank, # type: ignore - shuffle=False, - homemade_sampler=False, - ) - - logger.debug(f"Validation loader size is {len(valid_loader)}") - from clinicadl.callbacks.callbacks import CodeCarbonTracker - - self._train( - train_loader, - valid_loader, - split, - resume=resume, - callbacks=[CodeCarbonTracker], - network=network, - ) - - if network is not None: - resume = False - - if cluster.master: - self.validator._ensemble_prediction( - self.maps_manager, - "train", - split, - self.config.validation.selection_metrics, - ) - self.validator._ensemble_prediction( - self.maps_manager, - "validation", - split, - self.config.validation.selection_metrics, - ) - - self.maps_manager._erase_tmp(split) - - def _train_ssda( - self, - split_list: Optional[List[int]] = None, - resume: bool = False, - ) -> None: - """ - Trains a single CNN for a source and target domain using semi-supervised domain adaptation. - - Parameters - ---------- - split_list : Optional[List[int]] (optional, default=None) - List of splits on which the training task is performed. - If None, performs training on all splits of the cross-validation. - resume : bool (optional, default=False) - If True, the job is resumed from checkpoint. - """ - - splitter_config = SplitterConfig(**self.config.get_dict()) - - split_manager = Splitter(splitter_config, split_list=split_list) - split_manager_target_lab = Splitter(splitter_config, split_list=split_list) - - for split in split_manager.split_iterator(): - logger.info(f"Training split {split}") - seed_everything( - self.config.reproducibility.seed, - self.config.reproducibility.deterministic, - self.config.reproducibility.compensation, - ) - - split_df_dict = split_manager[split] - split_df_dict_target_lab = split_manager_target_lab[split] - - logger.debug("Loading source training data...") - data_train_source = return_dataset( - self.config.data.caps_directory, - split_df_dict["train"], - self.config.data.preprocessing_dict, - transforms_config=self.config.transforms, - multi_cohort=self.config.data.multi_cohort, - label=self.config.data.label, - label_code=self.maps_manager.label_code, - ) - - logger.debug("Loading target labelled training data...") - data_train_target_labeled = return_dataset( - Path(self.config.ssda.caps_target), # TO CHECK - split_df_dict_target_lab["train"], - self.config.ssda.preprocessing_dict_target, - transforms_config=self.config.transforms, - multi_cohort=False, # A checker - label=self.config.data.label, - label_code=self.maps_manager.label_code, - ) - from torch.utils.data import ConcatDataset - - combined_dataset = ConcatDataset( - [data_train_source, data_train_target_labeled] - ) - - logger.debug("Loading target unlabelled training data...") - data_target_unlabeled = return_dataset( - Path(self.config.ssda.caps_target), - pd.read_csv(self.config.ssda.tsv_target_unlab, sep="\t"), - self.config.ssda.preprocessing_dict_target, - transforms_config=self.config.transforms, - multi_cohort=False, # A checker - label=self.config.data.label, - label_code=self.maps_manager.label_code, - ) - - logger.debug("Loading validation source data...") - data_valid_source = return_dataset( - self.config.data.caps_directory, - split_df_dict["validation"], - self.config.data.preprocessing_dict, - transforms_config=self.config.transforms, - multi_cohort=self.config.data.multi_cohort, - label=self.config.data.label, - label_code=self.maps_manager.label_code, - ) - logger.debug("Loading validation target labelled data...") - data_valid_target_labeled = return_dataset( - Path(self.config.ssda.caps_target), - split_df_dict_target_lab["validation"], - self.config.ssda.preprocessing_dict_target, - transforms_config=self.config.transforms, - multi_cohort=False, - label=self.config.data.label, - label_code=self.maps_manager.label_code, - ) - train_source_sampler = generate_sampler( - self.maps_manager.network_task, - data_train_source, - self.config.dataloader.sampler, - ) - - logger.info( - f"Getting train and validation loader with batch size {self.config.dataloader.batch_size}" - ) - - ## Oversampling of the target dataset - from torch.utils.data import SubsetRandomSampler - - # Create index lists for target labeled dataset - labeled_indices = list(range(len(data_train_target_labeled))) - - # Oversample the indices for the target labelled dataset to match the size of the labeled source dataset - data_train_source_size = ( - len(data_train_source) // self.config.dataloader.batch_size - ) - labeled_oversampled_indices = labeled_indices * ( - data_train_source_size // len(labeled_indices) - ) - - # Append remaining indices to match the size of the largest dataset - labeled_oversampled_indices += labeled_indices[ - : data_train_source_size % len(labeled_indices) - ] - - # Create SubsetRandomSamplers using the oversampled indices - labeled_sampler = SubsetRandomSampler(labeled_oversampled_indices) - - train_source_loader = DataLoader( - data_train_source, - batch_size=self.config.dataloader.batch_size, - sampler=train_source_sampler, - # shuffle=True, # len(data_train_source) < len(data_train_target_labeled), - num_workers=self.config.dataloader.n_proc, - worker_init_fn=pl_worker_init_function, - drop_last=True, - ) - logger.info( - f"Train source loader size is {len(train_source_loader)*self.config.dataloader.batch_size}" - ) - train_target_loader = DataLoader( - data_train_target_labeled, - batch_size=1, # To limit the need of oversampling - # sampler=train_target_sampler, - sampler=labeled_sampler, - num_workers=self.config.dataloader.n_proc, - worker_init_fn=pl_worker_init_function, - # shuffle=True, # len(data_train_target_labeled) < len(data_train_source), - drop_last=True, - ) - logger.info( - f"Train target labeled loader size oversample is {len(train_target_loader)}" - ) - - data_train_target_labeled.df = data_train_target_labeled.df[ - ["participant_id", "session_id", "diagnosis", "cohort", "domain"] - ] - - train_target_unl_loader = DataLoader( - data_target_unlabeled, - batch_size=self.config.dataloader.batch_size, - num_workers=self.config.dataloader.n_proc, - # sampler=unlabeled_sampler, - worker_init_fn=pl_worker_init_function, - shuffle=True, - drop_last=True, - ) - - logger.info( - f"Train target unlabeled loader size is {len(train_target_unl_loader)*self.config.dataloader.batch_size}" - ) - - valid_loader_source = DataLoader( - data_valid_source, - batch_size=self.config.dataloader.batch_size, - shuffle=False, - num_workers=self.config.dataloader.n_proc, - ) - logger.info( - f"Validation loader source size is {len(valid_loader_source)*self.config.dataloader.batch_size}" - ) - - valid_loader_target = DataLoader( - data_valid_target_labeled, - batch_size=self.config.dataloader.batch_size, # To check - shuffle=False, - num_workers=self.config.dataloader.n_proc, - ) - logger.info( - f"Validation loader target size is {len(valid_loader_target)*self.config.dataloader.batch_size}" - ) - - self._train_ssdann( - train_source_loader, - train_target_loader, - train_target_unl_loader, - valid_loader_target, - valid_loader_source, - split, - resume=resume, - ) - - self.validator._ensemble_prediction( - self.maps_manager, - "train", - split, - self.config.validation.selection_metrics, - ) - self.validator._ensemble_prediction( - self.maps_manager, - "validation", - split, - self.config.validation.selection_metrics, - ) - - self.maps_manager._erase_tmp(split) - - def _train( - self, - train_loader: DataLoader, - valid_loader: DataLoader, - split: int, - network: Optional[int] = None, - resume: bool = False, - callbacks: list[Callback] = [], - ): - """ - Core function shared by train and resume. - - Parameters - ---------- - train_loader : torch.utils.data.DataLoader - DataLoader wrapping the training set. - valid_loader : torch.utils.data.DataLoader - DataLoader wrapping the validation set. - split : int - Index of the split trained. - network : int (optional, default=None) - Index of the network trained (used in multi-network setting only). - resume : bool (optional, default=False) - If True the job is resumed from the checkpoint. - callbacks : List[Callback] (optional, default=[]) - List of callbacks to call during training. - - Raises - ------ - Exception - _description_ - """ - self._init_callbacks() - model, beginning_epoch = self.maps_manager._init_model( - split=split, - resume=resume, - transfer_path=self.config.transfer_learning.transfer_path, - transfer_selection=self.config.transfer_learning.transfer_selection_metric, - nb_unfrozen_layer=self.config.transfer_learning.nb_unfrozen_layer, - ) - model = DDP( - model, - fsdp=self.config.computational.fully_sharded_data_parallel, - amp=self.config.computational.amp, - ) - criterion = get_criterion( - self.maps_manager.network_task, self.config.model.loss - ) - - optimizer = self._init_optimizer(model, split=split, resume=resume) - self.callback_handler.on_train_begin( - self.maps_manager.parameters, - criterion=criterion, - optimizer=optimizer, - split=split, - maps_path=self.maps_manager.maps_path, - ) - - model.train() - train_loader.dataset.train() - - early_stopping = EarlyStopping( - "min", - min_delta=self.config.early_stopping.tolerance, - patience=self.config.early_stopping.patience, - ) - metrics_valid = {"loss": None} - - if cluster.master: - log_writer = LogWriter( - self.maps_manager.maps_path, - evaluation_metrics(self.maps_manager.network_task) + ["loss"], - split, - resume=resume, - beginning_epoch=beginning_epoch, - network=network, - ) - # retain_best = RetainBest( - # selection_metrics=list(self.config.validation.selection_metrics) - # ) ??? - - epoch = beginning_epoch - - retain_best = RetainBest( - selection_metrics=list(self.config.validation.selection_metrics) - ) - - scaler = GradScaler("cuda", enabled=self.config.computational.amp) - profiler = self._init_profiler() - - if self.config.callbacks.track_exp == "wandb": - from clinicadl.callbacks.tracking_exp import WandB_handler - - if self.config.lr_scheduler.adaptive_learning_rate: - from torch.optim.lr_scheduler import ReduceLROnPlateau - - # Initialize the ReduceLROnPlateau scheduler - scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.1) - - while epoch < self.config.optimization.epochs and not early_stopping.step( - metrics_valid["loss"] - ): - # self.callback_handler.on_epoch_begin(self.parameters, epoch = epoch) - - if isinstance(train_loader.sampler, DistributedSampler): - # It should always be true for a random sampler. But just in case - # we get a WeightedRandomSampler or a forgotten RandomSampler, - # we do not want to execute this line. - train_loader.sampler.set_epoch(epoch) - - model.zero_grad(set_to_none=True) - evaluation_flag, step_flag = True, True - - with profiler: - for i, data in enumerate(train_loader): - update: bool = ( - i + 1 - ) % self.config.optimization.accumulation_steps == 0 - sync = nullcontext() if update else model.no_sync() - with sync: - with autocast("cuda", enabled=self.maps_manager.std_amp): - _, loss_dict = model(data, criterion) - logger.debug(f"Train loss dictionary {loss_dict}") - loss = loss_dict["loss"] - scaler.scale(loss).backward() - - if update: - step_flag = False - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad(set_to_none=True) - - del loss - - # Evaluate the model only when no gradients are accumulated - if ( - self.config.validation.evaluation_steps != 0 - and (i + 1) % self.config.validation.evaluation_steps == 0 - ): - evaluation_flag = False - - _, metrics_train = self.validator.test( - mode=self.maps_manager.mode, - metrics_module=self.maps_manager.metrics_module, - n_classes=self.maps_manager.n_classes, - network_task=self.maps_manager.network_task, - model=model, - dataloader=train_loader, - criterion=criterion, - amp=self.maps_manager.std_amp, - ) - _, metrics_valid = self.validator.test( - mode=self.maps_manager.mode, - metrics_module=self.maps_manager.metrics_module, - n_classes=self.maps_manager.n_classes, - network_task=self.maps_manager.network_task, - model=model, - dataloader=valid_loader, - criterion=criterion, - amp=self.maps_manager.std_amp, - ) - - model.train() - train_loader.dataset.train() - - if cluster.master: - log_writer.step( - epoch, - i, - metrics_train, - metrics_valid, - len(train_loader), - ) - logger.info( - f"{self.config.data.mode} level training loss is {metrics_train['loss']} " - f"at the end of iteration {i}" - ) - logger.info( - f"{self.config.data.mode} level validation loss is {metrics_valid['loss']} " - f"at the end of iteration {i}" - ) - - profiler.step() - - # If no step has been performed, raise Exception - if step_flag: - raise Exception( - "The model has not been updated once in the epoch. The accumulation step may be too large." - ) - - # If no evaluation has been performed, warn the user - elif evaluation_flag and self.config.validation.evaluation_steps != 0: - logger.warning( - f"Your evaluation steps {self.config.validation.evaluation_steps} are too big " - f"compared to the size of the dataset. " - f"The model is evaluated only once at the end epochs." - ) - - # Update weights one last time if gradients were computed without update - if (i + 1) % self.config.optimization.accumulation_steps != 0: - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad(set_to_none=True) - - # Always test the results and save them once at the end of the epoch - model.zero_grad(set_to_none=True) - logger.debug(f"Last checkpoint at the end of the epoch {epoch}") - - _, metrics_train = self.validator.test( - mode=self.maps_manager.mode, - metrics_module=self.maps_manager.metrics_module, - n_classes=self.maps_manager.n_classes, - network_task=self.maps_manager.network_task, - model=model, - dataloader=train_loader, - criterion=criterion, - amp=self.maps_manager.std_amp, - ) - _, metrics_valid = self.validator.test( - mode=self.maps_manager.mode, - metrics_module=self.maps_manager.metrics_module, - n_classes=self.maps_manager.n_classes, - network_task=self.maps_manager.network_task, - model=model, - dataloader=valid_loader, - criterion=criterion, - amp=self.maps_manager.std_amp, - ) - - model.train() - train_loader.dataset.train() - - self.callback_handler.on_epoch_end( - self.maps_manager.parameters, - metrics_train=metrics_train, - metrics_valid=metrics_valid, - mode=self.config.data.mode, - i=i, - ) - - model_weights = { - "model": model.state_dict(), - "epoch": epoch, - "name": self.config.model.architecture, - } - optimizer_weights = { - "optimizer": model.optim_state_dict(optimizer), - "epoch": epoch, - "name": self.config.model.architecture, - } - - if cluster.master: - # Save checkpoints and best models - best_dict = retain_best.step(metrics_valid) - self.maps_manager._write_weights( - model_weights, - best_dict, - split, - network=network, - save_all_models=self.config.reproducibility.save_all_models, - ) - self.maps_manager._write_weights( - optimizer_weights, - None, - split, - filename="optimizer.pth.tar", - save_all_models=self.config.reproducibility.save_all_models, - ) - dist.barrier() - - if self.config.lr_scheduler.adaptive_learning_rate: - scheduler.step( - metrics_valid["loss"] - ) # Update learning rate based on validation loss - - epoch += 1 - - del model - self.validator._test_loader( - self.maps_manager, - train_loader, - criterion, - "train", - split, - self.config.validation.selection_metrics, - amp=self.maps_manager.std_amp, - network=network, - ) - self.validator._test_loader( - self.maps_manager, - valid_loader, - criterion, - "validation", - split, - self.config.validation.selection_metrics, - amp=self.maps_manager.std_amp, - network=network, - ) - - if save_outputs(self.maps_manager.network_task): - self.validator._compute_output_tensors( - self.maps_manager, - train_loader.dataset, - "train", - split, - self.config.validation.selection_metrics, - nb_images=1, - network=network, - ) - self.validator._compute_output_tensors( - self.maps_manager, - valid_loader.dataset, - "validation", - split, - self.config.validation.selection_metrics, - nb_images=1, - network=network, - ) - - self.callback_handler.on_train_end(parameters=self.maps_manager.parameters) - - def _train_ssdann( - self, - train_source_loader: DataLoader, - train_target_loader: DataLoader, - train_target_unl_loader: DataLoader, - valid_loader: DataLoader, - valid_source_loader: DataLoader, - split: int, - network: Optional[Any] = None, - resume: bool = False, - evaluate_source: bool = True, # TO MODIFY - ): - """ - _summary_ - - Parameters - ---------- - train_source_loader : torch.utils.data.DataLoader - _description_ - train_target_loader : torch.utils.data.DataLoader - _description_ - train_target_unl_loader : torch.utils.data.DataLoader - _description_ - valid_loader : torch.utils.data.DataLoader - _description_ - valid_source_loader : torch.utils.data.DataLoader - _description_ - split : int - _description_ - network : Optional[Any] (optional, default=None) - _description_ - resume : bool (optional, default=False) - _description_ - evaluate_source : bool (optional, default=True) - _description_ - - Raises - ------ - Exception - _description_ - """ - model, beginning_epoch = self.maps_manager._init_model( - split=split, - resume=resume, - transfer_path=self.config.transfer_learning.transfer_path, - transfer_selection=self.config.transfer_learning.transfer_selection_metric, - ) - - criterion = get_criterion( - self.maps_manager.network_task, self.config.model.loss - ) - logger.debug(f"Criterion for {self.config.network_task} is {criterion}") - optimizer = self._init_optimizer(model, split=split, resume=resume) - - logger.debug(f"Optimizer used for training is optimizer") - - model.train() - train_source_loader.dataset.train() - train_target_loader.dataset.train() - train_target_unl_loader.dataset.train() - - early_stopping = EarlyStopping( - "min", - min_delta=self.config.early_stopping.tolerance, - patience=self.config.early_stopping.patience, - ) - - metrics_valid_target = {"loss": None} - metrics_valid_source = {"loss": None} - - log_writer = LogWriter( - self.maps_manager.maps_path, - evaluation_metrics(self.maps_manager.network_task) + ["loss"], - split, - resume=resume, - beginning_epoch=beginning_epoch, - network=network, - ) - epoch = log_writer.beginning_epoch - - retain_best = RetainBest( - selection_metrics=list(self.config.validation.selection_metrics) - ) - import numpy as np - - while epoch < self.config.optimization.epochs and not early_stopping.step( - metrics_valid_target["loss"] - ): - logger.info(f"Beginning epoch {epoch}.") - - model.zero_grad() - evaluation_flag, step_flag = True, True - - for i, (data_source, data_target, data_target_unl) in enumerate( - zip(train_source_loader, train_target_loader, train_target_unl_loader) - ): - p = ( - float(epoch * len(train_target_loader)) - / 10 - / len(train_target_loader) - ) - alpha = 2.0 / (1.0 + np.exp(-10 * p)) - 1 - # alpha = 0 - _, _, loss_dict = model.compute_outputs_and_loss( - data_source, data_target, data_target_unl, criterion, alpha - ) # TO CHECK - logger.debug(f"Train loss dictionary {loss_dict}") - loss = loss_dict["loss"] - loss.backward() - if (i + 1) % self.config.optimization.accumulation_steps == 0: - step_flag = False - optimizer.step() - optimizer.zero_grad() - - del loss - - # Evaluate the model only when no gradients are accumulated - if ( - self.config.validation.evaluation_steps != 0 - and (i + 1) % self.config.validation.evaluation_steps == 0 - ): - evaluation_flag = False - - # Evaluate on target data - logger.info("Evaluation on target data") - ( - _, - metrics_train_target, - ) = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=train_target_loader, - criterion=criterion, - alpha=alpha, - target=True, - ) # TO CHECK - - ( - _, - metrics_valid_target, - ) = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=valid_loader, - criterion=criterion, - alpha=alpha, - target=True, - ) - - model.train() - train_target_loader.dataset.train() - - log_writer.step( - epoch, - i, - metrics_train_target, - metrics_valid_target, - len(train_target_loader), - "training_target.tsv", - ) - logger.info( - f"{self.config.data.mode} level training loss for target data is {metrics_train_target['loss']} " - f"at the end of iteration {i}" - ) - logger.info( - f"{self.config.data.mode} level validation loss for target data is {metrics_valid_target['loss']} " - f"at the end of iteration {i}" - ) - - # Evaluate on source data - logger.info("Evaluation on source data") - ( - _, - metrics_train_source, - ) = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=train_source_loader, - criterion=criterion, - alpha=alpha, - ) - ( - _, - metrics_valid_source, - ) = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=valid_source_loader, - criterion=criterion, - alpha=alpha, - ) - - model.train() - train_source_loader.dataset.train() - - log_writer.step( - epoch, - i, - metrics_train_source, - metrics_valid_source, - len(train_source_loader), - ) - logger.info( - f"{self.config.data.mode} level training loss for source data is {metrics_train_source['loss']} " - f"at the end of iteration {i}" - ) - logger.info( - f"{self.config.data.mode} level validation loss for source data is {metrics_valid_source['loss']} " - f"at the end of iteration {i}" - ) - - # If no step has been performed, raise Exception - if step_flag: - raise Exception( - "The model has not been updated once in the epoch. The accumulation step may be too large." - ) - - # If no evaluation has been performed, warn the user - elif evaluation_flag and self.config.validation.evaluation_steps != 0: - logger.warning( - f"Your evaluation steps {self.config.validation.evaluation_steps} are too big " - f"compared to the size of the dataset. " - f"The model is evaluated only once at the end epochs." - ) - - # Update weights one last time if gradients were computed without update - if (i + 1) % self.config.optimization.accumulation_steps != 0: - optimizer.step() - optimizer.zero_grad() - # Always test the results and save them once at the end of the epoch - model.zero_grad() - logger.debug(f"Last checkpoint at the end of the epoch {epoch}") - - if evaluate_source: - logger.info( - f"Evaluate source data at the end of the epoch {epoch} with alpha: {alpha}." - ) - _, metrics_train_source = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=train_source_loader, - criterion=criterion, - alpha=alpha, - target=True, - report_ci=False, - ) - _, metrics_valid_source = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=valid_source_loader, - criterion=criterion, - alpha=alpha, - target=True, - report_ci=False, - ) - - log_writer.step( - epoch, - i, - metrics_train_source, - metrics_valid_source, - len(train_source_loader), - ) - - logger.info( - f"{self.config.data.mode} level training loss for source data is {metrics_train_source['loss']} " - f"at the end of iteration {i}" - ) - logger.info( - f"{self.config.data.mode} level validation loss for source data is {metrics_valid_source['loss']} " - f"at the end of iteration {i}" - ) - - _, metrics_train_target = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=train_target_loader, - criterion=criterion, - alpha=alpha, - target=True, - ) - _, metrics_valid_target = test_da( - mode=self.maps_manager.mode, - n_classes=self.maps_manager.n_classes, - metrics_module=self.maps_manager.metrics_module, - network_task=self.maps_manager.network_task, - model=model, - dataloader=valid_loader, - criterion=criterion, - alpha=alpha, - target=True, - ) - - model.train() - train_source_loader.dataset.train() - train_target_loader.dataset.train() - - log_writer.step( - epoch, - i, - metrics_train_target, - metrics_valid_target, - len(train_target_loader), - "training_target.tsv", - ) - - logger.info( - f"{self.config.data.mode} level training loss for target data is {metrics_train_target['loss']} " - f"at the end of iteration {i}" - ) - logger.info( - f"{self.config.data.mode} level validation loss for target data is {metrics_valid_target['loss']} " - f"at the end of iteration {i}" - ) - - # Save checkpoints and best models - best_dict = retain_best.step(metrics_valid_target) - self.maps_manager._write_weights( - { - "model": model.state_dict(), - "epoch": epoch, - "name": self.config.model.architecture, - }, - best_dict, - split, - network=network, - save_all_models=False, - ) - self.maps_manager._write_weights( - { - "optimizer": optimizer.state_dict(), # TO MODIFY - "epoch": epoch, - "name": self.config.optimizer, - }, - None, - split, - filename="optimizer.pth.tar", - save_all_models=False, - ) - - epoch += 1 - - self.validator._test_loader_ssda( - self.maps_manager, - train_target_loader, - criterion, - data_group="train", - split=split, - selection_metrics=self.config.validation.selection_metrics, - network=network, - target=True, - alpha=0, - ) - self.validator._test_loader_ssda( - self.maps_manager, - valid_loader, - criterion, - data_group="validation", - split=split, - selection_metrics=self.config.validation.selection_metrics, - network=network, - target=True, - alpha=0, - ) - - if save_outputs(self.maps_manager.network_task): - self.validator._compute_output_tensors( - self.maps_manager, - train_target_loader.dataset, - "train", - split, - self.config.validation.selection_metrics, - nb_images=1, - network=network, - ) - self.validator._compute_output_tensors( - self.maps_manager, - train_target_loader.dataset, - "validation", - split, - self.config.validation.selection_metrics, - nb_images=1, - network=network, - ) - - def _init_callbacks(self) -> None: - """ - Initializes training callbacks. - """ - from clinicadl.callbacks.callbacks import CallbacksHandler, LoggerCallback - - # if self.callbacks is None: - # self.callbacks = [Callback()] - - self.callback_handler = CallbacksHandler() # callbacks=self.callbacks) - - if self.config.callbacks.emissions_calculator: - from clinicadl.callbacks.callbacks import CodeCarbonTracker - - self.callback_handler.add_callback(CodeCarbonTracker()) - - if self.config.callbacks.track_exp: - from clinicadl.callbacks.callbacks import Tracker - - self.callback_handler.add_callback(Tracker) - - self.callback_handler.add_callback(LoggerCallback()) - # self.callback_handler.add_callback(MetricConsolePrinterCallback()) - - def _init_optimizer( - self, - model: DDP, - split: Optional[int] = None, - resume: bool = False, - ) -> torch.optim.Optimizer: - """ - Initializes the optimizer. - - Parameters - ---------- - model : clinicadl.utils.maps_manager.ddp.DDP - The parallelizer. - split : int (optional, default=None) - The split considered. Should not be None if resume is True, but is - useless when resume is False. - resume : bool (optional, default=False) - If True, uses checkpoint to recover optimizer's old state. - - Returns - ------- - torch.optim.Optimizer - The optimizer. - """ - - optimizer_cls = getattr(torch.optim, self.config.optimizer.optimizer) - parameters = filter(lambda x: x.requires_grad, model.parameters()) - optimizer_kwargs = dict( - lr=self.config.optimizer.learning_rate, - weight_decay=self.config.optimizer.weight_decay, - ) - - optimizer = optimizer_cls(parameters, **optimizer_kwargs) - - if resume: - checkpoint_path = ( - self.maps_manager.maps_path - / f"split-{split}" - / "tmp" - / "optimizer.pth.tar" - ) - checkpoint_state = torch.load( - checkpoint_path, map_location=model.device, weights_only=True - ) - model.load_optim_state_dict(optimizer, checkpoint_state["optimizer"]) - - return optimizer - - def _init_profiler(self) -> torch.profiler.profile: - """ - Initializes the profiler. - - Returns - ------- - torch.profiler.profile - Profiler context manager. - """ - if self.config.optimization.profiler: - # TODO: no more profiler ???? - from clinicadl.utils.cluster.profiler import ( - ProfilerActivity, - profile, - schedule, - tensorboard_trace_handler, - ) - - time = datetime.now().strftime("%H:%M:%S") - filename = [self.maps_manager.maps_path / "profiler" / f"clinicadl_{time}"] - dist.broadcast_object_list(filename, src=0) - profiler = profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - schedule=schedule(wait=2, warmup=2, active=30, repeat=1), - on_trace_ready=tensorboard_trace_handler(filename[0]), - profile_memory=True, - record_shapes=False, - with_stack=False, - with_flops=False, - ) - else: - profiler = nullcontext() - profiler.step = lambda *args, **kwargs: None + def from_json(cls, config_file: Path, manager: ExperimentManager) -> Trainer: + """TO COMPLETE""" + return cls() - return profiler + def train(self, model: ClinicaDLModel, split: Split): + """TO COMPLETE""" + pass diff --git a/clinicadl/transforms/config.py b/clinicadl/transforms/config.py index 91c80c0de..79f4db7c7 100644 --- a/clinicadl/transforms/config.py +++ b/clinicadl/transforms/config.py @@ -4,7 +4,7 @@ import torchvision.transforms as torch_transforms from pydantic import BaseModel, ConfigDict, field_validator -from clinicadl.transforms import transforms +from clinicadl.transforms import factory from clinicadl.utils.enum import ( SizeReductionFactor, Transform, @@ -47,16 +47,16 @@ def get_transforms( transforms to apply in train and evaluation mode / transforms to apply in evaluation mode only. """ augmentation_dict = { - "Noise": transforms.RandomNoising(sigma=0.1), + "Noise": factory.RandomNoising(sigma=0.1), "Erasing": torch_transforms.RandomErasing(), - "CropPad": transforms.RandomCropPad(10), - "Smoothing": transforms.RandomSmoothing(), - "Motion": transforms.RandomMotion((2, 4), (2, 4), 2), - "Ghosting": transforms.RandomGhosting((4, 10)), - "Spike": transforms.RandomSpike(1, (1, 3)), - "BiasField": transforms.RandomBiasField(0.5), - "RandomBlur": transforms.RandomBlur((0, 2)), - "RandomSwap": transforms.RandomSwap(15, 100), + "CropPad": factory.RandomCropPad(10), + "Smoothing": factory.RandomSmoothing(), + "Motion": factory.RandomMotion((2, 4), (2, 4), 2), + "Ghosting": factory.RandomGhosting((4, 10)), + "Spike": factory.RandomSpike(1, (1, 3)), + "BiasField": factory.RandomBiasField(0.5), + "RandomBlur": factory.RandomBlur((0, 2)), + "RandomSwap": factory.RandomSwap(15, 100), "None": None, } @@ -71,12 +71,12 @@ def get_transforms( ] ) - transformations_list.append(transforms.NanRemoval()) + transformations_list.append(factory.NanRemoval()) if self.normalize: - transformations_list.append(transforms.MinMaxNormalization()) + transformations_list.append(factory.MinMaxNormalization()) if self.size_reduction: transformations_list.append( - transforms.SizeReduction(self.size_reduction_factor) + factory.SizeReduction(self.size_reduction_factor) ) all_transformations = torch_transforms.Compose(transformations_list) diff --git a/clinicadl/transforms/factory.py b/clinicadl/transforms/factory.py new file mode 100644 index 000000000..7b6de09a7 --- /dev/null +++ b/clinicadl/transforms/factory.py @@ -0,0 +1,228 @@ +# coding: utf8 + +from logging import getLogger + +import numpy as np +import torch +import torchio as tio + +from clinicadl.utils.exceptions import ClinicaDLConfigurationError + +logger = getLogger("clinicadl") + +################################## +# Transformations +################################## + + +class RandomNoising(object): + """Applies a random zoom to a tensor""" + + def __init__(self, sigma=0.1): + self.sigma = sigma + + def __call__(self, image): + import random + + sigma = random.uniform(0, self.sigma) + dist = torch.distributions.normal.Normal(0, sigma) + return image + dist.sample(image.shape) + + +class RandomSmoothing(object): + """Applies a random zoom to a tensor""" + + def __init__(self, sigma=1): + self.sigma = sigma + + def __call__(self, image): + import random + + from scipy.ndimage import gaussian_filter + + sigma = random.uniform(0, self.sigma) + image = gaussian_filter(image, sigma) # smoothing of data + image = torch.from_numpy(image).float() + return image + + +class RandomCropPad(object): + def __init__(self, length): + self.length = length + + def __call__(self, image): + dimensions = len(image.shape) - 1 + crop = np.random.randint(-self.length, self.length, dimensions) + if dimensions == 2: + output = torch.nn.functional.pad( + image, (-crop[0], crop[0], -crop[1], crop[1]) + ) + elif dimensions == 3: + output = torch.nn.functional.pad( + image, (-crop[0], crop[0], -crop[1], crop[1], -crop[2], crop[2]) + ) + else: + raise ValueError( + f"RandomCropPad is only available for 2D or 3D data. Image is {dimensions}D" + ) + return output + + +class GaussianSmoothing(object): + def __init__(self, sigma): + self.sigma = sigma + + def __call__(self, sample): + from scipy.ndimage.filters import gaussian_filter + + image = sample["image"] + np.nan_to_num(image, copy=False) + smoothed_image = gaussian_filter(image, sigma=self.sigma) + sample["image"] = smoothed_image + + return sample + + +class RandomMotion(object): + """Applies a Random Motion""" + + def __init__(self, translation, rotation, num_transforms): + self.rotation = rotation + self.translation = translation + self.num_transforms = num_transforms + + def __call__(self, image): + motion = tio.RandomMotion( + degrees=self.rotation, + translation=self.translation, + num_transforms=self.num_transforms, + ) + image = motion(image) + + return image + + +class RandomGhosting(object): + """Applies a Random Ghosting""" + + def __init__(self, num_ghosts): + self.num_ghosts = num_ghosts + + def __call__(self, image): + ghost = tio.RandomGhosting(num_ghosts=self.num_ghosts) + image = ghost(image) + + return image + + +class RandomSpike(object): + """Applies a Random Spike""" + + def __init__(self, num_spikes, intensity): + self.num_spikes = num_spikes + self.intensity = intensity + + def __call__(self, image): + spike = tio.RandomSpike( + num_spikes=self.num_spikes, + intensity=self.intensity, + ) + image = spike(image) + + return image + + +class RandomBiasField(object): + """Applies a Random Bias Field""" + + def __init__(self, coefficients): + self.coefficients = coefficients + + def __call__(self, image): + bias_field = tio.RandomBiasField(coefficients=self.coefficients) + image = bias_field(image) + + return image + + +class RandomBlur(object): + """Applies a Random Blur""" + + def __init__(self, std): + self.std = std + + def __call__(self, image): + blur = tio.RandomBlur(std=self.std) + image = blur(image) + + return image + + +class RandomSwap(object): + """Applies a Random Swap""" + + def __init__(self, patch_size, num_iterations): + self.patch_size = patch_size + self.num_iterations = num_iterations + + def __call__(self, image): + swap = tio.RandomSwap( + patch_size=self.patch_size, num_iterations=self.num_iterations + ) + image = swap(image) + + return image + + +class ToTensor(object): + """Convert image type to Tensor and diagnosis to diagnosis code""" + + def __call__(self, image): + np.nan_to_num(image, copy=False) + image = image.astype(float) + + return torch.from_numpy(image[np.newaxis, :]).float() + + +class MinMaxNormalization(object): + """Normalizes a tensor between 0 and 1""" + + def __call__(self, image): + return (image - image.min()) / (image.max() - image.min()) + + +class NanRemoval(object): + def __init__(self): + self.nan_detected = False # Avoid warning each time new data is seen + + def __call__(self, image): + if torch.isnan(image).any().item(): + if not self.nan_detected: + logger.warning( + "NaN values were found in your images and will be removed." + ) + self.nan_detected = True + return torch.nan_to_num(image) + else: + return image + + +class SizeReduction(object): + """Reshape the input tensor to be of size [80, 96, 80]""" + + def __init__(self, size_reduction_factor=2) -> None: + self.size_reduction_factor = size_reduction_factor + + def __call__(self, image): + if self.size_reduction_factor == 2: + return image[:, 4:164:2, 8:200:2, 8:168:2] + elif self.size_reduction_factor == 3: + return image[:, 0:168:3, 8:200:3, 4:172:3] + elif self.size_reduction_factor == 4: + return image[:, 4:164:4, 8:200:4, 8:168:4] + elif self.size_reduction_factor == 5: + return image[:, 4:164:5, 0:200:5, 8:168:5] + else: + raise ClinicaDLConfigurationError( + "size_reduction_factor must be 2, 3, 4 or 5." + ) diff --git a/clinicadl/transforms/transforms.py b/clinicadl/transforms/transforms.py index 7b6de09a7..be18c382a 100644 --- a/clinicadl/transforms/transforms.py +++ b/clinicadl/transforms/transforms.py @@ -1,228 +1,14 @@ -# coding: utf8 +from typing import List -from logging import getLogger +import torchio -import numpy as np -import torch -import torchio as tio -from clinicadl.utils.exceptions import ClinicaDLConfigurationError - -logger = getLogger("clinicadl") - -################################## -# Transformations -################################## - - -class RandomNoising(object): - """Applies a random zoom to a tensor""" - - def __init__(self, sigma=0.1): - self.sigma = sigma - - def __call__(self, image): - import random - - sigma = random.uniform(0, self.sigma) - dist = torch.distributions.normal.Normal(0, sigma) - return image + dist.sample(image.shape) - - -class RandomSmoothing(object): - """Applies a random zoom to a tensor""" - - def __init__(self, sigma=1): - self.sigma = sigma - - def __call__(self, image): - import random - - from scipy.ndimage import gaussian_filter - - sigma = random.uniform(0, self.sigma) - image = gaussian_filter(image, sigma) # smoothing of data - image = torch.from_numpy(image).float() - return image - - -class RandomCropPad(object): - def __init__(self, length): - self.length = length - - def __call__(self, image): - dimensions = len(image.shape) - 1 - crop = np.random.randint(-self.length, self.length, dimensions) - if dimensions == 2: - output = torch.nn.functional.pad( - image, (-crop[0], crop[0], -crop[1], crop[1]) - ) - elif dimensions == 3: - output = torch.nn.functional.pad( - image, (-crop[0], crop[0], -crop[1], crop[1], -crop[2], crop[2]) - ) - else: - raise ValueError( - f"RandomCropPad is only available for 2D or 3D data. Image is {dimensions}D" - ) - return output - - -class GaussianSmoothing(object): - def __init__(self, sigma): - self.sigma = sigma - - def __call__(self, sample): - from scipy.ndimage.filters import gaussian_filter - - image = sample["image"] - np.nan_to_num(image, copy=False) - smoothed_image = gaussian_filter(image, sigma=self.sigma) - sample["image"] = smoothed_image - - return sample - - -class RandomMotion(object): - """Applies a Random Motion""" - - def __init__(self, translation, rotation, num_transforms): - self.rotation = rotation - self.translation = translation - self.num_transforms = num_transforms - - def __call__(self, image): - motion = tio.RandomMotion( - degrees=self.rotation, - translation=self.translation, - num_transforms=self.num_transforms, - ) - image = motion(image) - - return image - - -class RandomGhosting(object): - """Applies a Random Ghosting""" - - def __init__(self, num_ghosts): - self.num_ghosts = num_ghosts - - def __call__(self, image): - ghost = tio.RandomGhosting(num_ghosts=self.num_ghosts) - image = ghost(image) - - return image - - -class RandomSpike(object): - """Applies a Random Spike""" - - def __init__(self, num_spikes, intensity): - self.num_spikes = num_spikes - self.intensity = intensity - - def __call__(self, image): - spike = tio.RandomSpike( - num_spikes=self.num_spikes, - intensity=self.intensity, - ) - image = spike(image) - - return image - - -class RandomBiasField(object): - """Applies a Random Bias Field""" - - def __init__(self, coefficients): - self.coefficients = coefficients - - def __call__(self, image): - bias_field = tio.RandomBiasField(coefficients=self.coefficients) - image = bias_field(image) - - return image - - -class RandomBlur(object): - """Applies a Random Blur""" - - def __init__(self, std): - self.std = std - - def __call__(self, image): - blur = tio.RandomBlur(std=self.std) - image = blur(image) - - return image - - -class RandomSwap(object): - """Applies a Random Swap""" - - def __init__(self, patch_size, num_iterations): - self.patch_size = patch_size - self.num_iterations = num_iterations - - def __call__(self, image): - swap = tio.RandomSwap( - patch_size=self.patch_size, num_iterations=self.num_iterations - ) - image = swap(image) - - return image - - -class ToTensor(object): - """Convert image type to Tensor and diagnosis to diagnosis code""" - - def __call__(self, image): - np.nan_to_num(image, copy=False) - image = image.astype(float) - - return torch.from_numpy(image[np.newaxis, :]).float() - - -class MinMaxNormalization(object): - """Normalizes a tensor between 0 and 1""" - - def __call__(self, image): - return (image - image.min()) / (image.max() - image.min()) - - -class NanRemoval(object): - def __init__(self): - self.nan_detected = False # Avoid warning each time new data is seen - - def __call__(self, image): - if torch.isnan(image).any().item(): - if not self.nan_detected: - logger.warning( - "NaN values were found in your images and will be removed." - ) - self.nan_detected = True - return torch.nan_to_num(image) - else: - return image - - -class SizeReduction(object): - """Reshape the input tensor to be of size [80, 96, 80]""" - - def __init__(self, size_reduction_factor=2) -> None: - self.size_reduction_factor = size_reduction_factor - - def __call__(self, image): - if self.size_reduction_factor == 2: - return image[:, 4:164:2, 8:200:2, 8:168:2] - elif self.size_reduction_factor == 3: - return image[:, 0:168:3, 8:200:3, 4:172:3] - elif self.size_reduction_factor == 4: - return image[:, 4:164:4, 8:200:4, 8:168:4] - elif self.size_reduction_factor == 5: - return image[:, 4:164:5, 0:200:5, 8:168:5] - else: - raise ClinicaDLConfigurationError( - "size_reduction_factor must be 2, 3, 4 or 5." - ) +class Transforms: + def __init__( + self, + data_augmentation=List[torchio], + image_transforms=List[torchio], + object_transforms=List[torchio], + ) -> None: + """TO COMPLETE""" + self.data_augmentation = data_augmentation diff --git a/clinicadl/utils/cli_param/option.py b/clinicadl/utils/cli_param/option.py index 6ff86cda2..75438ceda 100644 --- a/clinicadl/utils/cli_param/option.py +++ b/clinicadl/utils/cli_param/option.py @@ -58,13 +58,6 @@ multiple=True, default=None, ) -ssda_network = click.option( - "--ssda_network", - type=bool, - default=False, - show_default=True, - help="ssda training.", -) valid_longitudinal = click.option( "--valid_longitudinal/--valid_baseline", type=bool, diff --git a/clinicadl/utils/enum.py b/clinicadl/utils/enum.py index 3e9031534..4e5c7721c 100644 --- a/clinicadl/utils/enum.py +++ b/clinicadl/utils/enum.py @@ -1,6 +1,17 @@ from enum import Enum +class CaseInsensitiveEnum(str, Enum): + @classmethod + def _missing_(cls, value): + if isinstance(value, str): + value = value.lower() + for member in cls: + if member.lower() == value: + return member + return None + + class BaseEnum(Enum): """Base Enum object that will print valid inputs if the value passed is not valid.""" diff --git a/clinicadl/utils/iotools/train_utils.py b/clinicadl/utils/iotools/train_utils.py index e4347de3b..7989f7142 100644 --- a/clinicadl/utils/iotools/train_utils.py +++ b/clinicadl/utils/iotools/train_utils.py @@ -90,7 +90,7 @@ def get_model_list(architecture=None, input_size=None, model_layers=False): """ from inspect import getmembers, isclass - import clinicadl.network as network_package + import clinicadl.networks.old_network as network_package if not architecture: print("The list of currently available models is:") @@ -198,3 +198,69 @@ def merge_cli_and_config_file_options(task: Task, **kwargs) -> Dict[str, Any]: pass ### return options + + +def merge_cli_and_maps_json_options(maps_json: Path, **kwargs) -> Dict[str, Any]: + """ + Merges options from the CLI (passed by the user) and from the config file + (if it exists). + + Priority is given to options passed by the user via the CLI. If it is not + provided, it will look for the option in the possible config file. + If an option is not passed by the user and not found in the config file, it will + not be in the output. + + Parameters + ---------- + task : Task + The task that is performed (e.g. classification). + + Returns + ------- + Dict[str, Any] + A dictionary with training options. + """ + + from clinicadl.dataset.caps_dataset_utils import read_json + + + options = read_json(maps_json) + for arg in kwargs: + if ( + click.get_current_context().get_parameter_source(arg) + == ParameterSource.COMMANDLINE + ): + options[arg] = kwargs[arg] + + return options + + +def merge_options_and_maps_json_options(maps_json: Path, **kwargs) -> Dict[str, Any]: + """ + Merges options from the CLI (passed by the user) and from the config file + (if it exists). + + Priority is given to options passed by the user via the CLI. If it is not + provided, it will look for the option in the possible config file. + If an option is not passed by the user and not found in the config file, it will + not be in the output. + + Parameters + ---------- + task : Task + The task that is performed (e.g. classification). + + Returns + ------- + Dict[str, Any] + A dictionary with training options. + """ + + from clinicadl.dataset.caps_dataset_utils import read_json + + + options = read_json(maps_json) + for arg in kwargs: + options[arg] = kwargs[arg] + + return options diff --git a/clinicadl/utils/iotools/trainer_utils.py b/clinicadl/utils/iotools/trainer_utils.py index b77229ea6..ac1b6a3bf 100644 --- a/clinicadl/utils/iotools/trainer_utils.py +++ b/clinicadl/utils/iotools/trainer_utils.py @@ -19,8 +19,7 @@ def create_parameters_dict(config): parameters["transfer_path"] = False if parameters["data_augmentation"] == (): parameters["data_augmentation"] = False - parameters["preprocessing_dict_target"] = parameters["preprocessing_json_target"] - del parameters["preprocessing_json_target"] + del parameters["preprocessing_json"] # if "tsv_path" in parameters: # parameters["tsv_path"] = parameters["tsv_path"] diff --git a/clinicadl/utils/meta_maps/getter.py b/clinicadl/utils/meta_maps/getter.py index 1fa524950..4698c34ab 100644 --- a/clinicadl/utils/meta_maps/getter.py +++ b/clinicadl/utils/meta_maps/getter.py @@ -6,8 +6,8 @@ import pandas as pd -from clinicadl.maps_manager.maps_manager import MapsManager -from clinicadl.metrics.utils import find_selection_metrics, get_metrics +from clinicadl.experiment_manager.maps_manager import MapsManager +from clinicadl.metrics.old_metrics.utils import find_selection_metrics, get_metrics from clinicadl.splitter.split_utils import find_splits from clinicadl.utils.exceptions import MAPSError diff --git a/clinicadl/validator/config.py b/clinicadl/validator/config.py deleted file mode 100644 index 2f8c8a30a..000000000 --- a/clinicadl/validator/config.py +++ /dev/null @@ -1,47 +0,0 @@ -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Optional, Union - -from pydantic import ( - BaseModel, - ConfigDict, - computed_field, - field_validator, -) - -from clinicadl.utils.factories import DefaultFromLibrary - - -class ValidatorConfig(BaseModel): - """Base config class to configure the validator.""" - - maps_path: Path - mode: str - network_task: str - num_networks: Optional[int] = None - fsdp: Optional[bool] = None - amp: Optional[bool] = None - metrics_module: Optional = None - n_classes: Optional[int] = None - nb_unfrozen_layers: Optional[int] = None - std_amp: Optional[bool] = None - - # pydantic config - model_config = ConfigDict( - validate_assignment=True, - use_enum_values=True, - validate_default=True, - ) - - @computed_field - @property - @abstractmethod - def metric(self) -> str: - """The name of the metric.""" - - @field_validator("get_not_nans", mode="after") - @classmethod - def validator_get_not_nans(cls, v): - assert not v, "get_not_nans not supported in ClinicaDL. Please set to False." - - return v diff --git a/clinicadl/validator/validator.py b/clinicadl/validator/validator.py deleted file mode 100644 index c8f5e9451..000000000 --- a/clinicadl/validator/validator.py +++ /dev/null @@ -1,496 +0,0 @@ -from logging import getLogger -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -import pandas as pd -import torch -import torch.distributed as dist -from torch.amp import autocast -from torch.nn.modules.loss import _Loss -from torch.utils.data import DataLoader - -from clinicadl.maps_manager.maps_manager import MapsManager -from clinicadl.metrics.metric_module import MetricModule -from clinicadl.metrics.utils import find_selection_metrics -from clinicadl.network.network import Network -from clinicadl.trainer.tasks_utils import columns, compute_metrics, generate_test_row -from clinicadl.utils import cluster -from clinicadl.utils.computational.ddp import DDP, init_ddp -from clinicadl.utils.enum import ( - ClassificationLoss, - ClassificationMetric, - ReconstructionLoss, - ReconstructionMetric, - RegressionLoss, - RegressionMetric, - Task, -) -from clinicadl.utils.exceptions import ( - ClinicaDLArgumentError, - ClinicaDLConfigurationError, - MAPSError, -) - -logger = getLogger("clinicadl.maps_manager") -level_list: List[str] = ["warning", "info", "debug"] - - -# TODO save weights on CPU for better compatibility - - -class Validator: - def test( - self, - mode: str, - metrics_module: MetricModule, - n_classes: int, - network_task, - model: Network, - dataloader: DataLoader, - criterion: _Loss, - use_labels: bool = True, - amp: bool = False, - report_ci=False, - ) -> Tuple[pd.DataFrame, Dict[str, float]]: - """ - Computes the predictions and evaluation metrics. - - Parameters - ---------- - model: Network - The model trained. - dataloader: DataLoader - Wrapper of a CapsDataset. - criterion: _Loss - Function to calculate the loss. - use_labels: bool - If True the true_label will be written in output DataFrame - and metrics dict will be created. - amp: bool - If True, enables Pytorch's automatic mixed precision. - - Returns - ------- - the results and metrics on the image level. - """ - model.eval() - dataloader.dataset.eval() - - results_df = pd.DataFrame(columns=columns(network_task, mode, n_classes)) - total_loss = {} - with torch.no_grad(): - for i, data in enumerate(dataloader): - # initialize the loss list to save the loss components - with autocast("cuda", enabled=amp): - outputs, loss_dict = model(data, criterion, use_labels=use_labels) - - if i == 0: - for loss_component in loss_dict.keys(): - total_loss[loss_component] = 0 - for loss_component in total_loss.keys(): - total_loss[loss_component] += loss_dict[loss_component].float() - - # Generate detailed DataFrame - for idx in range(len(data["participant_id"])): - row = generate_test_row( - network_task, - mode, - metrics_module, - n_classes, - idx, - data, - outputs.float(), - ) - row_df = pd.DataFrame( - row, columns=columns(network_task, mode, n_classes) - ) - results_df = pd.concat([results_df, row_df]) - - del outputs, loss_dict - dataframes = [None] * dist.get_world_size() - dist.gather_object( - results_df, dataframes if dist.get_rank() == 0 else None, dst=0 - ) - if dist.get_rank() == 0: - results_df = pd.concat(dataframes) - del dataframes - results_df.reset_index(inplace=True, drop=True) - - if not use_labels: - metrics_dict = None - else: - metrics_dict = compute_metrics( - network_task, results_df, metrics_module, report_ci=report_ci - ) - for loss_component in total_loss.keys(): - dist.reduce(total_loss[loss_component], dst=0) - loss_value = total_loss[loss_component].item() / cluster.world_size - - if report_ci: - metrics_dict["Metric_names"].append(loss_component) - metrics_dict["Metric_values"].append(loss_value) - metrics_dict["Lower_CI"].append("N/A") - metrics_dict["Upper_CI"].append("N/A") - metrics_dict["SE"].append("N/A") - - else: - metrics_dict[loss_component] = loss_value - - torch.cuda.empty_cache() - - return results_df, metrics_dict - - def test_da( - self, - mode: str, - metrics_module: MetricModule, - n_classes: int, - network_task: Union[str, Task], - model: Network, - dataloader: DataLoader, - criterion: _Loss, - alpha: float = 0, - use_labels: bool = True, - target: bool = True, - report_ci=False, - ) -> Tuple[pd.DataFrame, Dict[str, float]]: - """ - Computes the predictions and evaluation metrics. - - Args: - model: the model trained. - dataloader: wrapper of a CapsDataset. - criterion: function to calculate the loss. - use_labels: If True the true_label will be written in output DataFrame - and metrics dict will be created. - Returns: - the results and metrics on the image level. - """ - model.eval() - dataloader.dataset.eval() - results_df = pd.DataFrame(columns=columns(network_task, mode, n_classes)) - total_loss = 0 - with torch.no_grad(): - for i, data in enumerate(dataloader): - outputs, loss_dict = model.compute_outputs_and_loss_test( - data, criterion, alpha, target - ) - total_loss += loss_dict["loss"].item() - - # Generate detailed DataFrame - for idx in range(len(data["participant_id"])): - row = generate_test_row( - network_task, - mode, - metrics_module, - n_classes, - idx, - data, - outputs, - ) - row_df = pd.DataFrame( - row, columns=columns(network_task, mode, n_classes) - ) - results_df = pd.concat([results_df, row_df]) - - del outputs, loss_dict - results_df.reset_index(inplace=True, drop=True) - - if not use_labels: - metrics_dict = None - else: - metrics_dict = compute_metrics( - network_task, results_df, metrics_module, report_ci=report_ci - ) - if report_ci: - metrics_dict["Metric_names"].append("loss") - metrics_dict["Metric_values"].append(total_loss) - metrics_dict["Lower_CI"].append("N/A") - metrics_dict["Upper_CI"].append("N/A") - metrics_dict["SE"].append("N/A") - - else: - metrics_dict["loss"] = total_loss - - torch.cuda.empty_cache() - - return results_df, metrics_dict - - def _test_loader( - self, - maps_manager: MapsManager, - dataloader, - criterion, - data_group: str, - split: int, - selection_metrics, - use_labels=True, - gpu=None, - amp=False, - network=None, - report_ci=True, - ): - """ - Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files. - - Args: - dataloader (torch.utils.data.DataLoader): DataLoader wrapping the test CapsDataset. - criterion (torch.nn.modules.loss._Loss): optimization criterion used during training. - data_group (str): name of the data group used for the testing task. - split (int): Index of the split used to train the model tested. - selection_metrics (list[str]): List of metrics used to select the best models which are tested. - use_labels (bool): If True, the labels must exist in test meta-data and metrics are computed. - gpu (bool): If given, a new value for the device of the model will be computed. - amp (bool): If enabled, uses Automatic Mixed Precision (requires GPU usage). - network (int): Index of the network tested (only used in multi-network setting). - """ - for selection_metric in selection_metrics: - if cluster.master: - log_dir = ( - maps_manager.maps_path - / f"split-{split}" - / f"best-{selection_metric}" - / data_group - ) - maps_manager.write_description_log( - log_dir, - data_group, - dataloader.dataset.config.data.caps_dict, - dataloader.dataset.config.data.data_df, - ) - - # load the best trained model during the training - model, _ = maps_manager._init_model( - transfer_path=maps_manager.maps_path, - split=split, - transfer_selection=selection_metric, - gpu=gpu, - network=network, - ) - model = DDP( - model, - fsdp=maps_manager.fully_sharded_data_parallel, - amp=maps_manager.amp, - ) - - prediction_df, metrics = self.test( - mode=maps_manager.mode, - metrics_module=maps_manager.metrics_module, - n_classes=maps_manager.n_classes, - network_task=maps_manager.network_task, - model=model, - dataloader=dataloader, - criterion=criterion, - use_labels=use_labels, - amp=amp, - report_ci=report_ci, - ) - if use_labels: - if network is not None: - metrics[f"{maps_manager.mode}_id"] = network - - loss_to_log = ( - metrics["Metric_values"][-1] if report_ci else metrics["loss"] - ) - - logger.info( - f"{maps_manager.mode} level {data_group} loss is {loss_to_log} for model selected on {selection_metric}" - ) - - if cluster.master: - # Replace here - maps_manager._mode_level_to_tsv( - prediction_df, - metrics, - split, - selection_metric, - data_group=data_group, - ) - - def _test_loader_ssda( - self, - maps_manager: MapsManager, - dataloader, - criterion, - alpha, - data_group, - split, - selection_metrics, - use_labels=True, - gpu=None, - network=None, - target=False, - report_ci=True, - ): - """ - Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files. - - Args: - dataloader (torch.utils.data.DataLoader): DataLoader wrapping the test CapsDataset. - criterion (torch.nn.modules.loss._Loss): optimization criterion used during training. - data_group (str): name of the data group used for the testing task. - split (int): Index of the split used to train the model tested. - selection_metrics (list[str]): List of metrics used to select the best models which are tested. - use_labels (bool): If True, the labels must exist in test meta-data and metrics are computed. - gpu (bool): If given, a new value for the device of the model will be computed. - network (int): Index of the network tested (only used in multi-network setting). - """ - for selection_metric in selection_metrics: - log_dir = ( - maps_manager.maps_path - / f"split-{split}" - / f"best-{selection_metric}" - / data_group - ) - maps_manager.write_description_log( - log_dir, - data_group, - dataloader.dataset.caps_dict, - dataloader.dataset.df, - ) - - # load the best trained model during the training - model, _ = maps_manager._init_model( - transfer_path=maps_manager.maps_path, - split=split, - transfer_selection=selection_metric, - gpu=gpu, - network=network, - ) - prediction_df, metrics = self.test_da( - network_task=maps_manager.network_task, - model=model, - dataloader=dataloader, - criterion=criterion, - target=target, - report_ci=report_ci, - mode=maps_manager.mode, - metrics_module=maps_manager.metrics_module, - n_classes=maps_manager.n_classes, - ) - if use_labels: - if network is not None: - metrics[f"{maps_manager.mode}_id"] = network - - if report_ci: - loss_to_log = metrics["Metric_values"][-1] - else: - loss_to_log = metrics["loss"] - - logger.info( - f"{maps_manager.mode} level {data_group} loss is {loss_to_log} for model selected on {selection_metric}" - ) - - # Replace here - maps_manager._mode_level_to_tsv( - prediction_df, metrics, split, selection_metric, data_group=data_group - ) - - @torch.no_grad() - def _compute_output_tensors( - self, - maps_manager: MapsManager, - dataset, - data_group, - split, - selection_metrics, - nb_images=None, - gpu=None, - network=None, - ): - """ - Compute the output tensors and saves them in the MAPS. - - Args: - dataset (clinicadl.caps_dataset.data.CapsDataset): wrapper of the data set. - data_group (str): name of the data group used for the task. - split (int): split number. - selection_metrics (list[str]): metrics used for model selection. - nb_images (int): number of full images to write. Default computes the outputs of the whole data set. - gpu (bool): If given, a new value for the device of the model will be computed. - network (int): Index of the network tested (only used in multi-network setting). - """ - for selection_metric in selection_metrics: - # load the best trained model during the training - model, _ = maps_manager._init_model( - transfer_path=maps_manager.maps_path, - split=split, - transfer_selection=selection_metric, - gpu=gpu, - network=network, - nb_unfrozen_layer=maps_manager.nb_unfrozen_layer, - ) - model = DDP( - model, - fsdp=maps_manager.fully_sharded_data_parallel, - amp=maps_manager.amp, - ) - model.eval() - - tensor_path = ( - maps_manager.maps_path - / f"split-{split}" - / f"best-{selection_metric}" - / data_group - / "tensors" - ) - if cluster.master: - tensor_path.mkdir(parents=True, exist_ok=True) - dist.barrier() - - if nb_images is None: # Compute outputs for the whole data set - nb_modes = len(dataset) - else: - nb_modes = nb_images * dataset.elem_per_image - - for i in [ - *range(cluster.rank, nb_modes, cluster.world_size), - *range(int(nb_modes % cluster.world_size <= cluster.rank)), - ]: - data = dataset[i] - image = data["image"] - x = image.unsqueeze(0).to(model.device) - with autocast("cuda", enabled=maps_manager.std_amp): - output = model(x) - output = output.squeeze(0).cpu().float() - participant_id = data["participant_id"] - session_id = data["session_id"] - mode_id = data[f"{maps_manager.mode}_id"] - input_filename = f"{participant_id}_{session_id}_{maps_manager.mode}-{mode_id}_input.pt" - output_filename = f"{participant_id}_{session_id}_{maps_manager.mode}-{mode_id}_output.pt" - torch.save(image, tensor_path / input_filename) - torch.save(output, tensor_path / output_filename) - logger.debug(f"File saved at {[input_filename, output_filename]}") - - def _ensemble_prediction( - self, - maps_manager: MapsManager, - data_group, - split, - selection_metrics, - use_labels=True, - skip_leak_check=False, - ): - """Computes the results on the image-level.""" - - if not selection_metrics: - selection_metrics = find_selection_metrics(maps_manager.maps_path, split) - - for selection_metric in selection_metrics: - ##################### - # Soft voting - if maps_manager.num_networks > 1 and not skip_leak_check: - maps_manager._ensemble_to_tsv( - split, - selection=selection_metric, - data_group=data_group, - use_labels=use_labels, - ) - elif maps_manager.mode != "image" and not skip_leak_check: - maps_manager._mode_to_image_tsv( - split, - selection=selection_metric, - data_group=data_group, - use_labels=use_labels, - ) diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index e5a4a7302..000000000 --- a/tests/conftest.py +++ /dev/null @@ -1,42 +0,0 @@ -# coding: utf8 - -""" -This file contains a set of functional tests designed to check the correct execution of the pipeline and the -different functions available in ClinicaDL -""" - -import pytest - - -def pytest_addoption(parser): - parser.addoption( - "--input_data_directory", - action="store", - help="Directory for (only-read) inputs for tests", - ) - parser.addoption( - "--no-gpu", - action="store_true", - help="""To run tests on cpu. Default is False. - To use carefully, only to run tests locally. Should not be used in final CI tests. - Concretely, the tests won't fail if gpu option is false in the output MAPS whereas - it is true in the reference MAPS.""", - ) - parser.addoption( - "--adapt-base-dir", - action="store_true", - help="""To virtually change the base directory in the paths stored in the MAPS of the CI data. - Default is False. - To use carefully, only to run tests locally. Should not be used in final CI tests. - Concretely, the tests won't fail if only the base directories differ in the paths stored - in the output and reference MAPS.""", - ) - - -@pytest.fixture -def cmdopt(request): - config_param = {} - config_param["input"] = request.config.getoption("--input_data_directory") - config_param["no-gpu"] = request.config.getoption("--no-gpu") - config_param["adapt-base-dir"] = request.config.getoption("--adapt-base-dir") - return config_param diff --git a/tests/test_cli.py b/tests/test_cli.py deleted file mode 100644 index 687592bec..000000000 --- a/tests/test_cli.py +++ /dev/null @@ -1,200 +0,0 @@ -# coding: utf8 - -import pytest -from click.testing import CliRunner - -from clinicadl.cmdline import cli -from clinicadl.utils.enum import SliceDirection - - -# Test to ensure that the help string, at the command line, is invoked without errors -# Test for the first level at the command line -@pytest.fixture( - params=[ - "prepare-data", - "generate", - "interpret", - "predict", - "quality-check", - "random-search", - "train", - "tsvtools", - ] -) -def cli_args_first_lv(request): - task = request.param - return task - - -def test_first_lv(cli_args_first_lv): - runner = CliRunner() - task = cli_args_first_lv - print(f"Testing input cli {task}") - result = runner.invoke(cli, f"{task} -h") - assert result.exit_code == 0 - - -# Test for prepare-data cli, second level -@pytest.fixture( - params=[ - "image", - "slice", - "patch", - "roi", - ] -) -def prepare_data_cli_arg1(request): - return request.param - - -@pytest.fixture( - params=[ - "t1-linear", - "pet-linear", - "custom", - ] -) -def prepare_data_cli_arg2(request): - return request.param - - -def test_second_lv_prepare_data(prepare_data_cli_arg1, prepare_data_cli_arg2): - runner = CliRunner() - arg1 = prepare_data_cli_arg1 - arg2 = prepare_data_cli_arg2 - print(f"Testing input prepare_data cli {arg1} {arg2}") - result = runner.invoke(cli, f"prepare-data {arg1} {arg2} -h") - assert result.exit_code == 0 - - -# Test for the generate cli, second level -@pytest.fixture( - params=[ - "shepplogan", - "random", - "trivial", - ] -) -def generate_cli_arg1(request): - return request.param - - -def test_second_lv_generate(generate_cli_arg1): - runner = CliRunner() - arg1 = generate_cli_arg1 - print(f"Testing input generate cli {arg1}") - result = runner.invoke(cli, f"generate {arg1} -h") - assert result.exit_code == 0 - - -# Test for the interpret cli, second level -@pytest.fixture( - params=[ - "", - ] -) -def interpret_cli_arg1(request): - return request.param - - -def test_second_lv_interpret(interpret_cli_arg1): - runner = CliRunner() - cli_input = interpret_cli_arg1 - print(f"Testing input generate cli {cli_input}") - result = runner.invoke(cli, f"interpret {cli_input} -h") - assert result.exit_code == 0 - - -# Test for the predict cli, second level -@pytest.fixture( - params=[ - "", - ] -) -def predict_cli_arg1(request): - return request.param - - -def test_second_lv_predict(predict_cli_arg1): - runner = CliRunner() - cli_input = predict_cli_arg1 - print(f"Testing input predict cli {cli_input}") - result = runner.invoke(cli, f"predict {cli_input} -h") - assert result.exit_code == 0 - - -# Test for the train cli, second level -@pytest.fixture( - params=[ - "classification", - "regression", - "reconstruction", - "from_json", - "resume", - "list_models", - ] -) -def train_cli_arg1(request): - return request.param - - -def test_second_lv_train(train_cli_arg1): - runner = CliRunner() - cli_input = train_cli_arg1 - print(f"Testing input train cli {cli_input}") - result = runner.invoke(cli, f"train {cli_input} -h") - assert result.exit_code == 0 - - -# Test for the random-search cli, second level -@pytest.fixture(params=["generate", "analysis"]) -def rs_cli_arg1(request): - task = request.param - return task - - -def test_second_lv_random_search(rs_cli_arg1): - runner = CliRunner() - arg1 = rs_cli_arg1 - print(f"Testing input random-search cli {arg1}") - result = runner.invoke(cli, f"random-search {arg1} -h") - assert result.exit_code == 0 - - -# Test for the quality-check cli, second level -@pytest.fixture(params=["t1-linear", "t1-volume"]) -def qc_cli_arg1(request): - task = request.param - return task - - -def test_second_lv_quality_check(qc_cli_arg1): - runner = CliRunner() - arg1 = qc_cli_arg1 - print(f"Testing input quality-check cli {arg1}") - result = runner.invoke(cli, f"quality-check {arg1} -h") - assert result.exit_code == 0 - - -# Test for the tsvtool cli, second level -@pytest.fixture( - params=[ - "analysis", - "get-labels", - "kfold", - "split", - "prepare-experiment", - "get-progression", - "get-metadata", - ] -) -def tsvtool_cli_arg1(request): - return request.param - - -def test_second_lv_tsvtool(tsvtool_cli_arg1): - runner = CliRunner() - arg1 = tsvtool_cli_arg1 - print(f"Testing input tsvtools cli {arg1}") - result = runner.invoke(cli, f"tsvtools {arg1} -h") - assert result.exit_code == 0 diff --git a/tests/test_generate.py b/tests/test_generate.py deleted file mode 100644 index 9fc03535b..000000000 --- a/tests/test_generate.py +++ /dev/null @@ -1,119 +0,0 @@ -# coding: utf8 - -import os -from pathlib import Path - -import pytest - -from tests.testing_tools import clean_folder, compare_folders - - -@pytest.fixture( - params=[ - "random_example", - "trivial_example", - "shepplogan_example", - "hypometabolic_example", - "artifacts_example", - ] -) -def test_name(request): - return request.param - - -def test_generate(cmdopt, tmp_path, test_name): - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "generate" / "in" - ref_dir = base_dir / "generate" / "ref" - tmp_out_dir = tmp_path / "generate" / "out" - tmp_out_dir.mkdir(parents=True) - - clean_folder(tmp_out_dir, recreate=True) - - data_caps_pet = str(input_dir / "caps_pet") - data_caps_folder = str(input_dir / "caps") - - if test_name == "trivial_example": - output_folder = tmp_out_dir / test_name - test_input = [ - "generate", - "trivial", - data_caps_folder, - str(output_folder), - "--n_subjects", - "4", - "--preprocessing", - "t1-linear", - ] - elif test_name == "hypometabolic_example": - output_folder = tmp_out_dir / test_name - test_input = [ - "generate", - "hypometabolic", - data_caps_pet, - str(output_folder), - "--n_subjects", - "2", - "--pathology", - "ad", - "--anomaly_degree", - "50", - "--sigma", - "5", - ] - - elif test_name == "random_example": - output_folder = tmp_out_dir / test_name - test_input = [ - "generate", - "random", - data_caps_folder, - str(output_folder), - "--n_subjects", - "4", - "--mean", - "4000", - "--sigma", - "1000", - "--preprocessing", - "t1-linear", - ] - - elif test_name == "shepplogan_example": - n_subjects = 10 - output_folder = tmp_out_dir / test_name - test_input = [ - "generate", - "shepplogan", - str(output_folder), - "--n_subjects", - f"{n_subjects}", - ] - elif test_name == "artifacts_example": - output_folder = tmp_out_dir / test_name - test_input = [ - "generate", - "artifacts", - data_caps_folder, - str(output_folder), - "--preprocessing", - "t1-linear", - "--noise", - "--motion", - "--contrast", - ] - - else: - raise NotImplementedError(f"Test {test_name} is not implemented.") - - flag_error = not os.system("clinicadl " + " ".join(test_input)) - - assert flag_error - - if test_name == "shepplogan_example": - file = list((output_folder / "tensor_extraction").iterdir()) - old_name = output_folder / "tensor_extraction" / file[0] - new_name = output_folder / "tensor_extraction" / "extract_test.json" - old_name.rename(new_name) - - assert compare_folders(output_folder, ref_dir / test_name, tmp_out_dir) diff --git a/tests/test_interpret.py b/tests/test_interpret.py deleted file mode 100644 index 7b4c9358b..000000000 --- a/tests/test_interpret.py +++ /dev/null @@ -1,90 +0,0 @@ -# coding: utf8 - -import os -import shutil -from pathlib import Path - -import pytest - -from clinicadl.interpret.config import InterpretConfig -from clinicadl.predict.predict_manager import PredictManager - - -@pytest.fixture(params=["classification", "regression"]) -def test_name(request): - return request.param - - -def test_interpret(cmdopt, tmp_path, test_name): - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "interpret" / "in" - ref_dir = base_dir / "interpret" / "ref" - tmp_out_dir = tmp_path / "interpret" / "out" - tmp_out_dir.mkdir(parents=True) - - labels_dir_str = str(input_dir / "labels_list" / "2_fold") - maps_tmp_out_dir = str(tmp_out_dir / "maps") - if test_name == "classification": - cnn_input = [ - "train", - "classification", - str(input_dir / "caps_image"), - "t1-linear_mode-image.json", - labels_dir_str, - maps_tmp_out_dir, - "--architecture Conv5_FC3", - "--epochs", - "1", - "--n_splits", - "2", - "--split", - "0", - ] - - elif test_name == "regression": - cnn_input = [ - "train", - "regression", - str(input_dir / "caps_patch"), - "t1-linear_mode-patch.json", - labels_dir_str, - maps_tmp_out_dir, - "--architecture Conv5_FC3", - "--epochs", - "1", - "--n_splits", - "2", - "--split", - "0", - ] - else: - raise NotImplementedError(f"Test {test_name} is not implemented.") - - if cmdopt["no-gpu"]: - cnn_input.append("--no-gpu") - - run_interpret(cnn_input, tmp_out_dir, ref_dir) - - -def run_interpret(cnn_input, tmp_out_dir, ref_dir): - from clinicadl.utils.enum import InterpretationMethod - - maps_path = tmp_out_dir / "maps" - if maps_path.is_dir(): - shutil.rmtree(maps_path) - - train_error = not os.system("clinicadl " + " ".join(cnn_input)) - assert train_error - - for method in list(InterpretationMethod): - interpret_config = InterpretConfig( - maps_dir=maps_path, - data_group="train", - name=f"test-{method}", - method_cls=method, - ) - interpret_manager = PredictManager(interpret_config) - interpret_manager.interpret() - interpret_map = interpret_manager.get_interpretation( - "train", f"test-{interpret_config.method}" - ) diff --git a/tests/test_predict.py b/tests/test_predict.py deleted file mode 100644 index 849f0e20d..000000000 --- a/tests/test_predict.py +++ /dev/null @@ -1,136 +0,0 @@ -# coding: utf8 -import json -import shutil -from os.path import exists -from pathlib import Path - -import pytest - -from clinicadl.metrics.utils import get_metrics -from clinicadl.predict.predict_manager import PredictManager -from clinicadl.predict.utils import get_prediction - -from .testing_tools import compare_folders, modify_maps - - -@pytest.fixture( - params=[ - "predict_image_classification", - "predict_roi_regression", - "predict_slice_classification", - "predict_patch_regression", - "predict_patch_multi_classification", - "predict_roi_reconstruction", - ] -) -def test_name(request): - return request.param - - -def test_predict(cmdopt, tmp_path, test_name): - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "predict" / "in" - ref_dir = base_dir / "predict" / "ref" - tmp_out_dir = tmp_path / "predict" / "out" - tmp_out_dir.mkdir(parents=True) - - if test_name == "predict_image_classification": - maps_name = "maps_image_cnn" - modes = ["image"] - use_labels = True - elif test_name == "predict_slice_classification": - maps_name = "maps_slice_cnn" - modes = ["image", "slice"] - use_labels = True - elif test_name == "predict_patch_regression": - maps_name = "maps_patch_cnn" - modes = ["image", "patch"] - use_labels = False - elif test_name == "predict_roi_regression": - maps_name = "maps_roi_cnn" - modes = ["image", "roi"] - use_labels = False - elif test_name == "predict_patch_multi_classification": - maps_name = "maps_patch_multi_cnn" - modes = ["image", "patch"] - use_labels = False - elif test_name == "predict_roi_reconstruction": - maps_name = "maps_roi_ae" - modes = ["roi"] - use_labels = False - else: - raise NotImplementedError(f"Test {test_name} is not implemented.") - - shutil.copytree(input_dir / maps_name, tmp_out_dir / maps_name) - model_folder = tmp_out_dir / maps_name - - if cmdopt["adapt-base-dir"]: - with open(model_folder / "maps.json", "r") as f: - config = json.load(f) - config = modify_maps( - maps=config, - base_dir=base_dir, - no_gpu=cmdopt["no-gpu"], - adapt_base_dir=cmdopt["adapt-base-dir"], - ) - with open(model_folder / "maps.json", "w") as f: - json.dump(config, f, skipkeys=True, indent=4) - - with open(model_folder / "groups/test-RANDOM/maps.json", "r") as f: - config = json.load(f) - config = modify_maps( - maps=config, - base_dir=base_dir, - no_gpu=False, - adapt_base_dir=cmdopt["adapt-base-dir"], - ) - with open(model_folder / "groups/test-RANDOM/maps.json", "w") as f: - json.dump(config, f, skipkeys=True, indent=4) - - tmp_out_subdir = str(model_folder / "split-0/best-loss/test-RANDOM") - if exists(tmp_out_subdir): - shutil.rmtree(tmp_out_subdir) - - # # Correction of JSON file for ROI - # if "roi" in modes: - # json_path = model_folder / "maps.json" - # with open(json_path, "r") as f: - # parameters = json.load(f) - # parameters["roi_list"] = ["leftHippocampusBox", "rightHippocampusBox"] - # json_data = json.dumps(parameters, skipkeys=True, indent=4) - # with open(json_path, "w") as f: - # f.write(json_data) - - from clinicadl.predict.config import PredictConfig - - predict_config = PredictConfig( - maps_dir=model_folder, - data_group="test-RANDOM", - caps_directory=input_dir / "caps_random", - tsv_path=input_dir / "caps_random/data.tsv", - gpu=False, - use_labels=use_labels, - overwrite=True, - diagnoses=["CN"], - ) - predict_manager = PredictManager(predict_config) - predict_manager.predict() - - for mode in modes: - get_prediction( - predict_manager.maps_manager.maps_path, - data_group="test-RANDOM", - mode=mode, - ) - if use_labels: - get_metrics( - predict_manager.maps_manager.maps_path, - data_group="test-RANDOM", - mode=mode, - ) - - assert compare_folders( - tmp_out_dir / maps_name, - input_dir / maps_name, - tmp_out_dir, - ) diff --git a/tests/test_prepare_data.py b/tests/test_prepare_data.py deleted file mode 100644 index b6dda43d1..000000000 --- a/tests/test_prepare_data.py +++ /dev/null @@ -1,209 +0,0 @@ -# coding: utf8 - -import os -import shutil -import warnings -from os import PathLike -from os.path import join -from pathlib import Path -from typing import Any, Dict, List - -import pytest - -from clinicadl.caps_dataset.caps_dataset_config import ( - CapsDatasetConfig, - get_preprocessing, -) -from clinicadl.caps_dataset.extraction.config import ExtractionROIConfig -from clinicadl.caps_dataset.preprocessing.config import ( - CustomPreprocessingConfig, - PETPreprocessingConfig, -) -from clinicadl.utils.enum import ( - ExtractionMethod, - Preprocessing, - SUVRReferenceRegions, - Tracer, -) -from tests.testing_tools import clean_folder, compare_folders - -warnings.filterwarnings("ignore") - - -@pytest.fixture( - params=[ - "slice", - "patch", - "image", - "roi", - ] -) -def test_name(request): - return request.param - - -def test_prepare_data(cmdopt, tmp_path, test_name): - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "prepare_data" / "in" - ref_dir = base_dir / "prepare_data" / "ref" - tmp_out_dir = tmp_path / "prepare_data" / "out" - tmp_out_dir.mkdir(parents=True) - - clean_folder(tmp_out_dir, recreate=True) - - input_caps_directory = input_dir / "caps" - input_caps_flair_directory = input_dir / "caps_flair" - if test_name == "image": - if (tmp_out_dir / "caps_image").is_dir(): - shutil.rmtree(tmp_out_dir / "caps_image") - shutil.copytree(input_caps_directory, tmp_out_dir / "caps_image") - - if (tmp_out_dir / "caps_image_flair").is_dir(): - shutil.rmtree(tmp_out_dir / "caps_image_flair") - shutil.copytree(input_caps_flair_directory, tmp_out_dir / "caps_image_flair") - - config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=ExtractionMethod.IMAGE, - preprocessing_type=Preprocessing.T1_LINEAR, - preprocessing=Preprocessing.T1_LINEAR, - caps_directory=tmp_out_dir / "caps_image", - ) - - elif test_name == "patch": - if (tmp_out_dir / "caps_patch").is_dir(): - shutil.rmtree(tmp_out_dir / "caps_patch") - shutil.copytree(input_caps_directory, tmp_out_dir / "caps_patch") - - if (tmp_out_dir / "caps_patch_flair").is_dir(): - shutil.rmtree(tmp_out_dir / "caps_patch_flair") - shutil.copytree(input_caps_flair_directory, tmp_out_dir / "caps_patch_flair") - - config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=ExtractionMethod.PATCH, - preprocessing_type=Preprocessing.T1_LINEAR, - preprocessing=Preprocessing.T1_LINEAR, - caps_directory=tmp_out_dir / "caps_patch", - ) - - elif test_name == "slice": - if (tmp_out_dir / "caps_slice").is_dir(): - shutil.rmtree(tmp_out_dir / "caps_slice") - shutil.copytree(input_caps_directory, tmp_out_dir / "caps_slice") - - if (tmp_out_dir / "caps_slice_flair").is_dir(): - shutil.rmtree(tmp_out_dir / "caps_slice_flair") - shutil.copytree(input_caps_flair_directory, tmp_out_dir / "caps_slice_flair") - - config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=ExtractionMethod.SLICE, - preprocessing_type=Preprocessing.T1_LINEAR, - preprocessing=Preprocessing.T1_LINEAR, - caps_directory=tmp_out_dir / "caps_slice", - ) - - elif test_name == "roi": - if (tmp_out_dir / "caps_roi").is_dir(): - shutil.rmtree(tmp_out_dir / "caps_roi") - shutil.copytree(input_caps_directory, tmp_out_dir / "caps_roi") - - if (tmp_out_dir / "caps_roi_flair").is_dir(): - shutil.rmtree(tmp_out_dir / "caps_roi_flair") - shutil.copytree(input_caps_flair_directory, tmp_out_dir / "caps_roi_flair") - - config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=ExtractionMethod.ROI, - preprocessing_type=Preprocessing.T1_LINEAR, - preprocessing=Preprocessing.T1_LINEAR, - caps_directory=tmp_out_dir / "caps_image", - roi_list=["rightHippocampusBox", "leftHippocampusBox"], - ) - - else: - print(f"Test {test_name} not available.") - assert 0 - - run_test_prepare_data(input_dir, ref_dir, tmp_out_dir, test_name, config) - - -def run_test_prepare_data( - input_dir, ref_dir, out_dir, test_name: str, config: CapsDatasetConfig -): - modalities = ["t1-linear", "pet-linear", "flair-linear"] - uncropped_image = [True, False] - acquisition_label = ["18FAV45", "11CPIB"] - config.extraction.save_features = True - - for modality in modalities: - config.preprocessing.preprocessing = Preprocessing(modality) - config.preprocessing = get_preprocessing(Preprocessing(modality))() - if modality == "pet-linear": - for acq in acquisition_label: - assert isinstance(config.preprocessing, PETPreprocessingConfig) - config.preprocessing.tracer = Tracer(acq) - config.preprocessing.suvr_reference_region = SUVRReferenceRegions( - "pons2" - ) - config.preprocessing.use_uncropped_image = False - config.extraction.extract_json = ( - f"{modality}-{acq}_mode-{test_name}.json" - ) - tsv_file = join(input_dir, f"pet_{acq}.tsv") - mode = test_name - extract_generic(out_dir, mode, tsv_file, config) - - elif modality == "custom": - assert isinstance(config.preprocessing, CustomPreprocessingConfig) - config.preprocessing.use_uncropped_image = True - config.preprocessing.custom_suffix = ( - "graymatter_space-Ixi549Space_modulated-off_probability.nii.gz" - ) - if isinstance(config.extraction, ExtractionROIConfig): - config.extraction.roi_custom_template = "Ixi549Space" - config.extraction.extract_json = f"{modality}_mode-{test_name}.json" - tsv_file = input_dir / "subjects.tsv" - mode = test_name - extract_generic(out_dir, mode, tsv_file, config) - - elif modality == "t1-linear": - for flag in uncropped_image: - config.preprocessing.use_uncropped_image = flag - config.extraction.extract_json = ( - f"{modality}_crop-{not flag}_mode-{test_name}.json" - ) - mode = test_name - extract_generic(out_dir, mode, None, config) - - elif modality == "flair-linear": - config.data.caps_directory = Path( - str(config.data.caps_directory) + "_flair" - ) - config.extraction.save_features = False - for flag in uncropped_image: - config.preprocessing.use_uncropped_image = flag - config.extraction.extract_json = ( - f"{modality}_crop-{not flag}_mode-{test_name}.json" - ) - mode = f"{test_name}_flair" - extract_generic(out_dir, mode, None, config) - else: - raise NotImplementedError( - f"Test for modality {modality} was not implemented." - ) - - assert compare_folders( - out_dir / f"caps_{test_name}_flair", - ref_dir / f"caps_{test_name}_flair", - out_dir, - ) - assert compare_folders( - out_dir / f"caps_{test_name}", ref_dir / f"caps_{test_name}", out_dir - ) - - -def extract_generic(out_dir, mode, tsv_file, config: CapsDatasetConfig): - from clinicadl.prepare_data.prepare_data import DeepLearningPrepareData - - config.data.caps_directory = out_dir / f"caps_{mode}" - config.data.data_tsv = tsv_file - config.dataloader.n_proc = 1 - DeepLearningPrepareData(config) diff --git a/tests/test_qc.py b/tests/test_qc.py deleted file mode 100644 index 653986d9d..000000000 --- a/tests/test_qc.py +++ /dev/null @@ -1,94 +0,0 @@ -import shutil -from os import system -from os.path import join -from pathlib import Path - -import pandas as pd -import pytest - -from tests.testing_tools import compare_folders - - -@pytest.fixture(params=["t1-linear", "t1-volume", "pet-linear"]) -def test_name(request): - return request.param - - -def test_qc(cmdopt, tmp_path, test_name): - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "qualityCheck" / "in" - ref_dir = base_dir / "qualityCheck" / "ref" - tmp_out_dir = tmp_path / "qualityCheck" / "out" - tmp_out_dir.mkdir(parents=True) - - if test_name == "t1-linear": - out_tsv = tmp_out_dir / "QC.tsv" - test_input = [ - "t1-linear", - str(input_dir / "caps"), - str(out_tsv), - "--no-gpu", - ] - - elif test_name == "t1-volume": - out_dir = tmp_out_dir / "QC_T1V" - test_input = [ - "t1-volume", - str(input_dir / "caps_T1V"), - str(out_dir), - "Ixi549Space", - ] - - elif test_name == "pet-linear": - out_tsv = tmp_out_dir / "QC_pet.tsv" - test_input = [ - "pet-linear", - str(input_dir / "caps_pet"), - str(out_tsv), - "--tracer ", - "18FFDG", - "-suvr ", - "cerebellumPons2", - "--threshold", - "0.5", - ] - else: - raise NotImplementedError( - f"Quality check test on {test_name} is not implemented ." - ) - - flag_error = not system(f"clinicadl quality-check " + " ".join(test_input)) - assert flag_error - - if test_name == "t1-linear": - ref_tsv = join(ref_dir, "QC.tsv") - ref_df = pd.read_csv(ref_tsv, sep="\t") - ref_df.reset_index(inplace=True) - - out_df = pd.read_csv(out_tsv, sep="\t") - out_df.reset_index(inplace=True) - - out_df["pass_probability"] = round(out_df["pass_probability"], 2) - ref_df["pass_probability"] = round(ref_df["pass_probability"], 2) - - system(f"diff {out_tsv} {ref_tsv} ") - assert out_df.equals(ref_df) - - elif test_name == "t1-volume": - assert compare_folders(out_dir, ref_dir / "QC_T1V", tmp_out_dir) - - elif test_name == "pet-linear": - out_df = pd.read_csv(out_tsv, sep="\t") - ref_tsv = join(ref_dir, "QC_pet.tsv") - ref_df = pd.read_csv(ref_tsv, sep="\t") - out_df.reset_index(inplace=True) - ref_df.reset_index(inplace=True) - - out_df = pd.read_csv(out_tsv, sep="\t") - out_df.reset_index(inplace=True) - - out_df["pass_probability"] = round(out_df["pass_probability"], 2) - ref_df["pass_probability"] = round(ref_df["pass_probability"], 2) - - system(f"diff {out_tsv} {ref_tsv} ") - assert out_df.equals(ref_df) diff --git a/tests/test_random_search.py b/tests/test_random_search.py deleted file mode 100644 index 864f8b1fa..000000000 --- a/tests/test_random_search.py +++ /dev/null @@ -1,66 +0,0 @@ -# coding: utf8 - -import os -import shutil -from os.path import join -from pathlib import Path - -import pytest - -from .testing_tools import compare_folders, modify_toml - - -# random searxh for ROI with CNN -@pytest.fixture( - params=[ - "rs_roi_cnn", - ] -) -def test_name(request): - return request.param - - -def test_random_search(cmdopt, tmp_path, test_name): - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "randomSearch" / "in" - ref_dir = base_dir / "randomSearch" / "ref" - tmp_out_dir = tmp_path / "randomSearch" / "out" - - if os.path.exists(tmp_out_dir): - shutil.rmtree(tmp_out_dir) - tmp_out_dir.mkdir(parents=True) - - if test_name == "rs_roi_cnn": - toml_path = join(input_dir / "random_search.toml") - generate_input = ["random-search", str(tmp_out_dir), "job-1"] - else: - raise NotImplementedError(f"Test {test_name} is not implemented.") - - # Write random_search.toml file - shutil.copy(toml_path, tmp_out_dir) - - if cmdopt["no-gpu"] or cmdopt["adapt-base-dir"]: - modify_toml( - toml_path=tmp_out_dir / "random_search.toml", - base_dir=base_dir, - no_gpu=cmdopt["no-gpu"], - adapt_base_dir=cmdopt["adapt-base-dir"], - ) - - flag_error_generate = not os.system("clinicadl " + " ".join(generate_input)) - performances_flag = os.path.exists( - tmp_out_dir / "job-1" / "split-0" / "best-loss" / "train" - ) - assert flag_error_generate - assert performances_flag - - assert compare_folders( - tmp_out_dir / "job-1" / "groups", - ref_dir / "job-1" / "groups", - tmp_out_dir, - ) - assert compare_folders( - tmp_out_dir / "job-1" / "split-0" / "best-loss", - ref_dir / "job-1" / "split-0" / "best-loss", - tmp_out_dir, - ) diff --git a/tests/test_resume.py b/tests/test_resume.py deleted file mode 100644 index 1598267d8..000000000 --- a/tests/test_resume.py +++ /dev/null @@ -1,75 +0,0 @@ -# coding: utf8 -import json -import shutil -from os import system -from pathlib import Path - -import pytest - -from clinicadl.maps_manager.maps_manager import MapsManager -from clinicadl.splitter.config import SplitterConfig -from clinicadl.splitter.splitter import Splitter - -from .testing_tools import modify_maps - - -@pytest.fixture( - params=[ - "stopped_1", - "stopped_2", - "stopped_3", - ] -) -def test_name(request): - return request.param - - -def test_resume(cmdopt, tmp_path, test_name): - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "resume" / "in" - ref_dir = base_dir / "resume" / "ref" - tmp_out_dir = tmp_path / "resume" / "out" - tmp_out_dir.mkdir(parents=True) - - shutil.copytree(input_dir / test_name, tmp_out_dir / test_name) - maps_stopped = tmp_out_dir / test_name - - if cmdopt["no-gpu"] or cmdopt["adapt-base-dir"]: # modify the input MAPS - with open(maps_stopped / "maps.json", "r") as f: - config = json.load(f) - config = modify_maps( - maps=config, - base_dir=base_dir, - no_gpu=cmdopt["no-gpu"], - adapt_base_dir=cmdopt["adapt-base-dir"], - ) - with open(maps_stopped / "maps.json", "w") as f: - json.dump(config, f, skipkeys=True, indent=4) - - flag_error = not system(f"clinicadl -vv train resume {maps_stopped}") - assert flag_error - - maps_manager = MapsManager(maps_stopped) - splitter_config = SplitterConfig(**maps_manager.parameters) - split_manager = Splitter(splitter_config) - - for split in split_manager.split_iterator(): - performances_flag = ( - maps_stopped / f"split-{split}" / "best-loss" / "train" - ).exists() - assert performances_flag - - with open(maps_stopped / "maps.json", "r") as out: - json_data_out = json.load(out) - with open(ref_dir / "maps_image_cnn" / "maps.json", "r") as ref: - json_data_ref = json.load(ref) - - if cmdopt["no-gpu"] or cmdopt["adapt-base-dir"]: - json_data_ref = modify_maps( - maps=json_data_ref, - base_dir=base_dir, - no_gpu=cmdopt["no-gpu"], - adapt_base_dir=cmdopt["adapt-base-dir"], - ) - - assert json_data_ref == json_data_out diff --git a/tests/test_train_ae.py b/tests/test_train_ae.py deleted file mode 100644 index c7fbcb276..000000000 --- a/tests/test_train_ae.py +++ /dev/null @@ -1,121 +0,0 @@ -# coding: utf8 - -import json -import os -import shutil -from pathlib import Path - -import pytest - -from .testing_tools import clean_folder, compare_folders, modify_maps - - -@pytest.fixture( - params=[ - "image_ae", - "patch_multi_ae", - "roi_ae", - "slice_ae", - ] -) -def test_name(request): - return request.param - - -def test_train_ae(cmdopt, tmp_path, test_name): - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "train" / "in" - ref_dir = base_dir / "train" / "ref" - tmp_out_dir = tmp_path / "train" / "out" - tmp_out_dir.mkdir(parents=True) - - clean_folder(tmp_out_dir, recreate=True) - - labels_path = str(input_dir / "labels_list" / "2_fold") - config_path = str(input_dir / "train_config.toml") - split = 0 - - if test_name == "image_ae": - split = 1 - test_input = [ - "train", - "reconstruction", - str(input_dir / "caps_image"), - "t1-linear_crop-True_mode-image.json", - labels_path, - str(tmp_out_dir), - "-c", - config_path, - "--split", - str(split), - ] - elif test_name == "patch_multi_ae": - test_input = [ - "train", - "reconstruction", - str(input_dir / "caps_patch"), - "t1-linear_crop-True_mode-patch.json", - labels_path, - str(tmp_out_dir), - "-c", - config_path, - "--multi_network", - ] - elif test_name == "roi_ae": - test_input = [ - "train", - "reconstruction", - str(input_dir / "caps_roi"), - "t1-linear_crop-True_mode-roi.json", - labels_path, - str(tmp_out_dir), - "-c", - config_path, - ] - elif test_name == "slice_ae": - test_input = [ - "train", - "reconstruction", - str(input_dir / "caps_slice"), - "t1-linear_crop-True_mode-slice.json", - labels_path, - str(tmp_out_dir), - "-c", - config_path, - ] - else: - raise NotImplementedError(f"Test {test_name} is not implemented.") - - if cmdopt["no-gpu"]: - test_input.append("--no-gpu") - - if tmp_out_dir.is_dir(): - shutil.rmtree(tmp_out_dir) - - flag_error = not os.system("clinicadl " + " ".join(test_input)) - assert flag_error - - with open(tmp_out_dir / "maps.json", "r") as out: - json_data_out = json.load(out) - with open(ref_dir / ("maps_" + test_name) / "maps.json", "r") as ref: - json_data_ref = json.load(ref) - - if cmdopt["no-gpu"] or cmdopt["adapt-base-dir"]: - json_data_ref = modify_maps( - maps=json_data_ref, - base_dir=base_dir, - no_gpu=cmdopt["no-gpu"], - adapt_base_dir=cmdopt["adapt-base-dir"], - ) - assert json_data_out == json_data_ref # ["mode"] == mode - - assert compare_folders( - tmp_out_dir / "groups", - ref_dir / ("maps_" + test_name) / "groups", - tmp_path, - ) - assert compare_folders( - tmp_out_dir / f"split-{split}" / "best-loss", - ref_dir / ("maps_" + test_name) / f"split-{split}" / "best-loss", - tmp_path, - ) diff --git a/tests/test_train_cnn.py b/tests/test_train_cnn.py deleted file mode 100644 index 761fedbee..000000000 --- a/tests/test_train_cnn.py +++ /dev/null @@ -1,142 +0,0 @@ -# coding: utf8 - -import json -import os -import shutil -from pathlib import Path - -import pytest - -from .testing_tools import compare_folders, modify_maps - - -@pytest.fixture( - params=[ - "slice_cnn", - "image_cnn", - "patch_cnn", - "patch_multi_cnn", - "roi_cnn", - ] -) -def test_name(request): - return request.param - - -def test_train_cnn(cmdopt, tmp_path, test_name): - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "train" / "in" - ref_dir = base_dir / "train" / "ref" - tmp_out_dir = tmp_path / "train" / "out" - tmp_out_dir.mkdir(parents=True) - - labels_path = input_dir / "labels_list" / "2_fold" - config_path = input_dir / "train_config.toml" - split = "0" - - if test_name == "slice_cnn": - split_ref = 0 - test_input = [ - "train", - "classification", - str(input_dir / "caps_slice"), - "t1-linear_crop-True_mode-slice.json", - str(labels_path), - str(tmp_out_dir), - "-c", - str(config_path), - ] - elif test_name == "image_cnn": - split_ref = 1 - test_input = [ - "train", - "regression", - str(input_dir / "caps_image"), - "t1-linear_crop-True_mode-image.json", - str(labels_path), - str(tmp_out_dir), - "-c", - str(config_path), - ] - elif test_name == "patch_cnn": - split_ref = 0 - test_input = [ - "train", - "classification", - str(input_dir / "caps_patch"), - "t1-linear_crop-True_mode-patch.json", - str(labels_path), - str(tmp_out_dir), - "-c", - str(config_path), - "--split", - split, - ] - elif test_name == "patch_multi_cnn": - split_ref = 0 - test_input = [ - "train", - "classification", - str(input_dir / "caps_patch"), - "t1-linear_crop-True_mode-patch.json", - str(labels_path), - str(tmp_out_dir), - "-c", - str(config_path), - "--multi_network", - ] - elif test_name == "roi_cnn": - split_ref = 0 - test_input = [ - "train", - "classification", - str(input_dir / "caps_roi"), - "t1-linear_crop-True_mode-roi.json", - str(labels_path), - str(tmp_out_dir), - "-c", - str(config_path), - ] - else: - raise NotImplementedError(f"Test {test_name} is not implemented.") - - if cmdopt["no-gpu"]: - test_input.append("--no-gpu") - - if tmp_out_dir.is_dir(): - shutil.rmtree(tmp_out_dir) - - flag_error = not os.system("clinicadl " + " ".join(test_input)) - assert flag_error - - performances_flag = ( - tmp_out_dir / f"split-{split}" / "best-loss" / "train" - ).exists() - assert performances_flag - - with open(tmp_out_dir / "maps.json", "r") as out: - json_data_out = json.load(out) - with open(ref_dir / ("maps_" + test_name) / "maps.json", "r") as ref: - json_data_ref = json.load(ref) - - if cmdopt["no-gpu"] or cmdopt["adapt-base-dir"]: - json_data_ref = modify_maps( - maps=json_data_ref, - base_dir=base_dir, - no_gpu=cmdopt["no-gpu"], - adapt_base_dir=cmdopt["adapt-base-dir"], - ) - assert json_data_out == json_data_ref # ["mode"] == mode - - assert compare_folders( - tmp_out_dir / "groups", - ref_dir / ("maps_" + test_name) / "groups", - tmp_path, - ) - assert compare_folders( - tmp_out_dir / "split-0" / "best-loss", - ref_dir / ("maps_" + test_name) / f"split-{split_ref}" / "best-loss", - tmp_path, - ) - - shutil.rmtree(tmp_out_dir) diff --git a/tests/test_train_from_json.py b/tests/test_train_from_json.py deleted file mode 100644 index 06b307b0f..000000000 --- a/tests/test_train_from_json.py +++ /dev/null @@ -1,85 +0,0 @@ -import json -import shutil -from os import system -from pathlib import Path - -from .testing_tools import compare_folders_with_hashes, create_hashes_dict, modify_maps - - -def test_json_compatibility(cmdopt, tmp_path): - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "train_from_json" / "in" - tmp_out_dir = tmp_path / "train_from_json" / "out" - tmp_out_dir.mkdir(parents=True) - - split = "0" - config_json = input_dir / "maps_roi_cnn/maps.json" - reproduced_maps_dir = tmp_out_dir / "maps_reproduced" - - if reproduced_maps_dir.exists(): - shutil.rmtree(reproduced_maps_dir) - - if cmdopt["no-gpu"] or cmdopt["adapt-base-dir"]: # virtually modify the input MAPS - with open(config_json, "r") as f: - config = json.load(f) - config_json = tmp_out_dir / "modified_maps.json" - config = modify_maps( - maps=config, - base_dir=base_dir, - no_gpu=cmdopt["no-gpu"], - adapt_base_dir=cmdopt["adapt-base-dir"], - ) - with open(config_json, "w+") as f: - json.dump(config, f, skipkeys=True, indent=4) - - flag_error = not system( - f"clinicadl train from_json {str(config_json)} {str(reproduced_maps_dir)} -s {split}" - ) - assert flag_error - - -def test_determinism(cmdopt, tmp_path): - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "train_from_json" / "in" - tmp_out_dir = tmp_path / "train_from_json" / "out" - tmp_out_dir.mkdir(parents=True) - - maps_dir = tmp_out_dir / "maps_roi_cnn" - reproduced_maps_dir = tmp_out_dir / "reproduced_MAPS" - if maps_dir.exists(): - shutil.rmtree(maps_dir) - if reproduced_maps_dir.exists(): - shutil.rmtree(reproduced_maps_dir) - test_input = [ - "train", - "classification", - str(input_dir / "caps_roi"), - "t1-linear_mode-roi.json", - str(input_dir / "labels_list" / "2_fold"), - str(maps_dir), - "-c", - str(input_dir / "reproducibility_config.toml"), - ] - - if cmdopt["no-gpu"]: - test_input.append("--no-gpu") - - # Run first experiment - flag_error = not system("clinicadl " + " ".join(test_input)) - assert flag_error - input_hashes = create_hashes_dict( - maps_dir, - ignore_pattern_list=["tensorboard", ".log", "training.tsv", "maps.json"], - ) - - # Reproduce experiment (train from json) - config_json = tmp_out_dir / "maps_roi_cnn/maps.json" - flag_error = not system( - f"clinicadl train from_json {str(config_json)} {str(reproduced_maps_dir)} -s 0" - ) - assert flag_error - compare_folders_with_hashes( - reproduced_maps_dir, - input_hashes, - ignore_pattern_list=["tensorboard", ".log", "training.tsv", "maps.json"], - ) diff --git a/tests/test_transfer_learning.py b/tests/test_transfer_learning.py deleted file mode 100644 index d49cbd61f..000000000 --- a/tests/test_transfer_learning.py +++ /dev/null @@ -1,184 +0,0 @@ -import json -import os -import shutil -from pathlib import Path - -import pytest - -from .testing_tools import compare_folders, modify_maps - - -# Everything is tested on roi except for cnn --> multicnn (patch) as multicnn is not implemented for roi. -@pytest.fixture( - params=[ - "transfer_ae_ae", - "transfer_ae_cnn", - "transfer_cnn_cnn", - "transfer_cnn_multicnn", - ] -) -def test_name(request): - return request.param - - -def test_transfer_learning(cmdopt, tmp_path, test_name): - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "transferLearning" / "in" - ref_dir = base_dir / "transferLearning" / "ref" - tmp_out_dir = tmp_path / "transferLearning" / "out" - tmp_target_dir = tmp_path / "transferLearning" / "target" - tmp_out_dir.mkdir(parents=True) - - caps_roi_path = input_dir / "caps_roi" - extract_roi_str = "t1-linear_mode-roi.json" - labels_path = input_dir / "labels_list" / "2_fold" - config_path = input_dir / "train_config.toml" - if test_name == "transfer_ae_ae": - source_task = [ - "train", - "reconstruction", - str(caps_roi_path), - extract_roi_str, - str(labels_path), - str(tmp_out_dir / "maps_roi_ae"), - "-c", - str(config_path), - ] - target_task = [ - "train", - "reconstruction", - str(caps_roi_path), - extract_roi_str, - str(labels_path), - str(tmp_target_dir), - "-c", - str(config_path), - "--transfer_path", - str(tmp_out_dir / "maps_roi_ae"), - ] - name = "aeTOae" - elif test_name == "transfer_ae_cnn": - source_task = [ - "train", - "reconstruction", - str(caps_roi_path), - extract_roi_str, - str(labels_path), - str(tmp_out_dir / "maps_roi_ae"), - "-c", - str(config_path), - ] - target_task = [ - "train", - "classification", - str(caps_roi_path), - extract_roi_str, - str(labels_path), - str(tmp_target_dir), - "-c", - str(config_path), - "--transfer_path", - str(tmp_out_dir / "maps_roi_ae"), - ] - name = "aeTOcnn" - elif test_name == "transfer_cnn_cnn": - source_task = [ - "train", - "classification", - str(caps_roi_path), - extract_roi_str, - str(labels_path), - str(tmp_out_dir / "maps_roi_cnn"), - "-c", - str(config_path), - ] - target_task = [ - "train", - "classification", - str(caps_roi_path), - extract_roi_str, - str(labels_path), - str(tmp_target_dir), - "-c", - str(config_path), - "--transfer_path", - str(tmp_out_dir / "maps_roi_cnn"), - ] - name = "cnnTOcnn" - elif test_name == "transfer_cnn_multicnn": - source_task = [ - "train", - "classification", - str(caps_roi_path), - extract_roi_str, - str(labels_path), - str(tmp_out_dir / "maps_roi_cnn"), - "-c", - str(config_path), - ] - target_task = [ - "train", - "classification", - str(caps_roi_path), - extract_roi_str, - str(labels_path), - str(tmp_target_dir), - "-c", - str(config_path), - "--transfer_path", - str(tmp_out_dir / "maps_roi_cnn"), - "--multi_network", - ] - name = "cnnTOmulticnn" - else: - raise NotImplementedError(f"Test {test_name} is not implemented.") - - if cmdopt["no-gpu"]: - source_task.append("--no-gpu") - target_task.append("--no-gpu") - - if tmp_out_dir.exists(): - shutil.rmtree(tmp_out_dir) - if tmp_target_dir.exists(): - shutil.rmtree(tmp_target_dir) - - flag_source = not os.system("clinicadl -vvv " + " ".join(source_task)) - flag_target = not os.system("clinicadl -vvv " + " ".join(target_task)) - assert flag_source - assert flag_target - - with open(tmp_target_dir / "maps.json", "r") as out: - json_data_out = json.load(out) - with open(ref_dir / ("maps_roi_" + name) / "maps.json", "r") as ref: - json_data_ref = json.load(ref) - - # TODO : uncomment when CI data are correct - # ref_source_dir = Path(json_data_ref["transfer_path"]).parent - # json_data_ref["transfer_path"] = str( - # tmp_out_dir / Path(json_data_ref["transfer_path"]).relative_to(ref_source_dir) - # ) - # if cmdopt["no-gpu"] or cmdopt["adapt-base-dir"]: - # json_data_ref = modify_maps( - # maps=json_data_ref, - # base_dir=base_dir, - # no_gpu=cmdopt["no-gpu"], - # adapt_base_dir=cmdopt["adapt-base-dir"], - # ) - # TODO: remove and update data - json_data_ref["caps_directory"] = json_data_out["caps_directory"] - json_data_ref["gpu"] = json_data_out["gpu"] - json_data_ref["transfer_path"] = json_data_out["transfer_path"] - json_data_ref["tsv_path"] = json_data_out["tsv_path"] - ### - assert json_data_out == json_data_ref # ["mode"] == mode - - assert compare_folders( - tmp_target_dir / "groups", - ref_dir / ("maps_roi_" + name) / "groups", - tmp_path, - ) - assert compare_folders( - tmp_target_dir / "split-0" / "best-loss", - ref_dir / ("maps_roi_" + name) / "split-0" / "best-loss", - tmp_path, - ) diff --git a/tests/test_tsvtools.py b/tests/test_tsvtools.py deleted file mode 100644 index f77a333ac..000000000 --- a/tests/test_tsvtools.py +++ /dev/null @@ -1,284 +0,0 @@ -import os -import shutil -from pathlib import Path - -import pandas as pd -import pytest - -from clinicadl.tsvtools.tsvtools_utils import extract_baseline -from tests.testing_tools import compare_folders - -""" -Check the absence of data leakage - 1) Baseline datasets contain only one scan per subject - 2) No intersection between train and test sets - 3) Absence of MCI train subjects in test sets of subcategories of MCI -""" - - -@pytest.fixture( - params=[ - "test_getlabels", - "test_split", - "test_analysis", - "test_get_progression", - "test_prepare_experiment", - "test_get_metadata", - ] -) -def test_name(request): - return request.param - - -def test_tsvtools(cmdopt, tmp_path, test_name): - if test_name == "test_getlabels": - run_test_getlabels(cmdopt, tmp_path) - elif test_name == "test_split": - run_test_split(cmdopt, tmp_path) - elif test_name == "test_analysis": - run_test_analysis(cmdopt, tmp_path) - elif test_name == "test_prepare_experiment": - run_test_prepare_experiment(cmdopt, tmp_path) - elif test_name == "test_get_progression": - run_test_get_progression(cmdopt, tmp_path) - elif test_name == "test_get_metadata": - run_test_get_metadata(cmdopt, tmp_path) - - -def check_is_subject_unique(labels_path_baseline: Path): - print("Check subject uniqueness", labels_path_baseline) - - flag_is_unique = True - check_df = pd.read_csv(labels_path_baseline, sep="\t") - check_df.set_index(["participant_id", "session_id"], inplace=True) - if labels_path_baseline.name[-12:] != "baseline.tsv": - check_df = extract_baseline(check_df, set_index=False) - for _, subject_df in check_df.groupby(level=0): - if len(subject_df) > 1: - flag_is_unique = False - assert flag_is_unique - - -def check_is_independant( - train_path_baseline: Path, test_path_baseline: Path, subject_flag=True -): - print("Check independence") - - flag_is_independant = True - train_df = pd.read_csv(train_path_baseline, sep="\t") - train_df.set_index(["participant_id", "session_id"], inplace=True) - test_df = pd.read_csv(test_path_baseline, sep="\t") - test_df.set_index(["participant_id", "session_id"], inplace=True) - - for subject, session in train_df.index: - if (subject, session) in test_df.index: - flag_is_independant = False - - assert flag_is_independant - - -def run_test_suite(data_tsv, n_splits): - check_train = True - if n_splits == 0: - train_baseline_tsv = data_tsv / "train_baseline.tsv" - test_baseline_tsv = data_tsv / "test_baseline.tsv" - if not train_baseline_tsv.exists(): - check_train = False - - check_is_subject_unique(test_baseline_tsv) - if check_train: - check_is_subject_unique(train_baseline_tsv) - check_is_independant(train_baseline_tsv, test_baseline_tsv) - - else: - for split_number in range(n_splits): - for folder, _, files in os.walk(data_tsv / "split"): - folder = Path(folder) - - for file in files: - if file[-3:] == "tsv": - check_is_subject_unique(folder / file) - train_baseline_tsv = folder / "train_baseline.tsv" - test_baseline_tsv = folder / "test_baseline.tsv" - if train_baseline_tsv.exists(): - if test_baseline_tsv.exists(): - check_is_independant(train_baseline_tsv, test_baseline_tsv) - - -def run_test_getlabels(cmdopt, tmp_path): - """Checks that getlabels is working and that it is coherent with - previous version in reference_path.""" - - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "tsvtools" / "in" - ref_dir = base_dir / "tsvtools" / "ref" - tmp_out_dir = tmp_path / "tsvtools" / "out" - tmp_out_dir.mkdir(parents=True) - - import shutil - - bids_output = tmp_out_dir / "bids" - bids_directory = input_dir / "bids" - restrict_tsv = input_dir / "restrict.tsv" - output_tsv = tmp_out_dir - if tmp_out_dir.exists(): - shutil.rmtree(tmp_out_dir) - tmp_out_dir.mkdir(parents=True) - shutil.copytree(bids_directory, bids_output) - merged_tsv = input_dir / "merge-tsv.tsv" - missing_mods_directory = input_dir / "missing_mods" - - flag_getlabels = not os.system( - f"clinicadl -vvv tsvtools get-labels {str(bids_output)} {str(output_tsv)} " - f"-d AD -d MCI -d CN -d Dementia " - f"--merged_tsv {str(merged_tsv)} --missing_mods {str(missing_mods_directory)} " - f"--restriction_tsv {str(restrict_tsv)}" - ) - assert flag_getlabels - - out_df = pd.read_csv(tmp_out_dir / "labels.tsv", sep="\t") - ref_df = pd.read_csv(ref_dir / "labels.tsv", sep="\t") - assert out_df.equals(ref_df) - - -def run_test_split(cmdopt, tmp_path): - """Checks that: - - split and kfold are working - - the loading functions can find the output - - no data leakage is introduced in split and kfold. - """ - - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "tsvtools" / "in" - ref_dir = base_dir / "tsvtools" / "ref" - tmp_out_dir = tmp_path / "tsvtools" / "out" - tmp_out_dir.mkdir(parents=True) - - n_test = 10 - n_splits = 2 - train_tsv = tmp_out_dir / "split/train.tsv" - labels_tsv = tmp_out_dir / "labels.tsv" - shutil.copyfile(input_dir / "labels.tsv", labels_tsv) - - flag_split = not os.system( - f"clinicadl -vvv tsvtools split {str(labels_tsv)} --subset_name test --n_test {n_test}" - ) - flag_getmetadata = not os.system( - f"clinicadl -vvv tsvtools get-metadata {str(train_tsv)} {str(labels_tsv)} -voi age -voi sex -voi diagnosis" - ) - flag_kfold = not os.system( - f"clinicadl -vvv tsvtools kfold {str(train_tsv)} --n_splits {n_splits} --subset_name validation" - ) - assert flag_split - assert flag_getmetadata - assert flag_kfold - - assert compare_folders(tmp_out_dir / "split", ref_dir / "split", tmp_out_dir) - - run_test_suite(tmp_out_dir, n_splits) - - -def run_test_analysis(cmdopt, tmp_path): - """Checks that analysis can be performed.""" - - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "tsvtools" / "in" - ref_dir = base_dir / "tsvtools" / "ref" - tmp_out_dir = tmp_path / "tsvtools" / "out" - tmp_out_dir.mkdir(parents=True) - - merged_tsv = input_dir / "merge-tsv.tsv" - labels_tsv = input_dir / "labels.tsv" - output_tsv = tmp_out_dir / "analysis.tsv" - ref_analysis_tsv = ref_dir / "analysis.tsv" - - flag_analysis = not os.system( - f"clinicadl tsvtools analysis {str(merged_tsv)} {str(labels_tsv)} {str(output_tsv)} " - f"--diagnoses CN --diagnoses MCI --diagnoses Dementia" - ) - - assert flag_analysis - ref_df = pd.read_csv(ref_analysis_tsv, sep="\t") - out_df = pd.read_csv(output_tsv, sep="\t") - assert out_df.equals(ref_df) - - -def run_test_get_progression(cmdopt, tmp_path): - """Checks that get-progression can be performed""" - - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "tsvtools" / "in" - ref_dir = base_dir / "tsvtools" / "ref" - tmp_out_dir = tmp_path / "tsvtools" / "out" - tmp_out_dir.mkdir(parents=True) - - input_progression_tsv = input_dir / "labels.tsv" - progression_tsv = tmp_out_dir / "progression.tsv" - ref_progression_tsv = ref_dir / "progression.tsv" - shutil.copyfile(input_progression_tsv, progression_tsv) - - flag_get_progression = not os.system( - f"clinicadl tsvtools get-progression {str(progression_tsv)} " - ) - assert flag_get_progression - - ref_df = pd.read_csv(ref_progression_tsv, sep="\t") - out_df = pd.read_csv(progression_tsv, sep="\t") - assert out_df.equals(ref_df) - - -def run_test_prepare_experiment(cmdopt, tmp_path): - """Checks that: - - split and kfold are working - - the loading functions can find the output - - no data leakage is introduced in split and kfold. - """ - - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "tsvtools" / "in" - ref_dir = base_dir / "tsvtools" / "ref" - tmp_out_dir = tmp_path / "tsvtools" / "out" - tmp_out_dir.mkdir(parents=True) - - labels_tsv = tmp_out_dir / "labels.tsv" - shutil.copyfile(input_dir / "labels.tsv", labels_tsv) - - validation_type = "kfold" - n_valid = 2 - n_test = 10 - flag_prepare_experiment = not os.system( - f"clinicadl -vvv tsvtools prepare-experiment {str(labels_tsv)} --n_test {n_test} --validation_type {validation_type} --n_validation {n_valid}" - ) - - assert flag_prepare_experiment - - assert compare_folders(tmp_out_dir / "split", ref_dir / "split", tmp_out_dir) - - run_test_suite(tmp_out_dir, n_valid) - - -def run_test_get_metadata(cmdopt, tmp_path): - """Checks that get-metadata can be performed""" - base_dir = Path(cmdopt["input"]) - input_dir = base_dir / "tsvtools" / "in" - ref_dir = base_dir / "tsvtools" / "ref" - tmp_out_dir = tmp_path / "tsvtools" / "out" - tmp_out_dir.mkdir(parents=True) - - input_metadata_tsv = input_dir / "restrict.tsv" - metadata_tsv = tmp_out_dir / "metadata.tsv" - input_labels_tsv = input_dir / "labels.tsv" - labels_tsv = tmp_out_dir / "labels.tsv" - ref_metadata_tsv = ref_dir / "metadata.tsv" - - shutil.copyfile(input_metadata_tsv, metadata_tsv) - shutil.copyfile(input_labels_tsv, labels_tsv) - - flag_get_metadata = not os.system( - f"clinicadl tsvtools get-metadata {str(metadata_tsv)} {str(labels_tsv)} -voi diagnosis -voi sex -voi age" - ) - assert flag_get_metadata - - ref_df = pd.read_csv(ref_metadata_tsv, sep="\t") - out_df = pd.read_csv(metadata_tsv, sep="\t") - assert out_df.equals(ref_df) diff --git a/tests/testing_tools.py b/tests/testing_tools.py deleted file mode 100644 index 4044d1022..000000000 --- a/tests/testing_tools.py +++ /dev/null @@ -1,255 +0,0 @@ -import pathlib -from os import PathLike -from pathlib import Path -from typing import Any, Dict, List - - -def ignore_pattern(file_path: pathlib.Path, ignore_pattern_list: List[str]) -> bool: - if not ignore_pattern_list: - return False - - for pattern in ignore_pattern_list: - if pattern in file_path.__str__(): - return True - return False - - -def create_hashes_dict( - path_folder: pathlib.Path, ignore_pattern_list: List[str] = None -) -> Dict[str, str]: - """ - Computes a dictionary of files with their corresponding hashes - - Args: - path_folder: starting point for the tree listing. - ignore_pattern_list: list of patterns to be ignored to create hash dictionary. - - Returns: - all_files: a dictionary of the form {/path/to/file.extension: hash(file.extension)} - """ - import hashlib - - def file_as_bytes(input_file): - with input_file: - return input_file.read() - - all_files = [] - for file in path_folder.rglob("*"): - if not ignore_pattern(file, ignore_pattern_list) and file.is_file(): - all_files.append(file) - - dict_hashes = { - fname.relative_to(path_folder).__str__(): str( - hashlib.md5(file_as_bytes(open(fname, "rb"))).digest() - ) - for fname in all_files - } - return dict_hashes - - -def compare_folders_with_hashes( - path_folder: pathlib.Path, - hashes_dict: Dict[str, str], - ignore_pattern_list: List[str] = None, -): - """ - Compares the files of a folder against a reference - - Args: - path_folder: starting point for the tree listing. - hashes_dict: a dictionary of the form {/path/to/file.extension: hash(file.extension)} - ignore_pattern_list: list of patterns to be ignored to create hash dictionary. - """ - hashes_new = create_hashes_dict(path_folder, ignore_pattern_list) - - if hashes_dict != hashes_new: - error_message1 = "" - error_message2 = "" - for key in hashes_dict: - if key not in hashes_new: - error_message1 += "{0} not found !\n".format(key) - elif hashes_dict[key] != hashes_new[key]: - error_message2 += "{0} does not match the reference file !\n".format( - key - ) - raise ValueError(error_message1 + error_message2) - - -def models_equal(state_dict_1, state_dict_2, epsilon=0): - import torch - - for key_item_1, key_item_2 in zip(state_dict_1.items(), state_dict_2.items()): - if torch.mean(torch.abs(key_item_1[1] - key_item_2[1])) > epsilon: - print(f"Not equivalent: {key_item_1[0]} != {key_item_2[0]}") - return False - return True - - -def tree(dir_: PathLike, file_out: PathLike): - """Creates a file (file_out) with a visual tree representing the file - hierarchy at a given directory - - .. note:: - Does not display empty directories. - - """ - from pathlib import Path - - if not dir_.is_dir(): - raise FileNotFoundError(f"No directory found at {dir_}.") - - file_content = "" - - for path in sorted(Path(dir_).rglob("*")): - if path.is_dir() and not any(path.iterdir()): - continue - depth = len(path.relative_to(dir_).parts) - spacer = " " * depth - file_content = file_content + f"{spacer}+ {path.name}\n" - - Path(file_out).write_text(file_content) - - -def compare_folders(outdir: PathLike, refdir: PathLike, tmp_path: PathLike) -> bool: - """ - Compares the file hierarchy of two folders. - - Args: - outdir: path to the first fodler. - refdir: path to the second folder. - tmp_path: path to a temporary folder. - """ - - from filecmp import cmp - from pathlib import PurePath - - file_out = PurePath(tmp_path) / "file_out.txt" - file_ref = PurePath(tmp_path) / "file_ref.txt" - tree(outdir, file_out) - tree(refdir, file_ref) - if not cmp(file_out, file_ref): - with open(file_out, "r") as fin: - out_message = fin.read() - with open(file_ref, "r") as fin: - ref_message = fin.read() - raise ValueError( - "Comparison of out and ref directories shows mismatch :\n " - "OUT :\n" + out_message + "\n REF :\n" + ref_message - ) - return True - - -def compare_folder_with_files(folder: str, file_list: List[str]) -> bool: - """Compare file existing in two folders - Args: - folder: path to a folder - file_list: list of files which must be found in folder - Returns: - True if files in file_list were all found in folder. - """ - import os - - folder_list = [] - for root, dirs, files in os.walk(folder): - folder_list.extend(files) - - print(f"Missing files {set(file_list) - set(folder_list)}") - return set(file_list).issubset(set(folder_list)) - - -def clean_folder(path, recreate=True): - from os import makedirs - from os.path import abspath, exists - from shutil import rmtree - - abs_path = abspath(path) - if exists(abs_path): - rmtree(abs_path) - if recreate: - makedirs(abs_path) - - -def modify_maps( - maps: Dict[str, Any], - base_dir: Path, - no_gpu: bool = False, - adapt_base_dir: bool = False, -) -> Dict[str, Any]: - """ - Modifies a MAPS dictionary if the user passed --no-gpu or --adapt-base-dir flags. - - Parameters - ---------- - maps : Dict[str, Any] - The MAPS dictionary. - base_dir : Path - The base directory, where CI data are stored. - no_gpu : bool (optional, default=False) - Whether the user activated the --no-gpu flag. - adapt_base_dir : bool (optional, default=False) - Whether the user activated the --adapt-base-dir flag. - - Returns - ------- - Dict[str, Any] - The modified MAPS dictionary. - """ - if no_gpu: - maps["gpu"] = False - if adapt_base_dir: - base_dir = base_dir.resolve() - ref_base_dir = Path(maps["caps_directory"]).parents[2] - maps["caps_directory"] = str( - base_dir / Path(maps["caps_directory"]).relative_to(ref_base_dir) - ) - try: - maps["tsv_path"] = str( - base_dir / Path(maps["tsv_path"]).relative_to(ref_base_dir) - ) - except KeyError: # maps with only caps directory - pass - return maps - - -def modify_toml( - toml_path: Path, - base_dir: Path, - no_gpu: bool = False, - adapt_base_dir: bool = False, -) -> None: - """ - Modifies a TOML file if the user passed --no-gpu or --adapt-base-dir flags. - - Parameters - ---------- - toml_path : Path - The path of the TOML file. - base_dir : Path - The base directory, where CI data are stored. - no_gpu : bool (optional, default=False) - Whether the user activated the --no-gpu flag. - adapt_base_dir : bool (optional, default=False) - Whether the user activated the --adapt-base-dir flag. - """ - import toml - - config = toml.load(toml_path) - if no_gpu: - try: - config["Computational"]["gpu"] = False - except KeyError: - config["Computational"] = {"gpu": False} - if adapt_base_dir: - random_search_config = config["Random_Search"] - base_dir = base_dir.resolve() - ref_base_dir = Path(random_search_config["caps_directory"]).parents[2] - random_search_config["caps_directory"] = str( - base_dir - / Path(random_search_config["caps_directory"]).relative_to(ref_base_dir) - ) - random_search_config["tsv_path"] = str( - base_dir / Path(random_search_config["tsv_path"]).relative_to(ref_base_dir) - ) - f = open(toml_path, "w") - toml.dump(config, f) - f.close() diff --git a/tests/unittests/monai_metrics/config/test_classification.py b/tests/unittests/monai_metrics/config/test_classification.py index e4192a254..16941099d 100644 --- a/tests/unittests/monai_metrics/config/test_classification.py +++ b/tests/unittests/monai_metrics/config/test_classification.py @@ -1,11 +1,11 @@ import pytest from pydantic import ValidationError -from clinicadl.monai_metrics.config.classification import ( +from clinicadl.metrics.config.classification import ( ROCAUCConfig, create_confusion_matrix_config, ) -from clinicadl.monai_metrics.config.enum import ConfusionMatrixMetric +from clinicadl.metrics.config.enum import ConfusionMatrixMetric # ROCAUC diff --git a/tests/unittests/monai_metrics/config/test_factory.py b/tests/unittests/monai_metrics/config/test_factory.py index 5eed6f459..97be7453d 100644 --- a/tests/unittests/monai_metrics/config/test_factory.py +++ b/tests/unittests/monai_metrics/config/test_factory.py @@ -1,6 +1,6 @@ import pytest -from clinicadl.monai_metrics.config import ImplementedMetrics, create_metric_config +from clinicadl.metrics.config import ImplementedMetrics, create_metric_config def test_create_training_config(): diff --git a/tests/unittests/monai_metrics/config/test_generation.py b/tests/unittests/monai_metrics/config/test_generation.py index 4e1691567..1c55fb221 100644 --- a/tests/unittests/monai_metrics/config/test_generation.py +++ b/tests/unittests/monai_metrics/config/test_generation.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from clinicadl.monai_metrics.config.generation import MMDMetricConfig +from clinicadl.metrics.config.generation import MMDMetricConfig def test_fails_validation(): diff --git a/tests/unittests/monai_metrics/config/test_reconstruction.py b/tests/unittests/monai_metrics/config/test_reconstruction.py index 521c2717e..4ade04f5c 100644 --- a/tests/unittests/monai_metrics/config/test_reconstruction.py +++ b/tests/unittests/monai_metrics/config/test_reconstruction.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from clinicadl.monai_metrics.config.reconstruction import ( +from clinicadl.metrics.config.reconstruction import ( MultiScaleSSIMConfig, PSNRConfig, SSIMConfig, diff --git a/tests/unittests/monai_metrics/config/test_regression.py b/tests/unittests/monai_metrics/config/test_regression.py index f95f20b6a..7c4407e30 100644 --- a/tests/unittests/monai_metrics/config/test_regression.py +++ b/tests/unittests/monai_metrics/config/test_regression.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from clinicadl.monai_metrics.config.regression import ( +from clinicadl.metrics.config.regression import ( MAEConfig, MSEConfig, RMSEConfig, diff --git a/tests/unittests/monai_metrics/config/test_segmentation.py b/tests/unittests/monai_metrics/config/test_segmentation.py index 537f289c9..52fa8c501 100644 --- a/tests/unittests/monai_metrics/config/test_segmentation.py +++ b/tests/unittests/monai_metrics/config/test_segmentation.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from clinicadl.monai_metrics.config.segmentation import ( +from clinicadl.metrics.config.segmentation import ( DiceConfig, GeneralizedDiceConfig, HausdorffDistanceConfig, diff --git a/tests/unittests/monai_metrics/test_factory.py b/tests/unittests/monai_metrics/test_factory.py index 5d265e416..3896c6bbc 100644 --- a/tests/unittests/monai_metrics/test_factory.py +++ b/tests/unittests/monai_metrics/test_factory.py @@ -6,8 +6,8 @@ def test_get_metric(): from monai.metrics import SSIMMetric - from clinicadl.monai_metrics import get_metric - from clinicadl.monai_metrics.config import ImplementedMetrics, create_metric_config + from clinicadl.metrics import get_metric + from clinicadl.metrics.config import ImplementedMetrics, create_metric_config for metric_name in [e.value for e in ImplementedMetrics if e != "Loss"]: if ( @@ -53,7 +53,7 @@ def loss_fn_bis(y_pred: Tensor) -> Tensor: def test_loss_to_metric(): from torch import randn - from clinicadl.monai_metrics import loss_to_metric + from clinicadl.metrics import loss_to_metric y_pred = randn(10, 5, 5) y_true = randn(10, 5, 5) diff --git a/tests/unittests/monai_networks/config/__init__.py b/tests/unittests/monai_networks/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unittests/monai_networks/config/test_autoencoder.py b/tests/unittests/monai_networks/config/test_autoencoder.py deleted file mode 100644 index 707695434..000000000 --- a/tests/unittests/monai_networks/config/test_autoencoder.py +++ /dev/null @@ -1,171 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.autoencoder import ( - AutoEncoderConfig, - VarAutoEncoderConfig, -) - - -@pytest.fixture -def dummy_arguments(): - args = { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "channels": [2, 4], - "latent_size": 16, - } - return args - - -@pytest.fixture( - params=[ - {"in_shape": (1, 10, 10), "strides": (1, 1), "dropout": 1.1}, - {"in_shape": (1, 10, 10), "strides": (1, 1), "kernel_size": 4}, - {"in_shape": (1, 10, 10), "strides": (1, 1), "kernel_size": (3,)}, - {"in_shape": (1, 10, 10), "strides": (1, 1), "kernel_size": (3, 3, 3)}, - {"in_shape": (1, 10, 10), "strides": (1, 1), "up_kernel_size": 4}, - {"in_shape": (1, 10, 10), "strides": (1, 1), "up_kernel_size": (3,)}, - {"in_shape": (1, 10, 10), "strides": (1, 1), "up_kernel_size": (3, 3, 3)}, - { - "in_shape": (1, 10, 10), - "strides": (1, 1), - "inter_channels": (2, 2), - "inter_dilations": (2,), - }, - {"in_shape": (1, 10, 10), "strides": (1, 1), "inter_dilations": (2, 2)}, - {"in_shape": (1, 10, 10), "strides": (1, 1), "padding": (1, 1, 1)}, - {"in_shape": (1, 10, 10), "strides": (1, 2, 3)}, - {"in_shape": (1, 10, 10), "strides": (1, (1, 2, 3))}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -@pytest.fixture( - params=[ - {"in_shape": (1,), "strides": (1, 1)}, - {"in_shape": (1, 10), "strides": (1, 1)}, - ] -) -def bad_inputs_vae(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - AutoEncoderConfig(**bad_inputs) - with pytest.raises(ValidationError): - VarAutoEncoderConfig(**bad_inputs) - - -def test_fails_validations_vae(bad_inputs_vae): - with pytest.raises(ValidationError): - VarAutoEncoderConfig(**bad_inputs_vae) - - -@pytest.fixture( - params=[ - { - "in_shape": (1, 10, 10), - "strides": (1, 1), - "dropout": 0.5, - "kernel_size": 5, - "inter_channels": (2, 2), - "inter_dilations": (3, 3), - "padding": (2, 2), - }, - { - "in_shape": (1, 10, 10), - "strides": ((1, 2), 1), - "kernel_size": (3, 3), - "padding": 2, - "up_kernel_size": 5, - }, - ] -) -def good_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_passes_validations(good_inputs): - AutoEncoderConfig(**good_inputs) - VarAutoEncoderConfig(**good_inputs) - - -def test_AutoEncoderConfig(): - config = AutoEncoderConfig( - spatial_dims=2, - in_channels=1, - out_channels=1, - channels=[2, 4], - strides=[1, 1], - kernel_size=(3, 5), - up_kernel_size=(3, 3), - num_res_units=1, - inter_channels=(2, 2), - inter_dilations=(3, 3), - num_inter_units=1, - norm=("BATCh", {"eps": 0.1}), - dropout=0.1, - bias=False, - padding=1, - ) - assert config.network == "AutoEncoder" - assert config.spatial_dims == 2 - assert config.in_channels == 1 - assert config.out_channels == 1 - assert config.channels == (2, 4) - assert config.strides == (1, 1) - assert config.kernel_size == (3, 5) - assert config.num_res_units == 1 - assert config.inter_channels == (2, 2) - assert config.inter_dilations == (3, 3) - assert config.num_inter_units == 1 - assert config.norm == ("batch", {"eps": 0.1}) - assert config.act == "DefaultFromLibrary" - assert config.dropout == 0.1 - assert not config.bias - assert config.padding == 1 - - -def test_VarAutoEncoderConfig(): - config = VarAutoEncoderConfig( - spatial_dims=2, - in_shape=(1, 10, 10), - out_channels=1, - latent_size=16, - channels=[2, 4], - strides=[1, 1], - kernel_size=(3, 5), - up_kernel_size=(3, 3), - num_res_units=1, - inter_channels=(2, 2), - inter_dilations=(3, 3), - num_inter_units=1, - norm=("BATCh", {"eps": 0.1}), - dropout=0.1, - bias=False, - padding=1, - use_sigmoid=False, - ) - assert config.network == "VarAutoEncoder" - assert config.spatial_dims == 2 - assert config.in_shape == (1, 10, 10) - assert config.out_channels == 1 - assert config.latent_size == 16 - assert config.channels == (2, 4) - assert config.strides == (1, 1) - assert config.kernel_size == (3, 5) - assert config.num_res_units == 1 - assert config.inter_channels == (2, 2) - assert config.inter_dilations == (3, 3) - assert config.num_inter_units == 1 - assert config.norm == ("batch", {"eps": 0.1}) - assert config.act == "DefaultFromLibrary" - assert config.dropout == 0.1 - assert not config.bias - assert config.padding == 1 - assert not config.use_sigmoid diff --git a/tests/unittests/monai_networks/config/test_classifier.py b/tests/unittests/monai_networks/config/test_classifier.py deleted file mode 100644 index f63b774d5..000000000 --- a/tests/unittests/monai_networks/config/test_classifier.py +++ /dev/null @@ -1,132 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.classifier import ( - ClassifierConfig, - CriticConfig, - DiscriminatorConfig, -) - - -@pytest.fixture -def dummy_arguments(): - args = { - "classes": 2, - "channels": [2, 4], - } - return args - - -@pytest.fixture( - params=[ - {"in_shape": (3,), "strides": (1, 1)}, - {"in_shape": (1, 3, 3), "strides": (1, 1), "dropout": 1.1}, - {"in_shape": (1, 3, 3), "strides": (1, 1), "kernel_size": 4}, - {"in_shape": (1, 3, 3), "strides": (1, 1), "kernel_size": (3,)}, - {"in_shape": (1, 3, 3), "strides": (1, 1), "kernel_size": (3, 3, 3)}, - {"in_shape": (1, 3, 3), "strides": (1, 2, 3)}, - {"in_shape": (1, 3, 3), "strides": (1, (1, 2, 3))}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - ClassifierConfig(**bad_inputs) - with pytest.raises(ValidationError): - CriticConfig(**bad_inputs) - with pytest.raises(ValidationError): - DiscriminatorConfig(**bad_inputs) - - -@pytest.fixture( - params=[ - {"in_shape": (1, 3, 3), "strides": (1, 1), "dropout": 0.5, "kernel_size": 5}, - {"in_shape": (1, 3, 3), "strides": ((1, 2), 1), "kernel_size": (3, 3)}, - ] -) -def good_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_passes_validations(good_inputs): - ClassifierConfig(**good_inputs) - CriticConfig(**good_inputs) - DiscriminatorConfig(**good_inputs) - - -def test_ClassifierConfig(): - config = ClassifierConfig( - in_shape=(1, 3, 3), - classes=2, - channels=[2, 4], - strides=[1, 1], - kernel_size=(3, 5), - num_res_units=1, - act=("ELU", {"alpha": 2.0}), - dropout=0.1, - bias=False, - last_act=None, - ) - assert config.network == "Classifier" - assert config.in_shape == (1, 3, 3) - assert config.classes == 2 - assert config.channels == (2, 4) - assert config.strides == (1, 1) - assert config.kernel_size == (3, 5) - assert config.num_res_units == 1 - assert config.act == ("elu", {"alpha": 2.0}) - assert config.norm == "DefaultFromLibrary" - assert config.dropout == 0.1 - assert not config.bias - assert config.last_act is None - - -def test_CriticConfig(): - config = CriticConfig( - in_shape=(1, 3, 3), - channels=[2, 4], - strides=[1, 1], - kernel_size=(3, 5), - num_res_units=1, - act=("ELU", {"alpha": 2.0}), - dropout=0.1, - bias=False, - ) - assert config.network == "Critic" - assert config.in_shape == (1, 3, 3) - assert config.channels == (2, 4) - assert config.strides == (1, 1) - assert config.kernel_size == (3, 5) - assert config.num_res_units == 1 - assert config.act == ("elu", {"alpha": 2.0}) - assert config.norm == "DefaultFromLibrary" - assert config.dropout == 0.1 - assert not config.bias - - -def test_DiscriminatorConfig(): - config = DiscriminatorConfig( - in_shape=(1, 3, 3), - channels=[2, 4], - strides=[1, 1], - kernel_size=(3, 5), - num_res_units=1, - act=("ELU", {"alpha": 2.0}), - dropout=0.1, - bias=False, - last_act=("eLu", {"alpha": 0.5}), - ) - assert config.network == "Discriminator" - assert config.in_shape == (1, 3, 3) - assert config.channels == (2, 4) - assert config.strides == (1, 1) - assert config.kernel_size == (3, 5) - assert config.num_res_units == 1 - assert config.act == ("elu", {"alpha": 2.0}) - assert config.norm == "DefaultFromLibrary" - assert config.dropout == 0.1 - assert not config.bias - assert config.last_act == ("elu", {"alpha": 0.5}) diff --git a/tests/unittests/monai_networks/config/test_config.py b/tests/unittests/monai_networks/config/test_config.py new file mode 100644 index 000000000..f4ef7c65e --- /dev/null +++ b/tests/unittests/monai_networks/config/test_config.py @@ -0,0 +1,232 @@ +import pytest + +from clinicadl.networks.config.densenet import ( + DenseNet121Config, + DenseNet161Config, + DenseNet169Config, + DenseNet201Config, +) +from clinicadl.networks.config.resnet import ( + ResNet18Config, + ResNet34Config, + ResNet50Config, + ResNet101Config, + ResNet152Config, +) +from clinicadl.networks.config.senet import ( + SEResNet50Config, + SEResNet101Config, + SEResNet152Config, +) +from clinicadl.networks.config.vit import ( + ViTB16Config, + ViTB32Config, + ViTL16Config, + ViTL32Config, +) + + +@pytest.mark.parametrize( + "config_class", + [DenseNet121Config, DenseNet161Config, DenseNet169Config, DenseNet201Config], +) +def test_sota_densenet_config(config_class): + config = config_class(pretrained=True, num_outputs=None) + + assert config.num_outputs is None + assert config.pretrained + assert config.output_act == "DefaultFromLibrary" + assert config._type == "sota-DenseNet" + + +@pytest.mark.parametrize( + "config_class", + [ResNet18Config, ResNet34Config, ResNet50Config, ResNet101Config, ResNet152Config], +) +def test_sota_resnet_config(config_class): + config = config_class(pretrained=False, num_outputs=None) + + assert config.num_outputs is None + assert not config.pretrained + assert config.output_act == "DefaultFromLibrary" + assert config._type == "sota-ResNet" + + +@pytest.mark.parametrize( + "config_class", [SEResNet50Config, SEResNet101Config, SEResNet152Config] +) +def test_sota_senet_config(config_class): + config = config_class(output_act="relu", num_outputs=1) + + assert config.num_outputs == 1 + assert config.pretrained == "DefaultFromLibrary" + assert config.output_act == "relu" + assert config._type == "sota-SEResNet" + + +@pytest.mark.parametrize( + "config_class", [ViTB16Config, ViTB32Config, ViTL16Config, ViTL32Config] +) +def test_sota_vit_config(config_class): + config = config_class(output_act="relu", num_outputs=1) + + assert config.num_outputs == 1 + assert config.pretrained == "DefaultFromLibrary" + assert config.output_act == "relu" + assert config._type == "sota-ViT" + + +def test_autoencoder_config(): + from clinicadl.networks.config.autoencoder import AutoEncoderConfig + + config = AutoEncoderConfig( + in_shape=(1, 10, 10), + latent_size=1, + conv_args={"channels": [1]}, + output_act="softmax", + ) + assert config.in_shape == (1, 10, 10) + assert config.conv_args.channels == [1] + assert config.output_act == "softmax" + assert config.out_channels == "DefaultFromLibrary" + + +def test_vae_config(): + from clinicadl.networks.config.autoencoder import VAEConfig + + config = VAEConfig( + in_shape=(1, 10), + latent_size=1, + conv_args={"channels": [1], "adn_ordering": "NA"}, + output_act=("elu", {"alpha": 0.1}), + ) + assert config.in_shape == (1, 10) + assert config.conv_args.adn_ordering == "NA" + assert config.output_act == ("elu", {"alpha": 0.1}) + assert config.mlp_args == "DefaultFromLibrary" + + +def test_cnn_config(): + from clinicadl.networks.config.cnn import CNNConfig + + config = CNNConfig( + in_shape=(2, 10, 10, 10), num_outputs=1, conv_args={"channels": [1]} + ) + assert config.in_shape == (2, 10, 10, 10) + assert config.conv_args.channels == [1] + assert config.mlp_args == "DefaultFromLibrary" + + +def test_generator_config(): + from clinicadl.networks.config.generator import GeneratorConfig + + config = GeneratorConfig( + start_shape=(2, 10, 10), latent_size=2, conv_args={"channels": [1]} + ) + assert config.start_shape == (2, 10, 10) + assert config.conv_args.channels == [1] + assert config.mlp_args == "DefaultFromLibrary" + + +def test_conv_decoder_config(): + from clinicadl.networks.config.conv_decoder import ConvDecoderConfig + + config = ConvDecoderConfig( + in_channels=1, spatial_dims=2, channels=[1, 2], kernel_size=(3, 4) + ) + assert config.in_channels == 1 + assert config.kernel_size == (3, 4) + assert config.stride == "DefaultFromLibrary" + + +def test_conv_encoder_config(): + from clinicadl.networks.config.conv_encoder import ConvEncoderConfig + + config = ConvEncoderConfig( + in_channels=1, spatial_dims=2, channels=[1, 2], kernel_size=[(3, 4), (4, 5)] + ) + assert config.in_channels == 1 + assert config.kernel_size == [(3, 4), (4, 5)] + assert config.padding == "DefaultFromLibrary" + + +def test_mlp_config(): + from clinicadl.networks.config.mlp import MLPConfig + + config = MLPConfig( + in_channels=1, out_channels=1, hidden_channels=[2, 3], dropout=0.1 + ) + assert config.in_channels == 1 + assert config.dropout == 0.1 + assert config.act == "DefaultFromLibrary" + + +def test_resnet_config(): + from clinicadl.networks.config.resnet import ResNetConfig + + config = ResNetConfig( + spatial_dims=1, in_channels=1, num_outputs=None, block_type="bottleneck" + ) + assert config.num_outputs is None + assert config.block_type == "bottleneck" + assert config.bottleneck_reduction == "DefaultFromLibrary" + + +def test_seresnet_config(): + from clinicadl.networks.config.senet import SEResNetConfig + + config = SEResNetConfig( + spatial_dims=1, + in_channels=1, + num_outputs=None, + block_type="bottleneck", + se_reduction=2, + ) + assert config.num_outputs is None + assert config.block_type == "bottleneck" + assert config.se_reduction == 2 + assert config.bottleneck_reduction == "DefaultFromLibrary" + + +def test_densenet_config(): + from clinicadl.networks.config.densenet import DenseNetConfig + + config = DenseNetConfig( + spatial_dims=1, in_channels=1, num_outputs=2, n_dense_layers=(1, 2) + ) + assert config.num_outputs == 2 + assert config.n_dense_layers == (1, 2) + assert config.growth_rate == "DefaultFromLibrary" + + +def test_vit_config(): + from clinicadl.networks.config.vit import ViTConfig + + config = ViTConfig(in_shape=(1, 10), patch_size=2, num_outputs=1, embedding_dim=42) + assert config.num_outputs == 1 + assert config.embedding_dim == 42 + assert config.mlp_dim == "DefaultFromLibrary" + + +def test_unet_config(): + from clinicadl.networks.config.unet import UNetConfig + + config = UNetConfig(spatial_dims=1, in_channels=1, out_channels=1, channels=(4, 8)) + assert config.out_channels == 1 + assert config.channels == (4, 8) + assert config.output_act == "DefaultFromLibrary" + + +def test_att_unet_config(): + from clinicadl.networks.config.unet import AttentionUNetConfig + + config = AttentionUNetConfig( + spatial_dims=1, + in_channels=1, + out_channels=1, + channels=(4, 8), + output_act="softmax", + ) + assert config.spatial_dims == 1 + assert config.output_act == "softmax" + assert config.dropout == "DefaultFromLibrary" diff --git a/tests/unittests/monai_networks/config/test_densenet.py b/tests/unittests/monai_networks/config/test_densenet.py deleted file mode 100644 index a18b86f09..000000000 --- a/tests/unittests/monai_networks/config/test_densenet.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.densenet import DenseNetConfig - - -@pytest.fixture -def dummy_arguments(): - args = { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - } - return args - - -def test_fails_validations(dummy_arguments): - with pytest.raises(ValidationError): - DenseNetConfig(**{**dummy_arguments, **{"dropout_prob": 1.1}}) - - -def test_passes_validations(dummy_arguments): - DenseNetConfig(**{**dummy_arguments, **{"dropout_prob": 0.1}}) - - -def test_DenseNetConfig(): - config = DenseNetConfig( - spatial_dims=2, - in_channels=1, - out_channels=1, - init_features=16, - growth_rate=2, - block_config=(3, 5), - bn_size=1, - norm=("batch", {"eps": 0.5}), - dropout_prob=0.1, - ) - assert config.network == "DenseNet" - assert config.spatial_dims == 2 - assert config.in_channels == 1 - assert config.out_channels == 1 - assert config.init_features == 16 - assert config.growth_rate == 2 - assert config.block_config == (3, 5) - assert config.bn_size == 1 - assert config.norm == ("batch", {"eps": 0.5}) - assert config.act == "DefaultFromLibrary" - assert config.dropout_prob diff --git a/tests/unittests/monai_networks/config/test_factory.py b/tests/unittests/monai_networks/config/test_factory.py index 07c96e2a9..3f91b52e1 100644 --- a/tests/unittests/monai_networks/config/test_factory.py +++ b/tests/unittests/monai_networks/config/test_factory.py @@ -1,4 +1,4 @@ -from clinicadl.monai_networks.config import ImplementedNetworks, create_network_config +from clinicadl.networks.config import ImplementedNetworks, create_network_config def test_create_training_config(): @@ -9,9 +9,9 @@ def test_create_training_config(): config = config_class( spatial_dims=1, in_channels=2, - out_channels=3, + num_outputs=None, ) - assert config.network == "DenseNet" + assert config.name == "DenseNet" assert config.spatial_dims == 1 assert config.in_channels == 2 - assert config.out_channels == 3 + assert config.num_outputs is None diff --git a/tests/unittests/monai_networks/config/test_fcn.py b/tests/unittests/monai_networks/config/test_fcn.py deleted file mode 100644 index b7991368e..000000000 --- a/tests/unittests/monai_networks/config/test_fcn.py +++ /dev/null @@ -1,97 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.fcn import ( - FullyConnectedNetConfig, - VarFullyConnectedNetConfig, -) - - -@pytest.fixture -def dummy_arguments(): - args = { - "in_channels": 5, - "out_channels": 1, - "hidden_channels": [3, 2], - "latent_size": 16, - "encode_channels": [2, 3], - "decode_channels": [3, 2], - } - return args - - -@pytest.fixture( - params=[ - {"dropout": 1.1}, - {"adn_ordering": "NDB"}, - {"adn_ordering": "NND"}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - FullyConnectedNetConfig(**bad_inputs) - with pytest.raises(ValidationError): - VarFullyConnectedNetConfig(**bad_inputs) - - -@pytest.fixture( - params=[ - {"dropout": 0.5, "adn_ordering": "DAN"}, - {"adn_ordering": "AN"}, - ] -) -def good_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_passes_validations(good_inputs): - FullyConnectedNetConfig(**good_inputs) - VarFullyConnectedNetConfig(**good_inputs) - - -def test_FullyConnectedNetConfig(): - config = FullyConnectedNetConfig( - in_channels=5, - out_channels=1, - hidden_channels=[3, 2], - dropout=None, - act="prelu", - bias=False, - adn_ordering="ADN", - ) - assert config.network == "FullyConnectedNet" - assert config.in_channels == 5 - assert config.out_channels == 1 - assert config.hidden_channels == (3, 2) - assert config.dropout is None - assert config.act == "prelu" - assert not config.bias - assert config.adn_ordering == "ADN" - - -def test_VarFullyConnectedNetConfig(): - config = VarFullyConnectedNetConfig( - in_channels=5, - out_channels=1, - latent_size=16, - encode_channels=[2, 3], - decode_channels=[3, 2], - dropout=0.1, - act="prelu", - bias=False, - adn_ordering="ADN", - ) - assert config.network == "VarFullyConnectedNet" - assert config.in_channels == 5 - assert config.out_channels == 1 - assert config.latent_size == 16 - assert config.encode_channels == (2, 3) - assert config.decode_channels == (3, 2) - assert config.dropout == 0.1 - assert config.act == "prelu" - assert not config.bias - assert config.adn_ordering == "ADN" diff --git a/tests/unittests/monai_networks/config/test_generator.py b/tests/unittests/monai_networks/config/test_generator.py deleted file mode 100644 index 9ea1cd442..000000000 --- a/tests/unittests/monai_networks/config/test_generator.py +++ /dev/null @@ -1,72 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.generator import GeneratorConfig - - -@pytest.fixture -def dummy_arguments(): - args = {"latent_shape": (5,), "channels": (2, 4)} - return args - - -@pytest.fixture( - params=[ - {"start_shape": (3,), "strides": (1, 1)}, - {"start_shape": (1, 3), "strides": (1, 1), "dropout": 1.1}, - {"start_shape": (1, 3), "strides": (1, 1), "kernel_size": 4}, - {"start_shape": (1, 3), "strides": (1, 1), "kernel_size": (3, 3)}, - {"start_shape": (1, 3), "strides": (1, 2, 3)}, - {"start_shape": (1, 3), "strides": (1, (1, 2))}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - GeneratorConfig(**bad_inputs) - - -@pytest.fixture( - params=[ - {"start_shape": (1, 3), "strides": (1, 1), "dropout": 0.5, "kernel_size": 5}, - { - "start_shape": (1, 3, 3, 3), - "strides": ((1, 2, 3), 1), - "kernel_size": (3, 3, 3), - }, - ] -) -def good_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_passes_validations(good_inputs): - GeneratorConfig(**good_inputs) - - -def test_GeneratorConfig(): - config = GeneratorConfig( - latent_shape=(3,), - start_shape=(1, 3), - channels=[2, 4], - strides=[1, 1], - kernel_size=(3,), - num_res_units=1, - act="SIGMOID", - dropout=0.1, - bias=False, - ) - assert config.network == "Generator" - assert config.latent_shape == (3,) - assert config.start_shape == (1, 3) - assert config.channels == (2, 4) - assert config.strides == (1, 1) - assert config.kernel_size == (3,) - assert config.num_res_units == 1 - assert config.act == "sigmoid" - assert config.norm == "DefaultFromLibrary" - assert config.dropout == 0.1 - assert not config.bias diff --git a/tests/unittests/monai_networks/config/test_regressor.py b/tests/unittests/monai_networks/config/test_regressor.py deleted file mode 100644 index 920464cc2..000000000 --- a/tests/unittests/monai_networks/config/test_regressor.py +++ /dev/null @@ -1,72 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.regressor import RegressorConfig - - -@pytest.fixture -def dummy_arguments(): - args = { - "out_shape": (1,), - "channels": [2, 4], - } - return args - - -@pytest.fixture( - params=[ - {"in_shape": (3,), "strides": (1, 1)}, - {"in_shape": (1, 3, 3), "strides": (1, 1), "dropout": 1.1}, - {"in_shape": (1, 3, 3), "strides": (1, 1), "kernel_size": 4}, - {"in_shape": (1, 3, 3), "strides": (1, 1), "kernel_size": (3,)}, - {"in_shape": (1, 3, 3), "strides": (1, 1), "kernel_size": (3, 3, 3)}, - {"in_shape": (1, 3, 3), "strides": (1, 2, 3)}, - {"in_shape": (1, 3, 3), "strides": (1, (1, 2, 3))}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - RegressorConfig(**bad_inputs) - - -@pytest.fixture( - params=[ - {"in_shape": (1, 3, 3), "strides": (1, 1), "dropout": 0.5, "kernel_size": 5}, - {"in_shape": (1, 3, 3), "strides": ((1, 2), 1), "kernel_size": (3, 3)}, - ] -) -def good_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_passes_validations(good_inputs): - RegressorConfig(**good_inputs) - - -def test_RegressorConfig(): - config = RegressorConfig( - in_shape=(1, 3, 3), - out_shape=(1,), - channels=[2, 4], - strides=[1, 1], - kernel_size=(3, 5), - num_res_units=1, - act=("ELU", {"alpha": 2.0}), - dropout=0.1, - bias=False, - ) - assert config.network == "Regressor" - assert config.in_shape == (1, 3, 3) - assert config.out_shape == (1,) - assert config.channels == (2, 4) - assert config.strides == (1, 1) - assert config.kernel_size == (3, 5) - assert config.num_res_units == 1 - assert config.act == ("elu", {"alpha": 2.0}) - assert config.norm == "DefaultFromLibrary" - assert config.dropout == 0.1 - assert not config.bias diff --git a/tests/unittests/monai_networks/config/test_resnet.py b/tests/unittests/monai_networks/config/test_resnet.py deleted file mode 100644 index b238a3c93..000000000 --- a/tests/unittests/monai_networks/config/test_resnet.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.resnet import ResNetConfig - - -@pytest.fixture -def dummy_arguments(): - args = { - "block": "basic", - "layers": (2, 2, 2, 2), - } - return args - - -@pytest.fixture( - params=[ - {"block_inplanes": (2, 4, 8)}, - {"block_inplanes": (2, 4, 8, 16), "conv1_t_size": (3, 3)}, - {"block_inplanes": (2, 4, 8, 16), "conv1_t_stride": (3, 3)}, - {"block_inplanes": (2, 4, 8, 16), "shortcut_type": "C"}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - ResNetConfig(**bad_inputs) - - -@pytest.fixture( - params=[ - { - "block_inplanes": (2, 4, 8, 16), - "conv1_t_size": (3, 3, 3), - "conv1_t_stride": (3, 3, 3), - "shortcut_type": "B", - }, - {"block_inplanes": (2, 4, 8, 16), "conv1_t_size": 3, "conv1_t_stride": 3}, - ] -) -def good_inputs(request: pytest.FixtureRequest, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_passes_validations(good_inputs): - ResNetConfig(**good_inputs) - - -def test_ResNetConfig(): - config = ResNetConfig( - block="bottleneck", - layers=(2, 2, 2, 2), - block_inplanes=(2, 4, 8, 16), - spatial_dims=3, - n_input_channels=3, - conv1_t_size=3, - conv1_t_stride=4, - no_max_pool=True, - shortcut_type="A", - widen_factor=0.8, - num_classes=3, - feed_forward=False, - bias_downsample=False, - act=("relu", {"inplace": False}), - ) - assert config.network == "ResNet" - assert config.block == "bottleneck" - assert config.layers == (2, 2, 2, 2) - assert config.block_inplanes == (2, 4, 8, 16) - assert config.spatial_dims == 3 - assert config.n_input_channels == 3 - assert config.conv1_t_size == 3 - assert config.conv1_t_stride == 4 - assert config.no_max_pool - assert config.shortcut_type == "A" - assert config.widen_factor == 0.8 - assert config.num_classes == 3 - assert not config.feed_forward - assert not config.bias_downsample - assert config.act == ("relu", {"inplace": False}) diff --git a/tests/unittests/monai_networks/config/test_resnet_features.py b/tests/unittests/monai_networks/config/test_resnet_features.py deleted file mode 100644 index 9f6131974..000000000 --- a/tests/unittests/monai_networks/config/test_resnet_features.py +++ /dev/null @@ -1,56 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.resnet import ResNetFeaturesConfig - - -@pytest.fixture( - params=[ - {"model_name": "abc"}, - {"model_name": "resnet18", "pretrained": True, "spatial_dims": 2}, - {"model_name": "resnet18", "pretrained": True, "in_channels": 2}, - { - "model_name": "resnet18", - "in_channels": 2, - }, # pretrained should be set to False - {"model_name": "resnet18", "spatial_dims": 2}, - ] -) -def bad_inputs(request: pytest.FixtureRequest): - return request.param - - -def test_fails_validations(bad_inputs: dict): - with pytest.raises(ValidationError): - ResNetFeaturesConfig(**bad_inputs) - - -@pytest.fixture( - params=[ - {"model_name": "resnet18", "pretrained": True, "spatial_dims": 3}, - {"model_name": "resnet18", "pretrained": True, "in_channels": 1}, - {"model_name": "resnet18", "pretrained": True}, - {"model_name": "resnet18", "spatial_dims": 3}, - {"model_name": "resnet18", "in_channels": 1}, - ] -) -def good_inputs(request: pytest.FixtureRequest): - return {**request.param} - - -def test_passes_validations(good_inputs: dict): - ResNetFeaturesConfig(**good_inputs) - - -def test_ResNetFeaturesConfig(): - config = ResNetFeaturesConfig( - model_name="resnet200", - pretrained=False, - spatial_dims=2, - in_channels=2, - ) - assert config.network == "ResNetFeatures" - assert config.model_name == "resnet200" - assert not config.pretrained - assert config.spatial_dims == 2 - assert config.in_channels == 2 diff --git a/tests/unittests/monai_networks/config/test_segresnet.py b/tests/unittests/monai_networks/config/test_segresnet.py deleted file mode 100644 index 44b946d49..000000000 --- a/tests/unittests/monai_networks/config/test_segresnet.py +++ /dev/null @@ -1,41 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.resnet import SegResNetConfig - - -def test_fails_validations(): - with pytest.raises(ValidationError): - SegResNetConfig(dropout_prob=1.1) - - -def test_passes_validations(): - SegResNetConfig(dropout_prob=0.5) - - -def test_SegResNetConfig(): - config = SegResNetConfig( - spatial_dims=2, - init_filters=3, - in_channels=1, - out_channels=1, - dropout_prob=0.1, - act=("ELU", {"inplace": False}), - norm=("group", {"num_groups": 4}), - use_conv_final=False, - blocks_down=[1, 2, 3], - blocks_up=[3, 2, 1], - upsample_mode="pixelshuffle", - ) - assert config.network == "SegResNet" - assert config.spatial_dims == 2 - assert config.init_filters == 3 - assert config.in_channels == 1 - assert config.out_channels == 1 - assert config.dropout_prob == 0.1 - assert config.act == ("elu", {"inplace": False}) - assert config.norm == ("group", {"num_groups": 4}) - assert not config.use_conv_final - assert config.blocks_down == (1, 2, 3) - assert config.blocks_up == (3, 2, 1) - assert config.upsample_mode == "pixelshuffle" diff --git a/tests/unittests/monai_networks/config/test_unet.py b/tests/unittests/monai_networks/config/test_unet.py deleted file mode 100644 index d331e0a14..000000000 --- a/tests/unittests/monai_networks/config/test_unet.py +++ /dev/null @@ -1,133 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.unet import AttentionUnetConfig, UNetConfig - - -@pytest.fixture -def dummy_arguments(): - args = { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - } - return args - - -@pytest.fixture( - params=[ - {"strides": (1, 1), "channels": (2, 4, 8), "adn_ordering": "NDB"}, - {"strides": (1, 1), "channels": (2, 4, 8), "adn_ordering": "NND"}, - {"strides": (1, 1), "channels": (2, 4, 8), "dropout": 1.1}, - {"strides": (1, 1), "channels": (2, 4, 8), "kernel_size": 4}, - {"strides": (1, 1), "channels": (2, 4, 8), "kernel_size": (3,)}, - {"strides": (1, 1), "channels": (2, 4, 8), "kernel_size": (3, 3, 3)}, - {"strides": (1, 1), "channels": (2, 4, 8), "up_kernel_size": 4}, - {"strides": (1, 1), "channels": (2, 4, 8), "up_kernel_size": (3,)}, - {"strides": (1, 1), "channels": (2, 4, 8), "up_kernel_size": (3, 3, 3)}, - {"strides": (1, 2, 3), "channels": (2, 4, 8)}, - {"strides": (1, (1, 2, 3)), "channels": (2, 4, 8)}, - {"strides": (), "channels": (2,)}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - UNetConfig(**bad_inputs) - with pytest.raises(ValidationError): - AttentionUnetConfig(**bad_inputs) - - -@pytest.fixture( - params=[ - { - "strides": (1, 1), - "channels": (2, 4, 8), - "adn_ordering": "DAN", - "dropout": 0.5, - "kernel_size": 5, - "up_kernel_size": 5, - }, - { - "strides": ((1, 2),), - "channels": (2, 4), - "adn_ordering": "AN", - "kernel_size": (3, 5), - "up_kernel_size": (3, 5), - }, - ] -) -def good_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_passes_validations(good_inputs): - UNetConfig(**good_inputs) - AttentionUnetConfig(**good_inputs) - - -def test_UNetConfig(): - config = UNetConfig( - spatial_dims=2, - in_channels=1, - out_channels=1, - channels=[2, 4], - strides=[1], - kernel_size=(3, 5), - up_kernel_size=(3, 3), - num_res_units=1, - act="ElU", - norm=("BATCh", {"eps": 0.1}), - dropout=0.1, - bias=False, - adn_ordering="A", - ) - assert config.network == "UNet" - assert config.spatial_dims == 2 - assert config.in_channels == 1 - assert config.out_channels == 1 - assert config.channels == (2, 4) - assert config.strides == (1,) - assert config.kernel_size == (3, 5) - assert config.up_kernel_size == (3, 3) - assert config.num_res_units == 1 - assert config.act == "elu" - assert config.norm == ("batch", {"eps": 0.1}) - assert config.dropout == 0.1 - assert not config.bias - assert config.adn_ordering == "A" - - -def test_AttentionUnetConfig(): - config = AttentionUnetConfig( - spatial_dims=2, - in_channels=1, - out_channels=1, - channels=[2, 4], - strides=[1], - kernel_size=(3, 5), - up_kernel_size=(3, 3), - num_res_units=1, - act="ElU", - norm="inSTance", - dropout=0.1, - bias=False, - adn_ordering="DA", - ) - assert config.network == "AttentionUnet" - assert config.spatial_dims == 2 - assert config.in_channels == 1 - assert config.out_channels == 1 - assert config.channels == (2, 4) - assert config.strides == (1,) - assert config.kernel_size == (3, 5) - assert config.up_kernel_size == (3, 3) - assert config.num_res_units == 1 - assert config.act == "elu" - assert config.norm == "instance" - assert config.dropout == 0.1 - assert not config.bias - assert config.adn_ordering == "DA" diff --git a/tests/unittests/monai_networks/config/test_vit.py b/tests/unittests/monai_networks/config/test_vit.py deleted file mode 100644 index 737caf05e..000000000 --- a/tests/unittests/monai_networks/config/test_vit.py +++ /dev/null @@ -1,162 +0,0 @@ -import pytest -from pydantic import ValidationError - -from clinicadl.monai_networks.config.vit import ( - ViTAutoEncConfig, - ViTConfig, -) - - -@pytest.fixture -def dummy_arguments(): - args = { - "in_channels": 2, - } - return args - - -@pytest.fixture( - params=[ - {"img_size": (16, 16, 16), "patch_size": (4, 4, 4), "dropout_rate": 1.1}, - {"img_size": (16, 16), "patch_size": 4}, - {"img_size": 16, "patch_size": (4, 4)}, - {"img_size": 16, "patch_size": (4, 4)}, - { - "img_size": (16, 16, 16), - "patch_size": (4, 4, 4), - "hidden_size": 42, - "num_heads": 5, - }, - {"img_size": (16, 16, 16), "patch_size": (4, 4, 4), "num_heads": 5}, - {"img_size": (16, 16, 16), "patch_size": (4, 4, 4), "hidden_size": 42}, - ] -) -def bad_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -@pytest.fixture( - params=[ - {"img_size": (20, 20, 20), "patch_size": (4, 4, 5)}, - {"img_size": (20, 20, 20), "patch_size": (4, 4, 9)}, - ] -) -def bad_inputs_ae(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_fails_validations(bad_inputs): - with pytest.raises(ValidationError): - ViTConfig(**bad_inputs) - with pytest.raises(ValidationError): - ViTAutoEncConfig(**bad_inputs) - - -def test_fails_validations_ae(bad_inputs_ae): - with pytest.raises(ValidationError): - ViTAutoEncConfig(**bad_inputs_ae) - - -@pytest.fixture( - params=[ - { - "img_size": (16, 16, 16), - "patch_size": (4, 4, 4), - "dropout_rate": 0.5, - "hidden_size": 42, - "num_heads": 6, - }, - ] -) -def good_inputs(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -@pytest.fixture( - params=[ - {"img_size": 10, "patch_size": 3}, - ] -) -def good_inputs_vit(request, dummy_arguments): - return {**dummy_arguments, **request.param} - - -def test_passes_validations(good_inputs): - ViTConfig(**good_inputs) - ViTAutoEncConfig(**good_inputs) - - -def test_passes_validations_vit(good_inputs_vit): - ViTConfig(**good_inputs_vit) - - -def test_ViTConfig(): - config = ViTConfig( - in_channels=2, - img_size=16, - patch_size=4, - hidden_size=32, - mlp_dim=4, - num_layers=3, - num_heads=4, - proj_type="perceptron", - pos_embed_type="sincos", - classification=True, - num_classes=3, - dropout_rate=0.1, - spatial_dims=3, - post_activation=None, - qkv_bias=True, - ) - assert config.network == "ViT" - assert config.in_channels == 2 - assert config.img_size == 16 - assert config.patch_size == 4 - assert config.hidden_size == 32 - assert config.mlp_dim == 4 - assert config.num_layers == 3 - assert config.num_heads == 4 - assert config.proj_type == "perceptron" - assert config.pos_embed_type == "sincos" - assert config.classification - assert config.num_classes == 3 - assert config.dropout_rate == 0.1 - assert config.spatial_dims == 3 - assert config.post_activation is None - assert config.qkv_bias - assert config.save_attn == "DefaultFromLibrary" - - -def test_ViTAutoEncConfig(): - config = ViTAutoEncConfig( - in_channels=2, - img_size=16, - patch_size=4, - out_channels=2, - deconv_chns=7, - hidden_size=32, - mlp_dim=4, - num_layers=3, - num_heads=4, - proj_type="perceptron", - pos_embed_type="sincos", - dropout_rate=0.1, - spatial_dims=3, - qkv_bias=True, - ) - assert config.network == "ViTAutoEnc" - assert config.in_channels == 2 - assert config.img_size == 16 - assert config.patch_size == 4 - assert config.out_channels == 2 - assert config.deconv_chns == 7 - assert config.hidden_size == 32 - assert config.mlp_dim == 4 - assert config.num_layers == 3 - assert config.num_heads == 4 - assert config.proj_type == "perceptron" - assert config.pos_embed_type == "sincos" - assert config.dropout_rate == 0.1 - assert config.spatial_dims == 3 - assert config.qkv_bias - assert config.save_attn == "DefaultFromLibrary" diff --git a/tests/unittests/monai_networks/nn/__init__.py b/tests/unittests/monai_networks/nn/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unittests/monai_networks/nn/test_att_unet.py b/tests/unittests/monai_networks/nn/test_att_unet.py new file mode 100644 index 000000000..6f5786828 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_att_unet.py @@ -0,0 +1,134 @@ +import pytest +import torch + +from clinicadl.networks.nn import AttentionUNet +from clinicadl.networks.nn.layers.utils import ActFunction + +INPUT_1D = torch.randn(2, 1, 16) +INPUT_2D = torch.randn(2, 2, 32, 64) +INPUT_3D = torch.randn(2, 3, 16, 32, 8) + + +@pytest.mark.parametrize( + "input_tensor,out_channels,channels,act,output_act,dropout,error", + [ + (INPUT_1D, 1, (2, 3, 4), "relu", "sigmoid", None, False), + (INPUT_2D, 1, (2, 4, 5), "relu", None, 0.0, False), + (INPUT_3D, 2, (2, 3), None, ("softmax", {"dim": 1}), 0.1, False), + ( + INPUT_3D, + 2, + (2,), + None, + ("softmax", {"dim": 1}), + 0.1, + True, + ), # channels length is less than 2 + ], +) +def test_attentionunet( + input_tensor, out_channels, channels, act, output_act, dropout, error +): + batch_size, in_channels, *img_size = input_tensor.shape + spatial_dims = len(img_size) + if error: + with pytest.raises(ValueError): + AttentionUNet( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + channels=channels, + act=act, + output_act=output_act, + dropout=dropout, + ) + else: + net = AttentionUNet( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + channels=channels, + act=act, + output_act=output_act, + dropout=dropout, + ) + + out = net(input_tensor) + assert out.shape == (batch_size, out_channels, *img_size) + + if output_act: + assert net.output_act is not None + else: + assert net.output_act is None + + assert net.doubleconv[1].conv.out_channels == channels[0] + if dropout: + assert net.doubleconv[1].adn.D.p == dropout + else: + with pytest.raises(AttributeError): + net.doubleconv[1].conv.adn.D + + for i in range(1, len(channels)): + down = getattr(net, f"down{i}").doubleconv + up = getattr(net, f"doubleconv{i}") + att = getattr(net, f"attention{i}") + assert down[0].conv.in_channels == channels[i - 1] + assert down[1].conv.out_channels == channels[i] + assert att.W_g[0].out_channels == channels[i - 1] // 2 + assert att.W_x[0].out_channels == channels[i - 1] // 2 + assert up[0].conv.in_channels == channels[i - 1] * 2 + assert up[1].conv.out_channels == channels[i - 1] + for m in (down, up): + if dropout is not None: + assert m[1].adn.D.p == dropout + else: + with pytest.raises(AttributeError): + m[1].adn.D + with pytest.raises(AttributeError): + down = getattr(net, f"down{i+1}") + with pytest.raises(AttributeError): + getattr(net, f"doubleconv{i+1}") + with pytest.raises(AttributeError): + getattr(net, f"attention{i+1}") + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(act): + batch_size, in_channels, *img_size = INPUT_2D.shape + net = AttentionUNet( + spatial_dims=2, + in_channels=in_channels, + out_channels=2, + channels=(2, 4), + act=act, + output_act=act, + ) + assert net(INPUT_2D).shape == (batch_size, 2, *img_size) + + +def test_activation_parameters(): + in_channels = INPUT_2D.shape[1] + act = ("ELU", {"alpha": 0.1}) + output_act = ("ELU", {"alpha": 0.2}) + net = AttentionUNet( + spatial_dims=2, + in_channels=in_channels, + out_channels=2, + channels=(2, 4), + act=act, + output_act=output_act, + ) + assert isinstance(net.doubleconv[0].adn.A, torch.nn.ELU) + assert net.doubleconv[0].adn.A.alpha == 0.1 + + assert isinstance(net.down1.doubleconv[0].adn.A, torch.nn.ELU) + assert net.down1.doubleconv[0].adn.A.alpha == 0.1 + + assert isinstance(net.upsample1[1].adn.A, torch.nn.ELU) + assert net.upsample1[1].adn.A.alpha == 0.1 + + assert isinstance(net.doubleconv1[1].adn.A, torch.nn.ELU) + assert net.doubleconv1[1].adn.A.alpha == 0.1 + + assert isinstance(net.output_act, torch.nn.ELU) + assert net.output_act.alpha == 0.2 diff --git a/tests/unittests/monai_networks/nn/test_autoencoder.py b/tests/unittests/monai_networks/nn/test_autoencoder.py new file mode 100644 index 000000000..c3d54b458 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_autoencoder.py @@ -0,0 +1,215 @@ +import pytest +import torch +from torch.nn import GELU, Sigmoid, Tanh + +from clinicadl.networks.nn import AutoEncoder +from clinicadl.networks.nn.layers.utils import ActFunction + + +@pytest.mark.parametrize( + "input_tensor,kernel_size,stride,padding,dilation,pooling,pooling_indices,unpooling_mode", + [ + (torch.randn(2, 1, 21), 3, 1, 0, 1, ("max", {"kernel_size": 2}), [0], "linear"), + ( + torch.randn(2, 1, 65, 85), + (3, 5), + (1, 2), + 0, + (1, 2), + ("max", {"kernel_size": 2, "stride": 1}), + [0], + "bilinear", + ), + ( + torch.randn(2, 1, 64, 62, 61), # to test output padding + 4, + 2, + (1, 1, 0), + 1, + ("avg", {"kernel_size": 3, "stride": 2}), + [-1], + "convtranspose", + ), + ( + torch.randn(2, 1, 51, 55, 45), + 3, + 2, + 0, + 1, + ("max", {"kernel_size": 2, "ceil_mode": True}), + [0, 1, 2], + "convtranspose", + ), + ( + torch.randn(2, 1, 51, 55, 45), + 3, + 2, + 0, + 1, + [ + ("max", {"kernel_size": 2, "ceil_mode": True}), + ("adaptivemax", {"output_size": (5, 4, 3)}), + ], + [-1, 1], + "convtranspose", + ), + ], +) +def test_output_shape( + input_tensor, + kernel_size, + stride, + padding, + dilation, + pooling, + pooling_indices, + unpooling_mode, +): + net = AutoEncoder( + in_shape=input_tensor.shape[1:], + latent_size=3, + conv_args={ + "channels": [2, 4, 8], + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + "pooling": pooling, + "pooling_indices": pooling_indices, + }, + unpooling_mode=unpooling_mode, + ) + output = net(input_tensor) + assert output.shape == input_tensor.shape + + +def test_out_channels(): + input_tensor = torch.randn(2, 1, 64, 62, 61) + net = AutoEncoder( + in_shape=input_tensor.shape[1:], + latent_size=3, + conv_args={"channels": [2, 4, 8]}, + mlp_args={"hidden_channels": [8, 4]}, + out_channels=3, + ) + assert net(input_tensor).shape == (2, 3, 64, 62, 61) + assert net.decoder.convolutions.layer2.conv.in_channels == 2 + assert net.decoder.convolutions.layer2.conv.out_channels == 3 + + +@pytest.mark.parametrize( + "pooling,unpooling_mode", + [ + (("adaptivemax", {"output_size": (17, 16, 19)}), "nearest"), + (("adaptivemax", {"output_size": (17, 16, 19)}), "convtranspose"), + (("max", {"kernel_size": 2}), "nearest"), + (("max", {"kernel_size": 2}), "convtranspose"), + ( + ("max", {"kernel_size": 2, "stride": 1, "dilation": 2, "padding": 1}), + "nearest", + ), + ( + ("max", {"kernel_size": 2, "stride": 1, "dilation": 2, "padding": 1}), + "convtranspose", + ), + (("avg", {"kernel_size": 3, "ceil_mode": True}), "nearest"), + (("avg", {"kernel_size": 3, "ceil_mode": True}), "convtranspose"), + ], +) +def test_invert_pooling(pooling, unpooling_mode): + input_tensor = torch.randn(2, 1, 20, 27, 22) + net = AutoEncoder( + in_shape=(1, 20, 27, 22), + latent_size=1, + conv_args={"channels": [], "pooling": pooling, "pooling_indices": [-1]}, + mlp_args=None, + unpooling_mode=unpooling_mode, + ) + output = net(input_tensor) + assert output.shape == input_tensor.shape + + +@pytest.mark.parametrize( + "kernel_size,stride,padding,dilation", + [ + ((3, 2, 1), (1, 1, 2), (1, 1, 0), 1), + ((4, 5, 2), (3, 1, 1), (0, 0, 1), (2, 1, 1)), + ], +) +def test_invert_conv(kernel_size, stride, padding, dilation): + input_tensor = torch.randn(2, 1, 20, 27, 22) + net = AutoEncoder( + in_shape=(1, 20, 27, 22), + latent_size=1, + conv_args={ + "channels": [1], + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + }, + mlp_args=None, + ) + output = net(input_tensor) + assert output.shape == input_tensor.shape + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_out_activation(act): + input_tensor = torch.randn(2, 1, 32, 32) + net = AutoEncoder( + in_shape=(1, 32, 32), + latent_size=3, + conv_args={"channels": [2]}, + output_act=act, + ) + assert net(input_tensor).shape == (2, 1, 32, 32) + + +def test_params(): + net = AutoEncoder( + in_shape=(1, 100, 100), + latent_size=3, + conv_args={"channels": [2], "act": "celu", "output_act": "sigmoid"}, + mlp_args={"hidden_channels": [2], "act": "relu", "output_act": "gelu"}, + output_act="tanh", + out_channels=2, + ) + assert net.encoder.convolutions.act == "celu" + assert net.decoder.convolutions.act == "celu" + assert net.encoder.mlp.act == "relu" + assert net.decoder.mlp.act == "relu" + assert isinstance(net.encoder.mlp.output.output_act, GELU) + assert isinstance(net.encoder.mlp.output.output_act, GELU) + assert isinstance(net.encoder.convolutions.output_act, Sigmoid) + assert isinstance(net.decoder.convolutions.output_act, Tanh) + + +@pytest.mark.parametrize( + "in_shape,upsampling_mode,error", + [ + ((1, 10), "bilinear", True), + ((1, 10, 10), "linear", True), + ((1, 10, 10), "trilinear", True), + ((1, 10, 10, 10), "bicubic", True), + ((1, 10), "linear", False), + ((1, 10, 10), "bilinear", False), + ((1, 10, 10, 10), "trilinear", False), + ], +) +def test_checks(in_shape, upsampling_mode, error): + if error: + with pytest.raises(ValueError): + AutoEncoder( + in_shape=in_shape, + latent_size=3, + conv_args={"channels": []}, + unpooling_mode=upsampling_mode, + ) + else: + AutoEncoder( + in_shape=in_shape, + latent_size=3, + conv_args={"channels": []}, + unpooling_mode=upsampling_mode, + ) diff --git a/tests/unittests/monai_networks/nn/test_cnn.py b/tests/unittests/monai_networks/nn/test_cnn.py new file mode 100644 index 000000000..a1c2d5585 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_cnn.py @@ -0,0 +1,62 @@ +import pytest +import torch +from torch.nn import Flatten, Linear, Softmax + +from clinicadl.networks.nn import CNN, MLP, ConvEncoder + +INPUT_1D = torch.randn(3, 1, 16) +INPUT_2D = torch.randn(3, 1, 15, 16) +INPUT_3D = torch.randn(3, 3, 20, 21, 22) + + +@pytest.mark.parametrize("input_tensor", [INPUT_1D, INPUT_2D, INPUT_3D]) +@pytest.mark.parametrize("channels", [(), (2, 4)]) +@pytest.mark.parametrize( + "mlp_args", [None, {"hidden_channels": []}, {"hidden_channels": (2, 4)}] +) +def test_cnn(input_tensor, channels, mlp_args): + in_shape = input_tensor.shape[1:] + net = CNN( + in_shape=in_shape, + num_outputs=2, + conv_args={"channels": channels}, + mlp_args=mlp_args, + ) + output = net(input_tensor) + assert output.shape == (3, 2) + assert isinstance(net.convolutions, ConvEncoder) + assert isinstance(net.mlp, MLP) + + if mlp_args is None or mlp_args["hidden_channels"] == []: + children = net.mlp.children() + assert isinstance(next(children), Flatten) + assert isinstance(next(children).linear, Linear) + with pytest.raises(StopIteration): + next(children) + + if channels == []: + with pytest.raises(StopIteration): + next(net.convolutions.parameters()) + + +@pytest.mark.parametrize( + "conv_args,mlp_args", + [ + (None, {"hidden_channels": [2]}), + ({"channels": [2]}, {}), + ], +) +def test_checks(conv_args, mlp_args): + with pytest.raises(ValueError): + CNN(in_shape=(1, 10, 10), num_outputs=2, conv_args=conv_args, mlp_args=mlp_args) + + +def test_params(): + conv_args = {"channels": [2], "act": "celu"} + mlp_args = {"hidden_channels": [2], "act": "relu", "output_act": "softmax"} + net = CNN( + in_shape=(1, 10, 10), num_outputs=2, conv_args=conv_args, mlp_args=mlp_args + ) + assert net.convolutions.act == "celu" + assert net.mlp.act == "relu" + assert isinstance(net.mlp.output.output_act, Softmax) diff --git a/tests/unittests/monai_networks/nn/test_conv_decoder.py b/tests/unittests/monai_networks/nn/test_conv_decoder.py new file mode 100644 index 000000000..44b0a76c2 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_conv_decoder.py @@ -0,0 +1,407 @@ +import pytest +import torch +from torch.nn import ELU, ConvTranspose2d, Dropout, InstanceNorm2d, Upsample + +from clinicadl.networks.nn import ConvDecoder +from clinicadl.networks.nn.layers.utils import ActFunction + + +@pytest.fixture +def input_tensor(): + return torch.randn(2, 1, 8, 8) + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(input_tensor, act): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + act=act, + output_act=act, + ) + output_shape = net(input_tensor).shape + return len(output_shape) == 4 and output_shape[1] == 1 + + +@pytest.mark.parametrize( + "kernel_size,stride,padding,output_padding,dilation,unpooling,unpooling_indices,norm,dropout,bias,adn_ordering", + [ + ( + 3, + 2, + 0, + 1, + 1, + ("upsample", {"scale_factor": 2}), + [2], + "batch", + None, + True, + "ADN", + ), + ( + (4, 4), + (2, 1), + 2, + (1, 0), + 2, + ("upsample", {"scale_factor": 2}), + [0, 1], + "instance", + 0.5, + False, + "DAN", + ), + ( + 5, + 1, + (2, 1), + 0, + 1, + [("upsample", {"size": (16, 16)}), ("convtranspose", {"kernel_size": 2})], + [0, 1], + "syncbatch", + 0.5, + True, + "NA", + ), + (5, 1, 0, 1, (2, 3), None, [0, 1], "instance", 0.0, True, "DN"), + ( + 5, + 1, + 2, + 0, + 1, + ("convtranspose", {"kernel_size": 2}), + None, + ("group", {"num_groups": 2}), + None, + True, + "N", + ), + ( + 5, + 3, + 2, + (2, 1), + 1, + ("convtranspose", {"kernel_size": 2}), + [0, 1], + None, + None, + True, + "", + ), + ], +) +def test_params( + input_tensor, + kernel_size, + stride, + padding, + output_padding, + dilation, + unpooling, + unpooling_indices, + norm, + dropout, + bias, + adn_ordering, +): + batch_size, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + # test size computation + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + dilation=dilation, + unpooling=unpooling, + unpooling_indices=unpooling_indices, + dropout=dropout, + act=None, + norm=norm, + bias=bias, + adn_ordering=adn_ordering, + _input_size=input_size, + ) + output = net(input_tensor) + assert output.shape == (batch_size, 1, *net.final_size) + + # other checks + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + dilation=dilation, + unpooling=unpooling, + unpooling_indices=unpooling_indices, + dropout=dropout, + act=None, + norm=norm, + bias=bias, + adn_ordering=adn_ordering, + ) + assert isinstance(net.layer2[0], ConvTranspose2d) + with pytest.raises(IndexError): + net.layer2[1] # no adn at the end + + named_layers = list(net.named_children()) + if unpooling and unpooling_indices and unpooling_indices != []: + for i, idx in enumerate(unpooling_indices): + name, layer = named_layers[idx + 1 + i] + if idx == -1: + assert name == "init_unpool" + else: + assert name == f"unpool{idx}" + if net.unpooling[i][0] == "upsample": + assert isinstance(layer, Upsample) + else: + assert isinstance(layer, ConvTranspose2d) + else: + for name, layer in named_layers: + assert not isinstance(layer, Upsample) + assert "unpool" not in name + + assert ( + net.layer0[0].kernel_size == kernel_size + if isinstance(kernel_size, tuple) + else (kernel_size, kernel_size) + ) + assert ( + net.layer0[0].stride == stride + if isinstance(stride, tuple) + else (stride, stride) + ) + assert ( + net.layer0[0].padding == padding + if isinstance(padding, tuple) + else (padding, padding) + ) + assert ( + net.layer0[0].output_padding == output_padding + if isinstance(output_padding, tuple) + else (output_padding, output_padding) + ) + assert ( + net.layer0[0].dilation == dilation + if isinstance(dilation, tuple) + else (dilation, dilation) + ) + + if bias: + assert len(net.layer0[0].bias) > 0 + assert len(net.layer1[0].bias) > 0 + assert len(net.layer2[0].bias) > 0 + else: + assert net.layer0[0].bias is None + assert net.layer1[0].bias is None + assert net.layer2[0].bias is None + if isinstance(dropout, float) and "D" in adn_ordering: + assert net.layer0[1].D.p == dropout + assert net.layer1[1].D.p == dropout + + +def test_activation_parameters(input_tensor): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + act = ("ELU", {"alpha": 0.1}) + output_act = ("ELU", {"alpha": 0.2}) + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + act=act, + output_act=output_act, + ) + assert isinstance(net.layer0[1].A, ELU) + assert net.layer0[1].A.alpha == 0.1 + assert isinstance(net.layer1[1].A, ELU) + assert net.layer1[1].A.alpha == 0.1 + assert isinstance(net.output_act, ELU) + assert net.output_act.alpha == 0.2 + + net = ConvDecoder( + spatial_dims=spatial_dims, in_channels=in_channels, channels=[2, 4, 1], act=None + ) + with pytest.raises(AttributeError): + net.layer0[1].A + with pytest.raises(AttributeError): + net.layer1[1].A + assert net.output_act is None + + +def test_norm_parameters(input_tensor): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + norm = ("instance", {"momentum": 1.0}) + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + norm=norm, + ) + assert isinstance(net.layer0[1].N, InstanceNorm2d) + assert net.layer0[1].N.momentum == 1.0 + assert isinstance(net.layer1[1].N, InstanceNorm2d) + assert net.layer1[1].N.momentum == 1.0 + + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + norm=None, + ) + with pytest.raises(AttributeError): + net.layer0[1].N + with pytest.raises(AttributeError): + net.layer1[1].N + + +def test_unpool_parameters(input_tensor): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + unpooling = ("convtranspose", {"kernel_size": 3, "stride": 2}) + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + unpooling=unpooling, + unpooling_indices=[1], + ) + assert isinstance(net.unpool1, ConvTranspose2d) + assert net.unpool1.stride == (2, 2) + assert net.unpool1.kernel_size == (3, 3) + + +@pytest.mark.parametrize("adn_ordering", ["DAN", "NA", "A"]) +def test_adn_ordering(input_tensor, adn_ordering): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + dropout=0.1, + adn_ordering=adn_ordering, + act="elu", + norm="instance", + ) + objects = {"D": Dropout, "N": InstanceNorm2d, "A": ELU} + for i, letter in enumerate(adn_ordering): + assert isinstance(net.layer0[1][i], objects[letter]) + assert isinstance(net.layer1[1][i], objects[letter]) + for letter in set(["A", "D", "N"]) - set(adn_ordering): + with pytest.raises(AttributeError): + getattr(net.layer0[1], letter) + with pytest.raises(AttributeError): + getattr(net.layer1[1], letter) + + +@pytest.mark.parametrize( + "input_tensor", [torch.randn(2, 1, 16), torch.randn(2, 3, 20, 21, 22)] +) +def test_other_dimensions(input_tensor): + batch_size, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + net = ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + _input_size=input_size, + ) + output = net(input_tensor) + assert output.shape == (batch_size, 1, *net.final_size) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"kernel_size": (3, 3, 3)}, + {"stride": [1, 1]}, + {"padding": [1, 1]}, + {"dilation": (1,)}, + {"unpooling_indices": [0, 1, 2, 3]}, + {"unpooling": "upsample", "unpooling_indices": [0]}, + {"norm": "group"}, + {"norm": "layer"}, + ], +) +def test_checks(input_tensor, kwargs): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + with pytest.raises(ValueError): + ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + **kwargs, + ) + + +@pytest.mark.parametrize( + "unpooling,error", + [ + (None, False), + ("abc", True), + ("upsample", True), + (("upsample",), True), + (("upsample", 2), True), + (("convtranspose", {"kernel_size": 2}), False), + (("upsample", {"scale_factor": 2}), False), + ( + [("upsample", {"scale_factor": 2}), ("convtranspose", {"kernel_size": 2})], + False, + ), + ([("upsample", {"scale_factor": 2}), None], True), + ([("upsample", {"scale_factor": 2}), "convtranspose"], True), + ([("upsample", {"scale_factor": 2}), ("convtranspose", 2)], True), + ( + [ + ("upsample", {"scale_factor": 2}), + ("convtranspose", {"kernel_size": 2}), + ("convtranspose", {"kernel_size": 2}), + ], + True, + ), + ], +) +def test_check_unpool_layer(input_tensor, unpooling, error): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + if error: + with pytest.raises(ValueError): + ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + unpooling=unpooling, + unpooling_indices=[0, 1], + ) + else: + ConvDecoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + unpooling=unpooling, + unpooling_indices=[0, 1], + ) diff --git a/tests/unittests/monai_networks/nn/test_conv_encoder.py b/tests/unittests/monai_networks/nn/test_conv_encoder.py new file mode 100644 index 000000000..7239a7530 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_conv_encoder.py @@ -0,0 +1,400 @@ +import pytest +import torch +from torch.nn import ( + ELU, + AdaptiveAvgPool2d, + AdaptiveMaxPool2d, + AvgPool2d, + Conv2d, + Dropout, + InstanceNorm2d, + MaxPool2d, +) + +from clinicadl.networks.nn import ConvEncoder +from clinicadl.networks.nn.layers.utils import ActFunction + + +@pytest.fixture +def input_tensor(): + return torch.randn(2, 1, 55, 54) + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(input_tensor, act): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + act=act, + output_act=act, + ) + output_shape = net(input_tensor).shape + assert len(output_shape) == 4 and output_shape[1] == 1 + + +@pytest.mark.parametrize( + "kernel_size,stride,padding,dilation,pooling,pooling_indices,norm,dropout,bias,adn_ordering", + [ + ( + 3, + 1, + 0, + 1, + ("adaptivemax", {"output_size": 1}), + [2], + "batch", + None, + True, + "ADN", + ), + ( + (4, 4), + (2, 1), + 2, + 2, + ("max", {"kernel_size": 2}), + [0, 1], + "instance", + 0.5, + False, + "DAN", + ), + ( + 5, + 1, + (2, 1), + 1, + [ + ("avg", {"kernel_size": 2}), + ("max", {"kernel_size": 2}), + ("adaptiveavg", {"output_size": (2, 3)}), + ], + [-1, 1, 2], + "syncbatch", + 0.5, + True, + "NA", + ), + (5, 1, 0, (1, 2), None, [0, 1], "instance", 0.0, True, "DN"), + ( + 5, + 1, + 2, + 1, + ("avg", {"kernel_size": 2}), + None, + ("group", {"num_groups": 2}), + None, + True, + "N", + ), + (5, 1, 2, 1, ("avg", {"kernel_size": 2}), None, None, None, True, ""), + ], +) +def test_params( + input_tensor, + kernel_size, + stride, + padding, + dilation, + pooling, + pooling_indices, + norm, + dropout, + bias, + adn_ordering, +): + batch_size, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + # test output size + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + pooling=pooling, + pooling_indices=pooling_indices, + dropout=dropout, + act=None, + norm=norm, + bias=bias, + adn_ordering=adn_ordering, + _input_size=input_size, + ) + output = net(input_tensor) + assert output.shape == (batch_size, 1, *net.final_size) + + # other checks + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + pooling=pooling, + pooling_indices=pooling_indices, + dropout=dropout, + act=None, + norm=norm, + bias=bias, + adn_ordering=adn_ordering, + ) + assert isinstance(net.layer2.conv, Conv2d) + with pytest.raises(IndexError): + net.layer2[1] # no adn at the end + + named_layers = list(net.named_children()) + if pooling and pooling_indices and pooling_indices != []: + for i, idx in enumerate(pooling_indices): + name, layer = named_layers[idx + 1 + i] + if idx == -1: + assert name == "init_pool" + else: + assert name == f"pool{idx}" + pooling_mode = net.pooling[i][0] + if pooling_mode == "max": + assert isinstance(layer, MaxPool2d) + elif pooling_mode == "avg": + assert isinstance(layer, AvgPool2d) + elif pooling_mode == "adaptivemax": + assert isinstance(layer, AdaptiveMaxPool2d) + else: + assert isinstance(layer, AdaptiveAvgPool2d) + else: + for name, layer in named_layers: + assert not isinstance(layer, AvgPool2d) or isinstance(layer, MaxPool2d) + assert "pool" not in name + + assert ( + net.layer0.conv.kernel_size == kernel_size + if isinstance(kernel_size, tuple) + else (kernel_size, kernel_size) + ) + assert ( + net.layer0.conv.stride == stride + if isinstance(stride, tuple) + else (stride, stride) + ) + assert ( + net.layer0.conv.padding == padding + if isinstance(padding, tuple) + else (padding, padding) + ) + assert ( + net.layer0.conv.dilation == dilation + if isinstance(dilation, tuple) + else (dilation, dilation) + ) + + if bias: + assert len(net.layer0.conv.bias) > 0 + assert len(net.layer1.conv.bias) > 0 + assert len(net.layer2.conv.bias) > 0 + else: + assert net.layer0.conv.bias is None + assert net.layer1.conv.bias is None + assert net.layer2.conv.bias is None + if isinstance(dropout, float) and "D" in adn_ordering: + assert net.layer0.adn.D.p == dropout + assert net.layer1.adn.D.p == dropout + + +def test_activation_parameters(input_tensor): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + act = ("ELU", {"alpha": 0.1}) + output_act = ("ELU", {"alpha": 0.2}) + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + act=act, + output_act=output_act, + ) + assert isinstance(net.layer0.adn.A, ELU) + assert net.layer0.adn.A.alpha == 0.1 + assert isinstance(net.layer1.adn.A, ELU) + assert net.layer1.adn.A.alpha == 0.1 + assert isinstance(net.output_act, ELU) + assert net.output_act.alpha == 0.2 + + net = ConvEncoder( + spatial_dims=spatial_dims, in_channels=in_channels, channels=[2, 4, 1], act=None + ) + with pytest.raises(AttributeError): + net.layer0.adn.A + with pytest.raises(AttributeError): + net.layer1.adn.A + assert net.output_act is None + + +def test_norm_parameters(input_tensor): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + norm = ("instance", {"momentum": 1.0}) + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + norm=norm, + ) + assert isinstance(net.layer0.adn.N, InstanceNorm2d) + assert net.layer0.adn.N.momentum == 1.0 + assert isinstance(net.layer1.adn.N, InstanceNorm2d) + assert net.layer1.adn.N.momentum == 1.0 + + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + norm=None, + ) + with pytest.raises(AttributeError): + net.layer0.adn.N + with pytest.raises(AttributeError): + net.layer1.adn.N + + +def test_pool_parameters(input_tensor): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + pooling = ("avg", {"kernel_size": 3, "stride": 2}) + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + pooling=pooling, + pooling_indices=[1], + ) + assert isinstance(net.pool1, AvgPool2d) + assert net.pool1.stride == 2 + assert net.pool1.kernel_size == 3 + + +@pytest.mark.parametrize("adn_ordering", ["DAN", "NA", "A"]) +def test_adn_ordering(input_tensor, adn_ordering): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + dropout=0.1, + adn_ordering=adn_ordering, + act="elu", + norm="instance", + ) + objects = {"D": Dropout, "N": InstanceNorm2d, "A": ELU} + for i, letter in enumerate(adn_ordering): + assert isinstance(net.layer0.adn[i], objects[letter]) + assert isinstance(net.layer1.adn[i], objects[letter]) + for letter in set(["A", "D", "N"]) - set(adn_ordering): + with pytest.raises(AttributeError): + getattr(net.layer0.adn, letter) + with pytest.raises(AttributeError): + getattr(net.layer1.adn, letter) + + +@pytest.mark.parametrize( + "input_tensor", [torch.randn(2, 1, 16), torch.randn(2, 3, 20, 21, 22)] +) +def test_other_dimensions(input_tensor): + batch_size, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + net = ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + _input_size=input_size, + ) + output = net(input_tensor) + assert output.shape == (batch_size, 1, *net.final_size) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"kernel_size": (3, 3, 3)}, + {"stride": [1, 1]}, + {"padding": [1, 1]}, + {"dilation": (1,)}, + {"pooling_indices": [0, 1, 2, 3]}, + {"pooling": "avg", "pooling_indices": [0]}, + {"norm": "group"}, + {"_input_size": (1, 10, 10), "stride": 2, "channels": [2, 4, 6, 8]}, + ], +) +def test_checks(kwargs): + if "channels" not in kwargs: + kwargs["channels"] = [2, 4, 1] + if "in_channels" not in kwargs: + kwargs["in_channels"] = 1 + if "spatial_dims" not in kwargs: + kwargs["spatial_dims"] = 2 + with pytest.raises(ValueError): + ConvEncoder(**kwargs) + + +@pytest.mark.parametrize( + "pooling,error", + [ + (None, False), + ("abc", True), + ("max", True), + (("max",), True), + (("max", 3), True), + (("avg", {"stride": 1}), True), + (("avg", {"kernel_size": 1}), False), + (("avg", {"kernel_size": 1, "stride": 1}), False), + (("abc", {"kernel_size": 1, "stride": 1}), True), + ([("avg", {"kernel_size": 1}), ("max", {"kernel_size": 1})], False), + ([("avg", {"kernel_size": 1}), None], True), + ([("avg", {"kernel_size": 1}), "max"], True), + ([("avg", {"kernel_size": 1}), ("max", 3)], True), + ([("avg", {"kernel_size": 1}), ("max", {"stride": 1})], True), + ( + [ + ("avg", {"kernel_size": 1}), + ("max", {"stride": 1}), + ("max", {"stride": 1}), + ], + True, + ), + ], +) +def test_check_pool_layers(input_tensor, pooling, error): + _, in_channels, *input_size = input_tensor.shape + spatial_dims = len(input_size) + + if error: + with pytest.raises(ValueError): + ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + pooling=pooling, + pooling_indices=[0, 1], + ) + else: + ConvEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=[2, 4, 1], + pooling=pooling, + pooling_indices=[0, 1], + ) diff --git a/tests/unittests/monai_networks/nn/test_densenet.py b/tests/unittests/monai_networks/nn/test_densenet.py new file mode 100644 index 000000000..b7fdea50f --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_densenet.py @@ -0,0 +1,138 @@ +import pytest +import torch + +from clinicadl.networks.nn import DenseNet, get_densenet +from clinicadl.networks.nn.densenet import SOTADenseNet +from clinicadl.networks.nn.layers.utils import ActFunction + +INPUT_1D = torch.randn(3, 1, 16) +INPUT_2D = torch.randn(3, 2, 15, 16) +INPUT_3D = torch.randn(3, 3, 20, 21, 22) + + +@pytest.mark.parametrize( + "input_tensor,num_outputs,n_dense_layers,init_features,growth_rate,bottleneck_factor,act,output_act,dropout", + [ + (INPUT_1D, 2, (3, 4), 16, 8, 2, "relu", None, 0.1), + (INPUT_2D, None, (3, 4, 2), 9, 5, 3, "elu", "sigmoid", 0.0), + (INPUT_3D, 1, (2,), 4, 4, 2, "tanh", "sigmoid", 0.1), + ], +) +def test_densenet( + input_tensor, + num_outputs, + n_dense_layers, + init_features, + growth_rate, + bottleneck_factor, + act, + output_act, + dropout, +): + batch_size = input_tensor.shape[0] + net = DenseNet( + spatial_dims=len(input_tensor.shape[2:]), + in_channels=input_tensor.shape[1], + num_outputs=num_outputs, + n_dense_layers=n_dense_layers, + init_features=init_features, + growth_rate=growth_rate, + bottleneck_factor=bottleneck_factor, + act=act, + output_act=output_act, + dropout=dropout, + ) + output = net(input_tensor) + + if num_outputs: + assert output.shape == (batch_size, num_outputs) + else: + assert len(output.shape) == len(input_tensor.shape) + + if output_act and num_outputs: + assert net.fc.output_act is not None + elif output_act and num_outputs is None: + with pytest.raises(AttributeError): + net.fc.output_act + + features = net.features + for i, n in enumerate(n_dense_layers, start=1): + dense_block = getattr(features, f"denseblock{i}") + for k in range(1, n + 1): + dense_layer = getattr(dense_block, f"denselayer{k}").layers + assert dense_layer.conv1.out_channels == growth_rate * bottleneck_factor + assert dense_layer.conv2.out_channels == growth_rate + if dropout: + assert dense_layer.dropout.p == dropout + with pytest.raises(AttributeError): + getattr(dense_block, f"denseblock{n+1}") + with pytest.raises(AttributeError): + getattr(dense_block, f"denseblock{i+1}") + + assert features.conv0.out_channels == init_features + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(act): + batch_size = INPUT_2D.shape[0] + net = DenseNet( + spatial_dims=len(INPUT_2D.shape[2:]), + in_channels=INPUT_2D.shape[1], + n_dense_layers=(2, 2), + num_outputs=2, + act=act, + ) + assert net(INPUT_2D).shape == (batch_size, 2) + + +def test_activation_parameters(): + act = ("ELU", {"alpha": 0.1}) + output_act = ("ELU", {"alpha": 0.2}) + net = DenseNet( + spatial_dims=len(INPUT_2D.shape[2:]), + in_channels=INPUT_2D.shape[1], + num_outputs=2, + n_dense_layers=(2, 2), + act=act, + output_act=output_act, + ) + assert isinstance(net.features.denseblock1.denselayer1.layers.act1, torch.nn.ELU) + assert net.features.denseblock1.denselayer1.layers.act1.alpha == 0.1 + assert isinstance(net.fc.output_act, torch.nn.ELU) + assert net.fc.output_act.alpha == 0.2 + + +@pytest.mark.parametrize( + "name,num_outputs,output_act", + [ + (SOTADenseNet.DENSENET_121, 1, "sigmoid"), + (SOTADenseNet.DENSENET_161, 2, None), + (SOTADenseNet.DENSENET_169, None, "sigmoid"), + (SOTADenseNet.DENSENET_201, None, None), + ], +) +def test_get_densenet(name, num_outputs, output_act): + densenet = get_densenet( + name, num_outputs=num_outputs, output_act=output_act, pretrained=True + ) + if num_outputs: + assert densenet.fc.out.out_features == num_outputs + else: + assert densenet.fc is None + + if output_act and num_outputs: + assert densenet.fc.output_act is not None + elif output_act and num_outputs is None: + with pytest.raises(AttributeError): + densenet.fc.output_act + + +def test_get_densenet_output(): + from torchvision.models import densenet121 + + densenet = get_densenet( + SOTADenseNet.DENSENET_121, num_outputs=None, pretrained=True + ).features + gt = densenet121(weights="DEFAULT").features + x = torch.randn(1, 3, 128, 128) + assert (densenet(x) == gt(x)).all() diff --git a/tests/unittests/monai_networks/nn/test_generator.py b/tests/unittests/monai_networks/nn/test_generator.py new file mode 100644 index 000000000..0bc918a7d --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_generator.py @@ -0,0 +1,67 @@ +import pytest +import torch +from torch.nn import Flatten, Linear + +from clinicadl.networks.nn import MLP, ConvDecoder, Generator + + +@pytest.fixture +def input_tensor(): + return torch.randn(2, 8) + + +@pytest.mark.parametrize("channels", [(), (2, 4)]) +@pytest.mark.parametrize( + "mlp_args", [None, {"hidden_channels": []}, {"hidden_channels": (2, 4)}] +) +@pytest.mark.parametrize("start_shape", [(1, 5), (1, 5, 5), (1, 5, 5)]) +def test_generator(input_tensor, start_shape, channels, mlp_args): + latent_size = input_tensor.shape[1] + net = Generator( + latent_size=latent_size, + start_shape=start_shape, + conv_args={"channels": channels}, + mlp_args=mlp_args, + ) + output = net(input_tensor) + assert output.shape[1:] == net.output_shape + assert isinstance(net.convolutions, ConvDecoder) + assert isinstance(net.mlp, MLP) + + if mlp_args is None or mlp_args["hidden_channels"] == []: + children = net.mlp.children() + assert isinstance(next(children), Flatten) + assert isinstance(next(children).linear, Linear) + with pytest.raises(StopIteration): + next(children) + + if channels == []: + with pytest.raises(StopIteration): + next(net.convolutions.parameters()) + + +@pytest.mark.parametrize( + "conv_args,mlp_args", + [ + (None, {"hidden_channels": [2]}), + ({"channels": [2]}, {}), + ], +) +def test_checks(conv_args, mlp_args): + with pytest.raises(ValueError): + Generator( + latent_size=2, + start_shape=(1, 10, 10), + conv_args=conv_args, + mlp_args=mlp_args, + ) + + +def test_params(): + conv_args = {"channels": [2], "act": "celu"} + mlp_args = {"hidden_channels": [2], "act": "relu"} + net = Generator( + latent_size=2, start_shape=(1, 10, 10), conv_args=conv_args, mlp_args=mlp_args + ) + assert net.convolutions.act == "celu" + assert net.mlp.act == "relu" diff --git a/tests/unittests/monai_networks/nn/test_mlp.py b/tests/unittests/monai_networks/nn/test_mlp.py new file mode 100644 index 000000000..91ad682d1 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_mlp.py @@ -0,0 +1,125 @@ +import pytest +import torch +from torch.nn import ELU, Dropout, InstanceNorm1d, Linear + +from clinicadl.networks.nn import MLP +from clinicadl.networks.nn.layers.utils import ActFunction + + +@pytest.fixture +def input_tensor(): + return torch.randn(8, 10) + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(input_tensor, act): + net = MLP( + in_channels=10, out_channels=2, hidden_channels=[6, 4], act=act, output_act=act + ) + assert net(input_tensor).shape == (8, 2) + + +@pytest.mark.parametrize( + "dropout,norm,bias,adn_ordering", + [ + (None, "batch", True, "ADN"), + (0.5, "layer", False, "DAN"), + (0.5, "syncbatch", True, "NA"), + (0.0, "instance", True, "DN"), + (None, ("group", {"num_groups": 2}), True, "N"), + (0.5, None, True, "ADN"), + (0.5, "batch", True, ""), + ], +) +def test_params(input_tensor, dropout, norm, bias, adn_ordering): + net = MLP( + in_channels=10, + out_channels=2, + hidden_channels=[6, 4], + dropout=dropout, + norm=norm, + act=None, + bias=bias, + adn_ordering=adn_ordering, + ) + assert net(input_tensor).shape == (8, 2) + assert isinstance(net.output.linear, Linear) + + if bias: + assert len(net.hidden0.linear.bias) > 0 + assert len(net.hidden1.linear.bias) > 0 + assert len(net.output.linear.bias) > 0 + else: + assert net.hidden0.linear.bias is None + assert net.hidden1.linear.bias is None + assert net.output.linear.bias is None + if isinstance(dropout, float) and "D" in adn_ordering: + assert net.hidden0.adn.D.p == dropout + assert net.hidden1.adn.D.p == dropout + + +def test_activation_parameters(): + act = ("ELU", {"alpha": 0.1}) + output_act = ("ELU", {"alpha": 0.2}) + net = MLP( + in_channels=10, + out_channels=2, + hidden_channels=[6, 4], + act=act, + output_act=output_act, + ) + assert isinstance(net.hidden0.adn.A, ELU) + assert net.hidden0.adn.A.alpha == 0.1 + assert isinstance(net.hidden1.adn.A, ELU) + assert net.hidden1.adn.A.alpha == 0.1 + assert isinstance(net.output.output_act, ELU) + assert net.output.output_act.alpha == 0.2 + + net = MLP(in_channels=10, out_channels=2, hidden_channels=[6, 4], act=None) + with pytest.raises(AttributeError): + net.hidden0.adn.A + with pytest.raises(AttributeError): + net.hidden1.adn.A + assert net.output.output_act is None + + +def test_norm_parameters(): + norm = ("instance", {"momentum": 1.0}) + net = MLP(in_channels=10, out_channels=2, hidden_channels=[6, 4], norm=norm) + assert isinstance(net.hidden0.adn.N, InstanceNorm1d) + assert net.hidden0.adn.N.momentum == 1.0 + assert isinstance(net.hidden1.adn.N, InstanceNorm1d) + assert net.hidden1.adn.N.momentum == 1.0 + + net = MLP(in_channels=10, out_channels=2, hidden_channels=[6, 4], act=None) + with pytest.raises(AttributeError): + net.layer_0[1].N + with pytest.raises(AttributeError): + net.layer_1[1].N + + +@pytest.mark.parametrize("adn_ordering", ["DAN", "NA", "A"]) +def test_adn_ordering(adn_ordering): + net = MLP( + in_channels=10, + out_channels=2, + hidden_channels=[6, 4], + dropout=0.1, + adn_ordering=adn_ordering, + act="elu", + norm="instance", + ) + objects = {"D": Dropout, "N": InstanceNorm1d, "A": ELU} + for i, letter in enumerate(adn_ordering): + assert isinstance(net.hidden0.adn[i], objects[letter]) + assert isinstance(net.hidden1.adn[i], objects[letter]) + for letter in set(["A", "D", "N"]) - set(adn_ordering): + with pytest.raises(AttributeError): + getattr(net.hidden0.adn, letter) + with pytest.raises(AttributeError): + getattr(net.hidden1.adn, letter) + + +def test_checks(): + with pytest.raises(ValueError): + MLP(in_channels=10, out_channels=2, hidden_channels=[6, 4], norm="group") diff --git a/tests/unittests/monai_networks/nn/test_resnet.py b/tests/unittests/monai_networks/nn/test_resnet.py new file mode 100644 index 000000000..a99ea06dc --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_resnet.py @@ -0,0 +1,173 @@ +import pytest +import torch + +from clinicadl.networks.nn import ResNet, get_resnet +from clinicadl.networks.nn.layers.resnet import ResNetBlock, ResNetBottleneck +from clinicadl.networks.nn.layers.utils import ActFunction +from clinicadl.networks.nn.resnet import SOTAResNet + +INPUT_1D = torch.randn(3, 1, 16) +INPUT_2D = torch.randn(3, 2, 15, 16) +INPUT_3D = torch.randn(3, 3, 20, 21, 22) + + +@pytest.mark.parametrize( + "input_tensor,num_outputs,block_type,n_res_blocks,n_features,init_conv_size,init_conv_stride,bottleneck_reduction,act,output_act", + [ + (INPUT_1D, 2, "basic", (2, 3), (4, 8), 7, 1, 2, "relu", None), + ( + INPUT_2D, + None, + "bottleneck", + (3, 2, 2), + (8, 12, 16), + 5, + (2, 1), + 4, + "elu", + "sigmoid", + ), + (INPUT_3D, 1, "bottleneck", (2,), (3,), (4, 3, 4), 2, 1, "tanh", "sigmoid"), + ], +) +def test_resnet( + input_tensor, + num_outputs, + block_type, + n_res_blocks, + n_features, + init_conv_size, + init_conv_stride, + bottleneck_reduction, + act, + output_act, +): + batch_size = input_tensor.shape[0] + spatial_dims = len(input_tensor.shape[2:]) + net = ResNet( + spatial_dims=spatial_dims, + in_channels=input_tensor.shape[1], + num_outputs=num_outputs, + block_type=block_type, + n_res_blocks=n_res_blocks, + n_features=n_features, + init_conv_size=init_conv_size, + init_conv_stride=init_conv_stride, + bottleneck_reduction=bottleneck_reduction, + act=act, + output_act=output_act, + ) + output = net(input_tensor) + + if num_outputs: + assert output.shape == (batch_size, num_outputs) + else: + assert len(output.shape) == len(input_tensor.shape) + + if output_act and num_outputs: + assert net.fc.output_act is not None + elif output_act and num_outputs is None: + with pytest.raises(AttributeError): + net.fc.output_act + + for i, (n_blocks, n_feats) in enumerate(zip(n_res_blocks, n_features), start=1): + layer = getattr(net, f"layer{i}") + for k in range(n_blocks): + res_block = layer[k] + if block_type == "basic": + assert isinstance(res_block, ResNetBlock) + else: + assert isinstance(res_block, ResNetBottleneck) + if block_type == "basic": + assert res_block.conv2.out_channels == n_feats + else: + assert res_block.conv1.out_channels == n_feats // bottleneck_reduction + assert res_block.conv3.out_channels == n_feats + with pytest.raises(IndexError): + layer[k + 1] + with pytest.raises(AttributeError): + getattr(net, f"layer{i+1}") + + assert ( + net.conv0.kernel_size == init_conv_size + if isinstance(init_conv_size, tuple) + else (init_conv_size,) * spatial_dims + ) + assert ( + net.conv0.stride == init_conv_stride + if isinstance(init_conv_stride, tuple) + else (init_conv_stride,) * spatial_dims + ) + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(act): + batch_size = INPUT_2D.shape[0] + net = ResNet( + spatial_dims=len(INPUT_2D.shape[2:]), + in_channels=INPUT_2D.shape[1], + num_outputs=2, + n_features=(8, 16), + n_res_blocks=(2, 2), + act=act, + ) + assert net(INPUT_2D).shape == (batch_size, 2) + + +def test_activation_parameters(): + act = ("ELU", {"alpha": 0.1}) + output_act = ("ELU", {"alpha": 0.2}) + net = ResNet( + spatial_dims=len(INPUT_2D.shape[2:]), + in_channels=INPUT_2D.shape[1], + num_outputs=2, + n_features=(8, 16), + n_res_blocks=(2, 2), + act=act, + output_act=output_act, + ) + assert isinstance(net.layer1[0].act1, torch.nn.ELU) + assert net.layer1[0].act1.alpha == 0.1 + assert isinstance(net.layer2[1].act2, torch.nn.ELU) + assert net.layer2[1].act2.alpha == 0.1 + assert isinstance(net.act0, torch.nn.ELU) + assert net.act0.alpha == 0.1 + assert isinstance(net.fc.output_act, torch.nn.ELU) + assert net.fc.output_act.alpha == 0.2 + + +@pytest.mark.parametrize( + "name,num_outputs,output_act", + [ + (SOTAResNet.RESNET_18, 1, "sigmoid"), + (SOTAResNet.RESNET_34, 2, None), + (SOTAResNet.RESNET_50, None, "sigmoid"), + (SOTAResNet.RESNET_101, None, None), + (SOTAResNet.RESNET_152, None, None), + ], +) +def test_get_resnet(name, num_outputs, output_act): + resnet = get_resnet( + name, num_outputs=num_outputs, output_act=output_act, pretrained=True + ) + if num_outputs: + assert resnet.fc.out.out_features == num_outputs + else: + assert resnet.fc is None + + if output_act and num_outputs: + assert resnet.fc.output_act is not None + elif output_act and num_outputs is None: + with pytest.raises(AttributeError): + resnet.fc.output_act + + +def test_get_resnet_output(): + from torchvision.models import resnet18 + + resnet = get_resnet(SOTAResNet.RESNET_18, num_outputs=None, pretrained=True) + gt = resnet18(weights="DEFAULT") + gt.avgpool = torch.nn.Identity() + gt.fc = torch.nn.Identity() + x = torch.randn(1, 3, 128, 128) + assert (torch.flatten(resnet(x), start_dim=1) == gt(x)).all() diff --git a/tests/unittests/monai_networks/nn/test_senet.py b/tests/unittests/monai_networks/nn/test_senet.py new file mode 100644 index 000000000..6e3527b38 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_senet.py @@ -0,0 +1,172 @@ +import pytest +import torch + +from clinicadl.networks.nn import SEResNet, get_seresnet +from clinicadl.networks.nn.layers.senet import SEResNetBlock, SEResNetBottleneck +from clinicadl.networks.nn.layers.utils import ActFunction +from clinicadl.networks.nn.senet import SOTAResNet + +INPUT_1D = torch.randn(3, 1, 16) +INPUT_2D = torch.randn(3, 2, 15, 16) +INPUT_3D = torch.randn(3, 3, 20, 21, 22) + + +@pytest.mark.parametrize( + "input_tensor,num_outputs,block_type,n_res_blocks,n_features,init_conv_size,init_conv_stride,bottleneck_reduction,act,output_act,se_reduction", + [ + (INPUT_1D, 2, "basic", (2, 3), (4, 8), 7, 1, 2, "relu", None, 4), + ( + INPUT_2D, + None, + "bottleneck", + (3, 2, 2), + (8, 12, 16), + 5, + (2, 1), + 4, + "elu", + "sigmoid", + 2, + ), + (INPUT_3D, 1, "bottleneck", (2,), (3,), (4, 3, 4), 2, 1, "tanh", "sigmoid", 2), + ], +) +def test_seresnet( + input_tensor, + num_outputs, + block_type, + n_res_blocks, + n_features, + init_conv_size, + init_conv_stride, + bottleneck_reduction, + act, + output_act, + se_reduction, +): + batch_size = input_tensor.shape[0] + spatial_dims = len(input_tensor.shape[2:]) + net = SEResNet( + spatial_dims=spatial_dims, + in_channels=input_tensor.shape[1], + num_outputs=num_outputs, + block_type=block_type, + n_res_blocks=n_res_blocks, + n_features=n_features, + init_conv_size=init_conv_size, + init_conv_stride=init_conv_stride, + bottleneck_reduction=bottleneck_reduction, + act=act, + output_act=output_act, + se_reduction=se_reduction, + ) + output = net(input_tensor) + + if num_outputs: + assert output.shape == (batch_size, num_outputs) + else: + assert len(output.shape) == len(input_tensor.shape) + + if output_act and num_outputs: + assert net.fc.output_act is not None + elif output_act and num_outputs is None: + with pytest.raises(AttributeError): + net.fc.output_act + + for i, (n_blocks, n_feats) in enumerate(zip(n_res_blocks, n_features), start=1): + layer = getattr(net, f"layer{i}") + for k in range(n_blocks): + res_block = layer[k] + if block_type == "basic": + assert isinstance(res_block, SEResNetBlock) + else: + assert isinstance(res_block, SEResNetBottleneck) + if block_type == "basic": + assert res_block.conv2.out_channels == n_feats + else: + assert res_block.conv1.out_channels == n_feats // bottleneck_reduction + assert res_block.conv3.out_channels == n_feats + with pytest.raises(IndexError): + layer[k + 1] + with pytest.raises(AttributeError): + getattr(net, f"layer{i+1}") + + assert ( + net.conv0.kernel_size == init_conv_size + if isinstance(init_conv_size, tuple) + else (init_conv_size,) * spatial_dims + ) + assert ( + net.conv0.stride == init_conv_stride + if isinstance(init_conv_stride, tuple) + else (init_conv_stride,) * spatial_dims + ) + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(act): + batch_size = INPUT_2D.shape[0] + net = SEResNet( + spatial_dims=len(INPUT_2D.shape[2:]), + in_channels=INPUT_2D.shape[1], + num_outputs=2, + n_features=(8, 16), + n_res_blocks=(2, 2), + act=act, + se_reduction=2, + ) + assert net(INPUT_2D).shape == (batch_size, 2) + + +def test_activation_parameters(): + act = ("ELU", {"alpha": 0.1}) + output_act = ("ELU", {"alpha": 0.2}) + net = SEResNet( + spatial_dims=len(INPUT_2D.shape[2:]), + in_channels=INPUT_2D.shape[1], + num_outputs=2, + n_features=(8, 16), + n_res_blocks=(2, 2), + act=act, + output_act=output_act, + se_reduction=2, + ) + assert isinstance(net.layer1[0].act1, torch.nn.ELU) + assert net.layer1[0].act1.alpha == 0.1 + assert isinstance(net.layer2[1].act2, torch.nn.ELU) + assert net.layer2[1].act2.alpha == 0.1 + assert isinstance(net.act0, torch.nn.ELU) + assert net.act0.alpha == 0.1 + assert isinstance(net.fc.output_act, torch.nn.ELU) + assert net.fc.output_act.alpha == 0.2 + + +@pytest.mark.parametrize( + "name,num_outputs,output_act", + [ + (SOTAResNet.SE_RESNET_50, 1, "sigmoid"), + (SOTAResNet.SE_RESNET_101, 2, None), + (SOTAResNet.SE_RESNET_152, None, "sigmoid"), + ], +) +def test_get_seresnet(name, num_outputs, output_act): + seresnet = get_seresnet( + name, + num_outputs=num_outputs, + output_act=output_act, + ) + if num_outputs: + assert seresnet.fc.out.out_features == num_outputs + else: + assert seresnet.fc is None + + if output_act and num_outputs: + assert seresnet.fc.output_act is not None + elif output_act and num_outputs is None: + with pytest.raises(AttributeError): + seresnet.fc.output_act + + +def test_get_seresnet_error(): + with pytest.raises(ValueError): + get_seresnet(SOTAResNet.SE_RESNET_50, num_outputs=1, pretrained=True) diff --git a/tests/unittests/monai_networks/nn/test_unet.py b/tests/unittests/monai_networks/nn/test_unet.py new file mode 100644 index 000000000..b7f6349fc --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_unet.py @@ -0,0 +1,127 @@ +import pytest +import torch + +from clinicadl.networks.nn import UNet +from clinicadl.networks.nn.layers.utils import ActFunction + +INPUT_1D = torch.randn(2, 1, 16) +INPUT_2D = torch.randn(2, 2, 32, 64) +INPUT_3D = torch.randn(2, 3, 16, 32, 8) + + +@pytest.mark.parametrize( + "input_tensor,out_channels,channels,act,output_act,dropout,error", + [ + (INPUT_1D, 1, (2, 3, 4), "relu", "sigmoid", None, False), + (INPUT_2D, 1, (2, 4, 5), "relu", None, 0.0, False), + (INPUT_3D, 2, (2, 3), None, ("softmax", {"dim": 1}), 0.1, False), + ( + INPUT_3D, + 2, + (2,), + None, + ("softmax", {"dim": 1}), + 0.1, + True, + ), # channels length is less than 2 + ], +) +def test_unet(input_tensor, out_channels, channels, act, output_act, dropout, error): + batch_size, in_channels, *img_size = input_tensor.shape + spatial_dims = len(img_size) + if error: + with pytest.raises(ValueError): + UNet( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + channels=channels, + act=act, + output_act=output_act, + dropout=dropout, + ) + else: + net = UNet( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + channels=channels, + act=act, + output_act=output_act, + dropout=dropout, + ) + + out = net(input_tensor) + assert out.shape == (batch_size, out_channels, *img_size) + + if output_act: + assert net.output_act is not None + else: + assert net.output_act is None + + assert net.doubleconv[1].conv.out_channels == channels[0] + if dropout: + assert net.doubleconv[1].adn.D.p == dropout + else: + with pytest.raises(AttributeError): + net.doubleconv[1].conv.adn.D + + for i in range(1, len(channels)): + down = getattr(net, f"down{i}").doubleconv + up = getattr(net, f"doubleconv{i}") + assert down[0].conv.in_channels == channels[i - 1] + assert down[1].conv.out_channels == channels[i] + assert up[0].conv.in_channels == channels[i - 1] * 2 + assert up[1].conv.out_channels == channels[i - 1] + for m in (down, up): + if dropout is not None: + assert m[1].adn.D.p == dropout + else: + with pytest.raises(AttributeError): + m[1].adn.D + with pytest.raises(AttributeError): + down = getattr(net, f"down{i+1}") + with pytest.raises(AttributeError): + getattr(net, f"doubleconv{i+1}") + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(act): + batch_size, in_channels, *img_size = INPUT_2D.shape + net = UNet( + spatial_dims=2, + in_channels=in_channels, + out_channels=2, + channels=(2, 4), + act=act, + output_act=act, + ) + assert net(INPUT_2D).shape == (batch_size, 2, *img_size) + + +def test_activation_parameters(): + in_channels = INPUT_2D.shape[1] + act = ("ELU", {"alpha": 0.1}) + output_act = ("ELU", {"alpha": 0.2}) + net = UNet( + spatial_dims=2, + in_channels=in_channels, + out_channels=2, + channels=(2, 4), + act=act, + output_act=output_act, + ) + assert isinstance(net.doubleconv[0].adn.A, torch.nn.ELU) + assert net.doubleconv[0].adn.A.alpha == 0.1 + + assert isinstance(net.down1.doubleconv[0].adn.A, torch.nn.ELU) + assert net.down1.doubleconv[0].adn.A.alpha == 0.1 + + assert isinstance(net.upsample1[1].adn.A, torch.nn.ELU) + assert net.upsample1[1].adn.A.alpha == 0.1 + + assert isinstance(net.doubleconv1[1].adn.A, torch.nn.ELU) + assert net.doubleconv1[1].adn.A.alpha == 0.1 + + assert isinstance(net.output_act, torch.nn.ELU) + assert net.output_act.alpha == 0.2 diff --git a/tests/unittests/monai_networks/nn/test_vae.py b/tests/unittests/monai_networks/nn/test_vae.py new file mode 100644 index 000000000..6f2d0f279 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_vae.py @@ -0,0 +1,99 @@ +import pytest +import torch +from numpy import isclose +from torch.nn import ReLU + +from clinicadl.networks.nn import VAE + + +@pytest.mark.parametrize( + "input_tensor,kernel_size,stride,padding,dilation,pooling,pooling_indices", + [ + (torch.randn(2, 1, 21), 3, 1, 0, 1, ("max", {"kernel_size": 2}), [0]), + ( + torch.randn(2, 1, 65, 85), + (3, 5), + (1, 2), + 0, + (1, 2), + ("max", {"kernel_size": 2, "stride": 1}), + [0], + ), + ( + torch.randn(2, 1, 64, 62, 61), # to test output padding + 4, + 2, + (1, 1, 0), + 1, + ("avg", {"kernel_size": 3, "stride": 2}), + [0], + ), + ( + torch.randn(2, 1, 51, 55, 45), + 3, + 2, + 0, + 1, + ("max", {"kernel_size": 2, "ceil_mode": True}), + [0, 1], + ), + ( + torch.randn(2, 1, 51, 55, 45), + 3, + 2, + 0, + 1, + [ + ("max", {"kernel_size": 2, "ceil_mode": True}), + ("max", {"kernel_size": 2, "stride": 1, "ceil_mode": False}), + ], + [0, 1], + ), + ], +) +def test_output_shape( + input_tensor, kernel_size, stride, padding, dilation, pooling, pooling_indices +): + latent_size = 3 + net = VAE( + in_shape=input_tensor.shape[1:], + latent_size=latent_size, + conv_args={ + "channels": [2, 4, 8], + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + "pooling": pooling, + "pooling_indices": pooling_indices, + }, + ) + recon, mu, log_var = net(input_tensor) + assert recon.shape == input_tensor.shape + assert mu.shape == (input_tensor.shape[0], latent_size) + assert log_var.shape == (input_tensor.shape[0], latent_size) + + +def test_mu_log_var(): + net = VAE( + in_shape=(1, 5, 5), + latent_size=4, + conv_args={"channels": []}, + mlp_args={"hidden_channels": [12], "output_act": "relu", "act": "celu"}, + ) + assert net.mu.linear.in_features == 12 + assert net.log_var.linear.in_features == 12 + assert isinstance(net.mu.output_act, ReLU) + assert isinstance(net.log_var.output_act, ReLU) + assert net.encoder(torch.randn(2, 1, 5, 5)).shape == (2, 12) + _, mu, log_var = net(torch.randn(2, 1, 5, 5)) + assert not isclose(mu.detach().numpy(), log_var.detach().numpy()).all() + + net = VAE( + in_shape=(1, 5, 5), + latent_size=4, + conv_args={"channels": []}, + mlp_args={"hidden_channels": [12]}, + ) + assert net.mu.linear.in_features == 12 + assert net.log_var.linear.in_features == 12 diff --git a/tests/unittests/monai_networks/nn/test_vit.py b/tests/unittests/monai_networks/nn/test_vit.py new file mode 100644 index 000000000..b8b5938b8 --- /dev/null +++ b/tests/unittests/monai_networks/nn/test_vit.py @@ -0,0 +1,279 @@ +import numpy as np +import pytest +import torch + +from clinicadl.networks.nn import ViT, get_vit +from clinicadl.networks.nn.layers.utils import ActFunction +from clinicadl.networks.nn.vit import SOTAViT + +INPUT_1D = torch.randn(2, 1, 16) +INPUT_2D = torch.randn(2, 2, 15, 16) +INPUT_3D = torch.randn(2, 3, 24, 24, 24) + + +@pytest.mark.parametrize( + "input_tensor,patch_size,num_outputs,embedding_dim,num_layers,num_heads,mlp_dim,pos_embed_type,output_act,dropout,error", + [ + (INPUT_1D, 4, 1, 25, 3, 5, 26, None, "softmax", None, False), + ( + INPUT_1D, + 5, + 1, + 25, + 3, + 5, + 26, + None, + "softmax", + None, + True, + ), # img not divisible by patch + ( + INPUT_1D, + 4, + 1, + 25, + 3, + 4, + 26, + None, + "softmax", + None, + True, + ), # embedding not divisible by num heads + (INPUT_1D, 4, 1, 24, 5, 4, 26, "sincos", "softmax", None, True), # sincos + (INPUT_2D, (3, 4), None, 24, 2, 4, 42, "learnable", "tanh", 0.1, False), + ( + INPUT_2D, + 4, + None, + 24, + 2, + 6, + 42, + "learnable", + "tanh", + 0.1, + True, + ), # img not divisible by patch + ( + INPUT_2D, + (3, 4), + None, + 24, + 2, + 5, + 42, + "learnable", + "tanh", + 0.1, + True, + ), # embedding not divisible by num heads + ( + INPUT_2D, + (3, 4), + None, + 18, + 2, + 6, + 42, + "sincos", + "tanh", + 0.1, + True, + ), # sincos : embedding not divisible by 4 + (INPUT_2D, (3, 4), None, 24, 2, 6, 42, "sincos", "tanh", 0.1, False), + ( + INPUT_3D, + 6, + 2, + 15, + 2, + 3, + 42, + "sincos", + None, + 0.0, + True, + ), # sincos : embedding not divisible by 6 + (INPUT_3D, 6, 2, 18, 2, 3, 42, "sincos", None, 0.0, False), + ], +) +def test_vit( + input_tensor, + patch_size, + num_outputs, + embedding_dim, + num_layers, + num_heads, + mlp_dim, + pos_embed_type, + output_act, + dropout, + error, +): + batch_size = input_tensor.shape[0] + img_size = input_tensor.shape[2:] + spatial_dims = len(img_size) + if error: + with pytest.raises(ValueError): + ViT( + in_shape=input_tensor.shape[1:], + patch_size=patch_size, + num_outputs=num_outputs, + embedding_dim=embedding_dim, + num_layers=num_layers, + num_heads=num_heads, + mlp_dim=mlp_dim, + pos_embed_type=pos_embed_type, + output_act=output_act, + dropout=dropout, + ) + else: + net = ViT( + in_shape=input_tensor.shape[1:], + patch_size=patch_size, + num_outputs=num_outputs, + embedding_dim=embedding_dim, + num_layers=num_layers, + num_heads=num_heads, + mlp_dim=mlp_dim, + pos_embed_type=pos_embed_type, + output_act=output_act, + dropout=dropout, + ) + output = net(input_tensor) + + if num_outputs: + assert output.shape == (batch_size, num_outputs) + else: + n_patches = int( + np.prod( + np.array(img_size) + // np.array( + patch_size + if isinstance(patch_size, tuple) + else (patch_size,) * spatial_dims + ) + ) + ) + assert output.shape == (batch_size, n_patches, embedding_dim) + + if output_act and num_outputs: + assert net.fc.output_act is not None + elif output_act and num_outputs is None: + with pytest.raises(AttributeError): + net.fc.output_act + + assert net.conv_proj.out_channels == embedding_dim + encoder = net.encoder.layers + for transformer_block in encoder: + assert isinstance(transformer_block.norm1, torch.nn.LayerNorm) + assert isinstance(transformer_block.norm2, torch.nn.LayerNorm) + assert transformer_block.self_attention.num_heads == num_heads + assert transformer_block.self_attention.dropout == ( + dropout if dropout is not None else 0.0 + ) + assert transformer_block.self_attention.embed_dim == embedding_dim + assert transformer_block.mlp[0].out_features == mlp_dim + assert transformer_block.mlp[2].p == ( + dropout if dropout is not None else 0.0 + ) + assert transformer_block.mlp[4].p == ( + dropout if dropout is not None else 0.0 + ) + assert net.encoder.dropout.p == (dropout if dropout is not None else 0.0) + assert isinstance(net.encoder.norm, torch.nn.LayerNorm) + + pos_embedding = net.encoder.pos_embedding + if pos_embed_type is None: + assert not pos_embedding.requires_grad + assert (pos_embedding == torch.zeros_like(pos_embedding)).all() + elif pos_embed_type == "sincos": + assert not pos_embedding.requires_grad + if num_outputs: + assert ( + pos_embedding[0, 1, 0] == 0.0 + ) # first element of of sincos embedding of first patch is zero + else: + assert pos_embedding[0, 0, 0] == 0.0 + else: + assert pos_embedding.requires_grad + if num_outputs: + assert pos_embedding[0, 1, 0] != 0.0 + else: + assert pos_embedding[0, 0, 0] != 0.0 + + with pytest.raises(IndexError): + encoder[num_layers] + + +@pytest.mark.parametrize("act", [act for act in ActFunction]) +def test_activations(act): + batch_size = INPUT_2D.shape[0] + net = ViT( + in_shape=INPUT_2D.shape[1:], + patch_size=(3, 4), + num_outputs=1, + embedding_dim=12, + num_layers=2, + num_heads=3, + mlp_dim=24, + output_act=act, + ) + assert net(INPUT_2D).shape == (batch_size, 1) + + +def test_activation_parameters(): + output_act = ("ELU", {"alpha": 0.2}) + net = ViT( + in_shape=(1, 12, 12), + patch_size=3, + num_outputs=1, + embedding_dim=12, + num_layers=2, + num_heads=3, + mlp_dim=24, + output_act=output_act, + ) + assert isinstance(net.fc.output_act, torch.nn.ELU) + assert net.fc.output_act.alpha == 0.2 + + +@pytest.mark.parametrize( + "name,num_outputs,output_act,img_size", + [ + (SOTAViT.B_16, 1, "sigmoid", (224, 224)), + (SOTAViT.B_32, 2, None, (224, 224)), + (SOTAViT.L_16, None, "sigmoid", (224, 224)), + (SOTAViT.L_32, None, None, (224, 224)), + ], +) +def test_get_vit(name, num_outputs, output_act, img_size): + input_tensor = torch.randn(1, 3, *img_size) + + vit = get_vit(name, num_outputs=num_outputs, output_act=output_act, pretrained=True) + if num_outputs: + assert vit.fc.out.out_features == num_outputs + else: + assert vit.fc is None + + if output_act and num_outputs: + assert vit.fc.output_act is not None + elif output_act and num_outputs is None: + assert vit.fc is None + + vit(input_tensor) + + +def test_get_vit_output(): + from torchvision.models import vit_b_16 + + gt = vit_b_16(weights="DEFAULT") + gt.heads = torch.nn.Identity() + x = torch.randn(1, 3, 224, 224) + + vit = get_vit(SOTAViT.B_16, num_outputs=1, pretrained=True) + vit.fc = torch.nn.Identity() + with torch.no_grad(): + assert (vit(x) == gt(x)).all() diff --git a/tests/unittests/monai_networks/nn/utils/__init__.py b/tests/unittests/monai_networks/nn/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unittests/monai_networks/nn/utils/test_checks.py b/tests/unittests/monai_networks/nn/utils/test_checks.py new file mode 100644 index 000000000..b90368662 --- /dev/null +++ b/tests/unittests/monai_networks/nn/utils/test_checks.py @@ -0,0 +1,127 @@ +import pytest + +from clinicadl.networks.nn.utils.checks import ( + _check_conv_parameter, + check_adn_ordering, + check_conv_args, + check_mlp_args, + check_norm_layer, + check_pool_indices, + ensure_list_of_tuples, +) + + +@pytest.mark.parametrize( + "adn,error", + [("ADN", False), ("ND", False), ("A", False), ("AAD", True), ("ADM", True)], +) +def test_check_adn_ordering(adn, error): + if error: + with pytest.raises(ValueError): + check_adn_ordering(adn) + else: + check_adn_ordering(adn) + + +@pytest.mark.parametrize( + "parameter,expected_output", + [ + (5, (5, 5, 5)), + ((5, 4, 4), (5, 4, 4)), + ([5, 4], [(5, 5, 5), (4, 4, 4)]), + ([5, (4, 3, 3)], [(5, 5, 5), (4, 3, 3)]), + ((5, 5), None), + ([5, 5, 5], None), + ([5, (4, 4)], None), + (5.0, None), + ], +) +def test_check_conv_parameter(parameter, expected_output): + if expected_output: + assert ( + _check_conv_parameter(parameter, dim=3, n_layers=2, name="abc") + == expected_output + ) + else: + with pytest.raises(ValueError): + _check_conv_parameter(parameter, dim=3, n_layers=2, name="abc") + + +@pytest.mark.parametrize( + "parameter,expected_output", + [ + (5, [(5, 5, 5), (5, 5, 5)]), + ((5, 4, 4), [(5, 4, 4), (5, 4, 4)]), + ([5, 4], [(5, 5, 5), (4, 4, 4)]), + ([5, (4, 3, 3)], [(5, 5, 5), (4, 3, 3)]), + ], +) +def test_ensure_list_of_tuples(parameter, expected_output): + assert ( + ensure_list_of_tuples(parameter, dim=3, n_layers=2, name="abc") + == expected_output + ) + + +@pytest.mark.parametrize( + "indices,n_layers,error", + [ + ([0, 1, 2], 4, False), + ([0, 1, 2], 3, False), + ([-1, 1, 2], 3, False), + ([0, 1, 2], 2, True), + ([-2, 1, 2], 3, True), + ], +) +def test_check_pool_indices(indices, n_layers, error): + if error: + with pytest.raises(ValueError): + _ = check_pool_indices(indices, n_layers) + else: + check_pool_indices(indices, n_layers) + + +@pytest.mark.parametrize( + "inputs,error", + [ + (None, False), + ("abc", True), + ("batch", False), + ("group", True), + (("batch",), True), + (("batch", 3), True), + (("batch", {"eps": 0.1}), False), + (("group", {"num_groups": 2}), False), + (("group", {"num_groups": 2, "eps": 0.1}), False), + ], +) +def test_check_norm_layer(inputs, error): + if error: + with pytest.raises(ValueError): + _ = check_norm_layer(inputs) + else: + assert check_norm_layer(inputs) == inputs + + +@pytest.mark.parametrize( + "conv_args,error", + [(None, True), ({"kernel_size": 3}, True), ({"channels": [2]}, False)], +) +def test_check_conv_args(conv_args, error): + if error: + with pytest.raises(ValueError): + check_conv_args(conv_args) + else: + check_conv_args(conv_args) + + +@pytest.mark.parametrize( + "mlp_args,error", + [({"act": "tanh"}, True), ({"hidden_channels": [2]}, False)], +) +def test_check_mlp_args(mlp_args, error): + if error: + with pytest.raises(ValueError): + check_mlp_args(mlp_args) + else: + check_mlp_args(mlp_args) diff --git a/tests/unittests/monai_networks/nn/utils/test_shapes.py b/tests/unittests/monai_networks/nn/utils/test_shapes.py new file mode 100644 index 000000000..a0116ffe8 --- /dev/null +++ b/tests/unittests/monai_networks/nn/utils/test_shapes.py @@ -0,0 +1,281 @@ +import pytest +import torch + +from clinicadl.networks.nn.utils.shapes import ( + _calculate_adaptivepool_out_shape, + _calculate_avgpool_out_shape, + _calculate_maxpool_out_shape, + _calculate_upsample_out_shape, + calculate_conv_out_shape, + calculate_convtranspose_out_shape, + calculate_pool_out_shape, + calculate_unpool_out_shape, +) + +INPUT_1D = torch.randn(2, 1, 10) +INPUT_2D = torch.randn(2, 1, 32, 32) +INPUT_3D = torch.randn(2, 1, 20, 21, 22) + + +@pytest.mark.parametrize( + "input_tensor,kernel_size,stride,padding,dilation", + [ + (INPUT_3D, 7, 2, (1, 2, 3), 3), + (INPUT_2D, (5, 3), 1, 0, (2, 2)), + (INPUT_1D, 3, 1, 2, 1), + ], +) +def test_calculate_conv_out_shape(input_tensor, kernel_size, stride, padding, dilation): + in_shape = input_tensor.shape[2:] + dim = len(input_tensor.shape[2:]) + args = { + "in_channels": 1, + "out_channels": 1, + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + } + if dim == 1: + conv = torch.nn.Conv1d + elif dim == 2: + conv = torch.nn.Conv2d + else: + conv = torch.nn.Conv3d + + output_shape = conv(**args)(input_tensor).shape[2:] + assert ( + calculate_conv_out_shape(in_shape, kernel_size, stride, padding, dilation) + == output_shape + ) + + +@pytest.mark.parametrize( + "input_tensor,kernel_size,stride,padding,dilation,output_padding", + [ + (INPUT_3D, 7, 2, (1, 2, 3), 3, 0), + (INPUT_2D, (5, 3), 1, 0, (2, 2), (1, 0)), + (INPUT_1D, 3, 3, 2, 1, 2), + ], +) +def test_calculate_convtranspose_out_shape( + input_tensor, kernel_size, stride, padding, dilation, output_padding +): + in_shape = input_tensor.shape[2:] + dim = len(input_tensor.shape[2:]) + args = { + "in_channels": 1, + "out_channels": 1, + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + "output_padding": output_padding, + } + if dim == 1: + conv = torch.nn.ConvTranspose1d + elif dim == 2: + conv = torch.nn.ConvTranspose2d + else: + conv = torch.nn.ConvTranspose3d + + output_shape = conv(**args)(input_tensor).shape[2:] + assert ( + calculate_convtranspose_out_shape( + in_shape, kernel_size, stride, padding, output_padding, dilation + ) + == output_shape + ) + + +@pytest.mark.parametrize( + "input_tensor,kernel_size,stride,padding,dilation,ceil_mode", + [ + (INPUT_3D, 7, 2, (1, 2, 3), 3, False), + (INPUT_3D, 7, 2, (1, 2, 3), 3, True), + (INPUT_2D, (5, 3), 1, 0, (2, 2), False), + (INPUT_2D, (5, 3), 1, 0, (2, 2), True), + (INPUT_1D, 2, 1, 1, 1, False), + (INPUT_1D, 2, 1, 1, 1, True), + ], +) +def test_calculate_maxpool_out_shape( + input_tensor, kernel_size, stride, padding, dilation, ceil_mode +): + in_shape = input_tensor.shape[2:] + dim = len(input_tensor.shape[2:]) + args = { + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + "ceil_mode": ceil_mode, + } + if dim == 1: + max_pool = torch.nn.MaxPool1d + elif dim == 2: + max_pool = torch.nn.MaxPool2d + else: + max_pool = torch.nn.MaxPool3d + + output_shape = max_pool(**args)(input_tensor).shape[2:] + assert ( + _calculate_maxpool_out_shape( + in_shape, kernel_size, stride, padding, dilation, ceil_mode=ceil_mode + ) + == output_shape + ) + + +@pytest.mark.parametrize( + "input_tensor,kernel_size,stride,padding,ceil_mode", + [ + (INPUT_3D, 7, 2, (1, 2, 3), False), + (INPUT_3D, 7, 2, (1, 2, 3), True), + (INPUT_2D, (5, 3), 1, 0, False), + (INPUT_2D, (5, 3), 1, 0, True), + (INPUT_1D, 2, 1, 1, False), + (INPUT_1D, 2, 1, 1, True), + ( + INPUT_1D, + 2, + 3, + 1, + True, + ), # special case with ceil_mode (see: https://pytorch.org/docs/stable/generated/torch.nn.AvgPool1d.html) + ], +) +def test_calculate_avgpool_out_shape( + input_tensor, kernel_size, stride, padding, ceil_mode +): + in_shape = input_tensor.shape[2:] + dim = len(in_shape) + args = { + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "ceil_mode": ceil_mode, + } + if dim == 1: + avg_pool = torch.nn.AvgPool1d + elif dim == 2: + avg_pool = torch.nn.AvgPool2d + else: + avg_pool = torch.nn.AvgPool3d + output_shape = avg_pool(**args)(input_tensor).shape[2:] + assert ( + _calculate_avgpool_out_shape( + in_shape, kernel_size, stride, padding, ceil_mode=ceil_mode + ) + == output_shape + ) + + +@pytest.mark.parametrize( + "input_tensor,kwargs", + [ + (INPUT_3D, {"output_size": 1}), + (INPUT_2D, {"output_size": (1, 2)}), + (INPUT_1D, {"output_size": 3}), + ], +) +def test_calculate_adaptivepool_out_shape(input_tensor, kwargs): + in_shape = input_tensor.shape[2:] + dim = len(in_shape) + if dim == 1: + avg_pool = torch.nn.AdaptiveAvgPool1d + max_pool = torch.nn.AdaptiveMaxPool1d + elif dim == 2: + avg_pool = torch.nn.AdaptiveAvgPool2d + max_pool = torch.nn.AdaptiveMaxPool2d + else: + avg_pool = torch.nn.AdaptiveAvgPool3d + max_pool = torch.nn.AdaptiveMaxPool3d + + output_shape = max_pool(**kwargs)(input_tensor).shape[2:] + assert _calculate_adaptivepool_out_shape(in_shape, **kwargs) == output_shape + + output_shape = avg_pool(**kwargs)(input_tensor).shape[2:] + assert _calculate_adaptivepool_out_shape(in_shape, **kwargs) == output_shape + + +def test_calculate_pool_out_shape(): + in_shape = INPUT_3D.shape[2:] + assert calculate_pool_out_shape( + pool_mode="max", + in_shape=in_shape, + kernel_size=7, + stride=2, + padding=(1, 2, 3), + dilation=3, + ceil_mode=True, + ) == (3, 4, 6) + assert calculate_pool_out_shape( + pool_mode="avg", + in_shape=in_shape, + kernel_size=7, + stride=2, + padding=(1, 2, 3), + ceil_mode=True, + ) == (9, 10, 12) + assert calculate_pool_out_shape( + pool_mode="adaptiveavg", + in_shape=in_shape, + output_size=(3, 4, 5), + ) == (3, 4, 5) + assert calculate_pool_out_shape( + pool_mode="adaptivemax", + in_shape=in_shape, + output_size=1, + ) == (1, 1, 1) + with pytest.raises(ValueError): + calculate_pool_out_shape( + pool_mode="abc", + in_shape=in_shape, + kernel_size=7, + stride=2, + padding=(1, 2, 3), + dilation=3, + ceil_mode=True, + ) + + +@pytest.mark.parametrize( + "input_tensor,kwargs", + [ + (INPUT_3D, {"scale_factor": 2}), + (INPUT_2D, {"size": (40, 41)}), + (INPUT_2D, {"size": 40}), + (INPUT_2D, {"scale_factor": (3, 2)}), + (INPUT_1D, {"scale_factor": 2}), + ], +) +def test_calculate_upsample_out_shape(input_tensor, kwargs): + in_shape = input_tensor.shape[2:] + unpool = torch.nn.Upsample(**kwargs) + + output_shape = unpool(input_tensor).shape[2:] + assert _calculate_upsample_out_shape(in_shape, **kwargs) == output_shape + + +def test_calculate_unpool_out_shape(): + in_shape = INPUT_3D.shape[2:] + assert calculate_unpool_out_shape( + unpool_mode="convtranspose", + in_shape=in_shape, + kernel_size=5, + stride=1, + padding=0, + output_padding=0, + dilation=1, + ) == (24, 25, 26) + assert calculate_unpool_out_shape( + unpool_mode="upsample", + in_shape=in_shape, + scale_factor=2, + ) == (40, 42, 44) + with pytest.raises(ValueError): + calculate_unpool_out_shape( + unpool_mode="abc", + in_shape=in_shape, + ) diff --git a/tests/unittests/monai_networks/test_factory.py b/tests/unittests/monai_networks/test_factory.py index 28e8113fe..d52f67871 100644 --- a/tests/unittests/monai_networks/test_factory.py +++ b/tests/unittests/monai_networks/test_factory.py @@ -1,10 +1,15 @@ import pytest -from monai.networks.nets import ResNet -from monai.networks.nets.resnet import ResNetBottleneck -from torch.nn import Conv2d -from clinicadl.monai_networks import get_network -from clinicadl.monai_networks.config import create_network_config +from clinicadl.networks import ( + ImplementedNetworks, + get_network, + get_network_from_config, +) +from clinicadl.networks.config.autoencoder import AutoEncoderConfig +from clinicadl.networks.factory import _update_config_with_defaults +from clinicadl.networks.nn import AutoEncoder + +tested = [] @pytest.mark.parametrize( @@ -13,124 +18,285 @@ ( "AutoEncoder", { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "channels": [2, 2], - "strides": [1, 1], + "in_shape": (1, 64, 65), + "latent_size": 1, + "conv_args": {"channels": [2, 4]}, }, ), ( - "VarAutoEncoder", + "VAE", { - "spatial_dims": 3, - "in_shape": (1, 16, 16, 16), - "out_channels": 1, - "latent_size": 16, - "channels": [2, 2], - "strides": [1, 1], + "in_shape": (1, 64, 65), + "latent_size": 1, + "conv_args": {"channels": [2, 4]}, }, ), ( - "Regressor", + "CNN", { - "in_shape": (1, 16, 16, 16), - "out_shape": (1, 16, 16, 16), - "channels": [2, 2], - "strides": [1, 1], + "in_shape": (1, 64, 65), + "num_outputs": 1, + "conv_args": {"channels": [2, 4]}, }, ), ( - "Classifier", + "Generator", { - "in_shape": (1, 16, 16, 16), - "classes": 2, - "channels": [2, 2], - "strides": [1, 1], + "latent_size": 1, + "start_shape": (1, 5, 5), + "conv_args": {"channels": [2, 4]}, }, ), ( - "Discriminator", - {"in_shape": (1, 16, 16, 16), "channels": [2, 2], "strides": [1, 1]}, + "ConvDecoder", + { + "spatial_dims": 2, + "in_channels": 1, + "channels": [2, 4], + }, ), ( - "Critic", - {"in_shape": (1, 16, 16, 16), "channels": [2, 2], "strides": [1, 1]}, + "ConvEncoder", + { + "spatial_dims": 2, + "in_channels": 1, + "channels": [2, 4], + }, ), - ("DenseNet", {"spatial_dims": 3, "in_channels": 1, "out_channels": 1}), ( - "FullyConnectedNet", - {"in_channels": 3, "out_channels": 1, "hidden_channels": [2, 3]}, + "MLP", + { + "in_channels": 1, + "out_channels": 2, + "hidden_channels": [2, 4], + }, ), ( - "VarFullyConnectedNet", + "AttentionUNet", { + "spatial_dims": 2, "in_channels": 1, - "out_channels": 1, - "latent_size": 16, - "encode_channels": [2, 2], - "decode_channels": [2, 2], + "out_channels": 2, }, ), ( - "Generator", + "UNet", { - "latent_shape": (3,), - "start_shape": (1, 16, 16, 16), - "channels": [2, 2], - "strides": [1, 1], + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 2, }, ), ( "ResNet", { - "block": "bottleneck", - "layers": (4, 4, 4, 4), - "block_inplanes": (5, 5, 5, 5), "spatial_dims": 2, + "in_channels": 1, + "num_outputs": 1, }, ), - ("ResNetFeatures", {"model_name": "resnet10"}), - ("SegResNet", {}), ( - "UNet", + "DenseNet", { - "spatial_dims": 3, + "spatial_dims": 2, "in_channels": 1, - "out_channels": 1, - "channels": [2, 2, 2], - "strides": [1, 1], + "num_outputs": 1, }, ), ( - "AttentionUnet", + "SEResNet", { - "spatial_dims": 3, + "spatial_dims": 2, "in_channels": 1, - "out_channels": 1, - "channels": [2, 2, 2], - "strides": [1, 1], + "num_outputs": 1, + }, + ), + ( + "ViT", + { + "in_shape": (1, 64, 65), + "patch_size": (4, 5), + "num_outputs": 1, + }, + ), + ( + "ResNet-18", + { + "num_outputs": 1, + }, + ), + ( + "ResNet-34", + { + "num_outputs": 1, + }, + ), + ( + "ResNet-50", + { + "num_outputs": 1, + }, + ), + ( + "ResNet-101", + { + "num_outputs": 1, + }, + ), + ( + "ResNet-152", + { + "num_outputs": 1, + "pretrained": True, + }, + ), + ( + "DenseNet-121", + { + "num_outputs": 1, + }, + ), + ( + "DenseNet-161", + { + "num_outputs": 1, + }, + ), + ( + "DenseNet-169", + { + "num_outputs": 1, + }, + ), + ( + "DenseNet-201", + { + "num_outputs": 1, + "pretrained": True, + }, + ), + ( + "SEResNet-50", + { + "num_outputs": 1, + }, + ), + ( + "SEResNet-101", + { + "num_outputs": 1, + }, + ), + ( + "SEResNet-152", + { + "num_outputs": 1, + }, + ), + ( + "ViT-B/16", + { + "num_outputs": 1, + "pretrained": True, + }, + ), + ( + "ViT-B/32", + { + "num_outputs": 1, + }, + ), + ( + "ViT-L/16", + { + "num_outputs": 1, + }, + ), + ( + "ViT-L/32", + { + "num_outputs": 1, }, ), - ("ViT", {"in_channels": 3, "img_size": 16, "patch_size": 4}), - ("ViTAutoEnc", {"in_channels": 3, "img_size": 16, "patch_size": 4}), ], ) def test_get_network(network_name, params): - config = create_network_config(network_name)(**params) - network, updated_config = get_network(config) + tested.append(network_name) + _ = get_network(name=network_name, **params) + if network_name == "ViT-L/32": # the last one + assert set(tested) == set( + net.value for net in ImplementedNetworks + ) # check we haven't miss a network + + +def test_update_config_with_defaults(): + config = AutoEncoderConfig( + latent_size=1, + in_shape=(1, 10, 10), + conv_args={"channels": [1, 2], "dropout": 0.2}, + mlp_args={"hidden_channels": [5], "act": "relu"}, + ) + _update_config_with_defaults(config, AutoEncoder.__init__) + assert config.in_shape == (1, 10, 10) + assert config.latent_size == 1 + assert config.conv_args.channels == [1, 2] + assert config.conv_args.dropout == 0.2 + assert config.conv_args.act == "prelu" + assert config.mlp_args.hidden_channels == [5] + assert config.mlp_args.act == "relu" + assert config.mlp_args.norm == "batch" + assert config.out_channels is None + + +def test_parameters(): + net, updated_config = get_network( + "AutoEncoder", + return_config=True, + latent_size=1, + in_shape=(1, 10, 10), + conv_args={"channels": [1, 2], "dropout": 0.2}, + mlp_args={"hidden_channels": [5], "act": "relu"}, + ) + assert isinstance(net, AutoEncoder) + assert net.encoder.mlp.out_channels == 1 + assert net.encoder.mlp.hidden_channels == [5] + assert net.encoder.mlp.act == "relu" + assert net.encoder.mlp.norm == "batch" + assert net.in_shape == (1, 10, 10) + assert net.encoder.convolutions.channels == (1, 2) + assert net.encoder.convolutions.dropout == 0.2 + assert net.encoder.convolutions.act == "prelu" + + assert updated_config.in_shape == (1, 10, 10) + assert updated_config.latent_size == 1 + assert updated_config.conv_args.channels == [1, 2] + assert updated_config.conv_args.dropout == 0.2 + assert updated_config.conv_args.act == "prelu" + assert updated_config.mlp_args.hidden_channels == [5] + assert updated_config.mlp_args.act == "relu" + assert updated_config.mlp_args.norm == "batch" + assert updated_config.out_channels is None + + +def test_without_return(): + net = get_network( + "AutoEncoder", + return_config=False, + latent_size=1, + in_shape=(1, 10, 10), + conv_args={"channels": [1, 2]}, + ) + assert isinstance(net, AutoEncoder) - if network_name == "ResNet": - assert isinstance(network, ResNet) - assert isinstance(network.layer1[0], ResNetBottleneck) - assert len(network.layer1) == 4 - assert network.layer1[0].conv1.in_channels == 5 - assert isinstance(network.layer1[0].conv1, Conv2d) - assert updated_config.network == "ResNet" - assert updated_config.block == "bottleneck" - assert updated_config.layers == (4, 4, 4, 4) - assert updated_config.block_inplanes == (5, 5, 5, 5) - assert updated_config.spatial_dims == 2 - assert updated_config.conv1_t_size == 7 - assert updated_config.act == ("relu", {"inplace": True}) +def test_get_network_from_config(): + config = AutoEncoderConfig( + latent_size=1, + in_shape=(1, 10, 10), + conv_args={"channels": [1, 2], "dropout": 0.2}, + mlp_args={"hidden_channels": [5], "act": "relu"}, + ) + net, updated_config = get_network_from_config(config) + assert isinstance(net, AutoEncoder) + assert updated_config.conv_args.act == "prelu" + assert config.conv_args.act == "DefaultFromLibrary" diff --git a/tests/unittests/nn/blocks/test_decoder.py b/tests/unittests/nn/blocks/test_decoder.py index 01bf7aef1..38a2a9e28 100644 --- a/tests/unittests/nn/blocks/test_decoder.py +++ b/tests/unittests/nn/blocks/test_decoder.py @@ -1,7 +1,7 @@ import pytest import torch -import clinicadl.nn.blocks.decoder as decoder +import clinicadl.networks.old_network.nn.blocks.decoder as decoder @pytest.fixture diff --git a/tests/unittests/nn/blocks/test_encoder.py b/tests/unittests/nn/blocks/test_encoder.py index dcb676f96..9149b731a 100644 --- a/tests/unittests/nn/blocks/test_encoder.py +++ b/tests/unittests/nn/blocks/test_encoder.py @@ -1,7 +1,7 @@ import pytest import torch -import clinicadl.nn.blocks.encoder as encoder +import clinicadl.networks.old_network.nn.blocks.encoder as encoder @pytest.fixture diff --git a/tests/unittests/nn/blocks/test_residual.py b/tests/unittests/nn/blocks/test_residual.py index 302051ee3..7db9800ed 100644 --- a/tests/unittests/nn/blocks/test_residual.py +++ b/tests/unittests/nn/blocks/test_residual.py @@ -1,6 +1,6 @@ import torch -from clinicadl.nn.blocks import ResBlock +from clinicadl.networks.old_network.nn.blocks import ResBlock def test_resblock(): diff --git a/tests/unittests/nn/blocks/test_se.py b/tests/unittests/nn/blocks/test_se.py index 2444bcc3a..fba558ade 100644 --- a/tests/unittests/nn/blocks/test_se.py +++ b/tests/unittests/nn/blocks/test_se.py @@ -8,7 +8,7 @@ def input_3d(): def test_SE_Block(input_3d): - from clinicadl.nn.blocks import SE_Block + from clinicadl.networks.old_network.nn.blocks import SE_Block layer = SE_Block(num_channels=input_3d.shape[1], ratio_channel=4) out = layer(input_3d) @@ -16,7 +16,7 @@ def test_SE_Block(input_3d): def test_ResBlock_SE(input_3d): - from clinicadl.nn.blocks import ResBlock_SE + from clinicadl.networks.old_network.nn.blocks import ResBlock_SE layer = ResBlock_SE( num_channels=input_3d.shape[1], diff --git a/tests/unittests/nn/blocks/test_unet.py b/tests/unittests/nn/blocks/test_unet.py index 4e7170d77..e1f11ab5d 100644 --- a/tests/unittests/nn/blocks/test_unet.py +++ b/tests/unittests/nn/blocks/test_unet.py @@ -13,7 +13,7 @@ def skip_input(): def test_UNetDown(input_3d): - from clinicadl.nn.blocks import UNetDown + from clinicadl.networks.old_network.nn.blocks import UNetDown layer = UNetDown(in_size=input_3d.shape[1], out_size=8) out = layer(input_3d) @@ -21,7 +21,7 @@ def test_UNetDown(input_3d): def test_UNetUp(input_3d, skip_input): - from clinicadl.nn.blocks import UNetUp + from clinicadl.networks.old_network.nn.blocks import UNetUp layer = UNetUp(in_size=input_3d.shape[1] * 2, out_size=2) out = layer(input_3d, skip_input=skip_input) @@ -29,7 +29,7 @@ def test_UNetUp(input_3d, skip_input): def test_UNetFinalLayer(input_3d, skip_input): - from clinicadl.nn.blocks import UNetFinalLayer + from clinicadl.networks.old_network.nn.blocks import UNetFinalLayer layer = UNetFinalLayer(in_size=input_3d.shape[1] * 2, out_size=2) out = layer(input_3d, skip_input=skip_input) diff --git a/tests/unittests/nn/layers/factory/test_factories.py b/tests/unittests/nn/layers/factory/test_factories.py index 7036cc724..0c1af2da5 100644 --- a/tests/unittests/nn/layers/factory/test_factories.py +++ b/tests/unittests/nn/layers/factory/test_factories.py @@ -3,7 +3,7 @@ def test_get_conv_layer(): - from clinicadl.nn.layers.factory import get_conv_layer + from clinicadl.networks.old_network.nn.layers.factory import get_conv_layer assert get_conv_layer(2) == nn.Conv2d assert get_conv_layer(3) == nn.Conv3d @@ -12,7 +12,7 @@ def test_get_conv_layer(): def test_get_norm_layer(): - from clinicadl.nn.layers.factory import get_norm_layer + from clinicadl.networks.old_network.nn.layers.factory import get_norm_layer assert get_norm_layer("InstanceNorm", 2) == nn.InstanceNorm2d assert get_norm_layer("BatchNorm", 3) == nn.BatchNorm3d @@ -20,8 +20,8 @@ def test_get_norm_layer(): def test_get_pool_layer(): - from clinicadl.nn.layers import PadMaxPool3d - from clinicadl.nn.layers.factory import get_pool_layer + from clinicadl.networks.old_network.nn.layers import PadMaxPool3d + from clinicadl.networks.old_network.nn.layers.factory import get_pool_layer assert get_pool_layer("MaxPool", 2) == nn.MaxPool2d assert get_pool_layer("PadMaxPool", 3) == PadMaxPool3d diff --git a/tests/unittests/nn/layers/test_layers.py b/tests/unittests/nn/layers/test_layers.py index e07eb1cf6..633beb423 100644 --- a/tests/unittests/nn/layers/test_layers.py +++ b/tests/unittests/nn/layers/test_layers.py @@ -1,7 +1,7 @@ import pytest import torch -import clinicadl.nn.layers as layers +import clinicadl.networks.old_network.nn.layers as layers @pytest.fixture diff --git a/tests/unittests/nn/networks/factory/test_ae_factory.py b/tests/unittests/nn/networks/factory/test_ae_factory.py index a4fe1a762..8f997b874 100644 --- a/tests/unittests/nn/networks/factory/test_ae_factory.py +++ b/tests/unittests/nn/networks/factory/test_ae_factory.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -from clinicadl.nn.layers import ( +from clinicadl.networks.old_network.nn.layers import ( PadMaxPool2d, PadMaxPool3d, ) @@ -58,8 +58,8 @@ def __init__(self, input_size): @pytest.mark.parametrize("input, cnn", [("input_3d", "cnn3d"), ("input_2d", "cnn2d")]) def test_autoencoder_from_cnn(input, cnn, request): - from clinicadl.nn.networks.ae import AE - from clinicadl.nn.networks.factory import autoencoder_from_cnn + from clinicadl.networks.old_network.nn.networks.ae import AE + from clinicadl.networks.old_network.nn.networks.factory import autoencoder_from_cnn input_ = request.getfixturevalue(input) cnn = request.getfixturevalue(cnn)(input_size=input_.shape[1:]) diff --git a/tests/unittests/nn/networks/factory/test_resnet_factory.py b/tests/unittests/nn/networks/factory/test_resnet_factory.py index 1468d37ad..ed2d17610 100644 --- a/tests/unittests/nn/networks/factory/test_resnet_factory.py +++ b/tests/unittests/nn/networks/factory/test_resnet_factory.py @@ -5,7 +5,7 @@ def test_ResNetDesigner(): from torchvision.models.resnet import BasicBlock - from clinicadl.nn.networks.factory import ResNetDesigner + from clinicadl.networks.old_network.nn.networks.factory import ResNetDesigner input_ = torch.randn(2, 3, 100, 100) @@ -43,7 +43,7 @@ def forward(self, x): def test_ResNetDesigner3D(): - from clinicadl.nn.networks.factory import ResNetDesigner3D + from clinicadl.networks.old_network.nn.networks.factory import ResNetDesigner3D input_ = torch.randn(2, 3, 100, 100, 100) diff --git a/tests/unittests/nn/networks/factory/test_secnn_factory.py b/tests/unittests/nn/networks/factory/test_secnn_factory.py index 96be92620..c5650dfc7 100644 --- a/tests/unittests/nn/networks/factory/test_secnn_factory.py +++ b/tests/unittests/nn/networks/factory/test_secnn_factory.py @@ -3,7 +3,7 @@ def test_SECNNDesigner3D(): - from clinicadl.nn.networks.factory import SECNNDesigner3D + from clinicadl.networks.old_network.nn.networks.factory import SECNNDesigner3D input_ = torch.randn(2, 3, 100, 100, 100) diff --git a/tests/unittests/nn/networks/test_ae.py b/tests/unittests/nn/networks/test_ae.py index 9c6152d35..0f86ad24d 100644 --- a/tests/unittests/nn/networks/test_ae.py +++ b/tests/unittests/nn/networks/test_ae.py @@ -1,7 +1,7 @@ import pytest import torch -import clinicadl.nn.networks.ae as ae +import clinicadl.networks.old_network.nn.networks.ae as ae @pytest.mark.parametrize("network", [net.value for net in ae.AE2d]) diff --git a/tests/unittests/nn/networks/test_cnn.py b/tests/unittests/nn/networks/test_cnn.py index 3f6a0cb87..b09ff7bbf 100644 --- a/tests/unittests/nn/networks/test_cnn.py +++ b/tests/unittests/nn/networks/test_cnn.py @@ -1,7 +1,7 @@ import pytest import torch -import clinicadl.nn.networks.cnn as cnn +import clinicadl.networks.old_network.nn.networks.cnn as cnn @pytest.fixture diff --git a/tests/unittests/nn/networks/test_ssda.py b/tests/unittests/nn/networks/test_ssda.py deleted file mode 100644 index 06da85ff2..000000000 --- a/tests/unittests/nn/networks/test_ssda.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch - -from clinicadl.nn.networks.ssda import Conv5_FC3_SSDA - - -def test_UNet(): - input_ = torch.randn(2, 1, 64, 63, 62) - network = Conv5_FC3_SSDA(input_size=(1, 64, 63, 62), output_size=3) - output = network(input_) - for out in output: - assert out.shape == torch.Size((2, 3)) diff --git a/tests/unittests/nn/networks/test_unet.py b/tests/unittests/nn/networks/test_unet.py index ba0408cdb..4279205b8 100644 --- a/tests/unittests/nn/networks/test_unet.py +++ b/tests/unittests/nn/networks/test_unet.py @@ -1,6 +1,6 @@ import torch -from clinicadl.nn.networks.unet import UNet +from clinicadl.networks.old_network.nn.networks.unet import UNet def test_UNet(): diff --git a/tests/unittests/nn/networks/test_vae.py b/tests/unittests/nn/networks/test_vae.py index 308a2f185..890b0eacc 100644 --- a/tests/unittests/nn/networks/test_vae.py +++ b/tests/unittests/nn/networks/test_vae.py @@ -1,7 +1,7 @@ import pytest import torch -import clinicadl.nn.networks.vae as vae +import clinicadl.networks.old_network.nn.networks.vae as vae @pytest.fixture diff --git a/tests/unittests/nn/test_utils.py b/tests/unittests/nn/test_utils.py index bcd379613..f70b8c518 100644 --- a/tests/unittests/nn/test_utils.py +++ b/tests/unittests/nn/test_utils.py @@ -3,7 +3,7 @@ def test_compute_output_size(): - from clinicadl.nn.utils import compute_output_size + from clinicadl.networks.old_network.nn.utils import compute_output_size input_2d = torch.randn(3, 2, 100, 100) input_3d = torch.randn(3, 1, 100, 100, 100) diff --git a/tests/unittests/optim/early_stopping/test_config.py b/tests/unittests/optim/early_stopping/test_config.py index 574887a6f..4c12d8208 100644 --- a/tests/unittests/optim/early_stopping/test_config.py +++ b/tests/unittests/optim/early_stopping/test_config.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from clinicadl.optim.early_stopping import EarlyStoppingConfig +from clinicadl.optimization.early_stopping import EarlyStoppingConfig def test_EarlyStoppingConfig(): diff --git a/tests/unittests/optim/early_stopping/test_early_stopper.py b/tests/unittests/optim/early_stopping/test_early_stopper.py index 13f0d9f9c..91bbd7e09 100644 --- a/tests/unittests/optim/early_stopping/test_early_stopper.py +++ b/tests/unittests/optim/early_stopping/test_early_stopper.py @@ -1,6 +1,6 @@ import numpy as np -from clinicadl.optim.early_stopping import EarlyStopping, EarlyStoppingConfig +from clinicadl.optimization.early_stopping import EarlyStopping, EarlyStoppingConfig def test_EarlyStopping(): diff --git a/tests/unittests/optim/lr_scheduler/test_config.py b/tests/unittests/optim/lr_scheduler/test_config.py index dbf96ccc8..a2233055c 100644 --- a/tests/unittests/optim/lr_scheduler/test_config.py +++ b/tests/unittests/optim/lr_scheduler/test_config.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from clinicadl.optim.lr_scheduler.config import ( +from clinicadl.optimization.lr_scheduler.config import ( ConstantLRConfig, LinearLRConfig, MultiStepLRConfig, diff --git a/tests/unittests/optim/lr_scheduler/test_factory.py b/tests/unittests/optim/lr_scheduler/test_factory.py index cffb3d138..559f078ab 100644 --- a/tests/unittests/optim/lr_scheduler/test_factory.py +++ b/tests/unittests/optim/lr_scheduler/test_factory.py @@ -4,7 +4,7 @@ from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau -from clinicadl.optim.lr_scheduler import ( +from clinicadl.optimization.lr_scheduler import ( ImplementedLRScheduler, create_lr_scheduler_config, get_lr_scheduler, diff --git a/tests/unittests/optim/optimizer/test_config.py b/tests/unittests/optim/optimizer/test_config.py index bf1dbcd8f..7888f4eb8 100644 --- a/tests/unittests/optim/optimizer/test_config.py +++ b/tests/unittests/optim/optimizer/test_config.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from clinicadl.optim.optimizer.config import ( +from clinicadl.optimization.optimizer.config import ( AdadeltaConfig, AdagradConfig, AdamConfig, diff --git a/tests/unittests/optim/optimizer/test_factory.py b/tests/unittests/optim/optimizer/test_factory.py index 47b44a00a..12507134c 100644 --- a/tests/unittests/optim/optimizer/test_factory.py +++ b/tests/unittests/optim/optimizer/test_factory.py @@ -5,12 +5,12 @@ import torch.nn as nn from torch.optim import Adagrad -from clinicadl.optim.optimizer import ( +from clinicadl.optimization.optimizer import ( ImplementedOptimizer, create_optimizer_config, get_optimizer, ) -from clinicadl.optim.optimizer.factory import ( +from clinicadl.optimization.optimizer.factory import ( _get_params_in_group, _get_params_not_in_group, _regroup_args, diff --git a/tests/unittests/optim/optimizer/test_utils.py b/tests/unittests/optim/optimizer/test_utils.py index afa06a5d0..ba480af94 100644 --- a/tests/unittests/optim/optimizer/test_utils.py +++ b/tests/unittests/optim/optimizer/test_utils.py @@ -27,7 +27,7 @@ def network(): def test_get_params_in_groups(network): import torch - from clinicadl.optim.optimizer.utils import get_params_in_groups + from clinicadl.optimization.optimizer.utils import get_params_in_groups iterator, list_layers = get_params_in_groups(network, "dense1") assert next(iter(iterator)).shape == torch.Size((10, 10)) @@ -77,7 +77,7 @@ def test_get_params_in_groups(network): def test_find_params_not_in_group(network): import torch - from clinicadl.optim.optimizer.utils import get_params_not_in_groups + from clinicadl.optimization.optimizer.utils import get_params_not_in_groups iterator, list_layers = get_params_not_in_groups( network, diff --git a/tests/unittests/optim/test_config.py b/tests/unittests/optim/test_config.py index 9b980bb84..c74a13265 100644 --- a/tests/unittests/optim/test_config.py +++ b/tests/unittests/optim/test_config.py @@ -1,4 +1,4 @@ -from clinicadl.optim import OptimizationConfig +from clinicadl.optimization import OptimizationConfig def test_OptimizationConfig(): diff --git a/tests/unittests/train/test_utils.py b/tests/unittests/train/test_utils.py index 6b33787eb..2914d2d9b 100644 --- a/tests/unittests/train/test_utils.py +++ b/tests/unittests/train/test_utils.py @@ -7,7 +7,6 @@ expected_classification = { "architecture": "default", "multi_network": False, - "ssda_network": False, "dropout": 0.0, "latent_space_size": 128, "feature_size": 1024, @@ -65,7 +64,6 @@ expected_regression = { "architecture": "default", "multi_network": False, - "ssda_network": False, "dropout": 0.0, "latent_space_size": 128, "feature_size": 1024, @@ -121,7 +119,6 @@ expected_reconstruction = { "architecture": "default", "multi_network": False, - "ssda_network": False, "dropout": 0.0, "latent_space_size": 128, "feature_size": 1024, diff --git a/tests/unittests/train/trainer/test_training_config.py b/tests/unittests/train/trainer/test_training_config.py index 503b88ddf..061c64afe 100644 --- a/tests/unittests/train/trainer/test_training_config.py +++ b/tests/unittests/train/trainer/test_training_config.py @@ -3,11 +3,10 @@ import pytest from pydantic import ValidationError -from clinicadl.caps_dataset.data_config import DataConfig -from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig -from clinicadl.config.config.ssda import SSDAConfig -from clinicadl.network.config import NetworkConfig -from clinicadl.splitter.validation import ValidationConfig +from clinicadl.dataset.data_config import DataConfig +from clinicadl.dataset.dataloader_config import DataLoaderConfig +from clinicadl.networks.old_network.config import NetworkConfig +from clinicadl.predictor.validation import ValidationConfig from clinicadl.trainer.transfer_learning import TransferLearningConfig from clinicadl.transforms.config import TransformsConfig @@ -70,31 +69,6 @@ def test_model_config(): ) -def test_ssda_config(caps_example): - preprocessing_json_target = ( - caps_example / "tensor_extraction" / "preprocessing.json" - ) - c = SSDAConfig( - ssda_network=True, - preprocessing_json_target=preprocessing_json_target, - ) - expected_preprocessing_dict = { - "preprocessing": "t1-linear", - "mode": "image", - "use_uncropped_image": False, - "prepare_dl": False, - "extract_json": "t1-linear_mode-image.json", - "file_type": { - "pattern": "*space-MNI152NLin2009cSym_desc-Crop_res-1x1x1_T1w.nii.gz", - "description": "T1W Image registered using t1-linear and cropped (matrix size 169\u00d7208\u00d7179, 1 mm isotropic voxels)", - "needed_pipeline": "t1-linear", - }, - } - assert c.preprocessing_dict_target == expected_preprocessing_dict - c = SSDAConfig() - assert c.preprocessing_dict_target == {} - - def test_transferlearning_config(): c = TransferLearningConfig(transfer_path=False) assert c.transfer_path is None @@ -142,12 +116,10 @@ def network_task(self) -> str: params=[ {"gpu": "abc"}, {"n_splits": -1}, - {"optimizer": "abc"}, {"data_augmentation": ("abc",)}, {"diagnoses": "AD"}, {"batch_size": 0}, {"size_reduction_factor": 1}, - {"learning_rate": 0.0}, {"split": [-1]}, {"tolerance": -0.01}, ] @@ -161,7 +133,6 @@ def good_inputs(dummy_arguments): options = { "gpu": False, "n_splits": 7, - "optimizer": "Adagrad", "data_augmentation": ("Smoothing",), "diagnoses": ("AD",), "batch_size": 1, @@ -182,12 +153,10 @@ def test_passes_validations(good_inputs, training_config): c = training_config(**good_inputs) assert not c.computational.gpu assert c.split.n_splits == 7 - assert c.optimizer.optimizer == "Adagrad" assert c.transforms.data_augmentation == ("Smoothing",) assert c.data.diagnoses == ("AD",) assert c.dataloader.batch_size == 1 assert c.transforms.size_reduction_factor == 5 - assert c.optimizer.learning_rate == 1e-1 assert c.split.split == (0,) assert c.early_stopping.tolerance == 0.0 diff --git a/tests/unittests/utils/test_clinica_utils.py b/tests/unittests/utils/test_clinica_utils.py index 7b87ceacb..087441ff3 100644 --- a/tests/unittests/utils/test_clinica_utils.py +++ b/tests/unittests/utils/test_clinica_utils.py @@ -21,8 +21,8 @@ def test_pet_linear_nii( tracer, suvr_reference_region, uncropped_image, expected_pattern ): - from clinicadl.caps_dataset.preprocessing.config import PETPreprocessingConfig - from clinicadl.caps_dataset.preprocessing.utils import pet_linear_nii + from clinicadl.dataset.config.preprocessing import PETPreprocessingConfig + from clinicadl.dataset.utils import pet_linear_nii from clinicadl.utils.iotools.clinica_utils import FileType config = PETPreprocessingConfig(