diff --git a/nnunetv2/configuration.py b/nnunetv2/configuration.py index cdc8cb69a..288b6e453 100644 --- a/nnunetv2/configuration.py +++ b/nnunetv2/configuration.py @@ -1,10 +1,6 @@ import os -from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA - default_num_processes = 8 if 'nnUNet_def_n_proc' not in os.environ else int(os.environ['nnUNet_def_n_proc']) ANISO_THRESHOLD = 3 # determines when a sample is considered anisotropic (3 means that the spacing in the low # resolution axis must be 3x as large as the next largest spacing) - -default_n_proc_DA = get_allowed_n_proc_DA() diff --git a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py index c0ac2d37e..06b1e9541 100644 --- a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py @@ -15,7 +15,6 @@ from nnunetv2.preprocessing.normalization.map_channel_name_to_normalization import get_normalization_scheme from nnunetv2.preprocessing.resampling.default_resampling import resample_data_or_seg_to_shape, compute_new_shape from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name -from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA from nnunetv2.utilities.get_network_from_plans import get_network_from_plans from nnunetv2.utilities.json_export import recursive_fix_for_json_export from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets @@ -100,14 +99,10 @@ def static_estimate_VRAM_usage(patch_size: Tuple[int], """ Works for PlainConvUNet, ResidualEncoderUNet """ - a = torch.get_num_threads() - torch.set_num_threads(get_allowed_n_proc_DA()) - # print(f'instantiating network, patch size {patch_size}, pool op: {arch_kwargs["strides"]}') net = get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, output_channels, allow_init=False) ret = net.compute_conv_feature_map_size(patch_size) - torch.set_num_threads(a) return ret def determine_resampling(self, *args, **kwargs): diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index 426bbf047..387989f62 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -53,7 +53,7 @@ from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler from nnunetv2.utilities.collate_outputs import collate_outputs from nnunetv2.utilities.crossval_split import generate_crossval_split -from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA +from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA, get_allowed_n_proc_DA_val from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy from nnunetv2.utilities.get_network_from_plans import get_network_from_plans from nnunetv2.utilities.helpers import empty_cache, dummy_context @@ -635,16 +635,23 @@ def get_dataloaders(self): dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) - allowed_num_processes = get_allowed_n_proc_DA() - if allowed_num_processes == 0: + return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms) + + def init_dataloaders(self, dl_tr, tr_transforms, dl_val, val_transforms): + num_processes_train = get_allowed_n_proc_DA() + if num_processes_train == 0: mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) - mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) else: mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, data_loader=dl_tr, transform=tr_transforms, - num_processes=allowed_num_processes, num_cached=6, seeds=None, + num_processes=num_processes_train, num_cached=6, seeds=None, pin_memory=self.device.type == 'cuda', wait_time=0.02) + + num_processes_val = get_allowed_n_proc_DA_val() + if num_processes_val == 0: + mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) + else: mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, data_loader=dl_val, - transform=val_transforms, num_processes=max(1, allowed_num_processes // 2), + transform=val_transforms, num_processes=num_processes_val, num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda', wait_time=0.02) return mt_gen_train, mt_gen_val diff --git a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py index a96cb2bda..fb01bd8b0 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py +++ b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py @@ -2,7 +2,6 @@ import numpy as np import torch -from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose from batchgenerators.transforms.color_transforms import BrightnessTransform, ContrastAugmentationTransform, \ GammaTransform @@ -21,15 +20,12 @@ ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \ DownsampleSegForDSTransform2 -from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \ - LimitedLenWrapper from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \ ConvertSegmentationToRegionsTransform from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert3DTo2DTransform, \ Convert2DTo3DTransform from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer -from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA class nnUNetTrainerDA5(nnUNetTrainer): @@ -338,17 +334,7 @@ def get_dataloaders(self): dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) - allowed_num_processes = get_allowed_n_proc_DA() - if allowed_num_processes == 0: - mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) - mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) - else: - mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms, - allowed_num_processes, 6, None, True, 0.02) - mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms, - max(1, allowed_num_processes // 2), 3, None, True, 0.02) - - return mt_gen_train, mt_gen_val + return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms) def _brightnessadditive_localgamma_transform_scale(x, y): @@ -399,17 +385,7 @@ def get_dataloaders(self): dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) - allowed_num_processes = get_allowed_n_proc_DA() - if allowed_num_processes == 0: - mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) - mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) - else: - mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms, - allowed_num_processes, 6, None, True, 0.02) - mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms, - max(1, allowed_num_processes // 2), 3, None, True, 0.02) - - return mt_gen_train, mt_gen_val + return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms) class nnUNetTrainerDA5_10epochs(nnUNetTrainerDA5): diff --git a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py index be31857b3..d7e642a89 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py +++ b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py @@ -1,9 +1,4 @@ -from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter - -from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \ - LimitedLenWrapper from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer -from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA class nnUNetTrainerDAOrd0(nnUNetTrainer): @@ -42,17 +37,7 @@ def get_dataloaders(self): dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) - allowed_num_processes = get_allowed_n_proc_DA() - if allowed_num_processes == 0: - mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) - mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) - else: - mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms, - allowed_num_processes, 6, None, True, 0.02) - mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms, - max(1, allowed_num_processes // 2), 3, None, True, 0.02) - - return mt_gen_train, mt_gen_val + return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms) class nnUNetTrainer_DASegOrd0(nnUNetTrainer): @@ -91,17 +76,7 @@ def get_dataloaders(self): dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) - allowed_num_processes = get_allowed_n_proc_DA() - if allowed_num_processes == 0: - mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) - mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) - else: - mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms, - allowed_num_processes, 6, None, True, 0.02) - mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms, - max(1, allowed_num_processes // 2), 3, None, True, 0.02) - - return mt_gen_train, mt_gen_val + return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms) class nnUNetTrainer_DASegOrd0_NoMirroring(nnUNetTrainer): @@ -144,14 +119,4 @@ def get_dataloaders(self): dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) - allowed_num_processes = get_allowed_n_proc_DA() - if allowed_num_processes == 0: - mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) - mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) - else: - mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms, - allowed_num_processes, 6, None, True, 0.02) - mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms, - max(1, allowed_num_processes // 2), 3, None, True, 0.02) - - return mt_gen_train, mt_gen_val + return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms) diff --git a/nnunetv2/utilities/default_n_proc_DA.py b/nnunetv2/utilities/default_n_proc_DA.py index 3ecc92282..e12040824 100644 --- a/nnunetv2/utilities/default_n_proc_DA.py +++ b/nnunetv2/utilities/default_n_proc_DA.py @@ -2,6 +2,16 @@ import os +def get_allowed_n_proc_DA_val(): + """ + This function is used to set the number of processes used for the validation data loader. When nnUNet_n_proc_DA_val + is 0, the validation data is loaded sequentially in the main process. + """ + if 'nnUNet_n_proc_DA_val' in os.environ.keys(): + return int(os.environ['nnUNet_n_proc_DA_val']) + return get_allowed_n_proc_DA() // 2 + + def get_allowed_n_proc_DA(): """ This function is used to set the number of processes used on different Systems. It is specific to our cluster