From 7485013943a3a2cd5209c6cfa206d255033e9c28 Mon Sep 17 00:00:00 2001 From: sophieloiz Date: Thu, 5 Oct 2023 11:43:42 +0200 Subject: [PATCH] Debug task utils with ssda parameters --- clinicadl/train/tasks/task_utils.py | 47 ++++++++++++++++------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/clinicadl/train/tasks/task_utils.py b/clinicadl/train/tasks/task_utils.py index f4cb98a46..348d640fb 100644 --- a/clinicadl/train/tasks/task_utils.py +++ b/clinicadl/train/tasks/task_utils.py @@ -84,11 +84,13 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs): / "tensor_extraction" / kwargs["preprocessing_json"] ) - preprocessing_json_target = ( - Path(kwargs["caps_target"]) - / "tensor_extraction" - / kwargs["preprocessing_dict_target"] - ) + + if train_dict["ssda_network"]: + preprocessing_json_target = ( + Path(kwargs["caps_target"]) + / "tensor_extraction" + / kwargs["preprocessing_dict_target"] + ) else: caps_dict = CapsDataset.create_caps_dict( train_dict["caps_directory"], train_dict["multi_cohort"] @@ -109,29 +111,32 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs): f"in {caps_dict}." ) # To CHECK AND CHANGE - caps_target = Path(kwargs["caps_target"]) - preprocessing_json_target = ( - caps_target / "tensor_extraction" / kwargs["preprocessing_dict_target"] - ) - - if preprocessing_json_target.is_file(): - logger.info( - f"Preprocessing JSON {preprocessing_json_target} found in CAPS {caps_target}." - ) - json_found = True - if not json_found: - raise ValueError( - f"Preprocessing JSON {kwargs['preprocessing_json_target']} was not found for any CAPS " - f"in {caps_target}." + if train_dict["ssda_network"]: + caps_target = Path(kwargs["caps_target"]) + preprocessing_json_target = ( + caps_target / "tensor_extraction" / kwargs["preprocessing_dict_target"] ) + if preprocessing_json_target.is_file(): + logger.info( + f"Preprocessing JSON {preprocessing_json_target} found in CAPS {caps_target}." + ) + json_found = True + if not json_found: + raise ValueError( + f"Preprocessing JSON {kwargs['preprocessing_json_target']} was not found for any CAPS " + f"in {caps_target}." + ) + # Mode and preprocessing preprocessing_dict = read_preprocessing(preprocessing_json) - preprocessing_dict_target = read_preprocessing(preprocessing_json_target) train_dict["preprocessing_dict"] = preprocessing_dict - train_dict["preprocessing_dict_target"] = preprocessing_dict_target train_dict["mode"] = preprocessing_dict["mode"] + if train_dict["ssda_network"]: + preprocessing_dict_target = read_preprocessing(preprocessing_json_target) + train_dict["preprocessing_dict_target"] = preprocessing_dict_target + # Add default values if missing if ( preprocessing_dict["mode"] == "roi"