Skip to content

Commit

Permalink
test validation attr
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Sep 30, 2024
1 parent 72b3f0a commit d184a51
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 7 deletions.
1 change: 1 addition & 0 deletions clinicadl/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def _check_args(self, parameters):
)

split_manager = self._init_split_manager(None)

train_df = split_manager[0]["train"]
if "label" not in self.parameters:
self.parameters["label"] = None
Expand Down
10 changes: 5 additions & 5 deletions clinicadl/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def resume(self, splits: List[int]) -> None:
)
# TODO : check these two lines. Why do we need a split_manager?
split_manager = init_split_manager(
validation=self.config.validation,
validation=self.maps_manager.validation,
parameters=self.config.model_dump(),
split_list=splits,
)
Expand Down Expand Up @@ -225,7 +225,7 @@ def train(

else:
split_manager = init_split_manager(
self.config.validation, self.config.model_dump(), split_list
self.maps_manager.validation, self.config.model_dump(), split_list
)
for split in split_manager.split_iterator():
logger.info(f"Training split {split}")
Expand All @@ -249,7 +249,7 @@ def train(
def check_split_list(self, split_list, overwrite):
existing_splits = []
split_manager = init_split_manager(
self.config.validation, self.config.model_dump(), split_list
self.maps_manager.validation, self.config.model_dump(), split_list
)
for split in split_manager.split_iterator():
split_path = (
Expand Down Expand Up @@ -289,7 +289,7 @@ def _resume(
"""
missing_splits = []
split_manager = init_split_manager(
self.config.validation, self.config.model_dump(), split_list
self.maps_manager.validation, self.config.model_dump(), split_list
)

for split in split_manager.split_iterator():
Expand All @@ -310,7 +310,7 @@ def _resume(
self._train_ssda(split_list, resume=True)
else:
split_manager = init_split_manager(
self.config.validation, self.config.model_dump(), split_list
self.maps_manager.validation, self.config.model_dump(), split_list
)
for split in split_manager.split_iterator():
logger.info(f"Training split {split}")
Expand Down
1 change: 0 additions & 1 deletion clinicadl/validation/split_manager/kfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(
split_list,
)
self.n_splits = n_splits
self.validation = "Kfold"

def max_length(self) -> int:
return self.n_splits
Expand Down
1 change: 0 additions & 1 deletion clinicadl/validation/split_manager/single_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(
multi_cohort,
split_list,
)
self.validation = "SingleSplit"

def max_length(self) -> int:
return 1
Expand Down

0 comments on commit d184a51

Please sign in to comment.