Skip to content

Commit

Permalink
Add option to use longitudinal data for validation (#483)
Browse files Browse the repository at this point in the history
Add option to use longitudinal data for validation  (#483)
  • Loading branch information
camillebrianceau authored Apr 3, 2024
1 parent a8f2ea5 commit 6bfaccd
Show file tree
Hide file tree
Showing 18 changed files with 79 additions and 31 deletions.
1 change: 1 addition & 0 deletions clinicadl/resources/config/train_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ use_extracted_features = false
multi_cohort = false
diagnoses = ["AD", "CN"]
baseline = false
valid_longitudinal = false
normalize = true
data_augmentation = false
sampler = "random"
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/classification_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
@train_option.multi_cohort
@train_option.diagnoses
@train_option.baseline
@train_option.valid_longitudinal
@train_option.normalize
@train_option.data_augmentation
@train_option.sampler
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/reconstruction_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
@train_option.multi_cohort
@train_option.diagnoses
@train_option.baseline
@train_option.valid_longitudinal
@train_option.normalize
@train_option.data_augmentation
@train_option.sampler
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/regression_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
@train_option.multi_cohort
@train_option.diagnoses
@train_option.baseline
@train_option.valid_longitudinal
@train_option.normalize
@train_option.data_augmentation
@train_option.sampler
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs):
"track_exp",
"transfer_path",
"transfer_selection_metric",
"valid_longitudinal",
"weight_decay",
"sampler",
"save_all_models",
Expand Down
25 changes: 21 additions & 4 deletions clinicadl/tsvtools/kfold/kfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def write_splits(
n_splits: int,
subset_name: str,
results_directory: Path,
valid_longitudinal: bool = False,
):
"""
Split data at the subject-level in training and test to have equivalent distributions in split_label.
Expand Down Expand Up @@ -57,12 +58,11 @@ def write_splits(
for i, indices in enumerate(splits.split(np.zeros(len(y)), y)):
train_index, test_index = indices

test_df = baseline_df.iloc[test_index]
train_df = baseline_df.iloc[train_index]
long_train_df = retrieve_longitudinal(train_df, diagnosis_df)

train_df.reset_index(inplace=True, drop=True)
test_df.reset_index(inplace=True, drop=True)

test_df = baseline_df.iloc[test_index]

# train_df = train_df[["participant_id", "session_id"]]
# test_df = test_df[["participant_id", "session_id"]]
Expand All @@ -86,6 +86,15 @@ def write_splits(
sep="\t",
index=False,
)
if valid_longitudinal:
long_test_df = retrieve_longitudinal(test_df, diagnosis_df)
test_df.reset_index(inplace=True, drop=True)

long_test_df.to_csv(
results_directory / f"split-{i}" / f"{subset_name}.tsv",
sep="\t",
index=False,
)


def split_diagnoses(
Expand All @@ -94,6 +103,7 @@ def split_diagnoses(
subset_name: str = None,
stratification: str = None,
merged_tsv: Path = None,
valid_longitudinal: bool = False,
):
"""
Performs a k-fold split for each label independently on the subject level.
Expand Down Expand Up @@ -176,6 +186,13 @@ def split_diagnoses(
how="inner",
on=["participant_id", "session_id"],
)
write_splits(diagnosis_df, stratification, n_splits, subset_name, results_directory)
write_splits(
diagnosis_df,
stratification,
n_splits,
subset_name,
results_directory,
valid_longitudinal=valid_longitudinal,
)

logger.info(f"K-fold split is done.")
3 changes: 3 additions & 0 deletions clinicadl/tsvtools/kfold/kfold_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
type=str,
default=None,
)
@cli_param.option.valid_longitudinal
def cli(
data_tsv,
n_splits,
subset_name,
stratification,
merged_tsv,
valid_longitudinal,
):
"""Performs a k-fold split to prepare training.
Expand All @@ -47,6 +49,7 @@ def cli(
subset_name=subset_name,
stratification=stratification,
merged_tsv=merged_tsv,
valid_longitudinal=valid_longitudinal,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def cli(
p_age_threshold = 0.80
p_sex_threshold = 0.80
ignore_demographics = False
flag_not_baseline = False
valid_longitudinal = False
split_diagnoses(
data_tsv,
n_test=n_test,
Expand All @@ -67,7 +67,7 @@ def cli(
p_sex_threshold=p_sex_threshold,
ignore_demographics=ignore_demographics,
categorical_split_variable=None,
not_only_baseline=flag_not_baseline,
valid_longitudinal=valid_longitudinal,
)

parents_path = data_tsv.parents[0]
Expand Down Expand Up @@ -110,6 +110,7 @@ def cli(
p_sex_threshold=p_sex_threshold,
ignore_demographics=ignore_demographics,
categorical_split_variable=None,
valid_longitudinal=valid_longitudinal,
)

elif validation_type == "kfold":
Expand All @@ -120,6 +121,7 @@ def cli(
n_splits=int(n_validation),
subset_name="validation",
stratification=None,
valid_longitudinal=valid_longitudinal,
)


Expand Down
10 changes: 4 additions & 6 deletions clinicadl/tsvtools/split/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,8 @@ def split_diagnoses(
p_sex_threshold=0.80,
categorical_split_variable=None,
ignore_demographics=False,
verbose=0,
not_only_baseline=True,
multi_diagnoses=False,
valid_longitudinal=False,
):
"""
Performs a single split for each label independently on the subject level.
Expand Down Expand Up @@ -297,10 +296,9 @@ def split_diagnoses(
name = f"{subset_name}_baseline.tsv"
df_to_tsv(name, results_path, test_df, baseline=True)

if not_only_baseline:
if valid_longitudinal:
long_test_df = retrieve_longitudinal(test_df, diagnosis_df)
name = f"{subset_name}.tsv"
# long_test_df = long_test_df[["participant_id", "session_id"]]
df_to_tsv(name, results_path, long_test_df)

elif n_test > 0:
Expand Down Expand Up @@ -345,7 +343,7 @@ def split_diagnoses(
name = f"{subset_name}_baseline.tsv"
df_to_tsv(name, results_path, test_df, baseline=True)

if not_only_baseline:
if valid_longitudinal:
name = f"{subset_name}.tsv"
long_test_df = retrieve_longitudinal(test_df, diagnosis_df)
# long_test_df = long_test_df[["participant_id", "session_id"]]
Expand All @@ -354,7 +352,7 @@ def split_diagnoses(
else:
train_df = extract_baseline(diagnosis_df)
# train_df = train_df[["participant_id", "session_id"]]
if not_only_baseline:
if valid_longitudinal:
long_train_df = diagnosis_df

name = "train_baseline.tsv"
Expand Down
12 changes: 3 additions & 9 deletions clinicadl/tsvtools/split/split_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
@click.command(name="split", no_args_is_help=True)
@cli_param.argument.data_tsv
@cli_param.option.subset_name
@cli_param.option.valid_longitudinal
@click.option(
"--n_test",
help="- If >= 1, number of subjects to put in set with name 'subset_name'.\n\n "
Expand Down Expand Up @@ -43,13 +44,6 @@
default=None,
type=str,
)
@click.option(
"--not_only_keep_baseline",
help="If given will store the file with all subjects",
default=False,
is_flag=True,
type=bool,
)
@click.option(
"--multi-diagnoses",
help="If given, all columns are used to balance the split, not only age and sex",
Expand All @@ -60,12 +54,12 @@
def cli(
data_tsv,
subset_name,
valid_longitudinal,
n_test,
p_sex_threshold,
p_age_threshold,
ignore_demographics,
categorical_split_variable,
not_only_keep_baseline,
multi_diagnoses,
):
"""Performs a single split to prepare training.
Expand All @@ -84,8 +78,8 @@ def cli(
p_sex_threshold=p_sex_threshold,
ignore_demographics=ignore_demographics,
categorical_split_variable=categorical_split_variable,
not_only_baseline=not_only_keep_baseline,
multi_diagnoses=multi_diagnoses,
valid_longitudinal=valid_longitudinal,
)


Expand Down
6 changes: 6 additions & 0 deletions clinicadl/utils/cli_param/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@
show_default=True,
help="ssda training.",
)
valid_longitudinal = click.option(
"--valid_longitudinal/--valid_baseline",
type=bool,
default=None,
help="If provided, not only the baseline sessions are used for validation (careful with this bad habit).",
)
# GENERATE
participant_list = click.option(
"--participants_tsv",
Expand Down
6 changes: 6 additions & 0 deletions clinicadl/utils/cli_param/train_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@
default=None,
help="If provided, only the baseline sessions are used for training.",
)
valid_longitudinal = cli_param.option_group.data_group.option(
"--valid_longitudinal/--valid_baseline",
type=bool,
default=None,
help="If provided, not only the baseline sessions are used for validation (careful with this bad habit).",
)
normalize = cli_param.option_group.data_group.option(
"--normalize/--unnormalize",
type=bool,
Expand Down
1 change: 1 addition & 0 deletions clinicadl/utils/maps_manager/iotools.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def set_default(params_dict, default_dict):
"transfer_learning_path": "",
"transfer_learning_selection": "best_loss",
"gpu": True,
"valid_longitudinal": False,
"wd_bool": True,
"weight_decay": 4,
"sampler": "random",
Expand Down
2 changes: 2 additions & 0 deletions clinicadl/utils/split_manager/kfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(
diagnoses,
n_splits,
baseline=False,
valid_longitudinal=False,
multi_cohort=False,
split_list=None,
):
Expand All @@ -19,6 +20,7 @@ def __init__(
tsv_path,
diagnoses,
baseline,
valid_longitudinal,
multi_cohort,
split_list,
)
Expand Down
2 changes: 2 additions & 0 deletions clinicadl/utils/split_manager/single_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def __init__(
tsv_path,
diagnoses,
baseline=False,
valid_longitudinal=False,
multi_cohort=False,
split_list=None,
):
Expand All @@ -18,6 +19,7 @@ def __init__(
tsv_path,
diagnoses,
baseline,
valid_longitudinal,
multi_cohort,
split_list,
)
Expand Down
10 changes: 8 additions & 2 deletions clinicadl/utils/split_manager/split_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
tsv_path: Path,
diagnoses,
baseline=False,
valid_longitudinal=False,
multi_cohort=False,
split_list=None,
):
Expand All @@ -36,6 +37,8 @@ def __init__(
List of diagnosis
baseline: bool
if True, split only on baseline sessions
valid_longitudinal: bool
if True, split validation on longitudinal sessions
multi-cohort: bool
split_list: List[str]
Expand All @@ -46,6 +49,7 @@ def __init__(
self.multi_cohort = multi_cohort
self.diagnoses = diagnoses
self.baseline = baseline
self.valid_longitudinal = valid_longitudinal
self.split_list = split_list

@abc.abstractmethod
Expand Down Expand Up @@ -138,8 +142,10 @@ def concatenate_diagnoses(
train_path = train_path / "train_baseline.tsv"
else:
train_path = train_path / "train.tsv"

valid_path = valid_path / "validation_baseline.tsv"
if self.valid_longitudinal:
valid_path = valid_path / "validation.tsv"
else:
valid_path = valid_path / "validation_baseline.tsv"

train_df = pd.read_csv(train_path, sep="\t")
valid_df = pd.read_csv(valid_path, sep="\t")
Expand Down
21 changes: 13 additions & 8 deletions clinicadl/utils/tsvtools_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,24 @@ def extract_baseline(diagnosis_df, set_index=True):
from copy import deepcopy

if set_index:
all_df = diagnosis_df.set_index(["participant_id", "session_id"])
all_df = deepcopy(diagnosis_df)
all_df.set_index(["participant_id", "session_id"], inplace=True)
else:
all_df = deepcopy(diagnosis_df)

result_df = pd.DataFrame()
for subject, subject_df in all_df.groupby(level=0):
baseline = first_session(subject_df)
subject_baseline_df = pd.DataFrame(
data=[[subject, baseline] + subject_df.loc[(subject, baseline)].tolist()],
columns=["participant_id", "session_id"]
+ subject_df.columns.values.tolist(),
)
result_df = pd.concat([result_df, subject_baseline_df])
if subject != "participant_id":
baseline = first_session(subject_df)

subject_baseline_df = pd.DataFrame(
data=[
[subject, baseline] + subject_df.loc[(subject, baseline)].tolist()
],
columns=["participant_id", "session_id"]
+ subject_df.columns.values.tolist(),
)
result_df = pd.concat([result_df, subject_baseline_df])

result_df.reset_index(inplace=True, drop=True)
return result_df
Expand Down
1 change: 1 addition & 0 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 6bfaccd

Please sign in to comment.