From 0d76a8c7a05d4a6ab426d77519173d6d55322df5 Mon Sep 17 00:00:00 2001 From: camillebrianceau Date: Thu, 26 Sep 2024 09:59:39 +0200 Subject: [PATCH] pass tests --- clinicadl/trainer/trainer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index fe8c1ae9a..3c279d155 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -44,8 +44,6 @@ generate_sampler, get_criterion, save_outputs, - test, - test_da, ) logger = getLogger("clinicadl.trainer") @@ -869,7 +867,7 @@ def _train( ): evaluation_flag = False - _, metrics_train = test( + _, metrics_train = self.validator.test( mode=self.maps_manager.mode, metrics_module=self.maps_manager.metrics_module, n_classes=self.maps_manager.n_classes, @@ -879,7 +877,7 @@ def _train( criterion=criterion, amp=self.maps_manager.std_amp, ) - _, metrics_valid = test( + _, metrics_valid = self.validator.test( mode=self.maps_manager.mode, metrics_module=self.maps_manager.metrics_module, n_classes=self.maps_manager.n_classes, @@ -936,7 +934,7 @@ def _train( model.zero_grad(set_to_none=True) logger.debug(f"Last checkpoint at the end of the epoch {epoch}") - _, metrics_train = test( + _, metrics_train = self.validator.test( mode=self.maps_manager.mode, metrics_module=self.maps_manager.metrics_module, n_classes=self.maps_manager.n_classes, @@ -946,7 +944,7 @@ def _train( criterion=criterion, amp=self.maps_manager.std_amp, ) - _, metrics_valid = test( + _, metrics_valid = self.validator.test( mode=self.maps_manager.mode, metrics_module=self.maps_manager.metrics_module, n_classes=self.maps_manager.n_classes,