Skip to content

Commit

Permalink
pass tests
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Sep 26, 2024
1 parent d831968 commit 0d76a8c
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions clinicadl/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@
generate_sampler,
get_criterion,
save_outputs,
test,
test_da,
)

logger = getLogger("clinicadl.trainer")
Expand Down Expand Up @@ -869,7 +867,7 @@ def _train(
):
evaluation_flag = False

_, metrics_train = test(
_, metrics_train = self.validator.test(
mode=self.maps_manager.mode,
metrics_module=self.maps_manager.metrics_module,
n_classes=self.maps_manager.n_classes,
Expand All @@ -879,7 +877,7 @@ def _train(
criterion=criterion,
amp=self.maps_manager.std_amp,
)
_, metrics_valid = test(
_, metrics_valid = self.validator.test(
mode=self.maps_manager.mode,
metrics_module=self.maps_manager.metrics_module,
n_classes=self.maps_manager.n_classes,
Expand Down Expand Up @@ -936,7 +934,7 @@ def _train(
model.zero_grad(set_to_none=True)
logger.debug(f"Last checkpoint at the end of the epoch {epoch}")

_, metrics_train = test(
_, metrics_train = self.validator.test(
mode=self.maps_manager.mode,
metrics_module=self.maps_manager.metrics_module,
n_classes=self.maps_manager.n_classes,
Expand All @@ -946,7 +944,7 @@ def _train(
criterion=criterion,
amp=self.maps_manager.std_amp,
)
_, metrics_valid = test(
_, metrics_valid = self.validator.test(
mode=self.maps_manager.mode,
metrics_module=self.maps_manager.metrics_module,
n_classes=self.maps_manager.n_classes,
Expand Down

0 comments on commit 0d76a8c

Please sign in to comment.