From ca236284bfc14e444b97b672d23a39a62b22e58b Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Thu, 18 Jan 2024 10:13:08 +0100 Subject: [PATCH] fix: check class of dataloader before blindly calling _finish() (#1850) --- nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index 318be58e9..690a15fb2 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -11,6 +11,8 @@ import numpy as np import torch +from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter +from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \ @@ -865,9 +867,11 @@ def on_train_end(self): old_stdout = sys.stdout with open(os.devnull, 'w') as f: sys.stdout = f - if self.dataloader_train is not None: + if self.dataloader_train is not None and \ + isinstance(self.dataloader_train, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)): self.dataloader_train._finish() - if self.dataloader_val is not None: + if self.dataloader_val is not None and \ + isinstance(self.dataloader_train, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)): self.dataloader_val._finish() sys.stdout = old_stdout