Skip to content

Commit

Permalink
API ideas (#681)
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau authored Nov 25, 2024
1 parent 3cd23e4 commit 0cf01a6
Show file tree
Hide file tree
Showing 3 changed files with 314 additions and 0 deletions.
144 changes: 144 additions & 0 deletions clinicadl/API/complicated_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from pathlib import Path

import torchio.transforms as transforms

from clinicadl.dataset.caps_dataset import (
CapsDatasetPatch,
CapsDatasetRoi,
CapsDatasetSlice,
)
from clinicadl.dataset.caps_reader import CapsReader
from clinicadl.dataset.concat import ConcatDataset
from clinicadl.dataset.config.extraction import ExtractionConfig
from clinicadl.dataset.config.preprocessing import (
PreprocessingConfig,
T1PreprocessingConfig,
)
from clinicadl.dataset.dataloader_config import DataLoaderConfig
from clinicadl.experiment_manager.experiment_manager import ExperimentManager
from clinicadl.losses.config import CrossEntropyLossConfig
from clinicadl.losses.factory import get_loss_function
from clinicadl.model.clinicadl_model import ClinicaDLModel
from clinicadl.networks.config import ImplementedNetworks
from clinicadl.networks.factory import (
ConvEncoderOptions,
create_network_config,
get_network_from_config,
)
from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig
from clinicadl.optimization.optimizer.factory import get_optimizer
from clinicadl.predictor.predictor import Predictor
from clinicadl.splitter.kfold import KFolder
from clinicadl.splitter.split import get_single_split, split_tsv
from clinicadl.trainer.trainer import Trainer
from clinicadl.transforms.transforms import Transforms

# Create the Maps Manager / Read/write manager /
maps_path = Path("/")
manager = ExperimentManager(
maps_path, overwrite=False
) # a ajouter dans le manager: mlflow/ profiler/ etc ...

caps_directory = Path("caps_directory") # output of clinica pipelines
caps_reader = CapsReader(caps_directory, manager=manager)

preprocessing_1 = caps_reader.get_preprocessing("t1-linear")
caps_reader.prepare_data(
preprocessing=preprocessing_1, data_tsv=Path(""), n_proc=2
) # don't return anything -> just extract the image tensor and compute some information for each images

transforms_1 = Transforms(
object_augmentation=[transforms.Crop, transforms.Transform],
image_augmentation=[transforms.Crop, transforms.Transform],
extraction=ExtractionPatchConfig(patch_size=3),
image_transforms=[transforms.Blur, transforms.Ghosting],
object_transforms=[transforms.BiasField, transforms.Motion],
) # not mandatory

preprocessing_2 = caps_reader.get_preprocessing("pet-linear")
caps_reader.prepare_data(
preprocessing=preprocessing_1, data_tsv=Path(""), n_proc=2
) # to extract the tensor of the PET file this time
transforms_2 = Transforms(
object_augmentation=[transforms.Crop, transforms.Transform],
image_augmentation=[transforms.Crop, transforms.Transform],
extraction=ExtractionSliceConfig(),
image_transforms=[transforms.Blur, transforms.Ghosting],
object_transforms=[transforms.BiasField, transforms.Motion],
)

sub_ses_tsv_1 = Path("")
split_dir_1 = split_tsv(sub_ses_tsv_1) # -> creer un test.tsv et un train.tsv

sub_ses_tsv_2 = Path("")
split_dir_2 = split_tsv(sub_ses_tsv_2)

dataset_t1_roi = caps_reader.get_dataset(
preprocessing=preprocessing_1,
sub_ses_tsv=split_dir_1 / "train.tsv",
transforms=transforms_1,
) # do we give config or object for transforms ?
dataset_pet_patch = caps_reader.get_dataset(
preprocessing=preprocessing_2,
sub_ses_tsv=split_dir_2 / "train.tsv",
transforms=transforms_2,
)

dataset_multi_modality_multi_extract = ConcatDataset(
[
dataset_t1_roi,
dataset_pet_patch,
caps_reader.get_dataset_from_json(json_path=Path("dataset.json")),
]
) # 3 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention

# TODO : think about adding transforms in extract_json


config_file = Path("config_file")
trainer = Trainer.from_json(config_file=config_file, manager=manager)

# CAS CROSS-VALIDATION
splitter = KFolder(caps_dataset=dataset_multi_modality_multi_extract, manager=manager)
split_dir = splitter.make_splits(
n_splits=3, output_dir=Path(""), subset_name="validation", stratification=""
) # Optional data tsv and output_dir

dataloader_config = DataLoaderConfig(n_procs=3, batch_size=10)

for split in splitter.get_splits(splits=(0, 3, 4), dataloader_config=dataloader_config):
# bien définir ce qu'il y a dans l'objet split

network_config = create_network_config(ImplementedNetworks.CNN)(
in_shape=[2, 2, 2],
num_outputs=1,
conv_args=ConvEncoderOptions(channels=[3, 2, 2]),
)
model = ClinicaDLModelClassif.from_config(
network_config=network_config,
loss_config=CrossEntropyLossConfig(),
optimizer_config=AdamConfig(),
)

trainer.train(model, split)
# le trainer va instancier un predictor/valdiator dans le train ou dans le init

# TEST

preprocessing_test = caps_reader.get_preprocessing("pet-linear")
transforms_test = Transforms(
object_augmentation=[transforms.Crop, transforms.Transform],
image_augmentation=[transforms.Crop, transforms.Transform],
extraction=ExtractioImageConfig(),
image_transforms=[transforms.Blur, transforms.Ghosting],
object_transforms=[transforms.BiasField, transforms.Motion],
)

dataset_test = caps_reader.get_dataset(
preprocessing=preprocessing_test,
sub_ses_tsv=split_dir_1 / "test.tsv", # test only on data from the first dataset
transforms=transforms_test,
)

predictor = Predictor(model=model, manager=manager)
predictor.predict(dataset_test=dataset_test, split_number=2)
80 changes: 80 additions & 0 deletions clinicadl/API/cross_val.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from pathlib import Path

import torchio.transforms as transforms

from clinicadl.dataset.caps_dataset import (
CapsDatasetPatch,
CapsDatasetRoi,
CapsDatasetSlice,
)
from clinicadl.dataset.caps_reader import CapsReader
from clinicadl.dataset.concat import ConcatDataset
from clinicadl.dataset.config.extraction import ExtractionConfig
from clinicadl.dataset.config.preprocessing import (
PreprocessingConfig,
T1PreprocessingConfig,
)
from clinicadl.dataset.dataloader_config import DataLoaderConfig
from clinicadl.experiment_manager.experiment_manager import ExperimentManager
from clinicadl.losses.config import CrossEntropyLossConfig
from clinicadl.losses.factory import get_loss_function
from clinicadl.model.clinicadl_model import ClinicaDLModel
from clinicadl.networks.config import ImplementedNetworks
from clinicadl.networks.factory import (
ConvEncoderOptions,
create_network_config,
get_network_from_config,
)
from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig
from clinicadl.optimization.optimizer.factory import get_optimizer
from clinicadl.predictor.predictor import Predictor
from clinicadl.splitter.kfold import KFolder
from clinicadl.splitter.split import get_single_split, split_tsv
from clinicadl.trainer.trainer import Trainer
from clinicadl.transforms.config import TransformsConfig

# SIMPLE EXPERIMENT WITH A CAPS ALREADY EXISTING

maps_path = Path("/")
manager = ExperimentManager(maps_path, overwrite=False)

dataset_t1_image = CapsDatasetPatch.from_json(
extraction=Path("test.json"),
sub_ses_tsv=Path("split_dir") / "train.tsv",
)
config_file = Path("config_file")
trainer = Trainer.from_json(
config_file=config_file, manager=manager
) # gpu, amp, fsdp, seed

# CAS CROSS-VALIDATION
splitter = KFolder(caps_dataset=dataset_t1_image, manager=manager)
split_dir = splitter.make_splits(
n_splits=3,
output_dir=Path(""),
data_tsv=Path("labels.tsv"),
subset_name="validation",
stratification="",
) # Optional data tsv and output_dir
# n_splits must be >1
# for the single split case, this method output a path to the directory containing the train and test tsv files so we should have the same output here

# CAS EXISTING CROSS-VALIDATION
splitter = KFolder.from_split_dir(caps_dataset=dataset_t1_image, manager=manager)

# define the needed parameters for the dataloader
dataloader_config = DataLoaderConfig(n_procs=3, batch_size=10)

for split in splitter.get_splits(splits=(0, 3, 4), dataloader_config=dataloader_config):
# bien définir ce qu'il y a dans l'objet split

network_config = create_network_config(ImplementedNetworks.CNN)(
in_shape=[2, 2, 2],
num_outputs=1,
conv_args=ConvEncoderOptions(channels=[3, 2, 2]),
)
optimizer, _ = get_optimizer(network, AdamConfig())
model = ClinicaDLModel(network=network_config, loss=nn.MSE(), optimizer=optimizer)

trainer.train(model, split)
# le trainer va instancier un predictor/valdiator dans le train ou dans le init
90 changes: 90 additions & 0 deletions clinicadl/API/single_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from pathlib import Path

import torchio.transforms as transforms

from clinicadl.dataset.caps_reader import CapsReader
from clinicadl.dataset.concat import ConcatDataset
from clinicadl.dataset.config.extraction import ExtractionConfig, ExtractionPatchConfig
from clinicadl.dataset.config.preprocessing import (
PreprocessingConfig,
T1PreprocessingConfig,
)
from clinicadl.experiment_manager.experiment_manager import ExperimentManager
from clinicadl.losses.config import CrossEntropyLossConfig
from clinicadl.losses.factory import get_loss_function
from clinicadl.model.clinicadl_model import ClinicaDLModel
from clinicadl.networks.config import ImplementedNetworks
from clinicadl.networks.factory import (
ConvEncoderOptions,
create_network_config,
get_network_from_config,
)
from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig
from clinicadl.optimization.optimizer.factory import get_optimizer
from clinicadl.predictor.predictor import Predictor
from clinicadl.splitter.kfold import KFolder
from clinicadl.splitter.split import get_single_split, split_tsv
from clinicadl.trainer.trainer import Trainer
from clinicadl.transforms.config import TransformsConfig
from clinicadl.transforms.transforms import Transforms
from clinicadl.utils.enum import ExtractionMethod

# SIMPLE EXPERIMENT


caps_directory = Path("caps_directory") # output of clinica pipelines
caps_reader = CapsReader(caps_directory)
# un peu bizarre de passer un maps_path a cet endroit via le manager pq on veut pas forcmeent faire un entrainement ??

preprocessing_t1 = caps_reader.get_preprocessing("t1-linear")
caps_reader.prepare_data(
preprocessing=preprocessing_t1,
data_tsv=Path(""),
n_proc=2,
use_uncropped_images=False,
)
transforms_1 = Transforms(
object_augmentation=[transforms.RandomMotion()], # default = no transforms
image_augmentation=[transforms.RandomMotion()], # default = no transforms
object_transforms=[transforms.Blur((0.4, 0.5, 0.6))], # default = none
image_transforms=[transforms.Noise(0.2, 0.5, 3)], # default = MiniMax
extraction=ExtractionPatchConfig(patch_size=30, stride_size=20), # default = Image
) # not mandatory

sub_ses_tsv = Path("")
split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv

dataset_t1_image = caps_reader.get_dataset(
preprocessing=preprocessing_t1,
sub_ses_tsv=split_dir / "train.tsv",
transforms=transforms_1,
) # do we give config or ob -> dataset.json
# we can create a dataset.json in the CAPS ? or elsewhere ?
# but maybe we need to create a json file with the infos from the dataset (preprocessing, tsv file, transforms options and caps_directory)
dataset_t1_image = caps_reader.get_dataset_from_json("dataset.json")

# CAS SINGLE SPLIT
split = get_single_split(
n_subject_validation=0,
caps_dataset=dataset_t1_image,
# manager=manager,
) # as we said, maybe we do not need to pass the manager in this function

maps_path = Path("/")
manager = ExperimentManager(maps_path, overwrite=False)

config_file = Path("config_file")
trainer = Trainer.from_json(config_file=config_file, manager=manager)
# how to create the trainer not from a config file ?

network_config = create_network_config(ImplementedNetworks.CNN)(
in_shape=[2, 2, 2], num_outputs=1, conv_args=ConvEncoderOptions(channels=[3, 2, 2])
)
model = ClinicaDLModelClassif.from_config(
network_config=network_config,
loss_config=CrossEntropyLossConfig(),
optimizer_config=AdamConfig(),
)

trainer.train(model, split)
# le trainer va instancier un predictor/valdiator dans le train ou dans le init

0 comments on commit 0cf01a6

Please sign in to comment.