From 3c499ac23c447445b2aac6a4fc31c2c8df62d35e Mon Sep 17 00:00:00 2001 From: Roman Fitzjalen Date: Tue, 22 Oct 2024 11:50:07 +0300 Subject: [PATCH] better in-training checkpoint selector --- nnunetv2/run/run_training.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/nnunetv2/run/run_training.py b/nnunetv2/run/run_training.py index fff2ecd2e..0f1d0f593 100644 --- a/nnunetv2/run/run_training.py +++ b/nnunetv2/run/run_training.py @@ -74,15 +74,18 @@ def maybe_load_checkpoint(nnunet_trainer: nnUNetTrainer, continue_training: bool raise RuntimeError('Cannot both continue a training AND load pretrained weights. Pretrained weights can only ' 'be used at the beginning of the training.') if continue_training: - expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth') - if not isfile(expected_checkpoint_file): - expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_latest.pth') - # special case where --c is used to run a previously aborted validation - if not isfile(expected_checkpoint_file): - expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_best.pth') - if not isfile(expected_checkpoint_file): - print(f"WARNING: Cannot continue training because there seems to be no checkpoint available to " - f"continue from. Starting a new training...") + checkpoint_files = [ + join(nnunet_trainer.output_folder, 'checkpoint_final.pth'), + join(nnunet_trainer.output_folder, 'checkpoint_latest.pth'), + join(nnunet_trainer.output_folder, 'checkpoint_best.pth'), + ] + # Filter out the files that actually exist + existing_checkpoints = [ckpt for ckpt in checkpoint_files if isfile(ckpt)] + if existing_checkpoints: + # Select the checkpoint with the most recent modification time + expected_checkpoint_file = max(existing_checkpoints, key=os.path.getmtime) + else: + print("WARNING: Cannot continue training because there seems to be no checkpoint available to continue from. Starting a new training...") expected_checkpoint_file = None elif validation_only: expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')