Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sortir les étapes de validation du MapsManager #657

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions clinicadl/API_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig
from clinicadl.prepare_data.prepare_data import DeepLearningPrepareData
from clinicadl.trainer.config.classification import ClassificationConfig
from clinicadl.trainer.trainer import Trainer
from clinicadl.utils.enum import ExtractionMethod, Preprocessing, Task
from clinicadl.utils.iotools.train_utils import merge_cli_and_config_file_options

image_config = CapsDatasetConfig.from_preprocessing_and_extraction_method(
extraction=ExtractionMethod.IMAGE,
preprocessing_type=Preprocessing.T1_LINEAR,
)

DeepLearningPrepareData(image_config)

config = ClassificationConfig()
trainer = Trainer(config)
trainer.train(split_list=config.cross_validation.split, overwrite=True)
273 changes: 0 additions & 273 deletions clinicadl/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

import pandas as pd
import torch
import torch.distributed as dist
from torch.amp import autocast

from clinicadl.caps_dataset.caps_dataset_utils import read_json
from clinicadl.caps_dataset.data import (
Expand All @@ -17,16 +15,13 @@
from clinicadl.metrics.metric_module import MetricModule
from clinicadl.metrics.utils import (
check_selection_metric,
find_selection_metrics,
)
from clinicadl.predict.utils import get_prediction
from clinicadl.trainer.tasks_utils import (
ensemble_prediction,
evaluation_metrics,
generate_label_code,
output_size,
test,
test_da,
)
from clinicadl.transforms.config import TransformsConfig
from clinicadl.utils import cluster
Expand Down Expand Up @@ -149,274 +144,6 @@ def __getattr__(self, name):
###################################
# High-level functions templates #
###################################
def _test_loader(
self,
dataloader,
criterion,
data_group: str,
split: int,
selection_metrics,
use_labels=True,
gpu=None,
amp=False,
network=None,
report_ci=True,
):
"""
Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files.

Args:
dataloader (torch.utils.data.DataLoader): DataLoader wrapping the test CapsDataset.
criterion (torch.nn.modules.loss._Loss): optimization criterion used during training.
data_group (str): name of the data group used for the testing task.
split (int): Index of the split used to train the model tested.
selection_metrics (list[str]): List of metrics used to select the best models which are tested.
use_labels (bool): If True, the labels must exist in test meta-data and metrics are computed.
gpu (bool): If given, a new value for the device of the model will be computed.
amp (bool): If enabled, uses Automatic Mixed Precision (requires GPU usage).
network (int): Index of the network tested (only used in multi-network setting).
"""
for selection_metric in selection_metrics:
if cluster.master:
log_dir = (
self.maps_path
/ f"{self.split_name}-{split}"
/ f"best-{selection_metric}"
/ data_group
)
self.write_description_log(
log_dir,
data_group,
dataloader.dataset.config.data.caps_dict,
dataloader.dataset.config.data.data_df,
)

# load the best trained model during the training
model, _ = self._init_model(
transfer_path=self.maps_path,
split=split,
transfer_selection=selection_metric,
gpu=gpu,
network=network,
)
model = DDP(model, fsdp=self.fully_sharded_data_parallel, amp=self.amp)

prediction_df, metrics = test(
mode=self.mode,
metrics_module=self.metrics_module,
n_classes=self.n_classes,
network_task=self.network_task,
model=model,
dataloader=dataloader,
criterion=criterion,
use_labels=use_labels,
amp=amp,
report_ci=report_ci,
)
if use_labels:
if network is not None:
metrics[f"{self.mode}_id"] = network

loss_to_log = (
metrics["Metric_values"][-1] if report_ci else metrics["loss"]
)

logger.info(
f"{self.mode} level {data_group} loss is {loss_to_log} for model selected on {selection_metric}"
)

if cluster.master:
# Replace here
self._mode_level_to_tsv(
prediction_df,
metrics,
split,
selection_metric,
data_group=data_group,
)

def _test_loader_ssda(
self,
dataloader,
criterion,
alpha,
data_group,
split,
selection_metrics,
use_labels=True,
gpu=None,
network=None,
target=False,
report_ci=True,
):
"""
Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files.

Args:
dataloader (torch.utils.data.DataLoader): DataLoader wrapping the test CapsDataset.
criterion (torch.nn.modules.loss._Loss): optimization criterion used during training.
data_group (str): name of the data group used for the testing task.
split (int): Index of the split used to train the model tested.
selection_metrics (list[str]): List of metrics used to select the best models which are tested.
use_labels (bool): If True, the labels must exist in test meta-data and metrics are computed.
gpu (bool): If given, a new value for the device of the model will be computed.
network (int): Index of the network tested (only used in multi-network setting).
"""
for selection_metric in selection_metrics:
log_dir = (
self.maps_path
/ f"{self.split_name}-{split}"
/ f"best-{selection_metric}"
/ data_group
)
self.write_description_log(
log_dir,
data_group,
dataloader.dataset.caps_dict,
dataloader.dataset.df,
)

# load the best trained model during the training
model, _ = self._init_model(
transfer_path=self.maps_path,
split=split,
transfer_selection=selection_metric,
gpu=gpu,
network=network,
)
prediction_df, metrics = test_da(
self.network_task,
model,
dataloader,
criterion,
target=target,
report_ci=report_ci,
)
if use_labels:
if network is not None:
metrics[f"{self.mode}_id"] = network

if report_ci:
loss_to_log = metrics["Metric_values"][-1]
else:
loss_to_log = metrics["loss"]

logger.info(
f"{self.mode} level {data_group} loss is {loss_to_log} for model selected on {selection_metric}"
)

# Replace here
self._mode_level_to_tsv(
prediction_df, metrics, split, selection_metric, data_group=data_group
)

@torch.no_grad()
def _compute_output_tensors(
self,
dataset,
data_group,
split,
selection_metrics,
nb_images=None,
gpu=None,
network=None,
):
"""
Compute the output tensors and saves them in the MAPS.

Args:
dataset (clinicadl.caps_dataset.data.CapsDataset): wrapper of the data set.
data_group (str): name of the data group used for the task.
split (int): split number.
selection_metrics (list[str]): metrics used for model selection.
nb_images (int): number of full images to write. Default computes the outputs of the whole data set.
gpu (bool): If given, a new value for the device of the model will be computed.
network (int): Index of the network tested (only used in multi-network setting).
"""
for selection_metric in selection_metrics:
# load the best trained model during the training
model, _ = self._init_model(
transfer_path=self.maps_path,
split=split,
transfer_selection=selection_metric,
gpu=gpu,
network=network,
nb_unfrozen_layer=self.nb_unfrozen_layer,
)
model = DDP(model, fsdp=self.fully_sharded_data_parallel, amp=self.amp)
model.eval()

tensor_path = (
self.maps_path
/ f"{self.split_name}-{split}"
/ f"best-{selection_metric}"
/ data_group
/ "tensors"
)
if cluster.master:
tensor_path.mkdir(parents=True, exist_ok=True)
dist.barrier()

if nb_images is None: # Compute outputs for the whole data set
nb_modes = len(dataset)
else:
nb_modes = nb_images * dataset.elem_per_image

for i in [
*range(cluster.rank, nb_modes, cluster.world_size),
*range(int(nb_modes % cluster.world_size <= cluster.rank)),
]:
data = dataset[i]
image = data["image"]
x = image.unsqueeze(0).to(model.device)
with autocast("cuda", enabled=self.std_amp):
output = model(x)
output = output.squeeze(0).cpu().float()
participant_id = data["participant_id"]
session_id = data["session_id"]
mode_id = data[f"{self.mode}_id"]
input_filename = (
f"{participant_id}_{session_id}_{self.mode}-{mode_id}_input.pt"
)
output_filename = (
f"{participant_id}_{session_id}_{self.mode}-{mode_id}_output.pt"
)
torch.save(image, tensor_path / input_filename)
torch.save(output, tensor_path / output_filename)
logger.debug(f"File saved at {[input_filename, output_filename]}")

def _ensemble_prediction(
self,
data_group,
split,
selection_metrics,
use_labels=True,
skip_leak_check=False,
):
"""Computes the results on the image-level."""

if not selection_metrics:
selection_metrics = find_selection_metrics(
self.maps_path, self.split_name, split
)

for selection_metric in selection_metrics:
#####################
# Soft voting
if self.num_networks > 1 and not skip_leak_check:
self._ensemble_to_tsv(
split,
selection=selection_metric,
data_group=data_group,
use_labels=use_labels,
)
elif self.mode != "image" and not skip_leak_check:
self._mode_to_image_tsv(
split,
selection=selection_metric,
data_group=data_group,
use_labels=use_labels,
)

###############################
# Checks #
Expand Down
27 changes: 17 additions & 10 deletions clinicadl/predict/predict_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ClinicaDLDataLeakageError,
MAPSError,
)
from clinicadl.validator.validator import Validator

