-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'clinicadl_v2' into caps_dataset_transforms
- Loading branch information
Showing
162 changed files
with
3,821 additions
and
3,332 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,80 +1,35 @@ | ||
from pathlib import Path | ||
|
||
import torchio.transforms as transforms | ||
|
||
from clinicadl.dataset.caps_dataset import ( | ||
CapsDatasetPatch, | ||
CapsDatasetRoi, | ||
CapsDatasetSlice, | ||
) | ||
from clinicadl.dataset.caps_reader import CapsReader | ||
from clinicadl.dataset.concat import ConcatDataset | ||
from clinicadl.dataset.config.extraction import ExtractionConfig | ||
from clinicadl.dataset.config.preprocessing import ( | ||
PreprocessingConfig, | ||
T1PreprocessingConfig, | ||
) | ||
from clinicadl.dataset.dataloader_config import DataLoaderConfig | ||
from clinicadl.experiment_manager.experiment_manager import ExperimentManager | ||
from clinicadl.losses.config import CrossEntropyLossConfig | ||
from clinicadl.losses.factory import get_loss_function | ||
from clinicadl.model.clinicadl_model import ClinicaDLModel | ||
from clinicadl.networks.config import ImplementedNetworks | ||
from clinicadl.networks.factory import ( | ||
ConvEncoderOptions, | ||
create_network_config, | ||
get_network_from_config, | ||
) | ||
from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig | ||
from clinicadl.optimization.optimizer.factory import get_optimizer | ||
from clinicadl.predictor.predictor import Predictor | ||
from clinicadl.splitter.kfold import KFolder | ||
from clinicadl.splitter.split import get_single_split, split_tsv | ||
from clinicadl.trainer.trainer import Trainer | ||
from clinicadl.transforms.config import TransformsConfig | ||
from clinicadl.data.dataloader import DataLoaderConfig | ||
from clinicadl.data.datasets import CapsDataset | ||
from clinicadl.experiment_manager import ExperimentManager | ||
from clinicadl.splitter import KFold, make_kfold, make_split | ||
from clinicadl.trainer import Trainer | ||
|
||
# SIMPLE EXPERIMENT WITH A CAPS ALREADY EXISTING | ||
|
||
maps_path = Path("/") | ||
manager = ExperimentManager(maps_path, overwrite=False) | ||
|
||
dataset_t1_image = CapsDatasetPatch.from_json( | ||
extraction=Path("test.json"), | ||
sub_ses_tsv=Path("split_dir") / "train.tsv", | ||
) | ||
dataset_t1_image = CapsDataset.from_json(Path("json_path.json")) | ||
|
||
config_file = Path("config_file") | ||
trainer = Trainer.from_json( | ||
config_file=config_file, manager=manager | ||
) # gpu, amp, fsdp, seed | ||
|
||
# CAS CROSS-VALIDATION | ||
splitter = KFolder(caps_dataset=dataset_t1_image, manager=manager) | ||
split_dir = splitter.make_splits( | ||
n_splits=3, | ||
output_dir=Path(""), | ||
data_tsv=Path("labels.tsv"), | ||
subset_name="validation", | ||
stratification="", | ||
split_dir = make_split( | ||
dataset_t1_image.df, n_test=0.2, subset_name="validation", output_dir="test" | ||
) # Optional data tsv and output_dir | ||
# n_splits must be >1 | ||
# for the single split case, this method output a path to the directory containing the train and test tsv files so we should have the same output here | ||
fold_dir = make_kfold(split_dir / "train.tsv", n_splits=2) | ||
|
||
# CAS EXISTING CROSS-VALIDATION | ||
splitter = KFolder.from_split_dir(caps_dataset=dataset_t1_image, manager=manager) | ||
splitter = KFold(fold_dir) | ||
|
||
# define the needed parameters for the dataloader | ||
dataloader_config = DataLoaderConfig(n_procs=3, batch_size=10) | ||
|
||
for split in splitter.get_splits(splits=(0, 3, 4), dataloader_config=dataloader_config): | ||
# bien définir ce qu'il y a dans l'objet split | ||
# define the needed parameters for the dataloader | ||
dataloader_config = DataLoaderConfig(num_workers=3, batch_size=10) | ||
|
||
network_config = create_network_config(ImplementedNetworks.CNN)( | ||
in_shape=[2, 2, 2], | ||
num_outputs=1, | ||
conv_args=ConvEncoderOptions(channels=[3, 2, 2]), | ||
) | ||
optimizer, _ = get_optimizer(network, AdamConfig()) | ||
model = ClinicaDLModel(network=network_config, loss=nn.MSE(), optimizer=optimizer) | ||
|
||
trainer.train(model, split) | ||
# le trainer va instancier un predictor/valdiator dans le train ou dans le init | ||
for split in splitter.get_splits(dataset=dataset_t1_image): | ||
split.build_train_loader(dataloader_config) | ||
split.build_val_loader(num_workers=3, batch_size=10) |
Oops, something went wrong.