Skip to content

Commit

Permalink
Base for v2 (#676)
Browse files Browse the repository at this point in the history
* base for clinicadl v2
  • Loading branch information
camillebrianceau authored Oct 30, 2024
1 parent d17ce05 commit 68cf9da
Show file tree
Hide file tree
Showing 264 changed files with 2,958 additions and 4,820 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: 'Lint codebase'

on:
pull_request:
branches: [ "dev", "refactoring" ]
branches: [ "dev", "refactoring", "clinicadl_v2" ]
push:
branches: [ "dev", "refactoring" ]
branches: [ "dev", "refactoring", "clinicadl_v2" ]

permissions:
contents: read
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: Test

on:
push:
branches: ["dev", "refactoring"]
branches: ["dev", "refactoring", "clinicadl_v2"]
pull_request:
branches: ["dev", "refactoring"]
branches: ["dev", "refactoring", "clinicadl_v2"]

permissions:
contents: read
Expand Down
294 changes: 216 additions & 78 deletions clinicadl/API_test.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,225 @@
from pathlib import Path

from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig
from clinicadl.caps_dataset.data import return_dataset
from clinicadl.predictor.config import PredictConfig
import torchio

from clinicadl.dataset.caps_dataset import (
CapsDatasetPatch,
CapsDatasetRoi,
CapsDatasetSlice,
)
from clinicadl.dataset.caps_reader import CapsReader
from clinicadl.dataset.concat import ConcatDataset
from clinicadl.dataset.config.extraction import ExtractionConfig
from clinicadl.dataset.config.preprocessing import (
PreprocessingConfig,
T1PreprocessingConfig,
)
from clinicadl.experiment_manager.experiment_manager import ExperimentManager
from clinicadl.losses.config import CrossEntropyLossConfig
from clinicadl.losses.factory import get_loss_function
from clinicadl.model.clinicadl_model import ClinicaDLModel
from clinicadl.networks.config import ImplementedNetworks
from clinicadl.networks.factory import (
ConvEncoderOptions,
create_network_config,
get_network_from_config,
)
from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig
from clinicadl.optimization.optimizer.factory import get_optimizer
from clinicadl.predictor.predictor import Predictor
from clinicadl.prepare_data.prepare_data import DeepLearningPrepareData
from clinicadl.splitter.config import SplitterConfig
from clinicadl.splitter.splitter import Splitter
from clinicadl.trainer.config.classification import ClassificationConfig
from clinicadl.splitter.kfold import KFolder
from clinicadl.splitter.split import get_single_split, split_tsv
from clinicadl.trainer.trainer import Trainer
from clinicadl.utils.enum import ExtractionMethod, Preprocessing, Task
from clinicadl.utils.iotools.train_utils import merge_cli_and_config_file_options
from clinicadl.transforms.transforms import Transforms

# Create the Maps Manager / Read/write manager /
maps_path = Path("/")
manager = ExperimentManager(maps_path, overwrite=False)

caps_directory = Path("caps_directory") # output of clinica pipelines
caps_reader = CapsReader(caps_directory, manager=manager)

preprocessing_1 = caps_reader.get_preprocessing("t1-linear")
extraction_1 = caps_reader.extract_slice(preprocessing=preprocessing_1, arg_slice=2)
transforms_1 = Transforms(
data_augmentation=[torchio.t1, torchio.t2],
image_transforms=[torchio.t1, torchio.t2],
object_transforms=[torchio.t1, torchio.t2],
) # not mandatory

preprocessing_2 = caps_reader.get_preprocessing("pet-linear")
extraction_2 = caps_reader.extract_patch(preprocessing=preprocessing_2, arg_patch=2)
transforms_2 = Transforms(
data_augmentation=[torchio.t2],
image_transforms=[torchio.t1],
object_transforms=[torchio.t1, torchio.t2],
)

sub_ses_tsv = Path("")
split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv

dataset_t1_roi = caps_reader.get_dataset(
extraction=extraction_1,
preprocessing=preprocessing_1,
sub_ses_tsv=split_dir / "train.tsv",
transforms=transforms_1,
) # do we give config or object for transforms ?
dataset_pet_patch = caps_reader.get_dataset(
extraction=extraction_2,
preprocessing=preprocessing_2,
sub_ses_tsv=split_dir / "train.tsv",
transforms=transforms_2,
)

dataset_multi_modality_multi_extract = ConcatDataset(
[dataset_t1_roi, dataset_pet_patch]
) # 2 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention

config_file = Path("config_file")
trainer = Trainer.from_json(config_file=config_file, manager=manager)

# CAS CROSS-VALIDATION
splitter = KFolder(
n_splits=3, caps_dataset=dataset_multi_modality_multi_extract, manager=manager
)

for split in splitter.split_iterator(split_list=[0, 1]):
# bien définir ce qu'il y a dans l'objet split

loss, loss_config = get_loss_function(CrossEntropyLossConfig())
network_config = create_network_config(ImplementedNetworks.CNN)(
in_shape=[2, 2, 2],
num_outputs=1,
conv_args=ConvEncoderOptions(channels=[3, 2, 2]),
)
network, _ = get_network_from_config(network_config)
optimizer, _ = get_optimizer(network, AdamConfig())
model = ClinicaDLModel(network=network, loss=loss, optimizer=optimizer)

trainer.train(model, split)
# le trainer va instancier un predictor/valdiator dans le train ou dans le init


# CAS SINGLE SPLIT
split = get_single_split(
n_subject_validation=0,
caps_dataset=dataset_multi_modality_multi_extract,
manager=manager,
)

loss, loss_config = get_loss_function(CrossEntropyLossConfig())
network_config = create_network_config(ImplementedNetworks.CNN)(
in_shape=[2, 2, 2], num_outputs=1, conv_args=ConvEncoderOptions(channels=[3, 2, 2])
)
network, _ = get_network_from_config(network_config)
optimizer, _ = get_optimizer(network, AdamConfig())
model = ClinicaDLModel(network=network, loss=loss, optimizer=optimizer)

trainer.train(model, split)
# le trainer va instancier un predictor/valdiator dans le train ou dans le init


image_config = CapsDatasetConfig.from_preprocessing_and_extraction_method(
extraction=ExtractionMethod.IMAGE,
preprocessing_type=Preprocessing.T1_LINEAR,
# TEST

preprocessing_test: PreprocessingConfig = caps_reader.get_preprocessing("pet-linear")
extraction_test: ExtractionConfig = caps_reader.extract_patch(
preprocessing=preprocessing_2, arg_patch=2
)
transforms_test = Transforms(
data_augmentation=[torchio.t2],
image_transforms=[torchio.t1],
object_transforms=[torchio.t1, torchio.t2],
)

dataset_test = caps_reader.get_dataset(
extraction=extraction_test,
preprocessing=preprocessing_test,
sub_ses_tsv=split_dir / "test.tsv",
transforms=transforms_test,
)

DeepLearningPrepareData(image_config)

dataset = return_dataset(
input_dir,
data_df,
preprocessing_dict,
transforms_config,
label,
label_code,
cnn_index,
label_presence,
multi_cohort,
predictor = Predictor(manager=manager)
predictor.predict(dataset_test=dataset_test, split=2)


# SIMPLE EXPERIMENT


maps_path = Path("/")
manager = ExperimentManager(maps_path, overwrite=False)

caps_directory = Path("caps_directory") # output of clinica pipelines
caps_reader = CapsReader(caps_directory, manager=manager)

extraction_1 = caps_reader.extract_image(preprocessing=T1PreprocessingConfig())
transforms_1 = Transforms(
data_augmentation=[torchio.transforms.RandomMotion]
) # not mandatory

sub_ses_tsv = Path("")
split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv

dataset_t1_image = caps_reader.get_dataset(
extraction=extraction_1,
preprocessing=T1PreprocessingConfig(),
sub_ses_tsv=split_dir / "train.tsv",
transforms=transforms_1,
) # do we give config or ob


config_file = Path("config_file")
trainer = Trainer.from_json(config_file=config_file, manager=manager)

# CAS CROSS-VALIDATION
splitter = KFolder(n_splits=3, caps_dataset=dataset_t1_image, manager=manager)

for split in splitter.split_iterator(split_list=[0, 1]):
# bien définir ce qu'il y a dans l'objet split

loss, loss_config = get_loss_function(CrossEntropyLossConfig())
network_config = create_network_config(ImplementedNetworks.CNN)(
in_shape=[2, 2, 2],
num_outputs=1,
conv_args=ConvEncoderOptions(channels=[3, 2, 2]),
)
network, _ = get_network_from_config(network_config)
optimizer, _ = get_optimizer(network, AdamConfig())
model = ClinicaDLModel(network=network, loss=loss, optimizer=optimizer)

trainer.train(model, split)
# le trainer va instancier un predictor/valdiator dans le train ou dans le init


# SIMPLE EXPERIMENT WITH A CAPS ALREADY EXISTING

maps_path = Path("/")
manager = ExperimentManager(maps_path, overwrite=False)

# sub_ses_tsv = Path("")
# split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv

dataset_t1_image = CapsDatasetPatch.from_json(
extraction=extract_json,
sub_ses_tsv=split_dir / "train.tsv",
)
config_file = Path("config_file")
trainer = Trainer.from_json(config_file=config_file, manager=manager)

# CAS CROSS-VALIDATION
splitter = KFolder(n_splits=3, caps_dataset=dataset_t1_image, manager=manager)

for split in splitter.split_iterator(split_list=[0, 1]):
# bien définir ce qu'il y a dans l'objet split

loss, loss_config = get_loss_function(CrossEntropyLossConfig())
network_config = create_network_config(ImplementedNetworks.CNN)(
in_shape=[2, 2, 2],
num_outputs=1,
conv_args=ConvEncoderOptions(channels=[3, 2, 2]),
)
network, _ = get_network_from_config(network_config)
optimizer, _ = get_optimizer(network, AdamConfig())
model = ClinicaDLModel(network=network, loss=loss, optimizer=optimizer)

split_config = SplitterConfig()
splitter = Splitter(split_config)

validator_config = PredictConfig()
validator = Predictor(validator_config)

train_config = ClassificationConfig()
trainer = Trainer(train_config, validator)

for split in splitter.split_iterator():
for network in range(
first_network, self.maps_manager.num_networks
): # for multi_network
###### actual _train_single method of the Trainer ############
train_loader = trainer.get_dataloader(dataset, split, network, "train", config)
valid_loader = validator.get_dataloader(
dataset, split, network, "valid", config
) # ?? validatior, trainer ?

trainer._train(
train_loader,
valid_loader,
split=split,
network=network,
resume=resume, # in a config class
callbacks=[CodeCarbonTracker], # in a config class ?
)

validator._ensemble_prediction(
self.maps_manager,
"train",
split,
self.config.validation.selection_metrics,
)
validator._ensemble_prediction(
self.maps_manager,
"validation",
split,
self.config.validation.selection_metrics,
)
###### end ############


for split in splitter.split_iterator():
for network in range(
first_network, self.maps_manager.num_networks
): # for multi_network
###### actual _train_single method of the Trainer ############
test_loader = trainer.get_dataloader(dataset, split, network, "test", config)
validator.predict(test_loader)

interpret_config = PredictConfig(**kwargs)
predict_manager = Predictor(interpret_config)
predict_manager.interpret()
trainer.train(model, split)
# le trainer va instancier un predictor/valdiator dans le train ou dans le init
2 changes: 1 addition & 1 deletion clinicadl/commandline/modules_options/data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import click

from clinicadl.caps_dataset.data_config import DataConfig
from clinicadl.config.config_utils import get_default_from_config_class as get_default
from clinicadl.config.config_utils import get_type_from_config_class as get_type
from clinicadl.dataset.data_config import DataConfig

# Data
baseline = click.option(
Expand Down
2 changes: 1 addition & 1 deletion clinicadl/commandline/modules_options/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import click

from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig
from clinicadl.config.config_utils import get_default_from_config_class as get_default
from clinicadl.config.config_utils import get_type_from_config_class as get_type
from clinicadl.dataset.dataloader_config import DataLoaderConfig

# DataLoader
batch_size = click.option(
Expand Down
6 changes: 3 additions & 3 deletions clinicadl/commandline/modules_options/extraction.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import click

from clinicadl.caps_dataset.extraction.config import (
from clinicadl.config.config_utils import get_default_from_config_class as get_default
from clinicadl.config.config_utils import get_type_from_config_class as get_type
from clinicadl.dataset.config.extraction import (
ExtractionConfig,
ExtractionImageConfig,
ExtractionPatchConfig,
ExtractionROIConfig,
ExtractionSliceConfig,
)
from clinicadl.config.config_utils import get_default_from_config_class as get_default
from clinicadl.config.config_utils import get_type_from_config_class as get_type

extract_json = click.option(
"-ej",
Expand Down
2 changes: 1 addition & 1 deletion clinicadl/commandline/modules_options/maps_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import click

from clinicadl.config.config_utils import get_type_from_config_class as get_type
from clinicadl.maps_manager.config import MapsManagerConfig
from clinicadl.experiment_manager.config import MapsManagerConfig

maps_dir = click.argument("maps_dir", type=get_type("maps_dir", MapsManagerConfig))
data_group = click.option("data_group", type=get_type("data_group", MapsManagerConfig))
Expand Down
2 changes: 1 addition & 1 deletion clinicadl/commandline/modules_options/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from clinicadl.config.config_utils import get_default_from_config_class as get_default
from clinicadl.config.config_utils import get_type_from_config_class as get_type
from clinicadl.network.config import NetworkConfig
from clinicadl.networks.old_network.config import NetworkConfig

# Model
multi_network = click.option(
Expand Down
2 changes: 1 addition & 1 deletion clinicadl/commandline/modules_options/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from clinicadl.config.config_utils import get_default_from_config_class as get_default
from clinicadl.config.config_utils import get_type_from_config_class as get_type
from clinicadl.optimizer.optimization import OptimizationConfig
from clinicadl.optimization.config import OptimizationConfig

# Optimization
accumulation_steps = click.option(
Expand Down
Loading

0 comments on commit 68cf9da

Please sign in to comment.