From f3f0b129b9cd245918902425b2f54624a54f566b Mon Sep 17 00:00:00 2001 From: camillebrianceau Date: Thu, 3 Oct 2024 17:06:46 +0200 Subject: [PATCH] tests --- .../pipelines/train/classification/cli.py | 2 +- .../pipelines/train/reconstruction/cli.py | 2 +- .../pipelines/train/regression/cli.py | 2 +- clinicadl/maps_manager/maps_manager.py | 9 +- clinicadl/maps_manager/tmp_config.py | 4 - clinicadl/random_search/random_search.py | 2 +- clinicadl/splitter/splitter.py | 17 --- clinicadl/trainer/trainer.py | 42 +++--- tests/test_resume.py | 7 +- .../splitter/test_splitter_config.py | 137 ++++++++++++++++++ 10 files changed, 169 insertions(+), 55 deletions(-) create mode 100644 tests/unittests/splitter/test_splitter_config.py diff --git a/clinicadl/commandline/pipelines/train/classification/cli.py b/clinicadl/commandline/pipelines/train/classification/cli.py index 6a7814255..539f6cd42 100644 --- a/clinicadl/commandline/pipelines/train/classification/cli.py +++ b/clinicadl/commandline/pipelines/train/classification/cli.py @@ -115,4 +115,4 @@ def cli(**kwargs): options = merge_cli_and_config_file_options(Task.CLASSIFICATION, **kwargs) config = ClassificationConfig(**options) trainer = Trainer(config) - trainer.train(split_list=config.validation.split, overwrite=True) + trainer.train(split_list=config.split.split, overwrite=True) diff --git a/clinicadl/commandline/pipelines/train/reconstruction/cli.py b/clinicadl/commandline/pipelines/train/reconstruction/cli.py index 37bd50b41..d63bf63f8 100644 --- a/clinicadl/commandline/pipelines/train/reconstruction/cli.py +++ b/clinicadl/commandline/pipelines/train/reconstruction/cli.py @@ -112,4 +112,4 @@ def cli(**kwargs): options = merge_cli_and_config_file_options(Task.RECONSTRUCTION, **kwargs) config = ReconstructionConfig(**options) trainer = Trainer(config) - trainer.train(split_list=config.validation.split, overwrite=True) + trainer.train(split_list=config.split.split, overwrite=True) diff --git a/clinicadl/commandline/pipelines/train/regression/cli.py b/clinicadl/commandline/pipelines/train/regression/cli.py index fbc48e5b9..ff6dd68ca 100644 --- a/clinicadl/commandline/pipelines/train/regression/cli.py +++ b/clinicadl/commandline/pipelines/train/regression/cli.py @@ -111,4 +111,4 @@ def cli(**kwargs): options = merge_cli_and_config_file_options(Task.REGRESSION, **kwargs) config = RegressionConfig(**options) trainer = Trainer(config) - trainer.train(split_list=config.validation.split, overwrite=True) + trainer.train(split_list=config.split.split, overwrite=True) diff --git a/clinicadl/maps_manager/maps_manager.py b/clinicadl/maps_manager/maps_manager.py index 839a0044c..cb3cdf6d7 100644 --- a/clinicadl/maps_manager/maps_manager.py +++ b/clinicadl/maps_manager/maps_manager.py @@ -17,7 +17,8 @@ check_selection_metric, ) from clinicadl.predict.utils import get_prediction -from clinicadl.splitter.splitter import init_splitter +from clinicadl.splitter.config import SplitterConfig +from clinicadl.splitter.splitter import Splitter from clinicadl.trainer.tasks_utils import ( ensemble_prediction, evaluation_metrics, @@ -169,8 +170,9 @@ def _check_args(self, parameters): size_reduction=self.size_reduction, size_reduction_factor=self.size_reduction_factor, ) + splitter_config = SplitterConfig(**self.parameters) + split_manager = Splitter(splitter_config) - split_manager = init_splitter(parameters=self.parameters) train_df = split_manager[0]["train"] if "label" not in self.parameters: self.parameters["label"] = None @@ -316,7 +318,8 @@ def _write_training_data(self): def _write_train_val_groups(self): """Defines the training and validation groups at the initialization""" logger.debug("Writing training and validation groups...") - split_manager = init_splitter(parameters=self.parameters) + splitter_config = SplitterConfig(**self.parameters) + split_manager = Splitter(splitter_config) for split in split_manager.split_iterator(): for data_group in ["train", "validation"]: df = split_manager[split][data_group] diff --git a/clinicadl/maps_manager/tmp_config.py b/clinicadl/maps_manager/tmp_config.py index 5fcae1592..a31af7edb 100644 --- a/clinicadl/maps_manager/tmp_config.py +++ b/clinicadl/maps_manager/tmp_config.py @@ -183,10 +183,6 @@ def check_args(self): ) if self.network_task == "classification": - from clinicadl.splitter.splitter import ( - init_splitter, - ) - if self.n_splits > 1 and self.validation == "SingleSplit": self.validation = "KFoldSplit" diff --git a/clinicadl/random_search/random_search.py b/clinicadl/random_search/random_search.py index e4e944f7f..7929e9382 100755 --- a/clinicadl/random_search/random_search.py +++ b/clinicadl/random_search/random_search.py @@ -38,4 +38,4 @@ def launch_search(launch_directory: Path, job_name): output_maps_directory=maps_directory, **options ) trainer = Trainer(training_config) - trainer.train(split_list=training_config.validation.split, overwrite=True) + trainer.train(split_list=training_config.split.split, overwrite=True) diff --git a/clinicadl/splitter/splitter.py b/clinicadl/splitter/splitter.py index 536c16ef9..83d1ab127 100644 --- a/clinicadl/splitter/splitter.py +++ b/clinicadl/splitter/splitter.py @@ -18,23 +18,6 @@ logger = getLogger("clinicadl.split_manager") -def init_splitter( - parameters, - split_list=None, -): - data_config = DataConfig(**parameters) - validation_config = ValidationConfig(**parameters) - split_config = SplitConfig(**parameters) - - splitter_config = SplitterConfig( - data_config=data_config, - validation_config=validation_config, - split_config=split_config, - ) - - return Splitter(splitter_config, split_list=split_list) - - class Splitter: def __init__( self, diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index ec3e99eab..814aba2f8 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -34,7 +34,8 @@ ) from clinicadl.trainer.tasks_utils import create_training_config from clinicadl.validator.validator import Validator -from clinicadl.splitter.splitter import init_splitter +from clinicadl.splitter.splitter import Splitter +from clinicadl.splitter.config import SplitterConfig from clinicadl.transforms.config import TransformsConfig if TYPE_CHECKING: @@ -158,10 +159,10 @@ def resume(self, splits: List[int]) -> None: stopped_splits = set(find_stopped_splits(self.config.maps_manager.maps_dir)) finished_splits = set(find_finished_splits(self.maps_manager.maps_path)) # TODO : check these two lines. Why do we need a split_manager? - split_manager = init_splitter( - parameters=self.config.get_dict(), - split_list=splits, - ) + + splitter_config = SplitterConfig(**self.config.get_dict()) + split_manager = Splitter(splitter_config, split_list=splits) + split_iterator = split_manager.split_iterator() ### absent_splits = set(split_iterator) - stopped_splits - finished_splits @@ -216,10 +217,9 @@ def train( self._train_ssda(split_list, resume=False) else: - split_manager = init_splitter( - parameters=self.config.get_dict(), - split_list=split_list, - ) + splitter_config = SplitterConfig(**self.config.get_dict()) + split_manager = Splitter(splitter_config, split_list=split_list) + for split in split_manager.split_iterator(): logger.info(f"Training split {split}") seed_everything( @@ -241,10 +241,8 @@ def train( def check_split_list(self, split_list, overwrite): existing_splits = [] - split_manager = init_splitter( - parameters=self.config.get_dict(), - split_list=split_list, - ) + splitter_config = SplitterConfig(**self.config.get_dict()) + split_manager = Splitter(splitter_config, split_list=split_list) for split in split_manager.split_iterator(): split_path = self.maps_manager.maps_path / f"split-{split}" if split_path.is_dir(): @@ -280,10 +278,8 @@ def _resume( If splits specified in input do not exist. """ missing_splits = [] - split_manager = init_splitter( - parameters=self.config.get_dict(), - split_list=split_list, - ) + splitter_config = SplitterConfig(**self.config.get_dict()) + split_manager = Splitter(splitter_config, split_list=split_list) for split in split_manager.split_iterator(): if not (self.maps_manager.maps_path / f"split-{split}" / "tmp").is_dir(): missing_splits.append(split) @@ -502,14 +498,10 @@ def _train_ssda( If True, the job is resumed from checkpoint. """ - split_manager = init_splitter( - parameters=self.config.get_dict(), - split_list=split_list, - ) - split_manager_target_lab = init_splitter( - parameters=self.config.get_dict(), - split_list=split_list, - ) + splitter_config = SplitterConfig(**self.config.get_dict()) + + split_manager = Splitter(splitter_config, split_list=split_list) + split_manager_target_lab = Splitter(splitter_config, split_list=split_list) for split in split_manager.split_iterator(): logger.info(f"Training split {split}") diff --git a/tests/test_resume.py b/tests/test_resume.py index ae3da2b99..1598267d8 100644 --- a/tests/test_resume.py +++ b/tests/test_resume.py @@ -7,7 +7,8 @@ import pytest from clinicadl.maps_manager.maps_manager import MapsManager -from clinicadl.splitter.splitter import init_splitter +from clinicadl.splitter.config import SplitterConfig +from clinicadl.splitter.splitter import Splitter from .testing_tools import modify_maps @@ -49,7 +50,9 @@ def test_resume(cmdopt, tmp_path, test_name): assert flag_error maps_manager = MapsManager(maps_stopped) - split_manager = init_splitter(parameters=maps_manager.parameters) + splitter_config = SplitterConfig(**maps_manager.parameters) + split_manager = Splitter(splitter_config) + for split in split_manager.split_iterator(): performances_flag = ( maps_stopped / f"split-{split}" / "best-loss" / "train" diff --git a/tests/unittests/splitter/test_splitter_config.py b/tests/unittests/splitter/test_splitter_config.py new file mode 100644 index 000000000..ec2a941e4 --- /dev/null +++ b/tests/unittests/splitter/test_splitter_config.py @@ -0,0 +1,137 @@ +from pathlib import Path + +import pytest +from pydantic import ValidationError + +from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig +from clinicadl.config.config.ssda import SSDAConfig +from clinicadl.network.config import NetworkConfig +from clinicadl.splitter.config import SplitConfig, SplitterConfig +from clinicadl.splitter.validation import ValidationConfig +from clinicadl.trainer.transfer_learning import TransferLearningConfig +from clinicadl.transforms.config import TransformsConfig + + +# Tests for customed validators # +@pytest.fixture +def caps_example(): + dir_ = Path(__file__).parents[2] / "ressources" / "caps_example" + return dir_ + + +def test_split_config(): + c = SplitConfig( + n_splits=3, + split=[0], + tsv_path="", + ) + assert c.split == (0,) + + +def test_validation_config(): + c = ValidationConfig( + evaluation_steps=3, + valid_longitudinal=True, + ) + assert not c.skip_leak_check + assert c.selection_metrics == () + + +# Global tests on the TrainingConfig class # +@pytest.fixture +def dummy_arguments(caps_example): + args = { + "caps_directory": caps_example, + "preprocessing_json": "preprocessing.json", + "tsv_path": "", + "maps_dir": "", + "gpu": False, + "architecture": "", + "loss": "", + "selection_metrics": (), + } + return args + + +@pytest.fixture +def splitter_config(): + from pydantic import computed_field + + from clinicadl.splitter.config import SplitterConfig + + class TrainingConfig(TrainConfig): + @computed_field + @property + def network_task(self) -> str: + return "" + + return TrainingConfig + + +@pytest.fixture( + params=[ + {"gpu": "abc"}, + {"n_splits": -1}, + {"optimizer": "abc"}, + {"data_augmentation": ("abc",)}, + {"diagnoses": "AD"}, + {"batch_size": 0}, + {"size_reduction_factor": 1}, + {"learning_rate": 0.0}, + {"split": [-1]}, + {"tolerance": -0.01}, + ] +) +def bad_inputs(request, dummy_arguments): + return {**dummy_arguments, **request.param} + + +@pytest.fixture +def good_inputs(dummy_arguments): + options = { + "gpu": False, + "n_splits": 7, + "optimizer": "Adagrad", + "data_augmentation": ("Smoothing",), + "diagnoses": ("AD",), + "batch_size": 1, + "size_reduction_factor": 5, + "learning_rate": 1e-1, + "split": [0], + "tolerance": 0.0, + } + return {**dummy_arguments, **options} + + +def test_fails_validations(bad_inputs, training_config): + with pytest.raises(ValidationError): + training_config(**bad_inputs) + + +def test_passes_validations(good_inputs, training_config): + c = training_config(**good_inputs) + assert not c.computational.gpu + assert c.split.n_splits == 7 + assert c.optimizer.optimizer == "Adagrad" + assert c.transforms.data_augmentation == ("Smoothing",) + assert c.data.diagnoses == ("AD",) + assert c.dataloader.batch_size == 1 + assert c.transforms.size_reduction_factor == 5 + assert c.optimizer.learning_rate == 1e-1 + assert c.split.split == (0,) + assert c.early_stopping.tolerance == 0.0 + + +# Test config manipulation # +def test_assignment(dummy_arguments, training_config): + c = training_config(**dummy_arguments) + c.computational = {"gpu": False} + c.dataloader = DataLoaderConfig(**{"batch_size": 1}) + c.dataloader.n_proc = 10 + with pytest.raises(ValidationError): + c.computational = DataLoaderConfig() + with pytest.raises(ValidationError): + c.dataloader = {"sampler": "abc"} + assert not c.computational.gpu + assert c.dataloader.batch_size == 1 + assert c.dataloader.n_proc == 10