Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Oct 3, 2024
1 parent 299b7dd commit f3f0b12
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,4 @@ def cli(**kwargs):
options = merge_cli_and_config_file_options(Task.CLASSIFICATION, **kwargs)
config = ClassificationConfig(**options)
trainer = Trainer(config)
trainer.train(split_list=config.validation.split, overwrite=True)
trainer.train(split_list=config.split.split, overwrite=True)
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,4 @@ def cli(**kwargs):
options = merge_cli_and_config_file_options(Task.RECONSTRUCTION, **kwargs)
config = ReconstructionConfig(**options)
trainer = Trainer(config)
trainer.train(split_list=config.validation.split, overwrite=True)
trainer.train(split_list=config.split.split, overwrite=True)
2 changes: 1 addition & 1 deletion clinicadl/commandline/pipelines/train/regression/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,4 @@ def cli(**kwargs):
options = merge_cli_and_config_file_options(Task.REGRESSION, **kwargs)
config = RegressionConfig(**options)
trainer = Trainer(config)
trainer.train(split_list=config.validation.split, overwrite=True)
trainer.train(split_list=config.split.split, overwrite=True)
9 changes: 6 additions & 3 deletions clinicadl/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
check_selection_metric,
)
from clinicadl.predict.utils import get_prediction
from clinicadl.splitter.splitter import init_splitter
from clinicadl.splitter.config import SplitterConfig
from clinicadl.splitter.splitter import Splitter
from clinicadl.trainer.tasks_utils import (
ensemble_prediction,
evaluation_metrics,
Expand Down Expand Up @@ -169,8 +170,9 @@ def _check_args(self, parameters):
size_reduction=self.size_reduction,
size_reduction_factor=self.size_reduction_factor,
)
splitter_config = SplitterConfig(**self.parameters)
split_manager = Splitter(splitter_config)

split_manager = init_splitter(parameters=self.parameters)
train_df = split_manager[0]["train"]
if "label" not in self.parameters:
self.parameters["label"] = None
Expand Down Expand Up @@ -316,7 +318,8 @@ def _write_training_data(self):
def _write_train_val_groups(self):
"""Defines the training and validation groups at the initialization"""
logger.debug("Writing training and validation groups...")
split_manager = init_splitter(parameters=self.parameters)
splitter_config = SplitterConfig(**self.parameters)
split_manager = Splitter(splitter_config)
for split in split_manager.split_iterator():
for data_group in ["train", "validation"]:
df = split_manager[split][data_group]
Expand Down
4 changes: 0 additions & 4 deletions clinicadl/maps_manager/tmp_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,6 @@ def check_args(self):
)

if self.network_task == "classification":
from clinicadl.splitter.splitter import (
init_splitter,
)

if self.n_splits > 1 and self.validation == "SingleSplit":
self.validation = "KFoldSplit"

Expand Down
2 changes: 1 addition & 1 deletion clinicadl/random_search/random_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ def launch_search(launch_directory: Path, job_name):
output_maps_directory=maps_directory, **options
)
trainer = Trainer(training_config)
trainer.train(split_list=training_config.validation.split, overwrite=True)
trainer.train(split_list=training_config.split.split, overwrite=True)
17 changes: 0 additions & 17 deletions clinicadl/splitter/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,6 @@
logger = getLogger("clinicadl.split_manager")


def init_splitter(
parameters,
split_list=None,
):
data_config = DataConfig(**parameters)
validation_config = ValidationConfig(**parameters)
split_config = SplitConfig(**parameters)

splitter_config = SplitterConfig(
data_config=data_config,
validation_config=validation_config,
split_config=split_config,
)

return Splitter(splitter_config, split_list=split_list)


class Splitter:
def __init__(
self,
Expand Down
42 changes: 17 additions & 25 deletions clinicadl/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
)
from clinicadl.trainer.tasks_utils import create_training_config
from clinicadl.validator.validator import Validator
from clinicadl.splitter.splitter import init_splitter
from clinicadl.splitter.splitter import Splitter
from clinicadl.splitter.config import SplitterConfig
from clinicadl.transforms.config import TransformsConfig

if TYPE_CHECKING:
Expand Down Expand Up @@ -158,10 +159,10 @@ def resume(self, splits: List[int]) -> None:
stopped_splits = set(find_stopped_splits(self.config.maps_manager.maps_dir))
finished_splits = set(find_finished_splits(self.maps_manager.maps_path))
# TODO : check these two lines. Why do we need a split_manager?
split_manager = init_splitter(
parameters=self.config.get_dict(),
split_list=splits,
)

splitter_config = SplitterConfig(**self.config.get_dict())
split_manager = Splitter(splitter_config, split_list=splits)

