Skip to content

Commit

Permalink
tests ok ?
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Oct 9, 2024
1 parent 1d2bcf2 commit c0b424c
Show file tree
Hide file tree
Showing 23 changed files with 248 additions and 210 deletions.
14 changes: 9 additions & 5 deletions clinicadl/API_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig
from clinicadl.caps_dataset.data import return_dataset
from clinicadl.predictor.config import PredictConfig
from clinicadl.predictor.predictor import Predictor
from clinicadl.prepare_data.prepare_data import DeepLearningPrepareData
from clinicadl.splitter.config import SplitterConfig
from clinicadl.splitter.splitter import Splitter
from clinicadl.trainer.config.classification import ClassificationConfig
from clinicadl.trainer.trainer import Trainer
from clinicadl.utils.enum import ExtractionMethod, Preprocessing, Task
Expand All @@ -27,11 +31,11 @@
multi_cohort,
)

split_config = SplitConfig()
split_config = SplitterConfig()
splitter = Splitter(split_config)

validator_config = ValidatorConfig()
validator = Validator(validator_config)
validator_config = PredictConfig()
validator = Predictor(validator_config)

train_config = ClassificationConfig()
trainer = Trainer(train_config, validator)
Expand Down Expand Up @@ -78,6 +82,6 @@
test_loader = trainer.get_dataloader(dataset, split, network, "test", config)
validator.predict(test_loader)

interpret_config = InterpretConfig(**kwargs)
predict_manager = PredictManager(interpret_config)
interpret_config = PredictConfig(**kwargs)
predict_manager = Predictor(interpret_config)
predict_manager.interpret()
2 changes: 1 addition & 1 deletion clinicadl/commandline/modules_options/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from clinicadl.config.config_utils import get_default_from_config_class as get_default
from clinicadl.config.config_utils import get_type_from_config_class as get_type
from clinicadl.validator.validation import ValidationConfig
from clinicadl.predictor.validation import ValidationConfig

