diff --git a/clinicadl/resources/config/train_config.toml b/clinicadl/resources/config/train_config.toml index a9c1f5d7f..5d2be72d2 100644 --- a/clinicadl/resources/config/train_config.toml +++ b/clinicadl/resources/config/train_config.toml @@ -61,6 +61,7 @@ use_extracted_features = false multi_cohort = false diagnoses = ["AD", "CN"] baseline = false +valid_longitudinal = false normalize = true data_augmentation = false sampler = "random" diff --git a/clinicadl/train/tasks/classification_cli.py b/clinicadl/train/tasks/classification_cli.py index b170dd7a1..0633e22f0 100644 --- a/clinicadl/train/tasks/classification_cli.py +++ b/clinicadl/train/tasks/classification_cli.py @@ -31,6 +31,7 @@ @train_option.multi_cohort @train_option.diagnoses @train_option.baseline +@train_option.valid_longitudinal @train_option.normalize @train_option.data_augmentation @train_option.sampler diff --git a/clinicadl/train/tasks/reconstruction_cli.py b/clinicadl/train/tasks/reconstruction_cli.py index e821cd5f9..948535ba1 100644 --- a/clinicadl/train/tasks/reconstruction_cli.py +++ b/clinicadl/train/tasks/reconstruction_cli.py @@ -31,6 +31,7 @@ @train_option.multi_cohort @train_option.diagnoses @train_option.baseline +@train_option.valid_longitudinal @train_option.normalize @train_option.data_augmentation @train_option.sampler diff --git a/clinicadl/train/tasks/regression_cli.py b/clinicadl/train/tasks/regression_cli.py index a76d38bf2..f320c4ed5 100644 --- a/clinicadl/train/tasks/regression_cli.py +++ b/clinicadl/train/tasks/regression_cli.py @@ -31,6 +31,7 @@ @train_option.multi_cohort @train_option.diagnoses @train_option.baseline +@train_option.valid_longitudinal @train_option.normalize @train_option.data_augmentation @train_option.sampler diff --git a/clinicadl/train/tasks/task_utils.py b/clinicadl/train/tasks/task_utils.py index 6293d3131..26ec5c55a 100644 --- a/clinicadl/train/tasks/task_utils.py +++ b/clinicadl/train/tasks/task_utils.py @@ -61,6 +61,7 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs): "track_exp", "transfer_path", "transfer_selection_metric", + "valid_longitudinal", "weight_decay", "sampler", "seed", diff --git a/clinicadl/tsvtools/kfold/kfold.py b/clinicadl/tsvtools/kfold/kfold.py index 059f40964..02f34ce99 100644 --- a/clinicadl/tsvtools/kfold/kfold.py +++ b/clinicadl/tsvtools/kfold/kfold.py @@ -60,6 +60,7 @@ def write_splits( test_df = baseline_df.iloc[test_index] train_df = baseline_df.iloc[train_index] long_train_df = retrieve_longitudinal(train_df, diagnosis_df) + long_test_df = retrieve_longitudinal(test_df, diagnosis_df) train_df.reset_index(inplace=True, drop=True) test_df.reset_index(inplace=True, drop=True) @@ -86,6 +87,11 @@ def write_splits( sep="\t", index=False, ) + long_test_df.to_csv( + results_directory / f"split-{i}" / f"{subset_name}.tsv", + sep="\t", + index=False, + ) def split_diagnoses( diff --git a/clinicadl/utils/cli_param/train_option.py b/clinicadl/utils/cli_param/train_option.py index 82d2b2b09..4f565df25 100644 --- a/clinicadl/utils/cli_param/train_option.py +++ b/clinicadl/utils/cli_param/train_option.py @@ -171,6 +171,12 @@ default=None, help="If provided, only the baseline sessions are used for training.", ) +valid_longitudinal = cli_param.option_group.data_group.option( + "--valid_longitudinal/--not_valid_longitudinal", + type=bool, + default=None, + help="If provided, not only the baseline sessions are used for validation (careful with this bad habits).", +) normalize = cli_param.option_group.data_group.option( "--normalize/--unnormalize", type=bool, diff --git a/clinicadl/utils/split_manager/kfold.py b/clinicadl/utils/split_manager/kfold.py index 5749452e9..5ee560862 100644 --- a/clinicadl/utils/split_manager/kfold.py +++ b/clinicadl/utils/split_manager/kfold.py @@ -11,6 +11,7 @@ def __init__( diagnoses, n_splits, baseline=False, + valid_longitudinal=False, multi_cohort=False, split_list=None, ): @@ -19,6 +20,7 @@ def __init__( tsv_path, diagnoses, baseline, + valid_longitudinal, multi_cohort, split_list, ) diff --git a/clinicadl/utils/split_manager/single_split.py b/clinicadl/utils/split_manager/single_split.py index 64d4babf1..1b72f82c7 100644 --- a/clinicadl/utils/split_manager/single_split.py +++ b/clinicadl/utils/split_manager/single_split.py @@ -10,6 +10,7 @@ def __init__( tsv_path, diagnoses, baseline=False, + valid_longitudinal=False, multi_cohort=False, split_list=None, ): @@ -18,6 +19,7 @@ def __init__( tsv_path, diagnoses, baseline, + valid_longitudinal, multi_cohort, split_list, ) diff --git a/clinicadl/utils/split_manager/split_manager.py b/clinicadl/utils/split_manager/split_manager.py index 1f0e1a86f..683c654b3 100644 --- a/clinicadl/utils/split_manager/split_manager.py +++ b/clinicadl/utils/split_manager/split_manager.py @@ -21,6 +21,7 @@ def __init__( tsv_path: Path, diagnoses, baseline=False, + valid_longitudinal=False, multi_cohort=False, split_list=None, ): @@ -36,6 +37,8 @@ def __init__( List of diagnosis baseline: bool if True, split only on baseline sessions + valid_longitudinal: bool + if True, split validation on longitudinal sessions multi-cohort: bool split_list: List[str] @@ -46,6 +49,7 @@ def __init__( self.multi_cohort = multi_cohort self.diagnoses = diagnoses self.baseline = baseline + self.valid_longitudinal = valid_longitudinal self.split_list = split_list @abc.abstractmethod @@ -139,8 +143,10 @@ def concatenate_diagnoses( train_path = train_path / "train_baseline.tsv" else: train_path = train_path / "train.tsv" - - valid_path = valid_path / "validation_baseline.tsv" + if self.valid_longitudinal: + valid_path = valid_path / "validation.tsv" + else: + valid_path = valid_path / "validation_baseline.tsv" train_df = pd.read_csv(train_path, sep="\t") valid_df = pd.read_csv(valid_path, sep="\t")