split_iterator = split_manager.split_iterator()
###
absent_splits = set(split_iterator) - stopped_splits - finished_splits
Expand Down Expand Up @@ -216,10 +217,9 @@ def train(
self._train_ssda(split_list, resume=False)

else:
split_manager = init_splitter(
parameters=self.config.get_dict(),
split_list=split_list,
)
splitter_config = SplitterConfig(**self.config.get_dict())
split_manager = Splitter(splitter_config, split_list=split_list)

for split in split_manager.split_iterator():
logger.info(f"Training split {split}")
seed_everything(
Expand All @@ -241,10 +241,8 @@ def train(

def check_split_list(self, split_list, overwrite):
existing_splits = []
split_manager = init_splitter(
parameters=self.config.get_dict(),
split_list=split_list,
)
splitter_config = SplitterConfig(**self.config.get_dict())
split_manager = Splitter(splitter_config, split_list=split_list)
for split in split_manager.split_iterator():
split_path = self.maps_manager.maps_path / f"split-{split}"
if split_path.is_dir():
Expand Down Expand Up @@ -280,10 +278,8 @@ def _resume(
If splits specified in input do not exist.
"""
missing_splits = []
split_manager = init_splitter(
parameters=self.config.get_dict(),
split_list=split_list,
)
splitter_config = SplitterConfig(**self.config.get_dict())
split_manager = Splitter(splitter_config, split_list=split_list)
for split in split_manager.split_iterator():
if not (self.maps_manager.maps_path / f"split-{split}" / "tmp").is_dir():
missing_splits.append(split)
Expand Down Expand Up @@ -502,14 +498,10 @@ def _train_ssda(
If True, the job is resumed from checkpoint.
"""

split_manager = init_splitter(
parameters=self.config.get_dict(),
split_list=split_list,
)
split_manager_target_lab = init_splitter(
parameters=self.config.get_dict(),
split_list=split_list,
)
splitter_config = SplitterConfig(**self.config.get_dict())

split_manager = Splitter(splitter_config, split_list=split_list)
split_manager_target_lab = Splitter(splitter_config, split_list=split_list)

for split in split_manager.split_iterator():
logger.info(f"Training split {split}")
Expand Down
7 changes: 5 additions & 2 deletions tests/test_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import pytest

from clinicadl.maps_manager.maps_manager import MapsManager
from clinicadl.splitter.splitter import init_splitter
from clinicadl.splitter.config import SplitterConfig
from clinicadl.splitter.splitter import Splitter

from .testing_tools import modify_maps

Expand Down Expand Up @@ -49,7 +50,9 @@ def test_resume(cmdopt, tmp_path, test_name):
assert flag_error

maps_manager = MapsManager(maps_stopped)
split_manager = init_splitter(parameters=maps_manager.parameters)
splitter_config = SplitterConfig(**maps_manager.parameters)
split_manager = Splitter(splitter_config)

for split in split_manager.split_iterator():
performances_flag = (
maps_stopped / f"split-{split}" / "best-loss" / "train"
Expand Down
137 changes: 137 additions & 0 deletions tests/unittests/splitter/test_splitter_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from pathlib import Path

import pytest
from pydantic import ValidationError

from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig
from clinicadl.config.config.ssda import SSDAConfig
from clinicadl.network.config import NetworkConfig
from clinicadl.splitter.config import SplitConfig, SplitterConfig
from clinicadl.splitter.validation import ValidationConfig
from clinicadl.trainer.transfer_learning import TransferLearningConfig
from clinicadl.transforms.config import TransformsConfig


# Tests for customed validators #
@pytest.fixture
def caps_example():
dir_ = Path(__file__).parents[2] / "ressources" / "caps_example"
return dir_


def test_split_config():
c = SplitConfig(
n_splits=3,
split=[0],
tsv_path="",
)
assert c.split == (0,)


def test_validation_config():
c = ValidationConfig(
evaluation_steps=3,
valid_longitudinal=True,
)
assert not c.skip_leak_check
assert c.selection_metrics == ()


# Global tests on the TrainingConfig class #
@pytest.fixture
def dummy_arguments(caps_example):
args = {
"caps_directory": caps_example,
"preprocessing_json": "preprocessing.json",
"tsv_path": "",
"maps_dir": "",
"gpu": False,
"architecture": "",
"loss": "",
"selection_metrics": (),
}
return args


@pytest.fixture
def splitter_config():
from pydantic import computed_field

from clinicadl.splitter.config import SplitterConfig

class TrainingConfig(TrainConfig):
@computed_field
@property
def network_task(self) -> str:
return ""

return TrainingConfig


@pytest.fixture(
params=[
{"gpu": "abc"},
{"n_splits": -1},
{"optimizer": "abc"},
{"data_augmentation": ("abc",)},
{"diagnoses": "AD"},
{"batch_size": 0},
{"size_reduction_factor": 1},
{"learning_rate": 0.0},
{"split": [-1]},
{"tolerance": -0.01},
]
)
def bad_inputs(request, dummy_arguments):
return {**dummy_arguments, **request.param}


@pytest.fixture
def good_inputs(dummy_arguments):
options = {
"gpu": False,
"n_splits": 7,
"optimizer": "Adagrad",
"data_augmentation": ("Smoothing",),
"diagnoses": ("AD",),
"batch_size": 1,
"size_reduction_factor": 5,
"learning_rate": 1e-1,
"split": [0],
"tolerance": 0.0,
}
return {**dummy_arguments, **options}


def test_fails_validations(bad_inputs, training_config):
with pytest.raises(ValidationError):
training_config(**bad_inputs)


def test_passes_validations(good_inputs, training_config):
c = training_config(**good_inputs)
assert not c.computational.gpu
assert c.split.n_splits == 7
assert c.optimizer.optimizer == "Adagrad"
assert c.transforms.data_augmentation == ("Smoothing",)
assert c.data.diagnoses == ("AD",)
assert c.dataloader.batch_size == 1
assert c.transforms.size_reduction_factor == 5
assert c.optimizer.learning_rate == 1e-1
assert c.split.split == (0,)
assert c.early_stopping.tolerance == 0.0


# Test config manipulation #
def test_assignment(dummy_arguments, training_config):
c = training_config(**dummy_arguments)
c.computational = {"gpu": False}
c.dataloader = DataLoaderConfig(**{"batch_size": 1})
c.dataloader.n_proc = 10
with pytest.raises(ValidationError):
c.computational = DataLoaderConfig()
with pytest.raises(ValidationError):
c.dataloader = {"sampler": "abc"}
assert not c.computational.gpu
assert c.dataloader.batch_size == 1
assert c.dataloader.n_proc == 10

0 comments on commit f3f0b12

Please sign in to comment.