diff --git a/clinicadl/maps_manager/maps_manager.py b/clinicadl/maps_manager/maps_manager.py index 3b32486b5..a9d3bfc67 100644 --- a/clinicadl/maps_manager/maps_manager.py +++ b/clinicadl/maps_manager/maps_manager.py @@ -175,6 +175,7 @@ def _check_args(self, parameters): ) split_manager = self._init_split_manager(None) + train_df = split_manager[0]["train"] if "label" not in self.parameters: self.parameters["label"] = None diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index 9606477b6..9af58175d 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -166,7 +166,7 @@ def resume(self, splits: List[int]) -> None: ) # TODO : check these two lines. Why do we need a split_manager? split_manager = init_split_manager( - validation=self.config.validation, + validation=self.maps_manager.validation, parameters=self.config.model_dump(), split_list=splits, ) @@ -225,7 +225,7 @@ def train( else: split_manager = init_split_manager( - self.config.validation, self.config.model_dump(), split_list + self.maps_manager.validation, self.config.model_dump(), split_list ) for split in split_manager.split_iterator(): logger.info(f"Training split {split}") @@ -249,7 +249,7 @@ def train( def check_split_list(self, split_list, overwrite): existing_splits = [] split_manager = init_split_manager( - self.config.validation, self.config.model_dump(), split_list + self.maps_manager.validation, self.config.model_dump(), split_list ) for split in split_manager.split_iterator(): split_path = ( @@ -289,7 +289,7 @@ def _resume( """ missing_splits = [] split_manager = init_split_manager( - self.config.validation, self.config.model_dump(), split_list + self.maps_manager.validation, self.config.model_dump(), split_list ) for split in split_manager.split_iterator(): @@ -310,7 +310,7 @@ def _resume( self._train_ssda(split_list, resume=True) else: split_manager = init_split_manager( - self.config.validation, self.config.model_dump(), split_list + self.maps_manager.validation, self.config.model_dump(), split_list ) for split in split_manager.split_iterator(): logger.info(f"Training split {split}") diff --git a/clinicadl/validation/split_manager/kfold.py b/clinicadl/validation/split_manager/kfold.py index a87314c7e..a3c26baaa 100644 --- a/clinicadl/validation/split_manager/kfold.py +++ b/clinicadl/validation/split_manager/kfold.py @@ -25,7 +25,6 @@ def __init__( split_list, ) self.n_splits = n_splits - self.validation = "Kfold" def max_length(self) -> int: return self.n_splits diff --git a/clinicadl/validation/split_manager/single_split.py b/clinicadl/validation/split_manager/single_split.py index 92458d409..6ff282bb2 100644 --- a/clinicadl/validation/split_manager/single_split.py +++ b/clinicadl/validation/split_manager/single_split.py @@ -23,7 +23,6 @@ def __init__( multi_cohort, split_list, ) - self.validation = "SingleSplit" def max_length(self) -> int: return 1