Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Sep 28, 2023
1 parent 7b68885 commit a4bf53f
Show file tree
Hide file tree
Showing 10 changed files with 29 additions and 2 deletions.
1 change: 1 addition & 0 deletions clinicadl/resources/config/train_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ use_extracted_features = false
multi_cohort = false
diagnoses = ["AD", "CN"]
baseline = false
valid_longitudinal = false
normalize = true
data_augmentation = false
sampler = "random"
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/classification_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
@train_option.multi_cohort
@train_option.diagnoses
@train_option.baseline
@train_option.valid_longitudinal
@train_option.normalize
@train_option.data_augmentation
@train_option.sampler
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/reconstruction_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
@train_option.multi_cohort
@train_option.diagnoses
@train_option.baseline
@train_option.valid_longitudinal
@train_option.normalize
@train_option.data_augmentation
@train_option.sampler
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/regression_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
@train_option.multi_cohort
@train_option.diagnoses
@train_option.baseline
@train_option.valid_longitudinal
@train_option.normalize
@train_option.data_augmentation
@train_option.sampler
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs):
"track_exp",
"transfer_path",
"transfer_selection_metric",
"valid_longitudinal",
"weight_decay",
"sampler",
"seed",
Expand Down
6 changes: 6 additions & 0 deletions clinicadl/tsvtools/kfold/kfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def write_splits(
test_df = baseline_df.iloc[test_index]
train_df = baseline_df.iloc[train_index]
long_train_df = retrieve_longitudinal(train_df, diagnosis_df)
long_test_df = retrieve_longitudinal(test_df, diagnosis_df)

train_df.reset_index(inplace=True, drop=True)
test_df.reset_index(inplace=True, drop=True)
Expand All @@ -86,6 +87,11 @@ def write_splits(
sep="\t",
index=False,
)
long_test_df.to_csv(
results_directory / f"split-{i}" / f"{subset_name}.tsv",
sep="\t",
index=False,
)


def split_diagnoses(
Expand Down
6 changes: 6 additions & 0 deletions clinicadl/utils/cli_param/train_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@
default=None,
help="If provided, only the baseline sessions are used for training.",
)
valid_longitudinal = cli_param.option_group.data_group.option(
"--valid_longitudinal/--not_valid_longitudinal",
type=bool,
default=None,
help="If provided, not only the baseline sessions are used for validation (careful with this bad habits).",
)
normalize = cli_param.option_group.data_group.option(
"--normalize/--unnormalize",
type=bool,
Expand Down
2 changes: 2 additions & 0 deletions clinicadl/utils/split_manager/kfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(
diagnoses,
n_splits,
baseline=False,
valid_longitudinal=False,
multi_cohort=False,
split_list=None,
):
Expand All @@ -19,6 +20,7 @@ def __init__(
tsv_path,
diagnoses,
baseline,
valid_longitudinal,
multi_cohort,
split_list,
)
Expand Down
2 changes: 2 additions & 0 deletions clinicadl/utils/split_manager/single_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def __init__(
tsv_path,
diagnoses,
baseline=False,
valid_longitudinal=False,
multi_cohort=False,
split_list=None,
):
Expand All @@ -18,6 +19,7 @@ def __init__(
tsv_path,
diagnoses,
baseline,
valid_longitudinal,
multi_cohort,
split_list,
)
Expand Down
10 changes: 8 additions & 2 deletions clinicadl/utils/split_manager/split_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
tsv_path: Path,
diagnoses,
baseline=False,
valid_longitudinal=False,
multi_cohort=False,
split_list=None,
):
Expand All @@ -36,6 +37,8 @@ def __init__(
List of diagnosis
baseline: bool
if True, split only on baseline sessions
valid_longitudinal: bool
if True, split validation on longitudinal sessions
multi-cohort: bool
split_list: List[str]
Expand All @@ -46,6 +49,7 @@ def __init__(
self.multi_cohort = multi_cohort
self.diagnoses = diagnoses
self.baseline = baseline
self.valid_longitudinal = valid_longitudinal
self.split_list = split_list

@abc.abstractmethod
Expand Down Expand Up @@ -139,8 +143,10 @@ def concatenate_diagnoses(
train_path = train_path / "train_baseline.tsv"
else:
train_path = train_path / "train.tsv"

valid_path = valid_path / "validation_baseline.tsv"
if self.valid_longitudinal:
valid_path = valid_path / "validation.tsv"
else:
valid_path = valid_path / "validation_baseline.tsv"

train_df = pd.read_csv(train_path, sep="\t")
valid_df = pd.read_csv(valid_path, sep="\t")
Expand Down

0 comments on commit a4bf53f

Please sign in to comment.