-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* make_split * make_kfold * KFold *SingleSplit --------- Co-authored-by: camillebrianceau <[email protected]> Co-authored-by: camillebrianceau <[email protected]>
- Loading branch information
1 parent
464dddf
commit 621cc96
Showing
54 changed files
with
2,923 additions
and
546 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,80 +1,38 @@ | ||
from pathlib import Path | ||
|
||
import torchio.transforms as transforms | ||
|
||
from clinicadl.dataset.caps_dataset import ( | ||
CapsDatasetPatch, | ||
CapsDatasetRoi, | ||
CapsDatasetSlice, | ||
) | ||
from clinicadl.dataset.caps_reader import CapsReader | ||
from clinicadl.dataset.concat import ConcatDataset | ||
from clinicadl.dataset.config.extraction import ExtractionConfig | ||
from clinicadl.dataset.config.preprocessing import ( | ||
PreprocessingConfig, | ||
T1PreprocessingConfig, | ||
) | ||
from clinicadl.dataset.dataloader_config import DataLoaderConfig | ||
from clinicadl.dataset.datasets.caps_dataset import CapsDataset | ||
from clinicadl.experiment_manager.experiment_manager import ExperimentManager | ||
from clinicadl.losses.config import CrossEntropyLossConfig | ||
from clinicadl.losses.factory import get_loss_function | ||
from clinicadl.model.clinicadl_model import ClinicaDLModel | ||
from clinicadl.networks.config import ImplementedNetworks | ||
from clinicadl.networks.factory import ( | ||
ConvEncoderOptions, | ||
create_network_config, | ||
get_network_from_config, | ||
) | ||
from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig | ||
from clinicadl.optimization.optimizer.factory import get_optimizer | ||
from clinicadl.predictor.predictor import Predictor | ||
from clinicadl.splitter.kfold import KFolder | ||
from clinicadl.splitter.split import get_single_split, split_tsv | ||
from clinicadl.splitter.new_splitter.dataloader import DataLoaderConfig | ||
from clinicadl.splitter.new_splitter.splitter.kfold import KFold | ||
from clinicadl.trainer.trainer import Trainer | ||
from clinicadl.transforms.config import TransformsConfig | ||
|
||
# SIMPLE EXPERIMENT WITH A CAPS ALREADY EXISTING | ||
|
||
maps_path = Path("/") | ||
manager = ExperimentManager(maps_path, overwrite=False) | ||
|
||
dataset_t1_image = CapsDatasetPatch.from_json( | ||
extraction=Path("test.json"), | ||
sub_ses_tsv=Path("split_dir") / "train.tsv", | ||
) | ||
dataset_t1_image = CapsDataset.from_json(Path("json_path.json")) | ||
|
||
config_file = Path("config_file") | ||
trainer = Trainer.from_json( | ||
config_file=config_file, manager=manager | ||
) # gpu, amp, fsdp, seed | ||
|
||
# CAS CROSS-VALIDATION | ||
splitter = KFolder(caps_dataset=dataset_t1_image, manager=manager) | ||
split_dir = splitter.make_splits( | ||
n_splits=3, | ||
output_dir=Path(""), | ||
data_tsv=Path("labels.tsv"), | ||
subset_name="validation", | ||
stratification="", | ||
) # Optional data tsv and output_dir | ||
# n_splits must be >1 | ||
# for the single split case, this method output a path to the directory containing the train and test tsv files so we should have the same output here | ||
splitter = KFold(dataset=dataset_t1_image) | ||
splitter.make_splits(n_splits=3) | ||
split_dir = Path("") | ||
splitter.write(split_dir) | ||
|
||
# CAS EXISTING CROSS-VALIDATION | ||
splitter = KFolder.from_split_dir(caps_dataset=dataset_t1_image, manager=manager) | ||
splitter.read(split_dir) | ||
|
||
# define the needed parameters for the dataloader | ||
dataloader_config = DataLoaderConfig(n_procs=3, batch_size=10) | ||
dataloader_config = DataLoaderConfig(num_workers=3, batch_size=10) | ||
|
||
for split in splitter.get_splits(splits=(0, 3, 4), dataloader_config=dataloader_config): | ||
# bien définir ce qu'il y a dans l'objet split | ||
|
||
network_config = create_network_config(ImplementedNetworks.CNN)( | ||
in_shape=[2, 2, 2], | ||
num_outputs=1, | ||
conv_args=ConvEncoderOptions(channels=[3, 2, 2]), | ||
) | ||
optimizer, _ = get_optimizer(network, AdamConfig()) | ||
model = ClinicaDLModel(network=network_config, loss=nn.MSE(), optimizer=optimizer) | ||
for split in splitter.get_splits(splits=(0, 3, 4)): | ||
print(split) | ||
split.build_train_loader(dataloader_config) | ||
split.build_val_loader(num_workers=3, batch_size=10) | ||
|
||
trainer.train(model, split) | ||
# le trainer va instancier un predictor/valdiator dans le train ou dans le init | ||
print(split) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .config import DataLoaderConfig |
Oops, something went wrong.