diff --git a/nnunetv2/run/run_training.py b/nnunetv2/run/run_training.py index fff2ecd2e..0f1d0f593 100644 --- a/nnunetv2/run/run_training.py +++ b/nnunetv2/run/run_training.py @@ -74,15 +74,18 @@ def maybe_load_checkpoint(nnunet_trainer: nnUNetTrainer, continue_training: bool raise RuntimeError('Cannot both continue a training AND load pretrained weights. Pretrained weights can only ' 'be used at the beginning of the training.') if continue_training: - expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth') - if not isfile(expected_checkpoint_file): - expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_latest.pth') - # special case where --c is used to run a previously aborted validation - if not isfile(expected_checkpoint_file): - expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_best.pth') - if not isfile(expected_checkpoint_file): - print(f"WARNING: Cannot continue training because there seems to be no checkpoint available to " - f"continue from. Starting a new training...") + checkpoint_files = [ + join(nnunet_trainer.output_folder, 'checkpoint_final.pth'), + join(nnunet_trainer.output_folder, 'checkpoint_latest.pth'), + join(nnunet_trainer.output_folder, 'checkpoint_best.pth'), + ] + # Filter out the files that actually exist + existing_checkpoints = [ckpt for ckpt in checkpoint_files if isfile(ckpt)] + if existing_checkpoints: + # Select the checkpoint with the most recent modification time + expected_checkpoint_file = max(existing_checkpoints, key=os.path.getmtime) + else: + print("WARNING: Cannot continue training because there seems to be no checkpoint available to continue from. Starting a new training...") expected_checkpoint_file = None elif validation_only: expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')