Skip to content

Commit

Permalink
First draft for KFold (#684)
Browse files Browse the repository at this point in the history
* make_split
* make_kfold
* KFold
*SingleSplit

---------

Co-authored-by: camillebrianceau <[email protected]>
Co-authored-by: camillebrianceau <[email protected]>
  • Loading branch information
3 people authored Dec 18, 2024
1 parent 464dddf commit 621cc96
Show file tree
Hide file tree
Showing 54 changed files with 2,923 additions and 546 deletions.
137 changes: 67 additions & 70 deletions clinicadl/API/complicated_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,15 @@

import torchio.transforms as transforms

from clinicadl.dataset.caps_dataset import (
CapsDatasetPatch,
CapsDatasetRoi,
CapsDatasetSlice,
)
from clinicadl.dataset.caps_reader import CapsReader
from clinicadl.dataset.concat import ConcatDataset
from clinicadl.dataset.config.extraction import ExtractionConfig
from clinicadl.dataset.config.preprocessing import (
PreprocessingConfig,
T1PreprocessingConfig,
)
from clinicadl.dataset.dataloader_config import DataLoaderConfig
from clinicadl.dataset.datasets.caps_dataset import CapsDataset
from clinicadl.dataset.datasets.concat import ConcatDataset
from clinicadl.dataset.preprocessing import (
PreprocessingCustom,
PreprocessingPET,
PreprocessingT1,
)
from clinicadl.dataset.readers.caps_reader import CapsReader
from clinicadl.experiment_manager.experiment_manager import ExperimentManager
from clinicadl.losses.config import CrossEntropyLossConfig
from clinicadl.losses.factory import get_loss_function
Expand All @@ -31,6 +27,7 @@
from clinicadl.splitter.kfold import KFolder
from clinicadl.splitter.split import get_single_split, split_tsv
from clinicadl.trainer.trainer import Trainer
from clinicadl.transforms.extraction import ROI, BaseExtraction, Image, Patch, Slice
from clinicadl.transforms.transforms import Transforms

# Create the Maps Manager / Read/write manager /
Expand All @@ -40,60 +37,47 @@
) # a ajouter dans le manager: mlflow/ profiler/ etc ...

caps_directory = Path("caps_directory") # output of clinica pipelines
caps_reader = CapsReader(caps_directory, manager=manager)

preprocessing_1 = caps_reader.get_preprocessing("t1-linear")
caps_reader.prepare_data(
preprocessing=preprocessing_1, data_tsv=Path(""), n_proc=2
) # don't return anything -> just extract the image tensor and compute some information for each images

transforms_1 = Transforms(
object_augmentation=[transforms.Crop, transforms.Transform],
image_augmentation=[transforms.Crop, transforms.Transform],
extraction=ExtractionPatchConfig(patch_size=3),
image_transforms=[transforms.Blur, transforms.Ghosting],
object_transforms=[transforms.BiasField, transforms.Motion],
) # not mandatory

