From b0d54d06e6d550c562f8b1028d15e1ea225d69bc Mon Sep 17 00:00:00 2001 From: camillebrianceau <57992134+camillebrianceau@users.noreply.github.com> Date: Tue, 29 Oct 2024 17:46:13 +0100 Subject: [PATCH] Update clinicadl/API_test_v2.py --- clinicadl/API_test_v2.py | 142 +++++++++++++++++++++------------------ 1 file changed, 76 insertions(+), 66 deletions(-) diff --git a/clinicadl/API_test_v2.py b/clinicadl/API_test_v2.py index de24565c7..0bf730d3c 100644 --- a/clinicadl/API_test_v2.py +++ b/clinicadl/API_test_v2.py @@ -1,102 +1,112 @@ -# %% class -class MapsIO: +from pathlib import Path +from clinicadl.caps_dataset2.config.preprocessing import PreprocessingConfig +from clinicadl.caps_dataset2.config.extraction import ExtractionConfig +from clinicadl.caps_dataset2.data import CapsDatasetRoi, CapsDatasetPatch, CapsDatasetSlice +from clinicadl.transforms.config import TransformsConfig +import torchio +from clinicadl import tsvtools +from clinicadl.trainer.trainer import Trainer + +class ExperimentManager: pass - -class CapsDataset: +class CapsReader: pass - -class Splitter: +class Transforms: pass - -class ClinicaDLModels: +class Predictor: pass +class ClinicaDLModel: + pass -class Networks: +class KFolder: pass +def get_loss_function(): + pass -class VAE(Networks): +def get_network_from_config(): pass +def create_network_config(): + pass -class Optimizer: +def get_single_split(): pass +# Create the Maps Manager / Read/write manager / +maps_path = Path("/") +manager = ExperimentManager(maps_path, overwrite = False) -class Loss: - pass +caps_directory = Path("caps_directory") # output of clinica pipelines +caps_reader = CapsReader(caps_directory, manager= manager) +preprocessing_1: PreprocessingConfig = caps_reader.get_preprocessing("t1-linear") +extraction_1: ExtractionConfig = caps_reader.extract_slice(preprocessing=preprocessing_1, arg_slice = 2) +transforms_1 = Transforms(data_augmentation = [torchio.t1, torchio.t2], image_transforms= [torchio.t1, torchio.t2], object_transforms=[torchio.t1, torchio.t2]) # not mandatory -class Metrics: - pass +preprocessing_2: PreprocessingConfig = caps_reader.get_preprocessing("pet-linear") +extraction_2: ExtractionConfig= caps_reader.extract_patch(preprocessing=preprocessing_2, arg_patch = 2) +transforms_2 = Transforms(data_augmentation = [torchio.t2], image_transforms= [torchio.t1], object_transforms=[torchio.t1, torchio.t2]) +sub_ses_tsv = Path("") +split_dir = tsvtools.split_tsv(sub_ses_tsv) #-> creer un test.tsv et un train.tsv -class Trainer: - pass +dataset_t1_roi: CapsDatasetRoi = caps_reader.get_dataset( extraction = extraction_1, preprocessing = preprocessing_1, sub_ses_tsv = split_dir / "train.tsv", transforms = transforms_1) +dataset_pet_patch: CapsDatasetPatch = caps_reader.get_dataset(extraction = extraction_2, preprocessing= preprocessing_2, sub_ses_tsv = split_dir / "train.tsv", transforms = transforms_2) +dataset_multi_modality_multi_extract = concat_dataset(dataset_t1, dataset_pet) # 2 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention -class Validator: - pass +config_file = Path("config_file") +trainer = Trainer.from_json(config_file=config_file, manager=manager) +# CAS CROSS-VALIDATION +splitter = KFolder(n_splits = 3, caps_dataset=dataset_multi_modality_multi_extract, manager = manager) -# %% maps -maps = MapsIO("/path/to/maps") # Crée un dossier +for split in splitter.split_iterator(split_list = [0,1]): + # bien définir ce qu'il y a dans l'objet split + + loss, loss_config = get_loss_function(CrossEntropyLossConfig()) + network_config = create_network_config(ImplementedNetworks.CNN)(in_shape = [2,2,2], num_outputs = 1, conv_args = ConvEncoderOptions(channels = [3, 2, 2])) + network = get_network_from_config(network_config) + + model = ClinicaDLModel( + network= network, + loss=loss, + optimizer= AdamConfig() + ) -# %% Dataset -DataConfig = { - "caps_dir": "", - "tsv": "", - "mode": "", -} -capsdataset = CapsDataset(DataConfig, maps) # torch.dataset + trainer.train(model, split) + # le trainer va instancier un predictor/valdiator dans le train ou dans le init -# %% Model -network = VAE() # nn.module -loss = Loss() -optimizer = ClinicaDLOptim( - Adam() -) # get_optimizer({"name": "Adam", "par1": 0.5}) # torch.optim -# model = ClinicaDLModels( -# network, -# loss, -# optimizer, -# ) -# %% Cross val -SplitConfig = SplitterConfig() -splitter = Splitter(SplitConfig, capsdataset) +# CAS SINGLE SPLIT +split = get_single_split(n_subject_validation = 0, caps_dataset=dataset_multi_modality_multi_extract, manager = manager) -# %% Metrics -metrics1 = Metrics("MAE") # monai.metric -metrics2 = Metrics("MSE") # monai.metric +loss, loss_config = get_loss_function(CrossEntropyLossConfig()) +network_config = create_network_config(ImplementedNetworks.CNN)(in_shape = [2,2,2], num_outputs = 1, conv_args = ConvEncoderOptions(channels = [3, 2, 2])) +network = get_network_from_config(network_config) + +model = ClinicaDLModel( + network= network, + loss=loss, + optimizer= AdamConfig() +) +trainer.train(model, split) +# le trainer va instancier un predictor/valdiator dans le train ou dans le init -# %% Option 1 -for split in splitter.iterate(): - trainer = Trainer(split, maps, (optimizer)) - validator = Validator(split, [metrics1, metrics2], maps) - trainer.train(validator, model) +# TEST +preprocessing_test: PreprocessingConfig = caps_reader.get_preprocessing("pet-linear") +extraction_test: ExtractionConfig= caps_reader.extract_patch(preprocessing=preprocessing_2, arg_patch = 2) +transforms_test = Transforms(data_augmentation = [torchio.t2], image_transforms= [torchio.t1], object_transforms=[torchio.t1, torchio.t2]) -# %% Option 2 -val = Validator([metrics1, metrics2], maps) -trainer = Trainer(validator, maps) -for split in splitter.iterate(): - trainer.train(model, split) +dataset_test: CapsDatasetROI = caps_reader.get_dataset( extraction = extraction_test, preprocessing = preprocessing_test, sub_ses_tsv = split_dir / "test.tsv", transforms = transforms_test) -# %% Option 3 -trainer = Trainer( - maps, [metrics1, metrics2] -) # Initialise un maps manager + initialise un validator -for split in splitter.iterate(): - model = ClinicaDLModels( - network, - loss, - optimizer, - ) - trainer.train(model, split) +predictor = Predictor(manager= manager) +predictor.predict(dataset_test= dataset_test, split = 2)