logger = getLogger("clinicadl.predict_manager")
level_list: List[str] = ["warning", "info", "debug"]
Expand All @@ -38,6 +39,7 @@ class PredictManager:
def __init__(self, _config: Union[PredictConfig, InterpretConfig]) -> None:
self.maps_manager = MapsManager(_config.maps_dir)
self._config = _config
self.validator = Validator()

def predict(
self,
Expand Down Expand Up @@ -183,7 +185,8 @@ def predict(
split_selection_metrics,
)
if cluster.master:
self.maps_manager._ensemble_prediction(
self.validator._ensemble_prediction(
self.maps_manager,
self._config.data_group,
split,
self._config.selection_metrics,
Expand Down Expand Up @@ -288,20 +291,22 @@ def _predict_multi(
if self._config.n_proc is not None
else self.maps_manager.n_proc,
)
self.maps_manager._test_loader(
test_loader,
criterion,
self._config.data_group,
split,
split_selection_metrics,
self.validator._test_loader(
maps_manager=self.maps_manager,
dataloader=test_loader,
criterion=criterion,
data_group=self._config.data_group,
split=split,
selection_metrics=split_selection_metrics,
use_labels=self._config.use_labels,
gpu=self._config.gpu,
amp=self._config.amp,
network=network,
)
if self._config.save_tensor:
logger.debug("Saving tensors")
self.maps_manager._compute_output_tensors(
self.validator._compute_output_tensors(
self.maps_manager,
data_test,
self._config.data_group,
split,
Expand Down Expand Up @@ -416,7 +421,8 @@ def _predict_single(
if self._config.n_proc is not None
else self.maps_manager.n_proc,
)
self.maps_manager._test_loader(
self.validator._test_loader(
self.maps_manager,
test_loader,
criterion,
self._config.data_group,
Expand All @@ -428,7 +434,8 @@ def _predict_single(
)
if self._config.save_tensor:
logger.debug("Saving tensors")
self.maps_manager._compute_output_tensors(
self.validator._compute_output_tensors(
self.maps_manager,
data_test,
self._config.data_group,
split,
Expand Down
Loading
Loading