From c71e90ed605c34ac8fd94f73448533be749b9f38 Mon Sep 17 00:00:00 2001 From: Loizillon Sophie <68893000+sophieloiz@users.noreply.github.com> Date: Fri, 6 Oct 2023 14:43:20 +0200 Subject: [PATCH] Sl dart ssda (#485) * Add proposed SSDA method for MICCAI --- clinicadl/generate/generate.py | 4 + clinicadl/mlflow_test.py | 83 +++ .../random_search/random_search_utils.py | 1 + clinicadl/resources/config/train_config.toml | 7 +- clinicadl/train/tasks/classification_cli.py | 5 + clinicadl/train/tasks/reconstruction_cli.py | 5 + clinicadl/train/tasks/regression_cli.py | 5 + clinicadl/train/tasks/task_utils.py | 33 + clinicadl/tsvtools/split/split.py | 2 + clinicadl/utils/caps_dataset/data.py | 39 +- clinicadl/utils/cli_param/option.py | 7 + clinicadl/utils/cli_param/train_option.py | 34 + clinicadl/utils/maps_manager/logwriter.py | 8 +- clinicadl/utils/maps_manager/maps_manager.py | 629 +++++++++++++++++- clinicadl/utils/network/__init__.py | 1 + clinicadl/utils/network/cnn/models.py | 96 ++- clinicadl/utils/network/network_utils.py | 12 + clinicadl/utils/network/sub_network.py | 140 ++++ clinicadl/utils/network/vae/base_vae.py | 1 + .../utils/split_manager/split_manager.py | 7 +- .../utils/task_manager/classification.py | 29 + clinicadl/utils/task_manager/task_manager.py | 50 ++ poetry.lock | 36 + 23 files changed, 1217 insertions(+), 17 deletions(-) create mode 100644 clinicadl/mlflow_test.py diff --git a/clinicadl/generate/generate.py b/clinicadl/generate/generate.py index 3c70c6039..bfda1347f 100644 --- a/clinicadl/generate/generate.py +++ b/clinicadl/generate/generate.py @@ -165,6 +165,8 @@ def create_random_image(subject_id): write_missing_mods(output_dir, output_df) logger.info(f"Random dataset was generated at {output_dir}") + logger.info(f"Random dataset was generated at {output_dir}") + def generate_trivial_dataset( caps_directory: Path, @@ -355,6 +357,8 @@ def create_trivial_image(subject_id, output_df): write_missing_mods(output_dir, output_df) logger.info(f"Trivial dataset was generated at {output_dir}") + logger.info(f"Trivial dataset was generated at {output_dir}") + def generate_shepplogan_dataset( output_dir: Path, diff --git a/clinicadl/mlflow_test.py b/clinicadl/mlflow_test.py new file mode 100644 index 000000000..35a2c0995 --- /dev/null +++ b/clinicadl/mlflow_test.py @@ -0,0 +1,83 @@ +import logging +import os +import sys +import warnings +from urllib.parse import urlparse + +import mlflow +import mlflow.sklearn +import numpy as np +import pandas as pd +from sklearn.linear_model import ElasticNet +from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score +from sklearn.model_selection import train_test_split + +logging.basicConfig(level=logging.WARN) +logger = logging.getLogger(__name__) + + +def eval_metrics(actual, pred): + rmse = np.sqrt(mean_squared_error(actual, pred)) + mae = mean_absolute_error(actual, pred) + r2 = r2_score(actual, pred) + return rmse, mae, r2 + + +if __name__ == "__main__": + warnings.filterwarnings("ignore") + np.random.seed(40) + + # Read the wine-quality csv file from the URL + csv_url = "https://raw.githubusercontent.com/mlflow/mlflow/master/tests/data/winequality-red.csv" + try: + data = pd.read_csv(csv_url, sep=";") + except Exception as e: + logger.exception( + "Unable to download training & test CSV, check your internet connection. Error: %s", + e, + ) + + # Split the data into training and test sets. (0.75, 0.25) split. + train, test = train_test_split(data) + + # The predicted column is "quality" which is a scalar from [3, 9] + train_x = train.drop(["quality"], axis=1) + test_x = test.drop(["quality"], axis=1) + train_y = train[["quality"]] + test_y = test[["quality"]] + + alpha = float(sys.argv[1]) if len(sys.argv) > 1 else 0.5 + l1_ratio = float(sys.argv[2]) if len(sys.argv) > 2 else 0.5 + + with mlflow.start_run(): + lr = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, random_state=42) + lr.fit(train_x, train_y) + + predicted_qualities = lr.predict(test_x) + + (rmse, mae, r2) = eval_metrics(test_y, predicted_qualities) + + print("Elasticnet model (alpha={:f}, l1_ratio={:f}):".format(alpha, l1_ratio)) + print(" RMSE: %s" % rmse) + print(" MAE: %s" % mae) + print(" R2: %s" % r2) + + mlflow.log_param("alpha", alpha) + mlflow.log_param("l1_ratio", l1_ratio) + mlflow.log_metric("rmse", rmse) + mlflow.log_metric("r2", r2) + mlflow.log_metric("mae", mae) + + tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme + + # Model registry does not work with file store + if tracking_url_type_store != "file": + # Register the model + # There are other ways to use the Model Registry, which depends on the use case, + # please refer to the doc for more information: + # https://mlflow.org/docs/latest/model-registry.html#api-workflow + mlflow.sklearn.log_model( + lr, "model", registered_model_name="ElasticnetWineModel" + ) + else: + mlflow.sklearn.log_model(lr, "model") diff --git a/clinicadl/random_search/random_search_utils.py b/clinicadl/random_search/random_search_utils.py index 18c752bc6..8c44fc524 100644 --- a/clinicadl/random_search/random_search_utils.py +++ b/clinicadl/random_search/random_search_utils.py @@ -128,6 +128,7 @@ def random_sampling(rs_options: Dict[str, Any]) -> Dict[str, Any]: "mode": "fixed", "multi_cohort": "fixed", "multi_network": "choice", + "ssda_netork": "fixed", "n_fcblocks": "randint", "n_splits": "fixed", "n_proc": "fixed", diff --git a/clinicadl/resources/config/train_config.toml b/clinicadl/resources/config/train_config.toml index a9c1f5d7f..a6783514f 100644 --- a/clinicadl/resources/config/train_config.toml +++ b/clinicadl/resources/config/train_config.toml @@ -4,6 +4,7 @@ [Model] architecture = "default" # ex : Conv5_FC3 multi_network = false +ssda_network = false [Architecture] # CNN @@ -66,6 +67,10 @@ data_augmentation = false sampler = "random" size_reduction=false size_reduction_factor=2 +caps_target = "" +tsv_target_lab = "" +tsv_target_unlab = "" +preprocessing_dict_target = "" [Cross_validation] n_splits = 0 @@ -82,4 +87,4 @@ accumulation_steps = 1 profiler = false [Informations] -emissions_calculator = false \ No newline at end of file +emissions_calculator = false diff --git a/clinicadl/train/tasks/classification_cli.py b/clinicadl/train/tasks/classification_cli.py index b170dd7a1..d3914ecec 100644 --- a/clinicadl/train/tasks/classification_cli.py +++ b/clinicadl/train/tasks/classification_cli.py @@ -27,6 +27,7 @@ # Model @train_option.architecture @train_option.multi_network +@train_option.ssda_network # Data @train_option.multi_cohort @train_option.diagnoses @@ -34,6 +35,10 @@ @train_option.normalize @train_option.data_augmentation @train_option.sampler +@train_option.caps_target +@train_option.tsv_target_lab +@train_option.tsv_target_unlab +@train_option.preprocessing_dict_target # Cross validation @train_option.n_splits @train_option.split diff --git a/clinicadl/train/tasks/reconstruction_cli.py b/clinicadl/train/tasks/reconstruction_cli.py index e821cd5f9..baa078ca3 100644 --- a/clinicadl/train/tasks/reconstruction_cli.py +++ b/clinicadl/train/tasks/reconstruction_cli.py @@ -27,6 +27,7 @@ # Model @train_option.architecture @train_option.multi_network +@train_option.ssda_network # Data @train_option.multi_cohort @train_option.diagnoses @@ -34,6 +35,10 @@ @train_option.normalize @train_option.data_augmentation @train_option.sampler +@train_option.caps_target +@train_option.tsv_target_lab +@train_option.tsv_target_unlab +@train_option.preprocessing_dict_target # Cross validation @train_option.n_splits @train_option.split diff --git a/clinicadl/train/tasks/regression_cli.py b/clinicadl/train/tasks/regression_cli.py index a76d38bf2..19ed07254 100644 --- a/clinicadl/train/tasks/regression_cli.py +++ b/clinicadl/train/tasks/regression_cli.py @@ -27,6 +27,7 @@ # Model @train_option.architecture @train_option.multi_network +@train_option.ssda_network # Data @train_option.multi_cohort @train_option.diagnoses @@ -34,6 +35,10 @@ @train_option.normalize @train_option.data_augmentation @train_option.sampler +@train_option.caps_target +@train_option.tsv_target_lab +@train_option.tsv_target_unlab +@train_option.preprocessing_dict_target # Cross validation @train_option.n_splits @train_option.split diff --git a/clinicadl/train/tasks/task_utils.py b/clinicadl/train/tasks/task_utils.py index 6293d3131..348d640fb 100644 --- a/clinicadl/train/tasks/task_utils.py +++ b/clinicadl/train/tasks/task_utils.py @@ -50,6 +50,7 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs): "learning_rate", "multi_cohort", "multi_network", + "ssda_network", "n_proc", "n_splits", "nb_unfrozen_layer", @@ -65,6 +66,10 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs): "sampler", "seed", "split", + "caps_target", + "tsv_target_lab", + "tsv_target_unlab", + "preprocessing_dict_target", ] all_options_list = standard_options_list + task_options_list @@ -79,6 +84,13 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs): / "tensor_extraction" / kwargs["preprocessing_json"] ) + + if train_dict["ssda_network"]: + preprocessing_json_target = ( + Path(kwargs["caps_target"]) + / "tensor_extraction" + / kwargs["preprocessing_dict_target"] + ) else: caps_dict = CapsDataset.create_caps_dict( train_dict["caps_directory"], train_dict["multi_cohort"] @@ -98,12 +110,33 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs): f"Preprocessing JSON {kwargs['preprocessing_json']} was not found for any CAPS " f"in {caps_dict}." ) + # To CHECK AND CHANGE + if train_dict["ssda_network"]: + caps_target = Path(kwargs["caps_target"]) + preprocessing_json_target = ( + caps_target / "tensor_extraction" / kwargs["preprocessing_dict_target"] + ) + + if preprocessing_json_target.is_file(): + logger.info( + f"Preprocessing JSON {preprocessing_json_target} found in CAPS {caps_target}." + ) + json_found = True + if not json_found: + raise ValueError( + f"Preprocessing JSON {kwargs['preprocessing_json_target']} was not found for any CAPS " + f"in {caps_target}." + ) # Mode and preprocessing preprocessing_dict = read_preprocessing(preprocessing_json) train_dict["preprocessing_dict"] = preprocessing_dict train_dict["mode"] = preprocessing_dict["mode"] + if train_dict["ssda_network"]: + preprocessing_dict_target = read_preprocessing(preprocessing_json_target) + train_dict["preprocessing_dict_target"] = preprocessing_dict_target + # Add default values if missing if ( preprocessing_dict["mode"] == "roi" diff --git a/clinicadl/tsvtools/split/split.py b/clinicadl/tsvtools/split/split.py index 81ea7606f..158819f14 100644 --- a/clinicadl/tsvtools/split/split.py +++ b/clinicadl/tsvtools/split/split.py @@ -281,6 +281,8 @@ def split_diagnoses( if categorical_split_variable is None: categorical_split_variable = "diagnosis" + else: + categorical_split_variable.append("diagnosis") # Read files diagnosis_df_path = data_tsv.name diff --git a/clinicadl/utils/caps_dataset/data.py b/clinicadl/utils/caps_dataset/data.py index 027163b19..0542085fa 100644 --- a/clinicadl/utils/caps_dataset/data.py +++ b/clinicadl/utils/caps_dataset/data.py @@ -71,8 +71,11 @@ def __init__( raise AttributeError("Child class of CapsDataset, must set mode attribute.") self.df = data_df - - mandatory_col = {"participant_id", "session_id", "cohort"} + mandatory_col = { + "participant_id", + "session_id", + "cohort", + } if self.label_presence and self.label is not None: mandatory_col.add(self.label) @@ -108,6 +111,18 @@ def label_fn(self, target: Union[str, float, int]) -> Union[float, int]: else: return self.label_code[str(target)] + def domain_fn(self, target: Union[str, float, int]) -> Union[float, int]: + """ + Returns the label value usable in criterion. + + Args: + target: value of the target. + Returns: + label: value of the label usable in criterion. + """ + domain_code = {"t1": 0, "flair": 1} + return domain_code[str(target)] + def __len__(self) -> int: return len(self.df) * self.elem_per_image @@ -209,7 +224,12 @@ def _get_meta_data(self, idx: int) -> Tuple[str, str, str, int, int]: else: label = -1 - return participant, session, cohort, elem_idx, label + if "domain" in self.df.columns: + domain = self.df.loc[image_idx, "domain"] + domain = self.domain_fn(domain) + else: + domain = "" # TO MODIFY + return participant, session, cohort, elem_idx, label, domain def _get_full_image(self) -> torch.Tensor: """ @@ -323,7 +343,7 @@ def elem_index(self): return None def __getitem__(self, idx): - participant, session, cohort, _, label = self._get_meta_data(idx) + participant, session, cohort, _, label, domain = self._get_meta_data(idx) image_path = self._get_image_path(participant, session, cohort) image = torch.load(image_path) @@ -341,6 +361,7 @@ def __getitem__(self, idx): "session_id": session, "image_id": 0, "image_path": image_path.as_posix(), + "domain": domain, } return sample @@ -400,7 +421,9 @@ def elem_index(self): return self.patch_index def __getitem__(self, idx): - participant, session, cohort, patch_idx, label = self._get_meta_data(idx) + participant, session, cohort, patch_idx, label, domain = self._get_meta_data( + idx + ) image_path = self._get_image_path(participant, session, cohort) if self.prepare_dl: @@ -507,7 +530,7 @@ def elem_index(self): return self.roi_index def __getitem__(self, idx): - participant, session, cohort, roi_idx, label = self._get_meta_data(idx) + participant, session, cohort, roi_idx, label, domain = self._get_meta_data(idx) image_path = self._get_image_path(participant, session, cohort) if self.roi_list is None: @@ -672,7 +695,9 @@ def elem_index(self): return self.slice_index def __getitem__(self, idx): - participant, session, cohort, slice_idx, label = self._get_meta_data(idx) + participant, session, cohort, slice_idx, label, domain = self._get_meta_data( + idx + ) slice_idx = slice_idx + self.discarded_slices[0] image_path = self._get_image_path(participant, session, cohort) diff --git a/clinicadl/utils/cli_param/option.py b/clinicadl/utils/cli_param/option.py index 8ae3ab7e3..ce28850e7 100644 --- a/clinicadl/utils/cli_param/option.py +++ b/clinicadl/utils/cli_param/option.py @@ -57,6 +57,13 @@ multiple=True, default=None, ) +ssda_network = click.option( + "--ssda_network", + type=bool, + default=False, + show_default=True, + help="ssda training.", +) # GENERATE participant_list = click.option( "--participants_tsv", diff --git a/clinicadl/utils/cli_param/train_option.py b/clinicadl/utils/cli_param/train_option.py index 82d2b2b09..d64c7fe2c 100644 --- a/clinicadl/utils/cli_param/train_option.py +++ b/clinicadl/utils/cli_param/train_option.py @@ -95,6 +95,12 @@ default=None, help="If provided uses a multi-network framework.", ) +ssda_network = cli_param.option_group.model_group.option( + "--ssda_network/--single_network", + type=bool, + default=None, + help="If provided uses a ssda-network framework.", +) # Task label = cli_param.option_group.task_group.option( "--label", @@ -206,6 +212,34 @@ # default="random", help="Sampler used to load the training data set.", ) +caps_target = cli_param.option_group.data_group.option( + "--caps_target", + "-d", + type=str, + default=None, + help="CAPS of target data.", +) +tsv_target_lab = cli_param.option_group.data_group.option( + "--tsv_target_lab", + "-d", + type=str, + default=None, + help="TSV of labeled target data.", +) +tsv_target_unlab = cli_param.option_group.data_group.option( + "--tsv_target_unlab", + "-d", + type=str, + default=None, + help="TSV of unllabeled target data.", +) +preprocessing_dict_target = cli_param.option_group.data_group.option( + "--preprocessing_dict_target", + "-d", + type=str, + default=None, + help="Path to json taget.", +) # Cross validation n_splits = cli_param.option_group.cross_validation.option( "--n_splits", diff --git a/clinicadl/utils/maps_manager/logwriter.py b/clinicadl/utils/maps_manager/logwriter.py index 5cdae637b..b55739b2c 100644 --- a/clinicadl/utils/maps_manager/logwriter.py +++ b/clinicadl/utils/maps_manager/logwriter.py @@ -64,7 +64,7 @@ def __init__( self.writer_train = SummaryWriter(self.file_dir / "tensorboard" / "train") self.writer_valid = SummaryWriter(self.file_dir / "tensorboard" / "validation") - def step(self, epoch, i, metrics_train, metrics_valid, len_epoch): + def step(self, epoch, i, metrics_train, metrics_valid, len_epoch, file_name=None): """ Write a new row on the output file training.tsv. @@ -77,8 +77,10 @@ def step(self, epoch, i, metrics_train, metrics_valid, len_epoch): """ from time import time - # Write TSV file - tsv_path = self.file_dir / "training.tsv" + if file_name: + tsv_path = self.file_dir / file_name + else: + tsv_path = self.file_dir / "training.tsv" t_current = time() - self.beginning_time general_row = [epoch, i, t_current] diff --git a/clinicadl/utils/maps_manager/maps_manager.py b/clinicadl/utils/maps_manager/maps_manager.py index c72d6072c..ef776d234 100644 --- a/clinicadl/utils/maps_manager/maps_manager.py +++ b/clinicadl/utils/maps_manager/maps_manager.py @@ -82,6 +82,7 @@ def __init__( test_parameters = self.get_parameters() test_parameters = change_str_to_path(test_parameters) self.parameters = add_default_values(test_parameters) + self.ssda_network = False # A MODIFIER self.task_manager = self._init_task_manager(n_classes=self.output_size) self.split_name = ( self._check_split_wording() @@ -156,6 +157,8 @@ def train(self, split_list: List[int] = None, overwrite: bool = False): if self.multi_network: self._train_multi(split_list, resume=False) + elif self.ssda_network: + self._train_ssda(split_list, resume=False) else: self._train_single(split_list, resume=False) @@ -185,6 +188,8 @@ def resume(self, split_list: List[int] = None): if self.multi_network: self._train_multi(split_list, resume=True) + elif self.ssda_network: + self._train_ssda(split_list, resume=True) else: self._train_single(split_list, resume=True) @@ -848,6 +853,214 @@ def _train_multi(self, split_list: List[int] = None, resume: bool = False): self._erase_tmp(split) + def _train_ssda(self, split_list=None, resume=False): + """ + Trains a single CNN for a source and target domain using semi-supervised domain adaptation. + + Args: + split_list (list[int]): list of splits that are trained. + resume (bool): If True the job is resumed from checkpoint. + """ + from torch.utils.data import DataLoader + + train_transforms, all_transforms = get_transforms( + normalize=self.normalize, + data_augmentation=self.data_augmentation, + size_reduction=self.size_reduction, + size_reduction_factor=self.size_reduction_factor, + ) + + split_manager = self._init_split_manager(split_list) + + split_manager_target_lab = self._init_split_manager_ssda( + self.caps_target, self.tsv_target_lab, split_list + ) + + for split in split_manager.split_iterator(): + logger.info(f"Training split {split}") + seed_everything(self.seed, self.deterministic, self.compensation) + + split_df_dict = split_manager[split] + split_df_dict_target_lab = split_manager_target_lab[split] + + logger.debug("Loading source training data...") + data_train_source = return_dataset( + self.caps_directory, + split_df_dict["train"], + self.preprocessing_dict, + train_transformations=train_transforms, + all_transformations=all_transforms, + multi_cohort=self.multi_cohort, + label=self.label, + label_code=self.label_code, + ) + + logger.debug("Loading target labelled training data...") + data_train_target_labeled = return_dataset( + Path(self.caps_target), # TO CHECK + split_df_dict_target_lab["train"], + self.preprocessing_dict_target, + train_transformations=train_transforms, + all_transformations=all_transforms, + multi_cohort=False, # A checker + label=self.label, + label_code=self.label_code, + ) + from torch.utils.data import ConcatDataset, DataLoader + + combined_dataset = ConcatDataset( + [data_train_source, data_train_target_labeled] + ) + + logger.debug("Loading target unlabelled training data...") + data_target_unlabeled = return_dataset( + Path(self.caps_target), + pd.read_csv(self.tsv_target_unlab, sep="\t"), + self.preprocessing_dict_target, + train_transformations=train_transforms, + all_transformations=all_transforms, + multi_cohort=False, # A checker + label=self.label, + label_code=self.label_code, + ) + + logger.debug("Loading validation source data...") + data_valid_source = return_dataset( + self.caps_directory, + split_df_dict["validation"], + self.preprocessing_dict, + train_transformations=train_transforms, + all_transformations=all_transforms, + multi_cohort=self.multi_cohort, + label=self.label, + label_code=self.label_code, + ) + logger.debug("Loading validation target labelled data...") + data_valid_target_labeled = return_dataset( + Path(self.caps_target), + split_df_dict_target_lab["validation"], + self.preprocessing_dict_target, + train_transformations=train_transforms, + all_transformations=all_transforms, + multi_cohort=False, + label=self.label, + label_code=self.label_code, + ) + train_source_sampler = self.task_manager.generate_sampler( + data_train_source, self.sampler + ) + + logger.info( + f"Getting train and validation loader with batch size {self.batch_size}" + ) + + ## Oversampling of the target dataset + from torch.utils.data import SubsetRandomSampler + + # Create index lists for target labeled dataset + labeled_indices = list(range(len(data_train_target_labeled))) + + # Oversample the indices for the target labeld dataset to match the size of the labeled source dataset + data_train_source_size = len(data_train_source) // self.batch_size + labeled_oversampled_indices = labeled_indices * ( + data_train_source_size // len(labeled_indices) + ) + + # Append remaining indices to match the size of the largest dataset + labeled_oversampled_indices += labeled_indices[ + : data_train_source_size % len(labeled_indices) + ] + + # Create SubsetRandomSamplers using the oversampled indices + labeled_sampler = SubsetRandomSampler(labeled_oversampled_indices) + + train_source_loader = DataLoader( + data_train_source, + batch_size=self.batch_size, + sampler=train_source_sampler, + # shuffle=True, # len(data_train_source) < len(data_train_target_labeled), + num_workers=self.n_proc, + worker_init_fn=pl_worker_init_function, + drop_last=True, + ) + logger.info( + f"Train source loader size is {len(train_source_loader)*self.batch_size}" + ) + train_target_loader = DataLoader( + data_train_target_labeled, + batch_size=1, # To limit the need of oversampling + # sampler=train_target_sampler, + sampler=labeled_sampler, + num_workers=self.n_proc, + worker_init_fn=pl_worker_init_function, + # shuffle=True, # len(data_train_target_labeled) < len(data_train_source), + drop_last=True, + ) + logger.info( + f"Train target labeled loader size oversample is {len(train_target_loader)}" + ) + + data_train_target_labeled.df = data_train_target_labeled.df[ + ["participant_id", "session_id", "diagnosis", "cohort", "domain"] + ] + + train_target_unl_loader = DataLoader( + data_target_unlabeled, + batch_size=self.batch_size, + num_workers=self.n_proc, + # sampler=unlabeled_sampler, + worker_init_fn=pl_worker_init_function, + shuffle=True, + drop_last=True, + ) + + logger.info( + f"Train target unlabeled loader size is {len(train_target_unl_loader)*self.batch_size}" + ) + + valid_loader_source = DataLoader( + data_valid_source, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.n_proc, + ) + logger.info( + f"Validation loader source size is {len(valid_loader_source)*self.batch_size}" + ) + + valid_loader_target = DataLoader( + data_valid_target_labeled, + batch_size=self.batch_size, # To check + shuffle=False, + num_workers=self.n_proc, + ) + logger.info( + f"Validation loader target size is {len(valid_loader_target)*self.batch_size}" + ) + + self._train_ssdann( + train_source_loader, + train_target_loader, + train_target_unl_loader, + valid_loader_target, + valid_loader_source, + split, + resume=resume, + ) + + self._ensemble_prediction( + "train", + split, + self.selection_metrics, + ) + self._ensemble_prediction( + "validation", + split, + self.selection_metrics, + ) + + self._erase_tmp(split) + def _train( self, train_loader, @@ -1146,9 +1359,334 @@ def _train( nb_images=1, network=network, ) - self.callback_handler.on_train_end(self.parameters) + def _train_ssdann( + self, + train_source_loader, + train_target_loader, + train_target_unl_loader, + valid_loader, + valid_source_loader, + split, + network=None, + resume=False, + evaluate_source=True, # TO MODIFY + ): + """ + Core function shared by train and resume. + + Args: + train_loader (torch.utils.data.DataLoader): DataLoader wrapping the training set. + valid_loader (torch.utils.data.DataLoader): DataLoader wrapping the validation set. + split (int): Index of the split trained. + network (int): Index of the network trained (used in multi-network setting only). + resume (bool): If True the job is resumed from the checkpoint. + """ + + model, beginning_epoch = self._init_model( + split=split, + resume=resume, + transfer_path=self.transfer_path, + transfer_selection=self.transfer_selection_metric, + ) + + criterion = self.task_manager.get_criterion(self.loss) + logger.debug(f"Criterion for {self.network_task} is {criterion}") + optimizer = self._init_optimizer(model, split=split, resume=resume) + + logger.debug(f"Optimizer used for training is optimizer") + + model.train() + train_source_loader.dataset.train() + train_target_loader.dataset.train() + train_target_unl_loader.dataset.train() + + early_stopping = EarlyStopping( + "min", min_delta=self.tolerance, patience=self.patience + ) + + metrics_valid_target = {"loss": None} + metrics_valid_source = {"loss": None} + + log_writer = LogWriter( + self.maps_path, + self.task_manager.evaluation_metrics + ["loss"], + split, + resume=resume, + beginning_epoch=beginning_epoch, + network=network, + ) + epoch = log_writer.beginning_epoch + + retain_best = RetainBest(selection_metrics=list(self.selection_metrics)) + import numpy as np + + while epoch < self.epochs and not early_stopping.step( + metrics_valid_target["loss"] + ): + logger.info(f"Beginning epoch {epoch}.") + + model.zero_grad() + evaluation_flag, step_flag = True, True + + for i, (data_source, data_target, data_target_unl) in enumerate( + zip(train_source_loader, train_target_loader, train_target_unl_loader) + ): + p = ( + float(epoch * len(train_target_loader)) + / 10 + / len(train_target_loader) + ) + alpha = 2.0 / (1.0 + np.exp(-10 * p)) - 1 + # alpha = 0 + _, _, loss_dict = model.compute_outputs_and_loss( + data_source, data_target, data_target_unl, criterion, alpha + ) # TO CHECK + logger.debug(f"Train loss dictionnary {loss_dict}") + loss = loss_dict["loss"] + loss.backward() + if (i + 1) % self.accumulation_steps == 0: + step_flag = False + optimizer.step() + optimizer.zero_grad() + + del loss + + # Evaluate the model only when no gradients are accumulated + if ( + self.evaluation_steps != 0 + and (i + 1) % self.evaluation_steps == 0 + ): + evaluation_flag = False + + # Evaluate on taget data + logger.info("Evaluation on target data") + _, metrics_train_target = self.task_manager.test_da( + model, + train_target_loader, + criterion, + alpha, + target=True, + ) # TO CHECK + + _, metrics_valid_target = self.task_manager.test_da( + model, + valid_loader, + criterion, + alpha, + target=True, + ) + + model.train() + train_target_loader.dataset.train() + + log_writer.step( + epoch, + i, + metrics_train_target, + metrics_valid_target, + len(train_target_loader), + "training_target.tsv", + ) + logger.info( + f"{self.mode} level training loss for target data is {metrics_train_target['loss']} " + f"at the end of iteration {i}" + ) + logger.info( + f"{self.mode} level validation loss for target data is {metrics_valid_target['loss']} " + f"at the end of iteration {i}" + ) + + # Evaluate on source data + logger.info("Evaluation on source data") + _, metrics_train_source = self.task_manager.test_da( + model, train_source_loader, criterion, alpha + ) + _, metrics_valid_source = self.task_manager.test_da( + model, valid_source_loader, criterion, alpha + ) + + model.train() + train_source_loader.dataset.train() + + log_writer.step( + epoch, + i, + metrics_train_source, + metrics_valid_source, + len(train_source_loader), + ) + logger.info( + f"{self.mode} level training loss for source data is {metrics_train_source['loss']} " + f"at the end of iteration {i}" + ) + logger.info( + f"{self.mode} level validation loss for source data is {metrics_valid_source['loss']} " + f"at the end of iteration {i}" + ) + + # If no step has been performed, raise Exception + if step_flag: + raise Exception( + "The model has not been updated once in the epoch. The accumulation step may be too large." + ) + + # If no evaluation has been performed, warn the user + elif evaluation_flag and self.evaluation_steps != 0: + logger.warning( + f"Your evaluation steps {self.evaluation_steps} are too big " + f"compared to the size of the dataset. " + f"The model is evaluated only once at the end epochs." + ) + + # Update weights one last time if gradients were computed without update + if (i + 1) % self.accumulation_steps != 0: + optimizer.step() + optimizer.zero_grad() + # Always test the results and save them once at the end of the epoch + model.zero_grad() + logger.debug(f"Last checkpoint at the end of the epoch {epoch}") + + if evaluate_source: + logger.info( + f"Evaluate source data at the end of the epoch {epoch} with alpha: {alpha}." + ) + _, metrics_train_source = self.task_manager.test_da( + model, + train_source_loader, + criterion, + alpha, + True, + False, + ) + _, metrics_valid_source = self.task_manager.test_da( + model, + valid_source_loader, + criterion, + alpha, + True, + False, + ) + + log_writer.step( + epoch, + i, + metrics_train_source, + metrics_valid_source, + len(train_source_loader), + ) + + logger.info( + f"{self.mode} level training loss for source data is {metrics_train_source['loss']} " + f"at the end of iteration {i}" + ) + logger.info( + f"{self.mode} level validation loss for source data is {metrics_valid_source['loss']} " + f"at the end of iteration {i}" + ) + + _, metrics_train_target = self.task_manager.test_da( + model, + train_target_loader, + criterion, + alpha, + target=True, + ) + _, metrics_valid_target = self.task_manager.test_da( + model, + valid_loader, + criterion, + alpha, + target=True, + ) + + model.train() + train_source_loader.dataset.train() + train_target_loader.dataset.train() + + log_writer.step( + epoch, + i, + metrics_train_target, + metrics_valid_target, + len(train_target_loader), + "training_target.tsv", + ) + + logger.info( + f"{self.mode} level training loss for target data is {metrics_train_target['loss']} " + f"at the end of iteration {i}" + ) + logger.info( + f"{self.mode} level validation loss for target data is {metrics_valid_target['loss']} " + f"at the end of iteration {i}" + ) + + # Save checkpoints and best models + best_dict = retain_best.step(metrics_valid_target) + self._write_weights( + { + "model": model.state_dict(), + "epoch": epoch, + "name": self.architecture, + }, + best_dict, + split, + network=network, + ) + self._write_weights( + { + "optimizer": optimizer.state_dict(), # TO MODIFY + "epoch": epoch, + "name": self.optimizer, + }, + None, + split, + filename="optimizer.pth.tar", + ) + + epoch += 1 + + self._test_loader_ssda( + train_target_loader, + criterion, + data_group="train", + split=split, + selection_metrics=self.selection_metrics, + network=network, + target=True, + alpha=0, + ) + self._test_loader_ssda( + valid_loader, + criterion, + data_group="validation", + split=split, + selection_metrics=self.selection_metrics, + network=network, + target=True, + alpha=0, + ) + + if self.task_manager.save_outputs: + self._compute_output_tensors( + train_target_loader.dataset, + "train", + split, + self.selection_metrics, + nb_images=1, + network=network, + ) + self._compute_output_tensors( + train_target_loader.dataset, + "validation", + split, + self.selection_metrics, + nb_images=1, + network=network, + ) + def _test_loader( self, dataloader, @@ -1220,6 +1758,72 @@ def _test_loader( data_group=data_group, ) + def _test_loader_ssda( + self, + dataloader, + criterion, + alpha, + data_group, + split, + selection_metrics, + use_labels=True, + gpu=None, + network=None, + target=False, + ): + """ + Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files. + + Args: + dataloader (torch.utils.data.DataLoader): DataLoader wrapping the test CapsDataset. + criterion (torch.nn.modules.loss._Loss): optimization criterion used during training. + data_group (str): name of the data group used for the testing task. + split (int): Index of the split used to train the model tested. + selection_metrics (list[str]): List of metrics used to select the best models which are tested. + use_labels (bool): If True, the labels must exist in test meta-data and metrics are computed. + gpu (bool): If given, a new value for the device of the model will be computed. + network (int): Index of the network tested (only used in multi-network setting). + """ + for selection_metric in selection_metrics: + log_dir = ( + self.maps_path + / f"{self.split_name}-{split}" + / f"best-{selection_metric}" + / data_group + ) + self.write_description_log( + log_dir, + data_group, + dataloader.dataset.caps_dict, + dataloader.dataset.df, + ) + + # load the best trained model during the training + model, _ = self._init_model( + transfer_path=self.maps_path, + split=split, + transfer_selection=selection_metric, + gpu=gpu, + network=network, + ) + prediction_df, metrics = self.task_manager.test_da( + model, + dataloader, + criterion, + target=target, + ) + if use_labels: + if network is not None: + metrics[f"{self.mode}_id"] = network + logger.info( + f"{self.mode} level {data_group} loss is {metrics['loss']} for model selected on {selection_metric}" + ) + + # Replace here + self._mode_level_to_tsv( + prediction_df, metrics, split, selection_metric, data_group=data_group + ) + @torch.no_grad() def _compute_output_nifti( self, @@ -1490,7 +2094,6 @@ def _check_args(self, parameters): split_manager = self._init_split_manager(None) train_df = split_manager[0]["train"] - if "label" not in self.parameters: self.parameters["label"] = None @@ -1507,6 +2110,7 @@ def _check_args(self, parameters): self.parameters["label_code"] = self.task_manager.generate_label_code( train_df, self.label ) + full_dataset = return_dataset( self.caps_directory, train_df, @@ -2201,6 +2805,27 @@ def _init_split_manager(self, split_list=None): kwargs[arg] = self.parameters[arg] return split_class(**kwargs) + def _init_split_manager_ssda(self, caps_dir, tsv_dir, split_list=None): + # A intégrer directement dans _init_split_manager + from clinicadl.utils import split_manager + + split_class = getattr(split_manager, self.validation) + args = list( + split_class.__init__.__code__.co_varnames[ + : split_class.__init__.__code__.co_argcount + ] + ) + args.remove("self") + args.remove("split_list") + kwargs = {"split_list": split_list} + for arg in args: + kwargs[arg] = self.parameters[arg] + + kwargs["caps_directory"] = Path(caps_dir) + kwargs["tsv_path"] = Path(tsv_dir) + + return split_class(**kwargs) + def _init_task_manager(self, df=None, n_classes=None): from clinicadl.utils.task_manager import ( ClassificationManager, diff --git a/clinicadl/utils/network/__init__.py b/clinicadl/utils/network/__init__.py index 730bc5446..33c0765f5 100644 --- a/clinicadl/utils/network/__init__.py +++ b/clinicadl/utils/network/__init__.py @@ -2,6 +2,7 @@ from .cnn.models import ( Conv4_FC3, Conv5_FC3, + Conv5_FC3_SSDA, ResNet3D, SqueezeExcitationCNN, Stride_Conv5_FC3, diff --git a/clinicadl/utils/network/cnn/models.py b/clinicadl/utils/network/cnn/models.py index 6c336794a..def687ce9 100644 --- a/clinicadl/utils/network/cnn/models.py +++ b/clinicadl/utils/network/cnn/models.py @@ -8,7 +8,7 @@ from clinicadl.utils.network.cnn.resnet3D import ResNetDesigner3D from clinicadl.utils.network.cnn.SECNN import SECNNDesigner3D from clinicadl.utils.network.network_utils import PadMaxPool2d, PadMaxPool3d -from clinicadl.utils.network.sub_network import CNN +from clinicadl.utils.network.sub_network import CNN, CNN_SSDA def get_layers_fn(input_size): @@ -375,3 +375,97 @@ def get_dimension(): @staticmethod def get_task(): return ["classification"] + + +class Conv5_FC3_SSDA(CNN_SSDA): + """ + Reduce the 2D or 3D input image to an array of size output_size. + """ + + def __init__(self, input_size, gpu=True, output_size=2, dropout=0.5): + conv, norm, pool = get_layers_fn(input_size) + # fmt: off + convolutions = nn.Sequential( + conv(input_size[0], 8, 3, padding=1), + norm(8), + nn.ReLU(), + pool(2, 2), + + conv(8, 16, 3, padding=1), + norm(16), + nn.ReLU(), + pool(2, 2), + + conv(16, 32, 3, padding=1), + norm(32), + nn.ReLU(), + pool(2, 2), + + conv(32, 64, 3, padding=1), + norm(64), + nn.ReLU(), + pool(2, 2), + + conv(64, 128, 3, padding=1), + norm(128), + nn.ReLU(), + pool(2, 2), + + # conv(128, 256, 3, padding=1), + # norm(256), + # nn.ReLU(), + # pool(2, 2), + ) + + # Compute the size of the first FC layer + input_tensor = torch.zeros(input_size).unsqueeze(0) + output_convolutions = convolutions(input_tensor) + + fc_class_source = nn.Sequential( + nn.Flatten(), + nn.Dropout(p=dropout), + + nn.Linear(np.prod(list(output_convolutions.shape)).item(), 1300), + nn.ReLU(), + + nn.Linear(1300, 50), + nn.ReLU(), + + nn.Linear(50, output_size) + ) + + + fc_class_target= nn.Sequential( + nn.Flatten(), + nn.Dropout(p=dropout), + + nn.Linear(np.prod(list(output_convolutions.shape)).item(), 1300), + nn.ReLU(), + + nn.Linear(1300, 50), + nn.ReLU(), + + nn.Linear(50, output_size) + ) + + fc_domain = nn.Sequential( + nn.Flatten(), + nn.Dropout(p=dropout), + + nn.Linear(np.prod(list(output_convolutions.shape)).item(), 1300), + nn.ReLU(), + + nn.Linear(1300, 50), + nn.ReLU(), + + nn.Linear(50, output_size) + ) + # fmt: on + super().__init__( + convolutions=convolutions, + fc_class_source=fc_class_source, + fc_class_target=fc_class_target, + fc_domain=fc_domain, + n_classes=output_size, + gpu=gpu, + ) diff --git a/clinicadl/utils/network/network_utils.py b/clinicadl/utils/network/network_utils.py index dee93108c..569596491 100644 --- a/clinicadl/utils/network/network_utils.py +++ b/clinicadl/utils/network/network_utils.py @@ -3,6 +3,7 @@ """ import torch.nn as nn +from torch.autograd import Function class Reshape(nn.Module): @@ -156,3 +157,14 @@ def torch_summarize(model, show_weights=True, show_parameters=True): tmpstr = tmpstr + ")" return tmpstr + + +class ReverseLayerF(Function): + def forward(self, x, alpha): + self.alpha = alpha + return x.view_as(x) + + def backward(self, grad_output): + output = grad_output.neg() * self.alpha + + return output, None diff --git a/clinicadl/utils/network/sub_network.py b/clinicadl/utils/network/sub_network.py index 0e2c08c29..46270f4e7 100644 --- a/clinicadl/utils/network/sub_network.py +++ b/clinicadl/utils/network/sub_network.py @@ -11,6 +11,7 @@ CropMaxUnpool3d, PadMaxPool2d, PadMaxPool3d, + ReverseLayerF, ) logger = getLogger("clinicadl.networks") @@ -133,3 +134,142 @@ def compute_outputs_and_loss(self, input_dict, criterion, use_labels=True): loss = torch.Tensor([0]) return train_output, {"loss": loss} + + +class CNN_SSDA(Network): + def __init__( + self, + convolutions, + fc_class_source, + fc_class_target, + fc_domain, + n_classes, + gpu=False, + ): + super().__init__(gpu=gpu) + self.convolutions = convolutions.to(self.device) + self.fc_class_source = fc_class_source.to(self.device) + self.fc_class_target = fc_class_target.to(self.device) + self.fc_domain = fc_domain.to(self.device) + self.n_classes = n_classes + + @property + def layers(self): + return nn.Sequential( + self.convolutions, + self.fc_class_source, + self.fc_class_target, + self.fc_domain, + ) + + def transfer_weights(self, state_dict, transfer_class): + if issubclass(transfer_class, CNN_SSDA): + self.load_state_dict(state_dict) + elif issubclass(transfer_class, AutoEncoder): + convolutions_dict = OrderedDict( + [ + (k.replace("encoder.", ""), v) + for k, v in state_dict.items() + if "encoder" in k + ] + ) + self.convolutions.load_state_dict(convolutions_dict) + else: + raise ClinicaDLNetworksError( + f"Cannot transfer weights from {transfer_class} to CNN." + ) + + def forward(self, x, alpha): + x = self.convolutions(x) + x_class_source = self.fc_class_source(x) + x_class_target = self.fc_class_target(x) + x_reverse = ReverseLayerF.apply(x, alpha) + x_domain = self.fc_domain(x_reverse) + return x_class_source, x_class_target, x_domain + + def predict(self, x): + return self.forward(x) + + def compute_outputs_and_loss_test(self, input_dict, criterion, alpha, target): + images, labels = input_dict["image"].to(self.device), input_dict["label"].to( + self.device + ) + train_output_source, train_output_target, _ = self.forward(images, alpha) + + if target: + out = train_output_target + loss_bce = criterion(train_output_target, labels) + + else: + out = train_output_source + loss_bce = criterion(train_output_source, labels) + + return out, {"loss": loss_bce} + + def compute_outputs_and_loss( + self, data_source, data_target, data_target_unl, criterion, alpha + ): + images, labels = ( + data_source["image"].to(self.device), + data_source["label"].to(self.device), + ) + + images_target, labels_target = ( + data_target["image"].to(self.device), + data_target["label"].to(self.device), + ) + + images_target_unl = data_target_unl["image"].to(self.device) + + ( + train_output_class_source, + _, + train_output_domain_s, + ) = self.forward(images, alpha) + + ( + _, + train_output_class_target, + train_output_domain_t, + ) = self.forward(images_target, alpha) + + _, _, train_output_domain_target_unlab = self.forward(images_target_unl, alpha) + + loss_classif_source = criterion(train_output_class_source, labels) + loss_classif_target = criterion(train_output_class_target, labels_target) + + loss_classif = loss_classif_source + loss_classif_target + + labels_domain_s = ( + torch.zeros(data_source["image"].shape[0]).long().to(self.device) + ) + + labels_domain_tl = ( + torch.ones(data_target["image"].shape[0]).long().to(self.device) + ) + + labels_domain_tu = ( + torch.ones(data_target_unl["image"].shape[0]).long().to(self.device) + ) + + loss_domain_lab = criterion(train_output_domain_s, labels_domain_s) + loss_domain_lab_t = criterion(train_output_domain_t, labels_domain_tl) + loss_domain_t_unl = criterion( + train_output_domain_target_unlab, labels_domain_tu + ) + + loss_domain = loss_domain_lab + loss_domain_lab_t + loss_domain_t_unl + + total_loss = loss_classif + 0.1 * loss_domain + + return ( + train_output_class_source, + train_output_class_target, + {"loss": total_loss}, + ) + + def lr_scheduler(self, lr, optimizer, p): + lr = lr / (1 + 10 * p) ** 0.75 + for param_group in optimizer.param_groups: + param_group["lr"] = lr + return optimizer diff --git a/clinicadl/utils/network/vae/base_vae.py b/clinicadl/utils/network/vae/base_vae.py index 3e2dfeba8..19f7cf30e 100644 --- a/clinicadl/utils/network/vae/base_vae.py +++ b/clinicadl/utils/network/vae/base_vae.py @@ -39,6 +39,7 @@ def predict(self, x): output, _, _ = self.forward(x) return output + # Forward def forward(self, x): mu, logVar = self.encode(x) z = self.reparameterize(mu, logVar) diff --git a/clinicadl/utils/split_manager/split_manager.py b/clinicadl/utils/split_manager/split_manager.py index 1f0e1a86f..0a0fc751f 100644 --- a/clinicadl/utils/split_manager/split_manager.py +++ b/clinicadl/utils/split_manager/split_manager.py @@ -134,7 +134,6 @@ def concatenate_diagnoses( logger.debug(f"Validation data loaded at {valid_path}") if cohort_diagnoses is None: cohort_diagnoses = self.diagnoses - if self.baseline: train_path = train_path / "train_baseline.tsv" else: @@ -142,6 +141,7 @@ def concatenate_diagnoses( valid_path = valid_path / "validation_baseline.tsv" + print(train_path) train_df = pd.read_csv(train_path, sep="\t") valid_df = pd.read_csv(valid_path, sep="\t") @@ -195,8 +195,9 @@ def concatenate_diagnoses( ) except: pass - - train_df = train_df[train_df.diagnosis.isin(cohort_diagnoses)] + train_df = train_df[ + train_df.diagnosis.isin(cohort_diagnoses) + ] # TO MODIFY with train valid_df = valid_df[valid_df.diagnosis.isin(cohort_diagnoses)] train_df.reset_index(inplace=True, drop=True) diff --git a/clinicadl/utils/task_manager/classification.py b/clinicadl/utils/task_manager/classification.py index 11c191419..3265698fa 100644 --- a/clinicadl/utils/task_manager/classification.py +++ b/clinicadl/utils/task_manager/classification.py @@ -9,6 +9,9 @@ from torch.utils.data.distributed import DistributedSampler from clinicadl.utils.exceptions import ClinicaDLArgumentError + +logger = getLogger("clinicadl.task_manager") + from clinicadl.utils.task_manager.task_manager import TaskManager logger = getLogger("clinicadl.task_manager") @@ -119,6 +122,32 @@ def generate_sampler( f"The option {sampler_option} for sampler on classification task is not implemented" ) + @staticmethod + def generate_sampler_ssda(dataset, df, sampler_option="random", n_bins=5): + n_labels = df["diagnosis_train"].nunique() + count = np.zeros(n_labels) + + for idx in df.index: + label = df.loc[idx, "diagnosis_train"] + key = dataset.label_fn(label) + count[key] += 1 + + weight_per_class = 1 / np.array(count) + weights = [] + + for idx, label in enumerate(df["diagnosis_train"].values): + key = dataset.label_fn(label) + weights += [weight_per_class[key]] * dataset.elem_per_image + + if sampler_option == "random": + return sampler.RandomSampler(weights) + elif sampler_option == "weighted": + return sampler.WeightedRandomSampler(weights, len(weights)) + else: + raise NotImplementedError( + f"The option {sampler_option} for sampler on classification task is not implemented" + ) + def ensemble_prediction( self, performance_df, diff --git a/clinicadl/utils/task_manager/task_manager.py b/clinicadl/utils/task_manager/task_manager.py index 28b56edf8..b5c7602b1 100644 --- a/clinicadl/utils/task_manager/task_manager.py +++ b/clinicadl/utils/task_manager/task_manager.py @@ -238,3 +238,53 @@ def test( torch.cuda.empty_cache() return results_df, metrics_dict + + def test_da( + self, + model: Network, + dataloader: DataLoader, + criterion: _Loss, + alpha: float = 0, + use_labels: bool = True, + target: bool = True, + ) -> Tuple[pd.DataFrame, Dict[str, float]]: + """ + Computes the predictions and evaluation metrics. + + Args: + model: the model trained. + dataloader: wrapper of a CapsDataset. + criterion: function to calculate the loss. + use_labels: If True the true_label will be written in output DataFrame + and metrics dict will be created. + Returns: + the results and metrics on the image level. + """ + model.eval() + dataloader.dataset.eval() + results_df = pd.DataFrame(columns=self.columns) + total_loss = 0 + with torch.no_grad(): + for i, data in enumerate(dataloader): + outputs, loss_dict = model.compute_outputs_and_loss_test( + data, criterion, alpha, target + ) + total_loss += loss_dict["loss"].item() + + # Generate detailed DataFrame + for idx in range(len(data["participant_id"])): + row = self.generate_test_row(idx, data, outputs) + row_df = pd.DataFrame(row, columns=self.columns) + results_df = pd.concat([results_df, row_df]) + + del outputs, loss_dict + results_df.reset_index(inplace=True, drop=True) + + if not use_labels: + metrics_dict = None + else: + metrics_dict = self.compute_metrics(results_df) + metrics_dict["loss"] = total_loss + torch.cuda.empty_cache() + + return results_df, metrics_dict diff --git a/poetry.lock b/poetry.lock index 5e3fd5361..ce30d7d93 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3139,6 +3139,42 @@ test = ["altair", "bsmschema", "coverage[toml]", "pytest (>=3.3)", "pytest-cov"] tests = ["pybids[test]"] tutorial = ["ipykernel", "jinja2", "jupyter-client", "markupsafe", "nbconvert"] +[[package]] +name = "pybids" +version = "0.15.6" +description = "bids: interface with datasets conforming to BIDS" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pybids-0.15.6-py3-none-any.whl", hash = "sha256:8b3257379138669c2a995c65e0f08e8f9a007784aebba115815cf79e0d90bb5f"}, + {file = "pybids-0.15.6.tar.gz", hash = "sha256:3a3596d3cb725431e41745f73dafb5d07603ffb1b2b29e43904b7d2da571b1d3"}, +] + +[package.dependencies] +bids-validator = "*" +click = ">=8.0" +formulaic = ">=0.2.4,<0.6" +nibabel = ">=2.1" +num2words = "*" +numpy = [ + {version = "*", markers = "python_version >= \"3.9\""}, + {version = "<1.25.0.dev0", markers = "python_version < \"3.9\""}, +] +pandas = ">=0.23" +scipy = "*" +sqlalchemy = "<1.4.0.dev0" + +[package.extras] +ci-tests = ["codecov", "pybids[test]", "pytest-xdist"] +dev = ["pybids[doc,plotting,test]"] +doc = ["jupytext", "myst-nb", "numpydoc", "sphinx (>=2.2,!=5.1.0)", "sphinx-rtd-theme"] +docs = ["pybids[doc]"] +plotting = ["graphviz"] +test = ["bsmschema", "coverage[toml]", "pytest (>=3.3)", "pytest-cov"] +tests = ["pybids[test]"] +tutorial = ["ipykernel", "jinja2 (<3)", "jupyter-client", "markupsafe (<2.1)", "nbconvert"] + [[package]] name = "pydicom" version = "2.4.3"