diff --git a/clinicadl/resources/config/train_config.toml b/clinicadl/resources/config/train_config.toml index 2a21b70ba..368b60e7b 100644 --- a/clinicadl/resources/config/train_config.toml +++ b/clinicadl/resources/config/train_config.toml @@ -62,6 +62,7 @@ use_extracted_features = false multi_cohort = false diagnoses = ["AD", "CN"] baseline = false +valid_longitudinal = false normalize = true data_augmentation = false sampler = "random" diff --git a/clinicadl/train/tasks/classification_cli.py b/clinicadl/train/tasks/classification_cli.py index a8c2075da..2f470fd02 100644 --- a/clinicadl/train/tasks/classification_cli.py +++ b/clinicadl/train/tasks/classification_cli.py @@ -33,6 +33,7 @@ @train_option.multi_cohort @train_option.diagnoses @train_option.baseline +@train_option.valid_longitudinal @train_option.normalize @train_option.data_augmentation @train_option.sampler diff --git a/clinicadl/train/tasks/reconstruction_cli.py b/clinicadl/train/tasks/reconstruction_cli.py index f6e874338..95816a116 100644 --- a/clinicadl/train/tasks/reconstruction_cli.py +++ b/clinicadl/train/tasks/reconstruction_cli.py @@ -33,6 +33,7 @@ @train_option.multi_cohort @train_option.diagnoses @train_option.baseline +@train_option.valid_longitudinal @train_option.normalize @train_option.data_augmentation @train_option.sampler diff --git a/clinicadl/train/tasks/regression_cli.py b/clinicadl/train/tasks/regression_cli.py index de987e336..2533db406 100644 --- a/clinicadl/train/tasks/regression_cli.py +++ b/clinicadl/train/tasks/regression_cli.py @@ -33,6 +33,7 @@ @train_option.multi_cohort @train_option.diagnoses @train_option.baseline +@train_option.valid_longitudinal @train_option.normalize @train_option.data_augmentation @train_option.sampler diff --git a/clinicadl/train/tasks/task_utils.py b/clinicadl/train/tasks/task_utils.py index 658269bcc..0331468df 100644 --- a/clinicadl/train/tasks/task_utils.py +++ b/clinicadl/train/tasks/task_utils.py @@ -63,6 +63,7 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs): "track_exp", "transfer_path", "transfer_selection_metric", + "valid_longitudinal", "weight_decay", "sampler", "save_all_models", diff --git a/clinicadl/tsvtools/kfold/kfold.py b/clinicadl/tsvtools/kfold/kfold.py index fecdb6037..e60615cf9 100644 --- a/clinicadl/tsvtools/kfold/kfold.py +++ b/clinicadl/tsvtools/kfold/kfold.py @@ -21,6 +21,7 @@ def write_splits( n_splits: int, subset_name: str, results_directory: Path, + valid_longitudinal: bool = False, ): """ Split data at the subject-level in training and test to have equivalent distributions in split_label. @@ -57,12 +58,11 @@ def write_splits( for i, indices in enumerate(splits.split(np.zeros(len(y)), y)): train_index, test_index = indices - test_df = baseline_df.iloc[test_index] train_df = baseline_df.iloc[train_index] long_train_df = retrieve_longitudinal(train_df, diagnosis_df) - train_df.reset_index(inplace=True, drop=True) - test_df.reset_index(inplace=True, drop=True) + + test_df = baseline_df.iloc[test_index] # train_df = train_df[["participant_id", "session_id"]] # test_df = test_df[["participant_id", "session_id"]] @@ -86,6 +86,15 @@ def write_splits( sep="\t", index=False, ) + if valid_longitudinal: + long_test_df = retrieve_longitudinal(test_df, diagnosis_df) + test_df.reset_index(inplace=True, drop=True) + + long_test_df.to_csv( + results_directory / f"split-{i}" / f"{subset_name}.tsv", + sep="\t", + index=False, + ) def split_diagnoses( @@ -94,6 +103,7 @@ def split_diagnoses( subset_name: str = None, stratification: str = None, merged_tsv: Path = None, + valid_longitudinal: bool = False, ): """ Performs a k-fold split for each label independently on the subject level. @@ -176,6 +186,13 @@ def split_diagnoses( how="inner", on=["participant_id", "session_id"], ) - write_splits(diagnosis_df, stratification, n_splits, subset_name, results_directory) + write_splits( + diagnosis_df, + stratification, + n_splits, + subset_name, + results_directory, + valid_longitudinal=valid_longitudinal, + ) logger.info(f"K-fold split is done.") diff --git a/clinicadl/tsvtools/kfold/kfold_cli.py b/clinicadl/tsvtools/kfold/kfold_cli.py index 99b16159b..35fb99831 100644 --- a/clinicadl/tsvtools/kfold/kfold_cli.py +++ b/clinicadl/tsvtools/kfold/kfold_cli.py @@ -26,12 +26,14 @@ type=str, default=None, ) +@cli_param.option.valid_longitudinal def cli( data_tsv, n_splits, subset_name, stratification, merged_tsv, + valid_longitudinal, ): """Performs a k-fold split to prepare training. @@ -47,6 +49,7 @@ def cli( subset_name=subset_name, stratification=stratification, merged_tsv=merged_tsv, + valid_longitudinal=valid_longitudinal, ) diff --git a/clinicadl/tsvtools/prepare_experiment/prepare_experiment_cli.py b/clinicadl/tsvtools/prepare_experiment/prepare_experiment_cli.py index b0733683e..c7ee44d2f 100644 --- a/clinicadl/tsvtools/prepare_experiment/prepare_experiment_cli.py +++ b/clinicadl/tsvtools/prepare_experiment/prepare_experiment_cli.py @@ -58,7 +58,7 @@ def cli( p_age_threshold = 0.80 p_sex_threshold = 0.80 ignore_demographics = False - flag_not_baseline = False + valid_longitudinal = False split_diagnoses( data_tsv, n_test=n_test, @@ -67,7 +67,7 @@ def cli( p_sex_threshold=p_sex_threshold, ignore_demographics=ignore_demographics, categorical_split_variable=None, - not_only_baseline=flag_not_baseline, + valid_longitudinal=valid_longitudinal, ) parents_path = data_tsv.parents[0] @@ -110,6 +110,7 @@ def cli( p_sex_threshold=p_sex_threshold, ignore_demographics=ignore_demographics, categorical_split_variable=None, + valid_longitudinal=valid_longitudinal, ) elif validation_type == "kfold": @@ -120,6 +121,7 @@ def cli( n_splits=int(n_validation), subset_name="validation", stratification=None, + valid_longitudinal=valid_longitudinal, ) diff --git a/clinicadl/tsvtools/split/split.py b/clinicadl/tsvtools/split/split.py index c2610a8a6..082b9bd49 100644 --- a/clinicadl/tsvtools/split/split.py +++ b/clinicadl/tsvtools/split/split.py @@ -212,9 +212,8 @@ def split_diagnoses( p_sex_threshold=0.80, categorical_split_variable=None, ignore_demographics=False, - verbose=0, - not_only_baseline=True, multi_diagnoses=False, + valid_longitudinal=False, ): """ Performs a single split for each label independently on the subject level. @@ -297,10 +296,9 @@ def split_diagnoses( name = f"{subset_name}_baseline.tsv" df_to_tsv(name, results_path, test_df, baseline=True) - if not_only_baseline: + if valid_longitudinal: long_test_df = retrieve_longitudinal(test_df, diagnosis_df) name = f"{subset_name}.tsv" - # long_test_df = long_test_df[["participant_id", "session_id"]] df_to_tsv(name, results_path, long_test_df) elif n_test > 0: @@ -345,7 +343,7 @@ def split_diagnoses( name = f"{subset_name}_baseline.tsv" df_to_tsv(name, results_path, test_df, baseline=True) - if not_only_baseline: + if valid_longitudinal: name = f"{subset_name}.tsv" long_test_df = retrieve_longitudinal(test_df, diagnosis_df) # long_test_df = long_test_df[["participant_id", "session_id"]] @@ -354,7 +352,7 @@ def split_diagnoses( else: train_df = extract_baseline(diagnosis_df) # train_df = train_df[["participant_id", "session_id"]] - if not_only_baseline: + if valid_longitudinal: long_train_df = diagnosis_df name = "train_baseline.tsv" diff --git a/clinicadl/tsvtools/split/split_cli.py b/clinicadl/tsvtools/split/split_cli.py index 984d3a2f9..89e0dbdb4 100644 --- a/clinicadl/tsvtools/split/split_cli.py +++ b/clinicadl/tsvtools/split/split_cli.py @@ -6,6 +6,7 @@ @click.command(name="split", no_args_is_help=True) @cli_param.argument.data_tsv @cli_param.option.subset_name +@cli_param.option.valid_longitudinal @click.option( "--n_test", help="- If >= 1, number of subjects to put in set with name 'subset_name'.\n\n " @@ -43,13 +44,6 @@ default=None, type=str, ) -@click.option( - "--not_only_keep_baseline", - help="If given will store the file with all subjects", - default=False, - is_flag=True, - type=bool, -) @click.option( "--multi-diagnoses", help="If given, all columns are used to balance the split, not only age and sex", @@ -60,12 +54,12 @@ def cli( data_tsv, subset_name, + valid_longitudinal, n_test, p_sex_threshold, p_age_threshold, ignore_demographics, categorical_split_variable, - not_only_keep_baseline, multi_diagnoses, ): """Performs a single split to prepare training. @@ -84,8 +78,8 @@ def cli( p_sex_threshold=p_sex_threshold, ignore_demographics=ignore_demographics, categorical_split_variable=categorical_split_variable, - not_only_baseline=not_only_keep_baseline, multi_diagnoses=multi_diagnoses, + valid_longitudinal=valid_longitudinal, ) diff --git a/clinicadl/utils/cli_param/option.py b/clinicadl/utils/cli_param/option.py index 86a23dff6..5677dee1e 100644 --- a/clinicadl/utils/cli_param/option.py +++ b/clinicadl/utils/cli_param/option.py @@ -65,6 +65,12 @@ show_default=True, help="ssda training.", ) +valid_longitudinal = click.option( + "--valid_longitudinal/--valid_baseline", + type=bool, + default=None, + help="If provided, not only the baseline sessions are used for validation (careful with this bad habit).", +) # GENERATE participant_list = click.option( "--participants_tsv", diff --git a/clinicadl/utils/cli_param/train_option.py b/clinicadl/utils/cli_param/train_option.py index 64a00e0b9..c053277a5 100644 --- a/clinicadl/utils/cli_param/train_option.py +++ b/clinicadl/utils/cli_param/train_option.py @@ -183,6 +183,12 @@ default=None, help="If provided, only the baseline sessions are used for training.", ) +valid_longitudinal = cli_param.option_group.data_group.option( + "--valid_longitudinal/--valid_baseline", + type=bool, + default=None, + help="If provided, not only the baseline sessions are used for validation (careful with this bad habit).", +) normalize = cli_param.option_group.data_group.option( "--normalize/--unnormalize", type=bool, diff --git a/clinicadl/utils/maps_manager/iotools.py b/clinicadl/utils/maps_manager/iotools.py index 32f5936af..ed012c614 100644 --- a/clinicadl/utils/maps_manager/iotools.py +++ b/clinicadl/utils/maps_manager/iotools.py @@ -252,6 +252,7 @@ def set_default(params_dict, default_dict): "transfer_learning_path": "", "transfer_learning_selection": "best_loss", "gpu": True, + "valid_longitudinal": False, "wd_bool": True, "weight_decay": 4, "sampler": "random", diff --git a/clinicadl/utils/split_manager/kfold.py b/clinicadl/utils/split_manager/kfold.py index 5749452e9..5ee560862 100644 --- a/clinicadl/utils/split_manager/kfold.py +++ b/clinicadl/utils/split_manager/kfold.py @@ -11,6 +11,7 @@ def __init__( diagnoses, n_splits, baseline=False, + valid_longitudinal=False, multi_cohort=False, split_list=None, ): @@ -19,6 +20,7 @@ def __init__( tsv_path, diagnoses, baseline, + valid_longitudinal, multi_cohort, split_list, ) diff --git a/clinicadl/utils/split_manager/single_split.py b/clinicadl/utils/split_manager/single_split.py index 64d4babf1..1b72f82c7 100644 --- a/clinicadl/utils/split_manager/single_split.py +++ b/clinicadl/utils/split_manager/single_split.py @@ -10,6 +10,7 @@ def __init__( tsv_path, diagnoses, baseline=False, + valid_longitudinal=False, multi_cohort=False, split_list=None, ): @@ -18,6 +19,7 @@ def __init__( tsv_path, diagnoses, baseline, + valid_longitudinal, multi_cohort, split_list, ) diff --git a/clinicadl/utils/split_manager/split_manager.py b/clinicadl/utils/split_manager/split_manager.py index 225665231..af05a4b31 100644 --- a/clinicadl/utils/split_manager/split_manager.py +++ b/clinicadl/utils/split_manager/split_manager.py @@ -21,6 +21,7 @@ def __init__( tsv_path: Path, diagnoses, baseline=False, + valid_longitudinal=False, multi_cohort=False, split_list=None, ): @@ -36,6 +37,8 @@ def __init__( List of diagnosis baseline: bool if True, split only on baseline sessions + valid_longitudinal: bool + if True, split validation on longitudinal sessions multi-cohort: bool split_list: List[str] @@ -46,6 +49,7 @@ def __init__( self.multi_cohort = multi_cohort self.diagnoses = diagnoses self.baseline = baseline + self.valid_longitudinal = valid_longitudinal self.split_list = split_list @abc.abstractmethod @@ -138,8 +142,10 @@ def concatenate_diagnoses( train_path = train_path / "train_baseline.tsv" else: train_path = train_path / "train.tsv" - - valid_path = valid_path / "validation_baseline.tsv" + if self.valid_longitudinal: + valid_path = valid_path / "validation.tsv" + else: + valid_path = valid_path / "validation_baseline.tsv" train_df = pd.read_csv(train_path, sep="\t") valid_df = pd.read_csv(valid_path, sep="\t") diff --git a/clinicadl/utils/tsvtools_utils.py b/clinicadl/utils/tsvtools_utils.py index 39d890100..2ad5478a1 100644 --- a/clinicadl/utils/tsvtools_utils.py +++ b/clinicadl/utils/tsvtools_utils.py @@ -85,19 +85,24 @@ def extract_baseline(diagnosis_df, set_index=True): from copy import deepcopy if set_index: - all_df = diagnosis_df.set_index(["participant_id", "session_id"]) + all_df = deepcopy(diagnosis_df) + all_df.set_index(["participant_id", "session_id"], inplace=True) else: all_df = deepcopy(diagnosis_df) result_df = pd.DataFrame() for subject, subject_df in all_df.groupby(level=0): - baseline = first_session(subject_df) - subject_baseline_df = pd.DataFrame( - data=[[subject, baseline] + subject_df.loc[(subject, baseline)].tolist()], - columns=["participant_id", "session_id"] - + subject_df.columns.values.tolist(), - ) - result_df = pd.concat([result_df, subject_baseline_df]) + if subject != "participant_id": + baseline = first_session(subject_df) + + subject_baseline_df = pd.DataFrame( + data=[ + [subject, baseline] + subject_df.loc[(subject, baseline)].tolist() + ], + columns=["participant_id", "session_id"] + + subject_df.columns.values.tolist(), + ) + result_df = pd.concat([result_df, subject_baseline_df]) result_df.reset_index(inplace=True, drop=True) return result_df diff --git a/poetry.lock b/poetry.lock index c76b9d5c8..0f744ba42 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4015,3 +4015,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p lock-version = "2.0" python-versions = ">=3.8,<3.12" content-hash = "3e97ad0f601217720d44712124f8f27a3086b7f7bcb2077921e7053eacc65800" +