preprocessing_2 = caps_reader.get_preprocessing("pet-linear")
caps_reader.prepare_data(
preprocessing=preprocessing_1, data_tsv=Path(""), n_proc=2
) # to extract the tensor of the PET file this time
transforms_2 = Transforms(
object_augmentation=[transforms.Crop, transforms.Transform],
image_augmentation=[transforms.Crop, transforms.Transform],
extraction=ExtractionSliceConfig(),
image_transforms=[transforms.Blur, transforms.Ghosting],
object_transforms=[transforms.BiasField, transforms.Motion],

sub_ses_t1 = Path("/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_t1.tsv")
preprocessing_t1 = PreprocessingT1()
transforms_image = Transforms(
image_augmentation=[transforms.RandomMotion()],
extraction=Image(),
image_transforms=[transforms.Blur((0.5, 0.6, 0.3))],
)

print("T1 and image ")

dataset_t1_image = CapsDataset(
caps_directory=caps_directory,
data=sub_ses_t1,
preprocessing=preprocessing_t1,
transforms=transforms_image,
)
dataset_t1_image.prepare_data(n_proc=2) # to extract the tensor of the T1 file


sub_ses_pet_45 = Path(
"/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_pet_18FAV45.tsv"
)
preprocessing_pet_45 = PreprocessingPET(tracer="18FAV45", suvr_reference_region="pons2")

sub_ses_tsv_1 = Path("")
split_dir_1 = split_tsv(sub_ses_tsv_1) # -> creer un test.tsv et un train.tsv

sub_ses_tsv_2 = Path("")
split_dir_2 = split_tsv(sub_ses_tsv_2)

dataset_t1_roi = caps_reader.get_dataset(
preprocessing=preprocessing_1,
sub_ses_tsv=split_dir_1 / "train.tsv",
transforms=transforms_1,
) # do we give config or object for transforms ?
dataset_pet_patch = caps_reader.get_dataset(
preprocessing=preprocessing_2,
sub_ses_tsv=split_dir_2 / "train.tsv",
transforms=transforms_2,
dataset_pet_image = CapsDataset(
caps_directory=caps_directory,
data=sub_ses_pet_45,
preprocessing=preprocessing_pet_45,
transforms=transforms_image,
)
dataset_t1_image.prepare_data(n_proc=2) # to extract the tensor of the PET file


dataset_multi_modality_multi_extract = ConcatDataset(
dataset_multi_modality = ConcatDataset(
[
dataset_t1_roi,
dataset_pet_patch,
caps_reader.get_dataset_from_json(json_path=Path("dataset.json")),
dataset_t1_image,
dataset_pet_image,
]
) # 3 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention

# TODO : think about adding transforms in extract_json


config_file = Path("config_file")
trainer = Trainer.from_json(config_file=config_file, manager=manager)
Expand All @@ -106,6 +90,26 @@

dataloader_config = DataLoaderConfig(n_procs=3, batch_size=10)


# CAS 1

# Prérequis : déjà avoir des fichiers avec les listes train et validation
split_dir = make_kfold(
"dataset.tsv"
) # lit dataset.tsv => fait le kfold => ecrit la sortie dans split_dir
splitter = KFolder(
dataset_multi_modality, split_dir
) # c'est plutôt un iterable de dataloader

# CAS 2
splitter = KFolder(caps_dataset=dataset_t1_image)
splitter.make_splits(n_splits=3)
splitter.write(split_dir)

# or
splitter = KFolder(caps_dataset=dataset_t1_image)
splitter.read(split_dir)

for split in splitter.get_splits(splits=(0, 3, 4), dataloader_config=dataloader_config):
# bien définir ce qu'il y a dans l'objet split

Expand All @@ -125,19 +129,12 @@

# TEST

preprocessing_test = caps_reader.get_preprocessing("pet-linear")
transforms_test = Transforms(
object_augmentation=[transforms.Crop, transforms.Transform],
image_augmentation=[transforms.Crop, transforms.Transform],
extraction=ExtractioImageConfig(),
image_transforms=[transforms.Blur, transforms.Ghosting],
object_transforms=[transforms.BiasField, transforms.Motion],
)

dataset_test = caps_reader.get_dataset(
preprocessing=preprocessing_test,
sub_ses_tsv=split_dir_1 / "test.tsv", # test only on data from the first dataset
transforms=transforms_test,
dataset_test = CapsDataset(
caps_directory=caps_directory,
preprocessing=preprocessing_t1,
sub_ses_tsv=Path("test.tsv"), # test only on data from the first dataset
transforms=transforms_image,
)

predictor = Predictor(model=model, manager=manager)
Expand Down
74 changes: 16 additions & 58 deletions clinicadl/API/cross_val.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,38 @@
from pathlib import Path

import torchio.transforms as transforms

from clinicadl.dataset.caps_dataset import (
CapsDatasetPatch,
CapsDatasetRoi,
CapsDatasetSlice,
)
from clinicadl.dataset.caps_reader import CapsReader
from clinicadl.dataset.concat import ConcatDataset
from clinicadl.dataset.config.extraction import ExtractionConfig
from clinicadl.dataset.config.preprocessing import (
PreprocessingConfig,
T1PreprocessingConfig,
)
from clinicadl.dataset.dataloader_config import DataLoaderConfig
from clinicadl.dataset.datasets.caps_dataset import CapsDataset
from clinicadl.experiment_manager.experiment_manager import ExperimentManager
from clinicadl.losses.config import CrossEntropyLossConfig
from clinicadl.losses.factory import get_loss_function
from clinicadl.model.clinicadl_model import ClinicaDLModel
from clinicadl.networks.config import ImplementedNetworks
from clinicadl.networks.factory import (
ConvEncoderOptions,
create_network_config,
get_network_from_config,
)
from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig
from clinicadl.optimization.optimizer.factory import get_optimizer
from clinicadl.predictor.predictor import Predictor
from clinicadl.splitter.kfold import KFolder
from clinicadl.splitter.split import get_single_split, split_tsv
from clinicadl.splitter.new_splitter.dataloader import DataLoaderConfig
from clinicadl.splitter.new_splitter.splitter.kfold import KFold
from clinicadl.trainer.trainer import Trainer
from clinicadl.transforms.config import TransformsConfig

# SIMPLE EXPERIMENT WITH A CAPS ALREADY EXISTING

maps_path = Path("/")
manager = ExperimentManager(maps_path, overwrite=False)

dataset_t1_image = CapsDatasetPatch.from_json(
extraction=Path("test.json"),
sub_ses_tsv=Path("split_dir") / "train.tsv",
)
dataset_t1_image = CapsDataset.from_json(Path("json_path.json"))

config_file = Path("config_file")
trainer = Trainer.from_json(
config_file=config_file, manager=manager
) # gpu, amp, fsdp, seed

# CAS CROSS-VALIDATION
splitter = KFolder(caps_dataset=dataset_t1_image, manager=manager)
split_dir = splitter.make_splits(
n_splits=3,
output_dir=Path(""),
data_tsv=Path("labels.tsv"),
subset_name="validation",
stratification="",
) # Optional data tsv and output_dir
# n_splits must be >1
# for the single split case, this method output a path to the directory containing the train and test tsv files so we should have the same output here
splitter = KFold(dataset=dataset_t1_image)
splitter.make_splits(n_splits=3)
split_dir = Path("")
splitter.write(split_dir)

# CAS EXISTING CROSS-VALIDATION
splitter = KFolder.from_split_dir(caps_dataset=dataset_t1_image, manager=manager)
splitter.read(split_dir)

# define the needed parameters for the dataloader
dataloader_config = DataLoaderConfig(n_procs=3, batch_size=10)
dataloader_config = DataLoaderConfig(num_workers=3, batch_size=10)

for split in splitter.get_splits(splits=(0, 3, 4), dataloader_config=dataloader_config):
# bien définir ce qu'il y a dans l'objet split

network_config = create_network_config(ImplementedNetworks.CNN)(
in_shape=[2, 2, 2],
num_outputs=1,
conv_args=ConvEncoderOptions(channels=[3, 2, 2]),
)
optimizer, _ = get_optimizer(network, AdamConfig())
model = ClinicaDLModel(network=network_config, loss=nn.MSE(), optimizer=optimizer)
for split in splitter.get_splits(splits=(0, 3, 4)):
print(split)
split.build_train_loader(dataloader_config)
split.build_val_loader(num_workers=3, batch_size=10)

trainer.train(model, split)
# le trainer va instancier un predictor/valdiator dans le train ou dans le init
print(split)
9 changes: 5 additions & 4 deletions clinicadl/commandline/modules_options/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from clinicadl.config.config_utils import get_default_from_config_class as get_default
from clinicadl.config.config_utils import get_type_from_config_class as get_type
from clinicadl.splitter.config import SplitConfig
from clinicadl.splitter.splitter.kfold import KFoldConfig
from clinicadl.splitter.splitter.single_split import SingleSplitConfig

# Cross Validation
n_splits = click.option(
"--n_splits",
type=get_type("n_splits", SplitConfig),
default=get_default("n_splits", SplitConfig),
type=get_type("n_splits", KFoldConfig),
default=get_default("n_splits", KFoldConfig),
help="If a value is given for k will load data of a k-fold CV. "
"Default value (0) will load a single split.",
show_default=True,
Expand All @@ -17,7 +18,7 @@
"--split",
"-s",
type=int, # get_type("split", config.ValidationConfig),
default=get_default("split", SplitConfig),
default=get_default("split", SingleSplitConfig),
multiple=True,
help="Train the list of given splits. By default, all the splits are trained.",
show_default=True,
Expand Down
1 change: 1 addition & 0 deletions clinicadl/dataset/dataloader/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .config import DataLoaderConfig
Loading

0 comments on commit 621cc96

Please sign in to comment.