Skip to content

Commit

Permalink
first try to take out the validator
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Oct 1, 2024
1 parent 9a52b88 commit 954dce7
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 53 deletions.
18 changes: 8 additions & 10 deletions clinicadl/predict/predict_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,10 @@ def _predict_multi(
criterion=criterion,
data_group=self._config.data_group,
split=split,
selection_metrics=split_selection_metrics,
use_labels=self._config.use_labels,
gpu=self._config.gpu,
amp=self._config.amp,
# selection_metrics=split_selection_metrics,
# use_labels=self._config.use_labels,
# gpu=self._config.gpu,
# amp=self._config.amp,
network=network,
)
if self._config.save_tensor:
Expand All @@ -310,8 +310,6 @@ def _predict_multi(
data_test,
self._config.data_group,
split,
self._config.selection_metrics,
gpu=self._config.gpu,
network=network,
)
if self._config.save_nifti:
Expand Down Expand Up @@ -427,10 +425,10 @@ def _predict_single(
criterion,
self._config.data_group,
split,
split_selection_metrics,
use_labels=self._config.use_labels,
gpu=self._config.gpu,
amp=self._config.amp,
# split_selection_metrics,
# use_labels=self._config.use_labels,
# gpu=self._config.gpu,
# amp=self._config.amp,
)
if self._config.save_tensor:
logger.debug("Saving tensors")
Expand Down
11 changes: 1 addition & 10 deletions clinicadl/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def train(
amp=self.maps_manager.std_amp,
use_labels=use_labels,
report_ci=report_ci,
selection_metrics=self.config.validation.selection_metrics,
)

validator = Validator(validator_config=validator_config)
Expand Down Expand Up @@ -526,13 +527,11 @@ def _train_single(
self.maps_manager,
"train",
split,
self.config.validation.selection_metrics,
)
validator._ensemble_prediction(
self.maps_manager,
"validation",
split,
self.config.validation.selection_metrics,
)

self._erase_tmp(split)
Expand Down Expand Up @@ -740,13 +739,11 @@ def _train_ssda(
self.maps_manager,
"train",
split,
self.config.validation.selection_metrics,
)
validator._ensemble_prediction(
self.maps_manager,
"validation",
split,
self.config.validation.selection_metrics,
)

self._erase_tmp(split)
Expand Down Expand Up @@ -1021,8 +1018,6 @@ def _train(
criterion,
"train",
split,
self.config.validation.selection_metrics,
amp=self.maps_manager.std_amp,
network=network,
)
validator._test_loader(
Expand All @@ -1031,8 +1026,6 @@ def _train(
criterion,
"validation",
split,
self.config.validation.selection_metrics,
amp=self.maps_manager.std_amp,
network=network,
)

Expand All @@ -1042,7 +1035,6 @@ def _train(
train_loader.dataset,
"train",
split,
self.config.validation.selection_metrics,
nb_images=1,
network=network,
)
Expand All @@ -1051,7 +1043,6 @@ def _train(
valid_loader.dataset,
"validation",
split,
self.config.validation.selection_metrics,
nb_images=1,
network=network,
)
Expand Down
14 changes: 7 additions & 7 deletions clinicadl/validator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ class ValidatorConfig(BaseModel):

# maps_path: Path
mode: str

metrics_module: Optional = None
report_ci: bool = False
selection_metrics: list

n_classes: int = 1
network_task: str
# model: Network
# dataloader: DataLoader
# criterion: _Loss
num_networks: int = 1
use_labels: bool = True

gpu: Optional[bool] = None
amp: bool = False
fsdp: bool = False
report_ci = False
gpu: Optional[bool] = None
selection_metrics: list

split_name: Optional[str] = None
num_networks: Optional[int] = None
nb_unfrozen_layers: Optional[int] = None
std_amp: Optional[bool] = None

Expand Down
42 changes: 16 additions & 26 deletions clinicadl/validator/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,6 @@ def test(
dataloader: DataLoader,
criterion: _Loss,
metrics_module: MetricModule,
# mode: str,
# n_classes: int,
# network_task,
# use_labels: bool = True,
# amp: bool = False,
# report_ci=False,
) -> Tuple[pd.DataFrame, Dict[str, float]]:
"""
Computes the predictions and evaluation metrics.
Expand Down Expand Up @@ -248,10 +242,8 @@ def _compute_output_tensors(
dataset,
data_group,
split,
selection_metrics,
nb_images=None,
gpu=None,
network=None,
network: Optional[int] = None,
):
"""
Compute the output tensors and saves them in the MAPS.
Expand All @@ -265,20 +257,20 @@ def _compute_output_tensors(
gpu (bool): If given, a new value for the device of the model will be computed.
network (int): Index of the network tested (only used in multi-network setting).
"""
for selection_metric in selection_metrics:
for selection_metric in self.config.selection_metrics:
# load the best trained model during the training
model, _ = maps_manager._init_model(
transfer_path=maps_manager.maps_path,
split=split,
transfer_selection=selection_metric,
gpu=gpu,
gpu=self.config.gpu,
network=network,
nb_unfrozen_layer=maps_manager.nb_unfrozen_layer,
)
model = DDP(
model,
fsdp=maps_manager.fully_sharded_data_parallel,
amp=maps_manager.amp,
fsdp=self.config.fsdp,
amp=self.config.amp,
)
model.eval()

Expand All @@ -305,50 +297,48 @@ def _compute_output_tensors(
data = dataset[i]
image = data["image"]
x = image.unsqueeze(0).to(model.device)
with autocast("cuda", enabled=maps_manager.std_amp):
with autocast("cuda", enabled=self.config.std_amp):
output = model(x)
output = output.squeeze(0).cpu().float()
participant_id = data["participant_id"]
session_id = data["session_id"]
mode_id = data[f"{maps_manager.mode}_id"]
input_filename = f"{participant_id}_{session_id}_{maps_manager.mode}-{mode_id}_input.pt"
output_filename = f"{participant_id}_{session_id}_{maps_manager.mode}-{mode_id}_output.pt"
mode_id = data[f"{self.config.mode}_id"]
input_filename = f"{participant_id}_{session_id}_{self.config.mode}-{mode_id}_input.pt"
output_filename = f"{participant_id}_{session_id}_{self.config.mode}-{mode_id}_output.pt"
torch.save(image, tensor_path / input_filename)
torch.save(output, tensor_path / output_filename)
logger.debug(f"File saved at {[input_filename, output_filename]}")

def _ensemble_prediction(
self,
maps_manager: MapsManager,
data_group,
split,
selection_metrics,
use_labels=True,
data_group: str,
split: int,
skip_leak_check=False,
):
"""Computes the results on the image-level."""

if not selection_metrics:
if not self.config.selection_metrics:
selection_metrics = find_selection_metrics(
maps_manager.maps_path, maps_manager.split_name, split
)

for selection_metric in selection_metrics:
#####################
# Soft voting
if maps_manager.num_networks > 1 and not skip_leak_check:
if self.config.num_networks > 1 and not skip_leak_check:
maps_manager._ensemble_to_tsv(
split,
selection=selection_metric,
data_group=data_group,
use_labels=use_labels,
use_labels=self.config.use_labels,
)
elif maps_manager.mode != "image" and not skip_leak_check:
elif self.config.mode != "image" and not skip_leak_check:
maps_manager._mode_to_image_tsv(
split,
selection=selection_metric,
data_group=data_group,
use_labels=use_labels,
use_labels=self.config.use_labels,
)

def test_da(
Expand Down

0 comments on commit 954dce7

Please sign in to comment.