# Validation
valid_longitudinal = click.option(
Expand Down
4 changes: 2 additions & 2 deletions clinicadl/commandline/pipelines/interpret/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from clinicadl.commandline.pipelines.interpret import options
from clinicadl.interpret.config import InterpretConfig
from clinicadl.validator.validator import Validator
from clinicadl.predictor.predictor import Predictor


@click.command("interpret", no_args_is_help=True)
Expand Down Expand Up @@ -42,7 +42,7 @@ def cli(**kwargs):
"""

interpret_config = InterpretConfig(**kwargs)
predict_manager = Validator(interpret_config)
predict_manager = Predictor(interpret_config)
predict_manager.interpret()


Expand Down
14 changes: 7 additions & 7 deletions clinicadl/commandline/pipelines/interpret/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,28 @@

from clinicadl.config.config_utils import get_default_from_config_class as get_default
from clinicadl.config.config_utils import get_type_from_config_class as get_type
from clinicadl.interpret.config import InterpretConfig
from clinicadl.interpret.config import InterpretBaseConfig

# interpret specific
name = click.argument(
"name",
type=get_type("name", InterpretConfig),
type=get_type("name", InterpretBaseConfig),
)
method = click.argument(
"method",
type=get_type("method", InterpretConfig), # ["gradients", "grad-cam"]
type=get_type("method", InterpretBaseConfig), # ["gradients", "grad-cam"]
)
level = click.option(
"--level_grad_cam",
type=get_type("level", InterpretConfig),
default=get_default("level", InterpretConfig),
type=get_type("level", InterpretBaseConfig),
default=get_default("level", InterpretBaseConfig),
help="level of the feature map (after the layer corresponding to the number) chosen for grad-cam.",
show_default=True,
)
target_node = click.option(
"--target_node",
type=get_type("target_node", InterpretConfig),
default=get_default("target_node", InterpretConfig),
type=get_type("target_node", InterpretBaseConfig),
default=get_default("target_node", InterpretBaseConfig),
help="Which target node the gradients explain. Default takes the first output node.",
show_default=True,
)
Expand Down
6 changes: 3 additions & 3 deletions clinicadl/commandline/pipelines/predict/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
validation,
)
from clinicadl.commandline.pipelines.predict import options
from clinicadl.validator.config import PredictConfig
from clinicadl.validator.validator import Validator
from clinicadl.predictor.config import PredictConfig
from clinicadl.predictor.predictor import Predictor


@click.command(name="predict", no_args_is_help=True)
Expand Down Expand Up @@ -61,7 +61,7 @@ def cli(input_maps_directory, data_group, **kwargs):
"""

predict_config = PredictConfig(**kwargs)
predict_manager = Validator(predict_config)
predict_manager = Predictor(predict_config)
predict_manager.predict()


Expand Down
4 changes: 1 addition & 3 deletions clinicadl/commandline/pipelines/predict/options.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import click

from clinicadl.config.config_utils import get_default_from_config_class as get_default
from clinicadl.validator.config import PredictConfig
from clinicadl.predictor.config import PredictConfig

# predict specific
use_labels = click.option(
"--use_labels/--no_labels",
show_default=True,
default=get_default("use_labels", PredictConfig),
help="Set this option to --no_labels if your dataset does not contain ground truth labels.",
)
save_tensor = click.option(
Expand Down
20 changes: 10 additions & 10 deletions clinicadl/interpret/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig
from clinicadl.interpret.gradients import GradCam, Gradients, VanillaBackProp
from clinicadl.maps_manager.config import MapsManagerConfig
from clinicadl.predictor.validation import ValidationConfig
from clinicadl.splitter.config import SplitConfig
from clinicadl.utils.computational.computational import ComputationalConfig
from clinicadl.utils.enum import InterpretationMethod
from clinicadl.validator.validation import ValidationConfig

logger = getLogger("clinicadl.interpret_config")

Expand Down Expand Up @@ -44,13 +44,13 @@ def get_method(self) -> Gradients:
raise ValueError(f"The method {self.method.value} is not implemented")


class InterpretConfig(
MapsManagerConfig,
InterpretBaseConfig,
DataConfig,
ValidationConfig,
ComputationalConfig,
DataLoaderConfig,
SplitConfig,
):
class InterpretConfig(BaseModel):
"""Config class to perform Transfer Learning."""

maps_manager: MapsManagerConfig
data: DataConfig
validation: ValidationConfig
computational: ComputationalConfig
dataloader: DataLoaderConfig
split: SplitConfig
interpret: InterpretBaseConfig
6 changes: 1 addition & 5 deletions clinicadl/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from clinicadl.metrics.utils import (
check_selection_metric,
)
from clinicadl.predictor.utils import get_prediction
from clinicadl.splitter.config import SplitterConfig
from clinicadl.splitter.splitter import Splitter
from clinicadl.trainer.tasks_utils import (
Expand All @@ -37,7 +38,6 @@
add_default_values,
)
from clinicadl.utils.iotools.utils import path_encoder
from clinicadl.validator.utils import get_prediction

logger = getLogger("clinicadl.maps_manager")
level_list: List[str] = ["warning", "info", "debug"]
Expand Down Expand Up @@ -139,10 +139,6 @@ def __getattr__(self, name):
else:
raise AttributeError(f"'MapsManager' object has no attribute '{name}'")

###################################
# High-level functions templates #
###################################

###############################
# Checks #
###############################
Expand Down
File renamed without changes.
100 changes: 100 additions & 0 deletions clinicadl/predictor/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from logging import getLogger
from typing import Any, Dict

from pydantic import BaseModel, ConfigDict, computed_field

from clinicadl.caps_dataset.data_config import DataConfig as DataBaseConfig
from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig
from clinicadl.maps_manager.config import (
MapsManagerConfig as MapsManagerBaseConfig,
)
from clinicadl.maps_manager.maps_manager import MapsManager
from clinicadl.predictor.validation import ValidationConfig
from clinicadl.splitter.config import SplitConfig
from clinicadl.transforms.config import TransformsConfig
from clinicadl.utils.computational.computational import ComputationalConfig
from clinicadl.utils.enum import Task
from clinicadl.utils.exceptions import ClinicaDLArgumentError # type: ignore

logger = getLogger("clinicadl.predict_config")


class MapsManagerConfig(MapsManagerBaseConfig):
save_tensor: bool = False
save_latent_tensor: bool = False

def check_output_saving_tensor(self, network_task: str) -> None:
# Check if task is reconstruction for "save_tensor" and "save_nifti"
if self.save_tensor and network_task != "reconstruction":
raise ClinicaDLArgumentError(
"Cannot save tensors if the network task is not reconstruction. Please remove --save_tensor option."
)


class DataConfig(DataBaseConfig):
use_labels: bool = True


class PredictConfig(BaseModel):
"""Config class to perform Transfer Learning."""

maps_manager: MapsManagerConfig
data: DataConfig
validation: ValidationConfig
computational: ComputationalConfig
dataloader: DataLoaderConfig
split: SplitConfig
transforms: TransformsConfig

model_config = ConfigDict(validate_assignment=True)

def __init__(self, **kwargs):
super().__init__(
maps_manager=kwargs,
computational=kwargs,
dataloader=kwargs,
data=kwargs,
split=kwargs,
validation=kwargs,
transforms=kwargs,
)

def _update(self, config_dict: Dict[str, Any]) -> None:
"""Updates the configs with a dict given by the user."""
self.data.__dict__.update(config_dict)
self.split.__dict__.update(config_dict)
self.validation.__dict__.update(config_dict)
self.maps_manager.__dict__.update(config_dict)
self.split.__dict__.update(config_dict)
self.computational.__dict__.update(config_dict)
self.dataloader.__dict__.update(config_dict)
self.transforms.__dict__.update(config_dict)

def adapt_with_maps_manager_info(self, maps_manager: MapsManager):
self.maps_manager.check_output_saving_nifti(maps_manager.network_task)
self.data.diagnoses = (
maps_manager.diagnoses
if self.data.diagnoses is None or len(self.data.diagnoses) == 0
else self.data.diagnoses
)

self.dataloader.batch_size = (
maps_manager.batch_size
if not self.dataloader.batch_size
else self.dataloader.batch_size
)
self.dataloader.n_proc = (
maps_manager.n_proc
if not self.dataloader.n_proc
else self.dataloader.n_proc
)

self.split.adapt_cross_val_with_maps_manager_info(maps_manager)
self.maps_manager.check_output_saving_tensor(maps_manager.network_task)

self.transforms = TransformsConfig(
normalize=maps_manager.normalize,
data_augmentation=maps_manager.data_augmentation,
size_reduction=maps_manager.size_reduction,
size_reduction_factor=maps_manager.size_reduction_factor,
)
Loading

0 comments on commit c0b424c

Please sign in to comment.