From d54a9479bc17726623309d8b909c527793661dbf Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Tue, 23 Jan 2024 12:09:54 +0100 Subject: [PATCH 01/24] initial rework complete, still wip and needs testing --- documentation/competitions/AutoPETII.md | 2 +- documentation/explanation_plans_files.md | 2 +- .../default_experiment_planner.py | 169 ++++++++++-------- .../experiment_planners/network_topology.py | 5 +- .../plan_and_preprocess_api.py | 1 - nnunetv2/inference/predict_from_raw_data.py | 13 +- .../training/nnUNetTrainer/nnUNetTrainer.py | 75 ++++---- .../network_architecture/nnUNetTrainerBN.py | 86 +++------ nnunetv2/utilities/get_network_from_plans.py | 33 +++- .../utilities/plans_handling/plans_handler.py | 47 +---- 10 files changed, 215 insertions(+), 218 deletions(-) diff --git a/documentation/competitions/AutoPETII.md b/documentation/competitions/AutoPETII.md index 075256a03..f15ec5ba1 100644 --- a/documentation/competitions/AutoPETII.md +++ b/documentation/competitions/AutoPETII.md @@ -46,7 +46,7 @@ Add the following to the 'configurations' dict in 'nnUNetPlans.json': ```json "3d_fullres_resenc": { "inherits_from": "3d_fullres", - "UNet_class_name": "ResidualEncoderUNet", + "network_arch_class_name": "ResidualEncoderUNet", "n_conv_per_stage_encoder": [ 1, 3, diff --git a/documentation/explanation_plans_files.md b/documentation/explanation_plans_files.md index 00f121648..13ccda810 100644 --- a/documentation/explanation_plans_files.md +++ b/documentation/explanation_plans_files.md @@ -74,7 +74,7 @@ nnunetv2.preprocessing.resampling resampling function must be callable(data, current_spacing, new_spacing, **kwargs). It must be located in nnunetv2.preprocessing.resampling - `resampling_fn_seg_kwargs`: kwargs for resampling_fn_seg -- `UNet_class_name`: UNet class name, can be used to integrate custom dynamic architectures +- `network_arch_class_name`: UNet class name, can be used to integrate custom dynamic architectures - `UNet_base_num_features`: The number of starting features for the UNet architecture. Default is 32. Default: Features are doubled with each downsampling - `unet_max_num_features`: Maximum number of features (default: capped at 320 for 3D and 512 for 2d). The purpose is to diff --git a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py index ccb4a251e..0da198926 100644 --- a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py @@ -1,11 +1,11 @@ import shutil from copy import deepcopy -from functools import lru_cache -from typing import List, Union, Tuple, Type +from typing import List, Union, Tuple import numpy as np +import torch from batchgenerators.utilities.file_and_folder_operations import load_json, join, save_json, isfile, maybe_mkdir_p -from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet +from dynamic_network_architectures.architectures.unet import PlainConvUNet from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm from nnunetv2.configuration import ANISO_THRESHOLD @@ -15,9 +15,10 @@ from nnunetv2.preprocessing.normalization.map_channel_name_to_normalization import get_normalization_scheme from nnunetv2.preprocessing.resampling.default_resampling import resample_data_or_seg_to_shape, compute_new_shape from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA +from nnunetv2.utilities.get_network_from_plans import new_get_network from nnunetv2.utilities.json_export import recursive_fix_for_json_export -from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \ - get_filenames_of_train_images_and_targets +from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets class ExperimentPlanner(object): @@ -87,32 +88,22 @@ def determine_reader_writer(self): return determine_reader_writer_from_dataset_json(self.dataset_json, example_image) @staticmethod - @lru_cache(maxsize=None) def static_estimate_VRAM_usage(patch_size: Tuple[int], - n_stages: int, - strides: Union[int, List[int], Tuple[int, ...]], - UNet_class: Union[Type[PlainConvUNet], Type[ResidualEncoderUNet]], - num_input_channels: int, - features_per_stage: Tuple[int], - blocks_per_stage_encoder: Union[int, Tuple[int]], - blocks_per_stage_decoder: Union[int, Tuple[int]], - num_labels: int): + input_channels: int, + output_channels: int, + arch_class_name: str, + arch_kwargs: dict, + arch_kwargs_req_import: Tuple[str, ...]): """ Works for PlainConvUNet, ResidualEncoderUNet """ - dim = len(patch_size) - conv_op = convert_dim_to_conv_op(dim) - norm_op = get_matching_instancenorm(conv_op) - net = UNet_class(num_input_channels, n_stages, - features_per_stage, - conv_op, - 3, - strides, - blocks_per_stage_encoder, - num_labels, - blocks_per_stage_decoder, - norm_op=norm_op) - return net.compute_conv_feature_map_size(patch_size) + a = torch.get_num_threads() + torch.set_num_threads(get_allowed_n_proc_DA()) + net = new_get_network(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, output_channels, + allow_init=False) + ret = net.compute_conv_feature_map_size(patch_size) + torch.set_num_threads(a) + return ret def determine_resampling(self, *args, **kwargs): """ @@ -231,10 +222,24 @@ def determine_transpose(self): def get_plans_for_configuration(self, spacing: Union[np.ndarray, Tuple[float, ...], List[float]], - median_shape: Union[np.ndarray, Tuple[int, ...], List[int]], + median_shape: Union[np.ndarray, Tuple[int, ...]], data_identifier: str, - approximate_n_voxels_dataset: float) -> dict: + approximate_n_voxels_dataset: float, + _bad_patch_sizes: dict) -> dict: + def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: + return tuple([min(max_num_features, self.UNet_reference_com_nfeatures * 2 ** i) for + i in range(num_stages)]) + + def _keygen(patch_size, strides): + return str(patch_size) + '_' + str(strides) + assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}" + num_input_channels = len(self.dataset_json['channel_names'].keys() + if 'channel_names' in self.dataset_json.keys() + else self.dataset_json['modality'].keys()) + max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d + unet_conv_op = convert_dim_to_conv_op(len(spacing)) + # print(spacing, median_shape, approximate_n_voxels_dataset) # find an initial patch size # we first use the spacing to get an aspect ratio @@ -263,23 +268,38 @@ def get_plans_for_configuration(self, shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size, self.UNet_featuremap_min_edge_length, 999999) + num_stages = len(pool_op_kernel_sizes) + + norm = get_matching_instancenorm(unet_conv_op) + architecture_kwargs = { + 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, + 'arch_kwargs': { + 'n_stages': num_stages, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'n_conv_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + 'conv_bias': True, + 'norm_op': norm.__module__ + '.' + norm.__name__, + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, + 'dropout_op_kwargs': None, + 'nonlin': 'torch.nn.LeakyReLU', + 'nonlin_kwargs': {'inplace': True}, + }, + '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), + } # now estimate vram consumption - num_stages = len(pool_op_kernel_sizes) - estimate = self.static_estimate_VRAM_usage(tuple(patch_size), - num_stages, - tuple([tuple(i) for i in pool_op_kernel_sizes]), - self.UNet_class, - len(self.dataset_json['channel_names'].keys() - if 'channel_names' in self.dataset_json.keys() - else self.dataset_json['modality'].keys()), - tuple([min(self.UNet_max_features_2d if len(patch_size) == 2 else - self.UNet_max_features_3d, - self.UNet_reference_com_nfeatures * 2 ** i) for - i in range(len(pool_op_kernel_sizes))]), - self.UNet_blocks_per_stage_encoder[:num_stages], - self.UNet_blocks_per_stage_decoder[:num_stages - 1], - len(self.dataset_json['labels'].keys())) + estimate = self.static_estimate_VRAM_usage(patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) # how large is the reference for us here (batch size etc)? # adapt for our vram target @@ -287,10 +307,11 @@ def get_plans_for_configuration(self, (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) while estimate > reference: + _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # print(patch_size) # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the # aspect ratio the most (that is the largest relative to median shape) - axis_to_be_reduced = np.argsort(patch_size / median_shape[:len(spacing)])[-1] + axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1] # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256. @@ -298,6 +319,7 @@ def get_plans_for_configuration(self, # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first # subtract shape_must_be_divisible_by, then recompute it and then subtract the # recomputed shape_must_be_divisible_by. Annoying. + patch_size = list(patch_size) tmp = deepcopy(patch_size) tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] _, _, _, _, shape_must_be_divisible_by = \ @@ -313,20 +335,25 @@ def get_plans_for_configuration(self, 999999) num_stages = len(pool_op_kernel_sizes) - estimate = self.static_estimate_VRAM_usage(tuple(patch_size), - num_stages, - tuple([tuple(i) for i in pool_op_kernel_sizes]), - self.UNet_class, - len(self.dataset_json['channel_names'].keys() - if 'channel_names' in self.dataset_json.keys() - else self.dataset_json['modality'].keys()), - tuple([min(self.UNet_max_features_2d if len(patch_size) == 2 else - self.UNet_max_features_3d, - self.UNet_reference_com_nfeatures * 2 ** i) for - i in range(len(pool_op_kernel_sizes))]), - self.UNet_blocks_per_stage_encoder[:num_stages], - self.UNet_blocks_per_stage_decoder[:num_stages - 1], - len(self.dataset_json['labels'].keys())) + architecture_kwargs['arch_kwargs'].update({ + 'n_stages': num_stages, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'n_conv_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + }) + if _keygen(patch_size, pool_op_kernel_sizes) in _bad_patch_sizes.keys(): + _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + else: + estimate = self.static_estimate_VRAM_usage( + patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was # executed. If not, additional vram headroom is used to increase batch size @@ -344,7 +371,7 @@ def get_plans_for_configuration(self, normalization_schemes, mask_is_used_for_norm = \ self.determine_normalization_scheme_and_whether_mask_is_used_for_norm() - num_stages = len(pool_op_kernel_sizes) + plan = { 'data_identifier': data_identifier, 'preprocessor_name': self.preprocessor_name, @@ -354,20 +381,13 @@ def get_plans_for_configuration(self, 'spacing': spacing, 'normalization_schemes': normalization_schemes, 'use_mask_for_norm': mask_is_used_for_norm, - 'UNet_class_name': self.UNet_class.__name__, - 'UNet_base_num_features': self.UNet_base_num_features, - 'n_conv_per_stage_encoder': self.UNet_blocks_per_stage_encoder[:num_stages], - 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], - 'num_pool_per_axis': network_num_pool_per_axis, - 'pool_op_kernel_sizes': pool_op_kernel_sizes, - 'conv_kernel_sizes': conv_kernel_sizes, - 'unet_max_num_features': self.UNet_max_features_3d if len(spacing) == 3 else self.UNet_max_features_2d, 'resampling_fn_data': resampling_data.__name__, 'resampling_fn_seg': resampling_seg.__name__, 'resampling_fn_data_kwargs': resampling_data_kwargs, 'resampling_fn_seg_kwargs': resampling_seg_kwargs, 'resampling_fn_probabilities': resampling_softmax.__name__, 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs, + 'architecture': architecture_kwargs } return plan @@ -382,6 +402,8 @@ def plan_experiment(self): So for now if you want a different transpose_forward/backward you need to create a new planner. Also not too hard. """ + # we use this as a cache to prevent having to instantiate the architecture too often. Saves computation time + _tmp = {} # first get transpose transpose_forward, transpose_backward = self.determine_transpose() @@ -403,7 +425,7 @@ def plan_experiment(self): plan_3d_fullres = self.get_plans_for_configuration(fullres_spacing_transposed, new_median_shape_transposed, self.generate_data_identifier('3d_fullres'), - approximate_n_voxels_dataset) + approximate_n_voxels_dataset, _tmp) # maybe add 3d_lowres as well patch_size_fullres = plan_3d_fullres['patch_size'] median_num_voxels = np.prod(new_median_shape_transposed, dtype=np.float64) @@ -413,7 +435,6 @@ def plan_experiment(self): lowres_spacing = deepcopy(plan_3d_fullres['spacing']) spacing_increase_factor = 1.03 # used to be 1.01 but that is slow with new GPU memory estimation! - while num_voxels_in_patch / median_num_voxels < self.lowres_creation_threshold: # we incrementally increase the target spacing. We start with the anisotropic axis/axes until it/they # is/are similar (factor 2) to the other ax(i/e)s. @@ -426,11 +447,11 @@ def plan_experiment(self): dtype=np.float64) # print(lowres_spacing) plan_3d_lowres = self.get_plans_for_configuration(lowres_spacing, - [round(i) for i in plan_3d_fullres['spacing'] / - lowres_spacing * new_median_shape_transposed], + tuple([round(i) for i in plan_3d_fullres['spacing'] / + lowres_spacing * new_median_shape_transposed]), self.generate_data_identifier('3d_lowres'), float(np.prod(median_num_voxels) * - self.dataset_json['numTraining'])) + self.dataset_json['numTraining']), _tmp) num_voxels_in_patch = np.prod(plan_3d_lowres['patch_size'], dtype=np.int64) print(f'Attempting to find 3d_lowres config. ' f'\nCurrent spacing: {lowres_spacing}. ' @@ -448,7 +469,7 @@ def plan_experiment(self): # 2D configuration plan_2d = self.get_plans_for_configuration(fullres_spacing_transposed[1:], new_median_shape_transposed[1:], - self.generate_data_identifier('2d'), approximate_n_voxels_dataset) + self.generate_data_identifier('2d'), approximate_n_voxels_dataset, _tmp) plan_2d['batch_dice'] = True print('2D U-Net configuration:') diff --git a/nnunetv2/experiment_planning/experiment_planners/network_topology.py b/nnunetv2/experiment_planning/experiment_planners/network_topology.py index 1ce6a4665..6922f7b5d 100644 --- a/nnunetv2/experiment_planning/experiment_planners/network_topology.py +++ b/nnunetv2/experiment_planning/experiment_planners/network_topology.py @@ -100,6 +100,9 @@ def get_pool_and_conv_props(spacing, patch_size, min_feature_map_size, max_numpo must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis) patch_size = pad_shape(patch_size, must_be_divisible_by) + def _to_tuple(lst): + return tuple(_to_tuple(i) if isinstance(i, list) else i for i in lst) + # we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here conv_kernel_sizes.append([3]*dim) - return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by + return num_pool_per_axis, _to_tuple(pool_op_kernel_sizes), _to_tuple(conv_kernel_sizes), tuple(patch_size), must_be_divisible_by diff --git a/nnunetv2/experiment_planning/plan_and_preprocess_api.py b/nnunetv2/experiment_planning/plan_and_preprocess_api.py index 8c74f7c61..961aafc01 100644 --- a/nnunetv2/experiment_planning/plan_and_preprocess_api.py +++ b/nnunetv2/experiment_planning/plan_and_preprocess_api.py @@ -127,7 +127,6 @@ def preprocess_dataset(dataset_id: int, update=True) - def preprocess(dataset_ids: List[int], plans_identifier: str = 'nnUNetPlans', configurations: Union[Tuple[str], List[str]] = ('2d', '3d_fullres', '3d_lowres'), diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index cfc9e9c85..6ef927257 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -2,7 +2,6 @@ import itertools import multiprocessing import os -import traceback from copy import deepcopy from time import sleep from typing import Tuple, Union, List, Optional @@ -99,8 +98,16 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str, num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), trainer_name, 'nnunetv2.training.nnUNetTrainer') - network = trainer_class.build_network_architecture(plans_manager, dataset_json, configuration_manager, - num_input_channels, enable_deep_supervision=False) + + network = trainer_class.build_network_architecture( + configuration_manager.network_arch_class_name, + configuration_manager.network_arch_init_kwargs, + configuration_manager.network_arch_init_kwargs_req_import, + num_input_channels, + plans_manager.get_label_manager(dataset_json).num_segmentation_heads, + enable_deep_supervision=False + ) + self.plans_manager = plans_manager self.configuration_manager = configuration_manager self.list_of_parameters = parameters diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index 690a15fb2..756235a62 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -55,7 +55,7 @@ from nnunetv2.utilities.crossval_split import generate_crossval_split from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy -from nnunetv2.utilities.get_network_from_plans import get_network_from_plans +from nnunetv2.utilities.get_network_from_plans import new_get_network from nnunetv2.utilities.helpers import empty_cache, dummy_context from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager @@ -205,11 +205,12 @@ def initialize(self): self.dataset_json) self.network = self.build_network_architecture( - self.plans_manager, - self.dataset_json, - self.configuration_manager, + self.configuration_manager.network_arch_class_name, + self.configuration_manager.network_arch_init_kwargs, + self.configuration_manager.network_arch_init_kwargs_req_import, self.num_input_channels, - self.enable_deep_supervision, + self.label_manager.num_segmentation_heads, + self.enable_deep_supervision ).to(self.device) # compile network for free speedup if self._do_i_compile(): @@ -267,10 +268,11 @@ def _save_debug_information(self): save_json(dct, join(self.output_folder, "debug.json")) @staticmethod - def build_network_architecture(plans_manager: PlansManager, - dataset_json, - configuration_manager: ConfigurationManager, - num_input_channels, + def build_network_architecture(architecture_class_name: str, + arch_init_kwargs: dict, + arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], + num_input_channels: int, + num_output_channels: int, enable_deep_supervision: bool = True) -> nn.Module: """ This is where you build the architecture according to the plans. There is no obligation to use @@ -291,8 +293,14 @@ def build_network_architecture(plans_manager: PlansManager, should be generated. label_manager takes care of all that for you.) """ - return get_network_from_plans(plans_manager, dataset_json, configuration_manager, - num_input_channels, deep_supervision=enable_deep_supervision) + return new_get_network( + architecture_class_name, + arch_init_kwargs, + arch_init_kwargs_req_import, + num_input_channels, + num_output_channels, + allow_init=True, + deep_supervision=enable_deep_supervision) def _get_deep_supervision_scales(self): if self.enable_deep_supervision: @@ -366,7 +374,7 @@ def _build_loss(self): if self.enable_deep_supervision: deep_supervision_scales = self._get_deep_supervision_scales() - weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) if self.is_ddp and not self._do_i_compile(): # very strange and stupid interaction. DDP crashes and complains about unused parameters due to # weights[-1] = 0. Interestingly this crash doesn't happen with torch.compile enabled. Strange stuff. @@ -674,19 +682,19 @@ def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int): @staticmethod def get_training_transforms( - patch_size: Union[np.ndarray, Tuple[int]], - rotation_for_DA: dict, - deep_supervision_scales: Union[List, Tuple, None], - mirror_axes: Tuple[int, ...], - do_dummy_2d_data_aug: bool, - order_resampling_data: int = 3, - order_resampling_seg: int = 1, - border_val_seg: int = -1, - use_mask_for_norm: List[bool] = None, - is_cascaded: bool = False, - foreground_labels: Union[Tuple[int, ...], List[int]] = None, - regions: List[Union[List[int], Tuple[int, ...], int]] = None, - ignore_label: int = None, + patch_size: Union[np.ndarray, Tuple[int]], + rotation_for_DA: dict, + deep_supervision_scales: Union[List, Tuple, None], + mirror_axes: Tuple[int, ...], + do_dummy_2d_data_aug: bool, + order_resampling_data: int = 3, + order_resampling_seg: int = 1, + border_val_seg: int = -1, + use_mask_for_norm: List[bool] = None, + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None, ) -> AbstractTransform: tr_transforms = [] if do_dummy_2d_data_aug: @@ -768,11 +776,11 @@ def get_training_transforms( @staticmethod def get_validation_transforms( - deep_supervision_scales: Union[List, Tuple, None], - is_cascaded: bool = False, - foreground_labels: Union[Tuple[int, ...], List[int]] = None, - regions: List[Union[List[int], Tuple[int, ...], int]] = None, - ignore_label: int = None, + deep_supervision_scales: Union[List, Tuple, None], + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None, ) -> AbstractTransform: val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) @@ -1173,11 +1181,11 @@ def perform_actual_validation(self, save_probabilities: bool = False): for i, k in enumerate(dataset_val.keys()): proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, - allowed_num_queued=2) + allowed_num_queued=2) while not proceed: sleep(0.1) proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, - allowed_num_queued=2) + allowed_num_queued=2) self.print_to_log_file(f"predicting {k}") data, seg, properties = dataset_val.load_case(k) @@ -1262,7 +1270,8 @@ def perform_actual_validation(self, save_probabilities: bool = False): num_processes=default_num_processes * dist.get_world_size() if self.is_ddp else default_num_processes) self.print_to_log_file("Validation complete", also_print_to_console=True) - self.print_to_log_file("Mean Validation Dice: ", (metrics['foreground_mean']["Dice"]), also_print_to_console=True) + self.print_to_log_file("Mean Validation Dice: ", (metrics['foreground_mean']["Dice"]), + also_print_to_console=True) self.set_deep_supervision_enabled(True) compute_gaussian.cache_clear() diff --git a/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py b/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py index 5f6190c1b..6da3ca7e4 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py +++ b/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py @@ -1,73 +1,33 @@ -from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet, PlainConvUNet -from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_batchnorm -from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0, InitWeights_He -from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer -from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +from typing import Union, Tuple, List + +from dynamic_network_architectures.building_blocks.helper import get_matching_batchnorm from torch import nn +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer + class nnUNetTrainerBN(nnUNetTrainer): @staticmethod - def build_network_architecture(plans_manager: PlansManager, - dataset_json, - configuration_manager: ConfigurationManager, - num_input_channels, + def build_network_architecture(architecture_class_name: str, + arch_init_kwargs: dict, + arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], + num_input_channels: int, + num_output_channels: int, enable_deep_supervision: bool = True) -> nn.Module: - num_stages = len(configuration_manager.conv_kernel_sizes) - dim = len(configuration_manager.conv_kernel_sizes[0]) - conv_op = convert_dim_to_conv_op(dim) + if 'norm_op' not in arch_init_kwargs.keys(): + raise RuntimeError("'norm_op' not found in arch_init_kwargs. This does not look like an architecture " + "I can hack BN into. This trainer only works with default nnU-Net architectures.") - label_manager = plans_manager.get_label_manager(dataset_json) + from pydoc import locate + conv_op = locate(arch_init_kwargs['conv_op']) + bn_class = get_matching_batchnorm(conv_op) + arch_init_kwargs['norm_op'] = bn_class.__module__ + '.' + bn_class.__name__ + arch_init_kwargs['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True} - segmentation_network_class_name = configuration_manager.UNet_class_name - mapping = { - 'PlainConvUNet': PlainConvUNet, - 'ResidualEncoderUNet': ResidualEncoderUNet - } - kwargs = { - 'PlainConvUNet': { - 'conv_bias': True, - 'norm_op': get_matching_batchnorm(conv_op), - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, 'dropout_op_kwargs': None, - 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, - }, - 'ResidualEncoderUNet': { - 'conv_bias': True, - 'norm_op': get_matching_batchnorm(conv_op), - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, 'dropout_op_kwargs': None, - 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, - } - } - assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ - 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ - 'into either this ' \ - 'function (get_network_from_plans) or ' \ - 'the init of your nnUNetModule to accommodate that.' - network_class = mapping[segmentation_network_class_name] + return nnUNetTrainer.build_network_architecture(architecture_class_name, + arch_init_kwargs, + arch_init_kwargs_req_import, + num_input_channels, + num_output_channels, enable_deep_supervision) - conv_or_blocks_per_stage = { - 'n_conv_per_stage' - if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': configuration_manager.n_conv_per_stage_encoder, - 'n_conv_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder - } - # network class name!! - model = network_class( - input_channels=num_input_channels, - n_stages=num_stages, - features_per_stage=[min(configuration_manager.UNet_base_num_features * 2 ** i, - configuration_manager.unet_max_num_features) for i in range(num_stages)], - conv_op=conv_op, - kernel_sizes=configuration_manager.conv_kernel_sizes, - strides=configuration_manager.pool_op_kernel_sizes, - num_classes=label_manager.num_segmentation_heads, - deep_supervision=enable_deep_supervision, - **conv_or_blocks_per_stage, - **kwargs[segmentation_network_class_name] - ) - model.apply(InitWeights_He(1e-2)) - if network_class == ResidualEncoderUNet: - model.apply(init_last_bn_before_add_to_0) - return model diff --git a/nnunetv2/utilities/get_network_from_plans.py b/nnunetv2/utilities/get_network_from_plans.py index 1dd1dd2ec..25a8471e4 100644 --- a/nnunetv2/utilities/get_network_from_plans.py +++ b/nnunetv2/utilities/get_network_from_plans.py @@ -1,9 +1,38 @@ +import pydoc +from typing import Union + from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 +from torch import nn + from nnunetv2.utilities.network_initialization import InitWeights_He from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager -from torch import nn + + +def new_get_network(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, output_channels, + allow_init=True, deep_supervision: Union[bool, None] = None): + network_class = arch_class_name + architecture_kwargs = dict(**arch_kwargs) + for ri in arch_kwargs_req_import: + if architecture_kwargs[ri] is not None: + architecture_kwargs[ri] = pydoc.locate(architecture_kwargs[ri]) + + nw_class = pydoc.locate(network_class) + + if deep_supervision is not None and 'deep_supervision' not in arch_kwargs.keys(): + arch_kwargs['deep_supervision'] = deep_supervision + + network = nw_class( + input_channels=input_channels, + num_classes=output_channels, + **architecture_kwargs + ) + + if hasattr(network, 'initialize') and allow_init: + network.apply(network.initialize) + + return network def get_network_from_plans(plans_manager: PlansManager, @@ -24,7 +53,7 @@ def get_network_from_plans(plans_manager: PlansManager, label_manager = plans_manager.get_label_manager(dataset_json) - segmentation_network_class_name = configuration_manager.UNet_class_name + segmentation_network_class_name = configuration_manager.network_arch_class_name mapping = { 'PlainConvUNet': PlainConvUNet, 'ResidualEncoderUNet': ResidualEncoderUNet diff --git a/nnunetv2/utilities/plans_handling/plans_handler.py b/nnunetv2/utilities/plans_handling/plans_handler.py index 6c39fd1ed..03601817d 100644 --- a/nnunetv2/utilities/plans_handling/plans_handler.py +++ b/nnunetv2/utilities/plans_handling/plans_handler.py @@ -9,8 +9,6 @@ import torch from nnunetv2.preprocessing.resampling.utils import recursive_find_resampling_fn_by_name -from torch import nn - import nnunetv2 from batchgenerators.utilities.file_and_folder_operations import load_json, join @@ -77,49 +75,20 @@ def use_mask_for_norm(self) -> List[bool]: return self.configuration['use_mask_for_norm'] @property - def UNet_class_name(self) -> str: - return self.configuration['UNet_class_name'] - - @property - @lru_cache(maxsize=1) - def UNet_class(self) -> Type[nn.Module]: - unet_class = recursive_find_python_class(join(dynamic_network_architectures.__path__[0], "architectures"), - self.UNet_class_name, - current_module="dynamic_network_architectures.architectures") - if unet_class is None: - raise RuntimeError('The network architecture specified by the plans file ' - 'is non-standard (maybe your own?). Fix this by not using ' - 'ConfigurationManager.UNet_class to instantiate ' - 'it (probably just overwrite build_network_architecture of your trainer.') - return unet_class - - @property - def UNet_base_num_features(self) -> int: - return self.configuration['UNet_base_num_features'] - - @property - def n_conv_per_stage_encoder(self) -> List[int]: - return self.configuration['n_conv_per_stage_encoder'] - - @property - def n_conv_per_stage_decoder(self) -> List[int]: - return self.configuration['n_conv_per_stage_decoder'] - - @property - def num_pool_per_axis(self) -> List[int]: - return self.configuration['num_pool_per_axis'] + def network_arch_class_name(self) -> str: + return self.configuration['architecture']['network_class_name'] @property - def pool_op_kernel_sizes(self) -> List[List[int]]: - return self.configuration['pool_op_kernel_sizes'] + def network_arch_init_kwargs(self) -> dict: + return self.configuration['architecture']['arch_kwargs'] @property - def conv_kernel_sizes(self) -> List[List[int]]: - return self.configuration['conv_kernel_sizes'] + def network_arch_init_kwargs_req_import(self) -> Union[Tuple[str, ...], List[str]]: + return self.configuration['architecture']['_kw_requires_import'] @property - def unet_max_num_features(self) -> int: - return self.configuration['unet_max_num_features'] + def pool_op_kernel_sizes(self) -> Tuple[Tuple[int, ...], ...]: + return self.configuration['architecture']['arch_kwargs']['strides'] @property @lru_cache(maxsize=1) From e15c2186d51d5890eb4e92243927abdf3fbfac52 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Tue, 23 Jan 2024 16:11:27 +0100 Subject: [PATCH 02/24] ironed out some kinks, added a bunch of residual unet variants for testing --- .../default_experiment_planner.py | 14 +- .../experiment_planners/resUNet_planner.py | 210 +++++++++++++++++ .../experiment_planners/resUNet_planner2.py | 16 ++ .../experiment_planners/resUNet_planner3.py | 192 ++++++++++++++++ .../resencUNetBottleneck_planner.py | 216 ++++++++++++++++++ .../experiment_planners/resencUNet_planner.py | 215 ++++++++++++++++- .../plan_and_preprocess_api.py | 28 ++- .../plan_and_preprocess_entrypoints.py | 6 +- .../training/nnUNetTrainer/nnUNetTrainer.py | 4 +- nnunetv2/utilities/get_network_from_plans.py | 85 +------ 10 files changed, 873 insertions(+), 113 deletions(-) create mode 100644 nnunetv2/experiment_planning/experiment_planners/resUNet_planner.py create mode 100644 nnunetv2/experiment_planning/experiment_planners/resUNet_planner2.py create mode 100644 nnunetv2/experiment_planning/experiment_planners/resUNet_planner3.py create mode 100644 nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py diff --git a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py index 0da198926..06b52962b 100644 --- a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py @@ -16,7 +16,7 @@ from nnunetv2.preprocessing.resampling.default_resampling import resample_data_or_seg_to_shape, compute_new_shape from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA -from nnunetv2.utilities.get_network_from_plans import new_get_network +from nnunetv2.utilities.get_network_from_plans import get_network_from_plans from nnunetv2.utilities.json_export import recursive_fix_for_json_export from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets @@ -57,7 +57,6 @@ def __init__(self, dataset_name_or_id: Union[str, int], self.UNet_reference_val_corresp_GB = 8 self.UNet_reference_val_corresp_bs_2d = 12 self.UNet_reference_val_corresp_bs_3d = 2 - self.UNet_vram_target_GB = gpu_memory_target_in_gb self.UNet_featuremap_min_edge_length = 4 self.UNet_blocks_per_stage_encoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) self.UNet_blocks_per_stage_decoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) @@ -65,6 +64,8 @@ def __init__(self, dataset_name_or_id: Union[str, int], self.UNet_max_features_2d = 512 self.UNet_max_features_3d = 320 + self.UNet_vram_target_GB = gpu_memory_target_in_gb + self.lowres_creation_threshold = 0.25 # if the patch size of fullres is less than 25% of the voxels in the # median shape then we need a lowres config as well @@ -99,8 +100,8 @@ def static_estimate_VRAM_usage(patch_size: Tuple[int], """ a = torch.get_num_threads() torch.set_num_threads(get_allowed_n_proc_DA()) - net = new_get_network(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, output_channels, - allow_init=False) + net = get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, output_channels, + allow_init=False) ret = net.compute_conv_feature_map_size(patch_size) torch.set_num_threads(a) return ret @@ -457,6 +458,11 @@ def plan_experiment(self): f'\nCurrent spacing: {lowres_spacing}. ' f'\nCurrent patch size: {plan_3d_lowres["patch_size"]}. ' f'\nCurrent median shape: {plan_3d_fullres["spacing"] / lowres_spacing * new_median_shape_transposed}') + if np.prod(new_median_shape_transposed, dtype=np.float64) / median_num_voxels < 2: + print(f'Dropping 3d_lowres config because the image size difference to 3d_fullres is too small. ' + f'3d_fullres: {new_median_shape_transposed}, ' + f'3d_lowres: {[round(i) for i in plan_3d_fullres["spacing"] / lowres_spacing * new_median_shape_transposed]}') + plan_3d_lowres = None if plan_3d_lowres is not None: plan_3d_lowres['batch_dice'] = False plan_3d_fullres['batch_dice'] = True diff --git a/nnunetv2/experiment_planning/experiment_planners/resUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/resUNet_planner.py new file mode 100644 index 000000000..42042f44a --- /dev/null +++ b/nnunetv2/experiment_planning/experiment_planners/resUNet_planner.py @@ -0,0 +1,210 @@ +from copy import deepcopy +from typing import Union, List, Tuple + +import numpy as np +from dynamic_network_architectures.architectures.residual_unet import ResidualUNet +from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm + +from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner +from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props + + +class ResUNetPlanner(ExperimentPlanner): + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 8, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResUNetPlans', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + + self.UNet_class = ResidualUNet + # the following two numbers are really arbitrary and were set to reproduce default nnU-Net's configurations as + # much as possible + self.UNet_reference_val_3d = 680000000 + self.UNet_reference_val_2d = 135000000 + self.UNet_blocks_per_stage_encoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) + self.UNet_blocks_per_stage_decoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) + + def generate_data_identifier(self, configuration_name: str) -> str: + """ + configurations are unique within each plans file but different plans file can have configurations with the + same name. In order to distinguish the associated data we need a data identifier that reflects not just the + config but also the plans it originates from + """ + if configuration_name == '2d' or configuration_name == '3d_fullres': + # we do not deviate from ExperimentPlanner so we can reuse its data + return 'nnUNetPlans' + '_' + configuration_name + else: + return self.plans_identifier + '_' + configuration_name + + def get_plans_for_configuration(self, + spacing: Union[np.ndarray, Tuple[float, ...], List[float]], + median_shape: Union[np.ndarray, Tuple[int, ...]], + data_identifier: str, + approximate_n_voxels_dataset: float, + _bad_patch_sizes: dict) -> dict: + def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: + return tuple([min(max_num_features, self.UNet_reference_com_nfeatures * 2 ** i) for + i in range(num_stages)]) + + def _keygen(patch_size, strides): + return str(patch_size) + '_' + str(strides) + + assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}" + num_input_channels = len(self.dataset_json['channel_names'].keys() + if 'channel_names' in self.dataset_json.keys() + else self.dataset_json['modality'].keys()) + max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d + unet_conv_op = convert_dim_to_conv_op(len(spacing)) + + # print(spacing, median_shape, approximate_n_voxels_dataset) + # find an initial patch size + # we first use the spacing to get an aspect ratio + tmp = 1 / np.array(spacing) + + # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same + # volume as a patch of size 256 ** 3) + # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be + # ideal because large initial patch sizes increase computation time because more iterations in the while loop + # further down may be required. + if len(spacing) == 3: + initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)] + elif len(spacing) == 2: + initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)] + else: + raise RuntimeError() + + # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that + # this is different from how nnU-Net v1 does it! + # todo patch size can still get too large because we pad the patch size to a multiple of 2**n + initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])]) + + # use that to get the network topology. Note that this changes the patch_size depending on the number of + # pooling operations (must be divisible by 2**num_pool in each axis) + network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ + shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size, + self.UNet_featuremap_min_edge_length, + 999999) + num_stages = len(pool_op_kernel_sizes) + + norm = get_matching_instancenorm(unet_conv_op) + architecture_kwargs = { + 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, + 'arch_kwargs': { + 'n_stages': num_stages, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + 'conv_bias': True, + 'norm_op': norm.__module__ + '.' + norm.__name__, + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, + 'dropout_op_kwargs': None, + 'nonlin': 'torch.nn.LeakyReLU', + 'nonlin_kwargs': {'inplace': True}, + }, + '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), + } + + # now estimate vram consumption + estimate = self.static_estimate_VRAM_usage(patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + + # how large is the reference for us here (batch size etc)? + # adapt for our vram target + reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \ + (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) + + while estimate > reference: + _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + # print(patch_size) + # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the + # aspect ratio the most (that is the largest relative to median shape) + axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1] + + # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this + # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256. + # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size + # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first + # subtract shape_must_be_divisible_by, then recompute it and then subtract the + # recomputed shape_must_be_divisible_by. Annoying. + patch_size = list(patch_size) + tmp = deepcopy(patch_size) + tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] + _, _, _, _, shape_must_be_divisible_by = \ + get_pool_and_conv_props(spacing, tmp, + self.UNet_featuremap_min_edge_length, + 999999) + patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] + + # now recompute topology + network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ + shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size, + self.UNet_featuremap_min_edge_length, + 999999) + + num_stages = len(pool_op_kernel_sizes) + architecture_kwargs['arch_kwargs'].update({ + 'n_stages': num_stages, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + }) + if _keygen(patch_size, pool_op_kernel_sizes) in _bad_patch_sizes.keys(): + _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + else: + estimate = self.static_estimate_VRAM_usage( + patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + + # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was + # executed. If not, additional vram headroom is used to increase batch size + ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d + batch_size = round((reference / estimate) * ref_bs) + + # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot + # go smaller than self.UNet_min_batch_size though + bs_corresponding_to_5_percent = round( + approximate_n_voxels_dataset * 0.05 / np.prod(patch_size, dtype=np.float64)) + batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) + + resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() + resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn() + + normalization_schemes, mask_is_used_for_norm = \ + self.determine_normalization_scheme_and_whether_mask_is_used_for_norm() + + plan = { + 'data_identifier': data_identifier, + 'preprocessor_name': self.preprocessor_name, + 'batch_size': batch_size, + 'patch_size': patch_size, + 'median_image_size_in_voxels': median_shape, + 'spacing': spacing, + 'normalization_schemes': normalization_schemes, + 'use_mask_for_norm': mask_is_used_for_norm, + 'resampling_fn_data': resampling_data.__name__, + 'resampling_fn_seg': resampling_seg.__name__, + 'resampling_fn_data_kwargs': resampling_data_kwargs, + 'resampling_fn_seg_kwargs': resampling_seg_kwargs, + 'resampling_fn_probabilities': resampling_softmax.__name__, + 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs, + 'architecture': architecture_kwargs + } + return plan diff --git a/nnunetv2/experiment_planning/experiment_planners/resUNet_planner2.py b/nnunetv2/experiment_planning/experiment_planners/resUNet_planner2.py new file mode 100644 index 000000000..8cffbae77 --- /dev/null +++ b/nnunetv2/experiment_planning/experiment_planners/resUNet_planner2.py @@ -0,0 +1,16 @@ +from typing import Union, List, Tuple + +from nnunetv2.experiment_planning.experiment_planners.resUNet_planner import ResUNetPlanner + + +class ResUNetPlanner2(ResUNetPlanner): + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 8, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResUNet2Plans', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + + self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6) + self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) diff --git a/nnunetv2/experiment_planning/experiment_planners/resUNet_planner3.py b/nnunetv2/experiment_planning/experiment_planners/resUNet_planner3.py new file mode 100644 index 000000000..89018c29f --- /dev/null +++ b/nnunetv2/experiment_planning/experiment_planners/resUNet_planner3.py @@ -0,0 +1,192 @@ +from copy import deepcopy +from typing import Union, List, Tuple + +import numpy as np +from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm + +from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props +from nnunetv2.experiment_planning.experiment_planners.resUNet_planner import ResUNetPlanner + + +class ResUNetPlanner3(ResUNetPlanner): + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 8, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResUNet3Plans', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + + self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6) + self.UNet_blocks_per_stage_decoder = None + + def get_plans_for_configuration(self, + spacing: Union[np.ndarray, Tuple[float, ...], List[float]], + median_shape: Union[np.ndarray, Tuple[int, ...]], + data_identifier: str, + approximate_n_voxels_dataset: float, + _bad_patch_sizes: dict) -> dict: + def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: + return tuple([min(max_num_features, self.UNet_reference_com_nfeatures * 2 ** i) for + i in range(num_stages)]) + + def _keygen(patch_size, strides): + return str(patch_size) + '_' + str(strides) + + assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}" + num_input_channels = len(self.dataset_json['channel_names'].keys() + if 'channel_names' in self.dataset_json.keys() + else self.dataset_json['modality'].keys()) + max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d + unet_conv_op = convert_dim_to_conv_op(len(spacing)) + + # print(spacing, median_shape, approximate_n_voxels_dataset) + # find an initial patch size + # we first use the spacing to get an aspect ratio + tmp = 1 / np.array(spacing) + + # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same + # volume as a patch of size 256 ** 3) + # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be + # ideal because large initial patch sizes increase computation time because more iterations in the while loop + # further down may be required. + if len(spacing) == 3: + initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)] + elif len(spacing) == 2: + initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)] + else: + raise RuntimeError() + + # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that + # this is different from how nnU-Net v1 does it! + # todo patch size can still get too large because we pad the patch size to a multiple of 2**n + initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])]) + + # use that to get the network topology. Note that this changes the patch_size depending on the number of + # pooling operations (must be divisible by 2**num_pool in each axis) + network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ + shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size, + self.UNet_featuremap_min_edge_length, + 999999) + num_stages = len(pool_op_kernel_sizes) + + norm = get_matching_instancenorm(unet_conv_op) + architecture_kwargs = { + 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, + 'arch_kwargs': { + 'n_stages': num_stages, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_encoder[:num_stages - 1][::-1], + 'conv_bias': True, + 'norm_op': norm.__module__ + '.' + norm.__name__, + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, + 'dropout_op_kwargs': None, + 'nonlin': 'torch.nn.LeakyReLU', + 'nonlin_kwargs': {'inplace': True}, + }, + '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), + } + + # now estimate vram consumption + estimate = self.static_estimate_VRAM_usage(patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + + # how large is the reference for us here (batch size etc)? + # adapt for our vram target + reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \ + (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) + + while estimate > reference: + _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + # print(patch_size) + # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the + # aspect ratio the most (that is the largest relative to median shape) + axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1] + + # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this + # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256. + # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size + # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first + # subtract shape_must_be_divisible_by, then recompute it and then subtract the + # recomputed shape_must_be_divisible_by. Annoying. + patch_size = list(patch_size) + tmp = deepcopy(patch_size) + tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] + _, _, _, _, shape_must_be_divisible_by = \ + get_pool_and_conv_props(spacing, tmp, + self.UNet_featuremap_min_edge_length, + 999999) + patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] + + # now recompute topology + network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ + shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size, + self.UNet_featuremap_min_edge_length, + 999999) + + num_stages = len(pool_op_kernel_sizes) + architecture_kwargs['arch_kwargs'].update({ + 'n_stages': num_stages, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_encoder[:num_stages - 1][::-1], + }) + if _keygen(patch_size, pool_op_kernel_sizes) in _bad_patch_sizes.keys(): + _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + else: + estimate = self.static_estimate_VRAM_usage( + patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + + # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was + # executed. If not, additional vram headroom is used to increase batch size + ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d + batch_size = round((reference / estimate) * ref_bs) + + # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot + # go smaller than self.UNet_min_batch_size though + bs_corresponding_to_5_percent = round( + approximate_n_voxels_dataset * 0.05 / np.prod(patch_size, dtype=np.float64)) + batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) + + resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() + resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn() + + normalization_schemes, mask_is_used_for_norm = \ + self.determine_normalization_scheme_and_whether_mask_is_used_for_norm() + + plan = { + 'data_identifier': data_identifier, + 'preprocessor_name': self.preprocessor_name, + 'batch_size': batch_size, + 'patch_size': patch_size, + 'median_image_size_in_voxels': median_shape, + 'spacing': spacing, + 'normalization_schemes': normalization_schemes, + 'use_mask_for_norm': mask_is_used_for_norm, + 'resampling_fn_data': resampling_data.__name__, + 'resampling_fn_seg': resampling_seg.__name__, + 'resampling_fn_data_kwargs': resampling_data_kwargs, + 'resampling_fn_seg_kwargs': resampling_seg_kwargs, + 'resampling_fn_probabilities': resampling_softmax.__name__, + 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs, + 'architecture': architecture_kwargs + } + return plan diff --git a/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py b/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py new file mode 100644 index 000000000..1f167237b --- /dev/null +++ b/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py @@ -0,0 +1,216 @@ +from copy import deepcopy +from typing import Union, List, Tuple + +import numpy as np +from dynamic_network_architectures.architectures.residual_unet import ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm +from dynamic_network_architectures.building_blocks.residual import BottleneckD +from torch import nn + +from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props +from nnunetv2.experiment_planning.experiment_planners.resencUNet_planner import ResEncUNetPlanner + + +class ResEncUNetBottleneckPlanner(ResEncUNetPlanner): + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 8, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResBottleneckEncUNetPlans', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + + def get_plans_for_configuration(self, + spacing: Union[np.ndarray, Tuple[float, ...], List[float]], + median_shape: Union[np.ndarray, Tuple[int, ...]], + data_identifier: str, + approximate_n_voxels_dataset: float, + _bad_patch_sizes: dict) -> dict: + def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: + return tuple([min(max_num_features, self.UNet_reference_com_nfeatures * 2 ** i) for + i in range(num_stages)]) + + def _keygen(patch_size, strides): + return str(patch_size) + '_' + str(strides) + + assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}" + num_input_channels = len(self.dataset_json['channel_names'].keys() + if 'channel_names' in self.dataset_json.keys() + else self.dataset_json['modality'].keys()) + max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d + unet_conv_op = convert_dim_to_conv_op(len(spacing)) + + # print(spacing, median_shape, approximate_n_voxels_dataset) + # find an initial patch size + # we first use the spacing to get an aspect ratio + tmp = 1 / np.array(spacing) + + # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same + # volume as a patch of size 256 ** 3) + # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be + # ideal because large initial patch sizes increase computation time because more iterations in the while loop + # further down may be required. + if len(spacing) == 3: + initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)] + elif len(spacing) == 2: + initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)] + else: + raise RuntimeError() + + # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that + # this is different from how nnU-Net v1 does it! + # todo patch size can still get too large because we pad the patch size to a multiple of 2**n + initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])]) + + # use that to get the network topology. Note that this changes the patch_size depending on the number of + # pooling operations (must be divisible by 2**num_pool in each axis) + network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ + shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size, + self.UNet_featuremap_min_edge_length, + 999999) + num_stages = len(pool_op_kernel_sizes) + + norm = get_matching_instancenorm(unet_conv_op) + architecture_kwargs = { + 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, + 'arch_kwargs': { + 'n_stages': num_stages, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + 'conv_bias': True, + 'norm_op': norm.__module__ + '.' + norm.__name__, + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, + 'dropout_op_kwargs': None, + 'nonlin': 'torch.nn.LeakyReLU', + 'nonlin_kwargs': {'inplace': True}, + 'block': BottleneckD.__module__ + '.' + BottleneckD.__name__, + 'bottleneck_channels': [i // 4 for i in _features_per_stage(num_stages, max_num_features)] + }, + '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin', 'block'), + } + + # now estimate vram consumption + estimate = self.static_estimate_VRAM_usage(patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + + # how large is the reference for us here (batch size etc)? + # adapt for our vram target + reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \ + (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) + + while estimate > reference: + _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + # print(patch_size) + # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the + # aspect ratio the most (that is the largest relative to median shape) + axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1] + + # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this + # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256. + # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size + # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first + # subtract shape_must_be_divisible_by, then recompute it and then subtract the + # recomputed shape_must_be_divisible_by. Annoying. + patch_size = list(patch_size) + tmp = deepcopy(patch_size) + tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] + _, _, _, _, shape_must_be_divisible_by = \ + get_pool_and_conv_props(spacing, tmp, + self.UNet_featuremap_min_edge_length, + 999999) + patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] + + # now recompute topology + network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ + shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size, + self.UNet_featuremap_min_edge_length, + 999999) + + num_stages = len(pool_op_kernel_sizes) + architecture_kwargs['arch_kwargs'].update({ + 'n_stages': num_stages, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + 'bottleneck_channels': [i // 4 for i in _features_per_stage(num_stages, max_num_features)] + }) + if _keygen(patch_size, pool_op_kernel_sizes) in _bad_patch_sizes.keys(): + _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + else: + estimate = self.static_estimate_VRAM_usage( + patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + + # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was + # executed. If not, additional vram headroom is used to increase batch size + ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d + batch_size = round((reference / estimate) * ref_bs) + + # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot + # go smaller than self.UNet_min_batch_size though + bs_corresponding_to_5_percent = round( + approximate_n_voxels_dataset * 0.05 / np.prod(patch_size, dtype=np.float64)) + batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) + + resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() + resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn() + + normalization_schemes, mask_is_used_for_norm = \ + self.determine_normalization_scheme_and_whether_mask_is_used_for_norm() + + plan = { + 'data_identifier': data_identifier, + 'preprocessor_name': self.preprocessor_name, + 'batch_size': batch_size, + 'patch_size': patch_size, + 'median_image_size_in_voxels': median_shape, + 'spacing': spacing, + 'normalization_schemes': normalization_schemes, + 'use_mask_for_norm': mask_is_used_for_norm, + 'resampling_fn_data': resampling_data.__name__, + 'resampling_fn_seg': resampling_seg.__name__, + 'resampling_fn_data_kwargs': resampling_data_kwargs, + 'resampling_fn_seg_kwargs': resampling_seg_kwargs, + 'resampling_fn_probabilities': resampling_softmax.__name__, + 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs, + 'architecture': architecture_kwargs + } + return plan + + +if __name__ == '__main__': + # we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively + net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320), + conv_op=nn.Conv3d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2), + n_blocks_per_stage=(1, 3, 4, 6, 6, 6), num_classes=3, + n_conv_per_stage_decoder=(1, 1, 1, 1, 1), + conv_bias=True, norm_op=nn.InstanceNorm3d, norm_op_kwargs={}, dropout_op=None, + nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) + print(net.compute_conv_feature_map_size((128, 128, 128))) # -> 558319104. The value you see above was finetuned + # from this one to match the regular nnunetplans more closely + + net = ResidualEncoderUNet(input_channels=1, n_stages=7, features_per_stage=(32, 64, 128, 256, 512, 512, 512), + conv_op=nn.Conv2d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2, 2), + n_blocks_per_stage=(1, 3, 4, 6, 6, 6, 6), num_classes=3, + n_conv_per_stage_decoder=(1, 1, 1, 1, 1, 1), + conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={}, dropout_op=None, + nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) + print(net.compute_conv_feature_map_size((512, 512))) # -> 129793792 + diff --git a/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py index 52ca938ee..32684bf4e 100644 --- a/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py @@ -1,9 +1,14 @@ +import numpy as np +from copy import deepcopy from typing import Union, List, Tuple +from dynamic_network_architectures.architectures.residual_unet import ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm from torch import nn from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner -from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet + +from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props class ResEncUNetPlanner(ExperimentPlanner): @@ -14,23 +19,211 @@ def __init__(self, dataset_name_or_id: Union[str, int], suppress_transpose: bool = False): super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) - - self.UNet_base_num_features = 32 self.UNet_class = ResidualEncoderUNet # the following two numbers are really arbitrary and were set to reproduce default nnU-Net's configurations as # much as possible self.UNet_reference_val_3d = 680000000 self.UNet_reference_val_2d = 135000000 - self.UNet_reference_com_nfeatures = 32 - self.UNet_reference_val_corresp_GB = 8 - self.UNet_reference_val_corresp_bs_2d = 12 - self.UNet_reference_val_corresp_bs_3d = 2 - self.UNet_featuremap_min_edge_length = 4 self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6) self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) - self.UNet_min_batch_size = 2 - self.UNet_max_features_2d = 512 - self.UNet_max_features_3d = 320 + + def generate_data_identifier(self, configuration_name: str) -> str: + """ + configurations are unique within each plans file but different plans file can have configurations with the + same name. In order to distinguish the associated data we need a data identifier that reflects not just the + config but also the plans it originates from + """ + if configuration_name == '2d' or configuration_name == '3d_fullres': + # we do not deviate from ExperimentPlanner so we can reuse its data + return 'nnUNetPlans' + '_' + configuration_name + else: + return self.plans_identifier + '_' + configuration_name + + def get_plans_for_configuration(self, + spacing: Union[np.ndarray, Tuple[float, ...], List[float]], + median_shape: Union[np.ndarray, Tuple[int, ...]], + data_identifier: str, + approximate_n_voxels_dataset: float, + _bad_patch_sizes: dict) -> dict: + def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: + return tuple([min(max_num_features, self.UNet_reference_com_nfeatures * 2 ** i) for + i in range(num_stages)]) + + def _keygen(patch_size, strides): + return str(patch_size) + '_' + str(strides) + + assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}" + num_input_channels = len(self.dataset_json['channel_names'].keys() + if 'channel_names' in self.dataset_json.keys() + else self.dataset_json['modality'].keys()) + max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d + unet_conv_op = convert_dim_to_conv_op(len(spacing)) + + # print(spacing, median_shape, approximate_n_voxels_dataset) + # find an initial patch size + # we first use the spacing to get an aspect ratio + tmp = 1 / np.array(spacing) + + # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same + # volume as a patch of size 256 ** 3) + # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be + # ideal because large initial patch sizes increase computation time because more iterations in the while loop + # further down may be required. + if len(spacing) == 3: + initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)] + elif len(spacing) == 2: + initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)] + else: + raise RuntimeError() + + # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that + # this is different from how nnU-Net v1 does it! + # todo patch size can still get too large because we pad the patch size to a multiple of 2**n + initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])]) + + # use that to get the network topology. Note that this changes the patch_size depending on the number of + # pooling operations (must be divisible by 2**num_pool in each axis) + network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ + shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size, + self.UNet_featuremap_min_edge_length, + 999999) + num_stages = len(pool_op_kernel_sizes) + + norm = get_matching_instancenorm(unet_conv_op) + architecture_kwargs = { + 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, + 'arch_kwargs': { + 'n_stages': num_stages, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + 'conv_bias': True, + 'norm_op': norm.__module__ + '.' + norm.__name__, + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, + 'dropout_op_kwargs': None, + 'nonlin': 'torch.nn.LeakyReLU', + 'nonlin_kwargs': {'inplace': True}, + }, + '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), + } + + # now estimate vram consumption + estimate = self.static_estimate_VRAM_usage(patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + + # how large is the reference for us here (batch size etc)? + # adapt for our vram target + reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \ + (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) + + while estimate > reference: + _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + # print(patch_size) + # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the + # aspect ratio the most (that is the largest relative to median shape) + axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1] + + # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this + # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256. + # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size + # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first + # subtract shape_must_be_divisible_by, then recompute it and then subtract the + # recomputed shape_must_be_divisible_by. Annoying. + patch_size = list(patch_size) + tmp = deepcopy(patch_size) + tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] + _, _, _, _, shape_must_be_divisible_by = \ + get_pool_and_conv_props(spacing, tmp, + self.UNet_featuremap_min_edge_length, + 999999) + patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] + + # now recompute topology + network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ + shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size, + self.UNet_featuremap_min_edge_length, + 999999) + + num_stages = len(pool_op_kernel_sizes) + architecture_kwargs['arch_kwargs'].update({ + 'n_stages': num_stages, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + }) + if _keygen(patch_size, pool_op_kernel_sizes) in _bad_patch_sizes.keys(): + _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + else: + estimate = self.static_estimate_VRAM_usage( + patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + + # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was + # executed. If not, additional vram headroom is used to increase batch size + ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d + batch_size = round((reference / estimate) * ref_bs) + + # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot + # go smaller than self.UNet_min_batch_size though + bs_corresponding_to_5_percent = round( + approximate_n_voxels_dataset * 0.05 / np.prod(patch_size, dtype=np.float64)) + batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) + + resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() + resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn() + + normalization_schemes, mask_is_used_for_norm = \ + self.determine_normalization_scheme_and_whether_mask_is_used_for_norm() + + plan = { + 'data_identifier': data_identifier, + 'preprocessor_name': self.preprocessor_name, + 'batch_size': batch_size, + 'patch_size': patch_size, + 'median_image_size_in_voxels': median_shape, + 'spacing': spacing, + 'normalization_schemes': normalization_schemes, + 'use_mask_for_norm': mask_is_used_for_norm, + 'resampling_fn_data': resampling_data.__name__, + 'resampling_fn_seg': resampling_seg.__name__, + 'resampling_fn_data_kwargs': resampling_data_kwargs, + 'resampling_fn_seg_kwargs': resampling_seg_kwargs, + 'resampling_fn_probabilities': resampling_softmax.__name__, + 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs, + 'architecture': architecture_kwargs + } + return plan + + +class ResEncUNetPlanner2(ResEncUNetPlanner): + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 8, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNet2Plans', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + self.UNet_class = ResidualEncoderUNet + # this is supposed to give the same GPU memory requirement as the default nnU-Net + self.UNet_reference_val_3d = 600000000 + self.UNet_reference_val_2d = 115000000 + if __name__ == '__main__': diff --git a/nnunetv2/experiment_planning/plan_and_preprocess_api.py b/nnunetv2/experiment_planning/plan_and_preprocess_api.py index 961aafc01..c81e06a46 100644 --- a/nnunetv2/experiment_planning/plan_and_preprocess_api.py +++ b/nnunetv2/experiment_planning/plan_and_preprocess_api.py @@ -51,21 +51,24 @@ def plan_experiment_dataset(dataset_id: int, experiment_planner_class: Type[ExperimentPlanner] = ExperimentPlanner, gpu_memory_target_in_gb: float = 8, preprocess_class_name: str = 'DefaultPreprocessor', overwrite_target_spacing: Optional[Tuple[float, ...]] = None, - overwrite_plans_name: Optional[str] = None) -> dict: + overwrite_plans_name: Optional[str] = None) -> Tuple[dict, str]: """ overwrite_target_spacing ONLY applies to 3d_fullres and 3d_cascade fullres! """ kwargs = {} if overwrite_plans_name is not None: kwargs['plans_name'] = overwrite_plans_name - return experiment_planner_class(dataset_id, - gpu_memory_target_in_gb=gpu_memory_target_in_gb, - preprocessor_name=preprocess_class_name, - overwrite_target_spacing=[float(i) for i in overwrite_target_spacing] if - overwrite_target_spacing is not None else overwrite_target_spacing, - suppress_transpose=False, # might expose this later, - **kwargs - ).plan_experiment() + + planner = experiment_planner_class(dataset_id, + gpu_memory_target_in_gb=gpu_memory_target_in_gb, + preprocessor_name=preprocess_class_name, + overwrite_target_spacing=[float(i) for i in overwrite_target_spacing] if + overwrite_target_spacing is not None else overwrite_target_spacing, + suppress_transpose=False, # might expose this later, + **kwargs + ) + ret = planner.plan_experiment() + return ret, planner.plans_identifier def plan_experiments(dataset_ids: List[int], experiment_planner_class_name: str = 'ExperimentPlanner', @@ -78,9 +81,12 @@ def plan_experiments(dataset_ids: List[int], experiment_planner_class_name: str experiment_planner = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"), experiment_planner_class_name, current_module="nnunetv2.experiment_planning") + plans_identifier = None for d in dataset_ids: - plan_experiment_dataset(d, experiment_planner, gpu_memory_target_in_gb, preprocess_class_name, - overwrite_target_spacing, overwrite_plans_name) + _, plans_identifier = plan_experiment_dataset(d, experiment_planner, gpu_memory_target_in_gb, + preprocess_class_name, + overwrite_target_spacing, overwrite_plans_name) + return plans_identifier def preprocess_dataset(dataset_id: int, diff --git a/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py b/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py index 556f04a4f..88a37f024 100644 --- a/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py +++ b/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py @@ -149,7 +149,7 @@ def plan_and_preprocess_entry(): 'know what you are doing and NEVER use this without running the default nnU-Net first ' '(as a baseline). Changing the target spacing for the other configurations is currently ' 'not implemented. New target spacing must be a list of three numbers!') - parser.add_argument('-overwrite_plans_name', default='nnUNetPlans', required=False, + parser.add_argument('-overwrite_plans_name', default=None, required=False, help='[OPTIONAL] uSE A CUSTOM PLANS IDENTIFIER. If you used -gpu_memory_target, ' '-preprocessor_name or ' '-overwrite_target_spacing it is best practice to use -overwrite_plans_name to generate a ' @@ -183,7 +183,7 @@ def plan_and_preprocess_entry(): # experiment planning print('Experiment planning...') - plan_experiments(args.d, args.pl, args.gpu_memory_target, args.preprocessor_name, args.overwrite_target_spacing, args.overwrite_plans_name) + plans_identifier = plan_experiments(args.d, args.pl, args.gpu_memory_target, args.preprocessor_name, args.overwrite_target_spacing, args.overwrite_plans_name) # manage default np if args.np is None: @@ -194,7 +194,7 @@ def plan_and_preprocess_entry(): # preprocessing if not args.no_pp: print('Preprocessing...') - preprocess(args.d, args.overwrite_plans_name, args.c, np, args.verbose) + preprocess(args.d, plans_identifier, args.c, np, args.verbose) if __name__ == '__main__': diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index 756235a62..a233597d4 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -55,7 +55,7 @@ from nnunetv2.utilities.crossval_split import generate_crossval_split from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy -from nnunetv2.utilities.get_network_from_plans import new_get_network +from nnunetv2.utilities.get_network_from_plans import get_network_from_plans from nnunetv2.utilities.helpers import empty_cache, dummy_context from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager @@ -293,7 +293,7 @@ def build_network_architecture(architecture_class_name: str, should be generated. label_manager takes care of all that for you.) """ - return new_get_network( + return get_network_from_plans( architecture_class_name, arch_init_kwargs, arch_init_kwargs_req_import, diff --git a/nnunetv2/utilities/get_network_from_plans.py b/nnunetv2/utilities/get_network_from_plans.py index 25a8471e4..8d10cb4a3 100644 --- a/nnunetv2/utilities/get_network_from_plans.py +++ b/nnunetv2/utilities/get_network_from_plans.py @@ -1,17 +1,9 @@ import pydoc from typing import Union -from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet -from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op -from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 -from torch import nn -from nnunetv2.utilities.network_initialization import InitWeights_He -from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager - - -def new_get_network(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, output_channels, - allow_init=True, deep_supervision: Union[bool, None] = None): +def get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, output_channels, + allow_init=True, deep_supervision: Union[bool, None] = None): network_class = arch_class_name architecture_kwargs = dict(**arch_kwargs) for ri in arch_kwargs_req_import: @@ -32,75 +24,4 @@ def new_get_network(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_ if hasattr(network, 'initialize') and allow_init: network.apply(network.initialize) - return network - - -def get_network_from_plans(plans_manager: PlansManager, - dataset_json: dict, - configuration_manager: ConfigurationManager, - num_input_channels: int, - deep_supervision: bool = True): - """ - we may have to change this in the future to accommodate other plans -> network mappings - - num_input_channels can differ depending on whether we do cascade. Its best to make this info available in the - trainer rather than inferring it again from the plans here. - """ - num_stages = len(configuration_manager.conv_kernel_sizes) - - dim = len(configuration_manager.conv_kernel_sizes[0]) - conv_op = convert_dim_to_conv_op(dim) - - label_manager = plans_manager.get_label_manager(dataset_json) - - segmentation_network_class_name = configuration_manager.network_arch_class_name - mapping = { - 'PlainConvUNet': PlainConvUNet, - 'ResidualEncoderUNet': ResidualEncoderUNet - } - kwargs = { - 'PlainConvUNet': { - 'conv_bias': True, - 'norm_op': get_matching_instancenorm(conv_op), - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, 'dropout_op_kwargs': None, - 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, - }, - 'ResidualEncoderUNet': { - 'conv_bias': True, - 'norm_op': get_matching_instancenorm(conv_op), - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, 'dropout_op_kwargs': None, - 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, - } - } - assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ - 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ - 'into either this ' \ - 'function (get_network_from_plans) or ' \ - 'the init of your nnUNetModule to accommodate that.' - network_class = mapping[segmentation_network_class_name] - - conv_or_blocks_per_stage = { - 'n_conv_per_stage' - if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': configuration_manager.n_conv_per_stage_encoder, - 'n_conv_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder - } - # network class name!! - model = network_class( - input_channels=num_input_channels, - n_stages=num_stages, - features_per_stage=[min(configuration_manager.UNet_base_num_features * 2 ** i, - configuration_manager.unet_max_num_features) for i in range(num_stages)], - conv_op=conv_op, - kernel_sizes=configuration_manager.conv_kernel_sizes, - strides=configuration_manager.pool_op_kernel_sizes, - num_classes=label_manager.num_segmentation_heads, - deep_supervision=deep_supervision, - **conv_or_blocks_per_stage, - **kwargs[segmentation_network_class_name] - ) - model.apply(InitWeights_He(1e-2)) - if network_class == ResidualEncoderUNet: - model.apply(init_last_bn_before_add_to_0) - return model + return network \ No newline at end of file From d405232bef589e4694492103bd4cd4627ba94a41 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Tue, 23 Jan 2024 17:34:46 +0100 Subject: [PATCH 03/24] bugfix --- .../batch_running/collect_results_custom_Decathlon.py | 2 +- .../experiment_planners/default_experiment_planner.py | 9 +++++---- .../experiment_planners/resUNet_planner.py | 9 +++++---- .../experiment_planners/resUNet_planner3.py | 9 +++++---- .../experiment_planners/resencUNetBottleneck_planner.py | 9 +++++---- .../experiment_planners/resencUNet_planner.py | 9 +++++---- 6 files changed, 26 insertions(+), 21 deletions(-) diff --git a/nnunetv2/batch_running/collect_results_custom_Decathlon.py b/nnunetv2/batch_running/collect_results_custom_Decathlon.py index b670661c5..d5d08cf24 100644 --- a/nnunetv2/batch_running/collect_results_custom_Decathlon.py +++ b/nnunetv2/batch_running/collect_results_custom_Decathlon.py @@ -99,7 +99,7 @@ def summarize(input_file, output_file, folds: Tuple[int, ...], configs: Tuple[st 'nnUNetTrainer_DASegOrd0': ('nnUNetPlans',), } all_results_file= join(nnUNet_results, 'customDecResults.csv') - datasets = [2, 3, 4, 17, 20, 24, 27, 38, 55, 64, 82] + datasets = [2, 3, 4, 17, 24, 27, 38, 55, 137, 217, 221] # amos post challenge, kits2023 collect_results(use_these_trainers, datasets, all_results_file) folds = (0, 1, 2, 3, 4) diff --git a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py index 06b52962b..eb61c92fd 100644 --- a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py @@ -226,7 +226,7 @@ def get_plans_for_configuration(self, median_shape: Union[np.ndarray, Tuple[int, ...]], data_identifier: str, approximate_n_voxels_dataset: float, - _bad_patch_sizes: dict) -> dict: + _cache: dict) -> dict: def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: return tuple([min(max_num_features, self.UNet_reference_com_nfeatures * 2 ** i) for i in range(num_stages)]) @@ -308,7 +308,7 @@ def _keygen(patch_size, strides): (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) while estimate > reference: - _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # print(patch_size) # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the # aspect ratio the most (that is the largest relative to median shape) @@ -344,8 +344,8 @@ def _keygen(patch_size, strides): 'n_conv_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], }) - if _keygen(patch_size, pool_op_kernel_sizes) in _bad_patch_sizes.keys(): - _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): + estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] else: estimate = self.static_estimate_VRAM_usage( patch_size, @@ -355,6 +355,7 @@ def _keygen(patch_size, strides): architecture_kwargs['arch_kwargs'], architecture_kwargs['_kw_requires_import'], ) + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was # executed. If not, additional vram headroom is used to increase batch size diff --git a/nnunetv2/experiment_planning/experiment_planners/resUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/resUNet_planner.py index 42042f44a..f282acc9b 100644 --- a/nnunetv2/experiment_planning/experiment_planners/resUNet_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/resUNet_planner.py @@ -43,7 +43,7 @@ def get_plans_for_configuration(self, median_shape: Union[np.ndarray, Tuple[int, ...]], data_identifier: str, approximate_n_voxels_dataset: float, - _bad_patch_sizes: dict) -> dict: + _cache: dict) -> dict: def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: return tuple([min(max_num_features, self.UNet_reference_com_nfeatures * 2 ** i) for i in range(num_stages)]) @@ -125,7 +125,7 @@ def _keygen(patch_size, strides): (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) while estimate > reference: - _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # print(patch_size) # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the # aspect ratio the most (that is the largest relative to median shape) @@ -161,8 +161,8 @@ def _keygen(patch_size, strides): 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], }) - if _keygen(patch_size, pool_op_kernel_sizes) in _bad_patch_sizes.keys(): - _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): + estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] else: estimate = self.static_estimate_VRAM_usage( patch_size, @@ -172,6 +172,7 @@ def _keygen(patch_size, strides): architecture_kwargs['arch_kwargs'], architecture_kwargs['_kw_requires_import'], ) + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was # executed. If not, additional vram headroom is used to increase batch size diff --git a/nnunetv2/experiment_planning/experiment_planners/resUNet_planner3.py b/nnunetv2/experiment_planning/experiment_planners/resUNet_planner3.py index 89018c29f..4b1fd110d 100644 --- a/nnunetv2/experiment_planning/experiment_planners/resUNet_planner3.py +++ b/nnunetv2/experiment_planning/experiment_planners/resUNet_planner3.py @@ -25,7 +25,7 @@ def get_plans_for_configuration(self, median_shape: Union[np.ndarray, Tuple[int, ...]], data_identifier: str, approximate_n_voxels_dataset: float, - _bad_patch_sizes: dict) -> dict: + _cache: dict) -> dict: def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: return tuple([min(max_num_features, self.UNet_reference_com_nfeatures * 2 ** i) for i in range(num_stages)]) @@ -107,7 +107,7 @@ def _keygen(patch_size, strides): (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) while estimate > reference: - _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # print(patch_size) # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the # aspect ratio the most (that is the largest relative to median shape) @@ -143,8 +143,8 @@ def _keygen(patch_size, strides): 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_encoder[:num_stages - 1][::-1], }) - if _keygen(patch_size, pool_op_kernel_sizes) in _bad_patch_sizes.keys(): - _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): + estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] else: estimate = self.static_estimate_VRAM_usage( patch_size, @@ -154,6 +154,7 @@ def _keygen(patch_size, strides): architecture_kwargs['arch_kwargs'], architecture_kwargs['_kw_requires_import'], ) + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was # executed. If not, additional vram headroom is used to increase batch size diff --git a/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py b/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py index 1f167237b..1bd8ef477 100644 --- a/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py @@ -25,7 +25,7 @@ def get_plans_for_configuration(self, median_shape: Union[np.ndarray, Tuple[int, ...]], data_identifier: str, approximate_n_voxels_dataset: float, - _bad_patch_sizes: dict) -> dict: + _cache: dict) -> dict: def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: return tuple([min(max_num_features, self.UNet_reference_com_nfeatures * 2 ** i) for i in range(num_stages)]) @@ -109,7 +109,7 @@ def _keygen(patch_size, strides): (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) while estimate > reference: - _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # print(patch_size) # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the # aspect ratio the most (that is the largest relative to median shape) @@ -146,8 +146,8 @@ def _keygen(patch_size, strides): 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], 'bottleneck_channels': [i // 4 for i in _features_per_stage(num_stages, max_num_features)] }) - if _keygen(patch_size, pool_op_kernel_sizes) in _bad_patch_sizes.keys(): - _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): + estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] else: estimate = self.static_estimate_VRAM_usage( patch_size, @@ -157,6 +157,7 @@ def _keygen(patch_size, strides): architecture_kwargs['arch_kwargs'], architecture_kwargs['_kw_requires_import'], ) + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was # executed. If not, additional vram headroom is used to increase batch size diff --git a/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py index 32684bf4e..cd27225b4 100644 --- a/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py @@ -44,7 +44,7 @@ def get_plans_for_configuration(self, median_shape: Union[np.ndarray, Tuple[int, ...]], data_identifier: str, approximate_n_voxels_dataset: float, - _bad_patch_sizes: dict) -> dict: + _cache: dict) -> dict: def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: return tuple([min(max_num_features, self.UNet_reference_com_nfeatures * 2 ** i) for i in range(num_stages)]) @@ -126,7 +126,7 @@ def _keygen(patch_size, strides): (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) while estimate > reference: - _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # print(patch_size) # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the # aspect ratio the most (that is the largest relative to median shape) @@ -162,8 +162,8 @@ def _keygen(patch_size, strides): 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], }) - if _keygen(patch_size, pool_op_kernel_sizes) in _bad_patch_sizes.keys(): - _bad_patch_sizes[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): + estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] else: estimate = self.static_estimate_VRAM_usage( patch_size, @@ -173,6 +173,7 @@ def _keygen(patch_size, strides): architecture_kwargs['arch_kwargs'], architecture_kwargs['_kw_requires_import'], ) + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was # executed. If not, additional vram headroom is used to increase batch size From 52f9f04674f8eb27596956a03b078e83655bfeef Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Tue, 23 Jan 2024 17:49:00 +0100 Subject: [PATCH 04/24] speed improvement --- .../default_experiment_planner.py | 68 ++++++++++--------- .../experiment_planners/resUNet_planner.py | 59 ++++++++-------- .../experiment_planners/resUNet_planner3.py | 59 ++++++++-------- .../resencUNetBottleneck_planner.py | 64 ++++++++--------- .../experiment_planners/resencUNet_planner.py | 61 +++++++++-------- 5 files changed, 163 insertions(+), 148 deletions(-) diff --git a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py index eb61c92fd..90287ac18 100644 --- a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py @@ -100,7 +100,9 @@ def static_estimate_VRAM_usage(patch_size: Tuple[int], """ a = torch.get_num_threads() torch.set_num_threads(get_allowed_n_proc_DA()) - net = get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, output_channels, + print(f'instantiating network, patch size {patch_size}, pool op: {arch_kwargs["strides"]}') + net = get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, + output_channels, allow_init=False) ret = net.compute_conv_feature_map_size(patch_size) torch.set_num_threads(a) @@ -273,34 +275,38 @@ def _keygen(patch_size, strides): norm = get_matching_instancenorm(unet_conv_op) architecture_kwargs = { - 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, - 'arch_kwargs': { - 'n_stages': num_stages, - 'features_per_stage': _features_per_stage(num_stages, max_num_features), - 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, - 'kernel_sizes': conv_kernel_sizes, - 'strides': pool_op_kernel_sizes, - 'n_conv_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], - 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], - 'conv_bias': True, - 'norm_op': norm.__module__ + '.' + norm.__name__, - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, - 'dropout_op_kwargs': None, - 'nonlin': 'torch.nn.LeakyReLU', - 'nonlin_kwargs': {'inplace': True}, - }, - '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), - } + 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, + 'arch_kwargs': { + 'n_stages': num_stages, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'n_conv_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + 'conv_bias': True, + 'norm_op': norm.__module__ + '.' + norm.__name__, + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, + 'dropout_op_kwargs': None, + 'nonlin': 'torch.nn.LeakyReLU', + 'nonlin_kwargs': {'inplace': True}, + }, + '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), + } # now estimate vram consumption - estimate = self.static_estimate_VRAM_usage(patch_size, - num_input_channels, - len(self.dataset_json['labels'].keys()), - architecture_kwargs['network_class_name'], - architecture_kwargs['arch_kwargs'], - architecture_kwargs['_kw_requires_import'], - ) + if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): + estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] + else: + estimate = self.static_estimate_VRAM_usage(patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # how large is the reference for us here (batch size etc)? # adapt for our vram target @@ -308,7 +314,6 @@ def _keygen(patch_size, strides): (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) while estimate > reference: - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # print(patch_size) # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the # aspect ratio the most (that is the largest relative to median shape) @@ -355,7 +360,7 @@ def _keygen(patch_size, strides): architecture_kwargs['arch_kwargs'], architecture_kwargs['_kw_requires_import'], ) - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was # executed. If not, additional vram headroom is used to increase batch size @@ -450,7 +455,7 @@ def plan_experiment(self): # print(lowres_spacing) plan_3d_lowres = self.get_plans_for_configuration(lowres_spacing, tuple([round(i) for i in plan_3d_fullres['spacing'] / - lowres_spacing * new_median_shape_transposed]), + lowres_spacing * new_median_shape_transposed]), self.generate_data_identifier('3d_lowres'), float(np.prod(median_num_voxels) * self.dataset_json['numTraining']), _tmp) @@ -476,7 +481,8 @@ def plan_experiment(self): # 2D configuration plan_2d = self.get_plans_for_configuration(fullres_spacing_transposed[1:], new_median_shape_transposed[1:], - self.generate_data_identifier('2d'), approximate_n_voxels_dataset, _tmp) + self.generate_data_identifier('2d'), approximate_n_voxels_dataset, + _tmp) plan_2d['batch_dice'] = True print('2D U-Net configuration:') diff --git a/nnunetv2/experiment_planning/experiment_planners/resUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/resUNet_planner.py index f282acc9b..a26bcc70f 100644 --- a/nnunetv2/experiment_planning/experiment_planners/resUNet_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/resUNet_planner.py @@ -90,34 +90,38 @@ def _keygen(patch_size, strides): norm = get_matching_instancenorm(unet_conv_op) architecture_kwargs = { - 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, - 'arch_kwargs': { - 'n_stages': num_stages, - 'features_per_stage': _features_per_stage(num_stages, max_num_features), - 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, - 'kernel_sizes': conv_kernel_sizes, - 'strides': pool_op_kernel_sizes, - 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], - 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], - 'conv_bias': True, - 'norm_op': norm.__module__ + '.' + norm.__name__, - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, - 'dropout_op_kwargs': None, - 'nonlin': 'torch.nn.LeakyReLU', - 'nonlin_kwargs': {'inplace': True}, - }, - '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), - } + 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, + 'arch_kwargs': { + 'n_stages': num_stages, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + 'conv_bias': True, + 'norm_op': norm.__module__ + '.' + norm.__name__, + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, + 'dropout_op_kwargs': None, + 'nonlin': 'torch.nn.LeakyReLU', + 'nonlin_kwargs': {'inplace': True}, + }, + '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), + } # now estimate vram consumption - estimate = self.static_estimate_VRAM_usage(patch_size, - num_input_channels, - len(self.dataset_json['labels'].keys()), - architecture_kwargs['network_class_name'], - architecture_kwargs['arch_kwargs'], - architecture_kwargs['_kw_requires_import'], - ) + if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): + estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] + else: + estimate = self.static_estimate_VRAM_usage(patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # how large is the reference for us here (batch size etc)? # adapt for our vram target @@ -125,7 +129,6 @@ def _keygen(patch_size, strides): (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) while estimate > reference: - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # print(patch_size) # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the # aspect ratio the most (that is the largest relative to median shape) @@ -172,7 +175,7 @@ def _keygen(patch_size, strides): architecture_kwargs['arch_kwargs'], architecture_kwargs['_kw_requires_import'], ) - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was # executed. If not, additional vram headroom is used to increase batch size diff --git a/nnunetv2/experiment_planning/experiment_planners/resUNet_planner3.py b/nnunetv2/experiment_planning/experiment_planners/resUNet_planner3.py index 4b1fd110d..4c70c34ba 100644 --- a/nnunetv2/experiment_planning/experiment_planners/resUNet_planner3.py +++ b/nnunetv2/experiment_planning/experiment_planners/resUNet_planner3.py @@ -72,34 +72,38 @@ def _keygen(patch_size, strides): norm = get_matching_instancenorm(unet_conv_op) architecture_kwargs = { - 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, - 'arch_kwargs': { - 'n_stages': num_stages, - 'features_per_stage': _features_per_stage(num_stages, max_num_features), - 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, - 'kernel_sizes': conv_kernel_sizes, - 'strides': pool_op_kernel_sizes, - 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], - 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_encoder[:num_stages - 1][::-1], - 'conv_bias': True, - 'norm_op': norm.__module__ + '.' + norm.__name__, - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, - 'dropout_op_kwargs': None, - 'nonlin': 'torch.nn.LeakyReLU', - 'nonlin_kwargs': {'inplace': True}, - }, - '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), - } + 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, + 'arch_kwargs': { + 'n_stages': num_stages, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_encoder[:num_stages - 1][::-1], + 'conv_bias': True, + 'norm_op': norm.__module__ + '.' + norm.__name__, + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, + 'dropout_op_kwargs': None, + 'nonlin': 'torch.nn.LeakyReLU', + 'nonlin_kwargs': {'inplace': True}, + }, + '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), + } # now estimate vram consumption - estimate = self.static_estimate_VRAM_usage(patch_size, - num_input_channels, - len(self.dataset_json['labels'].keys()), - architecture_kwargs['network_class_name'], - architecture_kwargs['arch_kwargs'], - architecture_kwargs['_kw_requires_import'], - ) + if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): + estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] + else: + estimate = self.static_estimate_VRAM_usage(patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # how large is the reference for us here (batch size etc)? # adapt for our vram target @@ -107,7 +111,6 @@ def _keygen(patch_size, strides): (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) while estimate > reference: - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # print(patch_size) # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the # aspect ratio the most (that is the largest relative to median shape) @@ -154,7 +157,7 @@ def _keygen(patch_size, strides): architecture_kwargs['arch_kwargs'], architecture_kwargs['_kw_requires_import'], ) - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was # executed. If not, additional vram headroom is used to increase batch size diff --git a/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py b/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py index 1bd8ef477..b278e69d9 100644 --- a/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py @@ -72,36 +72,40 @@ def _keygen(patch_size, strides): norm = get_matching_instancenorm(unet_conv_op) architecture_kwargs = { - 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, - 'arch_kwargs': { - 'n_stages': num_stages, - 'features_per_stage': _features_per_stage(num_stages, max_num_features), - 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, - 'kernel_sizes': conv_kernel_sizes, - 'strides': pool_op_kernel_sizes, - 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], - 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], - 'conv_bias': True, - 'norm_op': norm.__module__ + '.' + norm.__name__, - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, - 'dropout_op_kwargs': None, - 'nonlin': 'torch.nn.LeakyReLU', - 'nonlin_kwargs': {'inplace': True}, - 'block': BottleneckD.__module__ + '.' + BottleneckD.__name__, - 'bottleneck_channels': [i // 4 for i in _features_per_stage(num_stages, max_num_features)] - }, - '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin', 'block'), - } + 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, + 'arch_kwargs': { + 'n_stages': num_stages, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + 'conv_bias': True, + 'norm_op': norm.__module__ + '.' + norm.__name__, + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, + 'dropout_op_kwargs': None, + 'nonlin': 'torch.nn.LeakyReLU', + 'nonlin_kwargs': {'inplace': True}, + 'block': BottleneckD.__module__ + '.' + BottleneckD.__name__, + 'bottleneck_channels': [i // 4 for i in _features_per_stage(num_stages, max_num_features)] + }, + '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin', 'block'), + } # now estimate vram consumption - estimate = self.static_estimate_VRAM_usage(patch_size, - num_input_channels, - len(self.dataset_json['labels'].keys()), - architecture_kwargs['network_class_name'], - architecture_kwargs['arch_kwargs'], - architecture_kwargs['_kw_requires_import'], - ) + if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): + estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] + else: + estimate = self.static_estimate_VRAM_usage(patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # how large is the reference for us here (batch size etc)? # adapt for our vram target @@ -109,7 +113,6 @@ def _keygen(patch_size, strides): (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) while estimate > reference: - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # print(patch_size) # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the # aspect ratio the most (that is the largest relative to median shape) @@ -157,7 +160,7 @@ def _keygen(patch_size, strides): architecture_kwargs['arch_kwargs'], architecture_kwargs['_kw_requires_import'], ) - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was # executed. If not, additional vram headroom is used to increase batch size @@ -214,4 +217,3 @@ def _keygen(patch_size, strides): conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={}, dropout_op=None, nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) print(net.compute_conv_feature_map_size((512, 512))) # -> 129793792 - diff --git a/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py index cd27225b4..29a539333 100644 --- a/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py @@ -91,34 +91,38 @@ def _keygen(patch_size, strides): norm = get_matching_instancenorm(unet_conv_op) architecture_kwargs = { - 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, - 'arch_kwargs': { - 'n_stages': num_stages, - 'features_per_stage': _features_per_stage(num_stages, max_num_features), - 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, - 'kernel_sizes': conv_kernel_sizes, - 'strides': pool_op_kernel_sizes, - 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], - 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], - 'conv_bias': True, - 'norm_op': norm.__module__ + '.' + norm.__name__, - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, - 'dropout_op_kwargs': None, - 'nonlin': 'torch.nn.LeakyReLU', - 'nonlin_kwargs': {'inplace': True}, - }, - '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), - } + 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, + 'arch_kwargs': { + 'n_stages': num_stages, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + 'conv_bias': True, + 'norm_op': norm.__module__ + '.' + norm.__name__, + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, + 'dropout_op_kwargs': None, + 'nonlin': 'torch.nn.LeakyReLU', + 'nonlin_kwargs': {'inplace': True}, + }, + '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), + } # now estimate vram consumption - estimate = self.static_estimate_VRAM_usage(patch_size, - num_input_channels, - len(self.dataset_json['labels'].keys()), - architecture_kwargs['network_class_name'], - architecture_kwargs['arch_kwargs'], - architecture_kwargs['_kw_requires_import'], - ) + if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): + estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] + else: + estimate = self.static_estimate_VRAM_usage(patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # how large is the reference for us here (batch size etc)? # adapt for our vram target @@ -126,7 +130,6 @@ def _keygen(patch_size, strides): (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) while estimate > reference: - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # print(patch_size) # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the # aspect ratio the most (that is the largest relative to median shape) @@ -173,7 +176,7 @@ def _keygen(patch_size, strides): architecture_kwargs['arch_kwargs'], architecture_kwargs['_kw_requires_import'], ) - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was # executed. If not, additional vram headroom is used to increase batch size @@ -226,7 +229,6 @@ def __init__(self, dataset_name_or_id: Union[str, int], self.UNet_reference_val_2d = 115000000 - if __name__ == '__main__': # we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320), @@ -245,4 +247,3 @@ def __init__(self, dataset_name_or_id: Union[str, int], conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={}, dropout_op=None, nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) print(net.compute_conv_feature_map_size((512, 512))) # -> 129793792 - From 4f0fbd82f5c76d7adac0fd0065a9c263586a23cf Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Tue, 23 Jan 2024 17:56:24 +0100 Subject: [PATCH 05/24] remove print --- .../experiment_planners/default_experiment_planner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py index 90287ac18..92ef7dfd2 100644 --- a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py @@ -100,7 +100,7 @@ def static_estimate_VRAM_usage(patch_size: Tuple[int], """ a = torch.get_num_threads() torch.set_num_threads(get_allowed_n_proc_DA()) - print(f'instantiating network, patch size {patch_size}, pool op: {arch_kwargs["strides"]}') + # print(f'instantiating network, patch size {patch_size}, pool op: {arch_kwargs["strides"]}') net = get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, output_channels, allow_init=False) From 4802f44e045cc544aa862075bf382c764d58b303 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Tue, 23 Jan 2024 18:03:29 +0100 Subject: [PATCH 06/24] ResEncUNetBottleneckDeeperPlanner --- .../batch_running/collect_results_custom_Decathlon.py | 4 +--- .../resencUNetBottleneck_planner.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/nnunetv2/batch_running/collect_results_custom_Decathlon.py b/nnunetv2/batch_running/collect_results_custom_Decathlon.py index d5d08cf24..f52cb6e3d 100644 --- a/nnunetv2/batch_running/collect_results_custom_Decathlon.py +++ b/nnunetv2/batch_running/collect_results_custom_Decathlon.py @@ -94,9 +94,7 @@ def summarize(input_file, output_file, folds: Tuple[int, ...], configs: Tuple[st if __name__ == '__main__': use_these_trainers = { - 'nnUNetTrainer': ('nnUNetPlans',), - 'nnUNetTrainerDiceCELoss_noSmooth': ('nnUNetPlans',), - 'nnUNetTrainer_DASegOrd0': ('nnUNetPlans',), + 'nnUNetTrainer': ('nnUNetPlans', 'nnUNetResEncUNetPlans', 'nnUNetResEncUNet2Plans', 'nnUNetResBottleneckEncUNetPlans', 'nnUNetResUNetPlans', 'nnUNetResUNet2Plans', 'nnUNetResUNet3Plans', 'nnUNetDeeperResBottleneckEncUNetPlans'), } all_results_file= join(nnUNet_results, 'customDecResults.csv') datasets = [2, 3, 4, 17, 24, 27, 38, 55, 137, 217, 221] # amos post challenge, kits2023 diff --git a/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py b/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py index b278e69d9..1d04c0aa0 100644 --- a/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py @@ -198,6 +198,17 @@ def _keygen(patch_size, strides): } return plan +class ResEncUNetBottleneckDeeperPlanner(ResEncUNetBottleneckPlanner): + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 8, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetDeeperResBottleneckEncUNetPlans', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + self.UNet_blocks_per_stage_encoder = (1, 3, 6, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9) + self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) + if __name__ == '__main__': # we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively From b243163df89829d80e71597194191679985e46ad Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Tue, 23 Jan 2024 18:05:13 +0100 Subject: [PATCH 07/24] ResEncUNetBottleneckDeeperPlanner --- .../experiment_planners/resencUNetBottleneck_planner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py b/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py index 1d04c0aa0..d48ebfec6 100644 --- a/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py @@ -206,7 +206,7 @@ def __init__(self, dataset_name_or_id: Union[str, int], suppress_transpose: bool = False): super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) - self.UNet_blocks_per_stage_encoder = (1, 3, 6, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9) + self.UNet_blocks_per_stage_encoder = (2, 3, 6, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9) self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) From eed0e02814c2c399d91b35bf0b792045634bd4e7 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Wed, 24 Jan 2024 07:18:25 +0100 Subject: [PATCH 08/24] update experiments --- .../generate_lsf_runs_customDecathlon.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py b/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py index 0a75fbd4d..cb4805d7a 100644 --- a/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py +++ b/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py @@ -25,14 +25,15 @@ def merge(dict1, dict2): 3: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 4: ("2d", "3d_fullres"), 17: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), - 20: ("2d", "3d_fullres"), 24: ("2d", "3d_fullres"), 27: ("2d", "3d_fullres"), 38: ("2d", "3d_fullres"), 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), - 64: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), - 82: ("2d", "3d_fullres"), - # 83: ("2d", "3d_fullres"), + 137: ("2d", "3d_fullres"), + 217: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + 221: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # kits 2023 + # amos post challenge } configurations_3d_fr_only = { @@ -52,22 +53,20 @@ def merge(dict1, dict2): } num_gpus = 1 - exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\" -R \"select[hname!='e230-dgx1-1']\" -R \"select[hname!='e230-dgxa100-1']\" -R \"select[hname!='e230-dgxa100-2']\" -R \"select[hname!='e230-dgxa100-3']\" -R \"select[hname!='e230-dgxa100-4']\"" - resources = "-R \"tensorcore\"" + exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\" -R \"select[hname!='e230-dgx1-1']\"" + resources = "" gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmem=33G" - queue = "-q gpu-lowprio" - preamble = "-L /bin/bash \"source ~/load_env_cluster4.sh && " - train_command = 'nnUNet_results=/dkfz/cluster/gpu/checkpoints/OE0441/isensee/nnUNet_results_remake_release nnUNetv2_train' + queue = "-q test.dgx" + preamble = "-L /bin/bash \"source ~/load_env_mamba_slumber.sh && " + train_command = 'nnUNetv2_train' folds = (0, ) # use_this = configurations_2d_only - use_this = merge(configurations_3d_fr_only, configurations_3d_lr_only) + use_this = configurations_3d_fr_only # use_this = merge(use_this, configurations_3d_c_only) use_these_modules = { - 'nnUNetTrainer': ('nnUNetPlans',), - 'nnUNetTrainerDiceCELoss_noSmooth': ('nnUNetPlans',), - # 'nnUNetTrainer_DASegOrd0': ('nnUNetPlans',), + 'nnUNetTrainer': ('nnUNetPlans', 'nnUNetResEncUNetPlans', 'nnUNetResEncUNet2Plans', 'nnUNetResBottleneckEncUNetPlans', 'nnUNetResUNetPlans', 'nnUNetResUNet2Plans', 'nnUNetResUNet3Plans', 'nnUNetDeeperResBottleneckEncUNetPlans'), } additional_arguments = f'--disable_checkpointing -num_gpus {num_gpus}' # '' From 84fd95c38743530ef10d422548288348c528cf83 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Thu, 25 Jan 2024 07:34:37 +0100 Subject: [PATCH 09/24] update experiments --- nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py b/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py index cb4805d7a..4f98521fb 100644 --- a/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py +++ b/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py @@ -30,10 +30,9 @@ def merge(dict1, dict2): 38: ("2d", "3d_fullres"), 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 137: ("2d", "3d_fullres"), - 217: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + 220: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 221: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), - # kits 2023 - # amos post challenge + 223: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), } configurations_3d_fr_only = { From 8f63c4f5f83c3766301c63fd8a288a653fcd6d1f Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Thu, 25 Jan 2024 08:52:22 +0100 Subject: [PATCH 10/24] update experiments --- nnunetv2/batch_running/collect_results_custom_Decathlon.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nnunetv2/batch_running/collect_results_custom_Decathlon.py b/nnunetv2/batch_running/collect_results_custom_Decathlon.py index f52cb6e3d..77e7dfb35 100644 --- a/nnunetv2/batch_running/collect_results_custom_Decathlon.py +++ b/nnunetv2/batch_running/collect_results_custom_Decathlon.py @@ -97,16 +97,16 @@ def summarize(input_file, output_file, folds: Tuple[int, ...], configs: Tuple[st 'nnUNetTrainer': ('nnUNetPlans', 'nnUNetResEncUNetPlans', 'nnUNetResEncUNet2Plans', 'nnUNetResBottleneckEncUNetPlans', 'nnUNetResUNetPlans', 'nnUNetResUNet2Plans', 'nnUNetResUNet3Plans', 'nnUNetDeeperResBottleneckEncUNetPlans'), } all_results_file= join(nnUNet_results, 'customDecResults.csv') - datasets = [2, 3, 4, 17, 24, 27, 38, 55, 137, 217, 221] # amos post challenge, kits2023 + datasets = [2, 3, 4, 17, 24, 27, 38, 55, 137, 217, 220, 221, 223] # amos post challenge, kits2023 collect_results(use_these_trainers, datasets, all_results_file) folds = (0, 1, 2, 3, 4) - configs = ("3d_fullres", "3d_lowres") + configs = ("3d_fullres", ) output_file = join(nnUNet_results, 'customDecResults_summary5fold.csv') summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) folds = (0, ) - configs = ("3d_fullres", "3d_lowres") + configs = ("3d_fullres", ) output_file = join(nnUNet_results, 'customDecResults_summaryfold0.csv') summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) From 83b938e2d71f4606a3aa12c21604ef78b4196897 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Tue, 30 Jan 2024 13:27:27 +0100 Subject: [PATCH 11/24] move residual unet planners ina new folder --- .../ResEncUNetBottleneck_planner.py} | 2 +- .../ResEncUNet_planner.py} | 0 .../{resUNet_planner.py => residual_unets/ResUNet_planner.py} | 0 .../{resUNet_planner2.py => residual_unets/ResUNet_planner2.py} | 2 +- .../{resUNet_planner3.py => residual_unets/ResUNet_planner3.py} | 2 +- .../experiment_planners/residual_unets/__init__.py | 0 6 files changed, 3 insertions(+), 3 deletions(-) rename nnunetv2/experiment_planning/experiment_planners/{resencUNetBottleneck_planner.py => residual_unets/ResEncUNetBottleneck_planner.py} (99%) rename nnunetv2/experiment_planning/experiment_planners/{resencUNet_planner.py => residual_unets/ResEncUNet_planner.py} (100%) rename nnunetv2/experiment_planning/experiment_planners/{resUNet_planner.py => residual_unets/ResUNet_planner.py} (100%) rename nnunetv2/experiment_planning/experiment_planners/{resUNet_planner2.py => residual_unets/ResUNet_planner2.py} (87%) rename nnunetv2/experiment_planning/experiment_planners/{resUNet_planner3.py => residual_unets/ResUNet_planner3.py} (99%) create mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets/__init__.py diff --git a/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNetBottleneck_planner.py similarity index 99% rename from nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py rename to nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNetBottleneck_planner.py index d48ebfec6..0cfed5f88 100644 --- a/nnunetv2/experiment_planning/experiment_planners/resencUNetBottleneck_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNetBottleneck_planner.py @@ -8,7 +8,7 @@ from torch import nn from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props -from nnunetv2.experiment_planning.experiment_planners.resencUNet_planner import ResEncUNetPlanner +from nnunetv2.experiment_planning.experiment_planners.residual_unets.ResEncUNet_planner import ResEncUNetPlanner class ResEncUNetBottleneckPlanner(ResEncUNetPlanner): diff --git a/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py similarity index 100% rename from nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py rename to nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py diff --git a/nnunetv2/experiment_planning/experiment_planners/resUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner.py similarity index 100% rename from nnunetv2/experiment_planning/experiment_planners/resUNet_planner.py rename to nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner.py diff --git a/nnunetv2/experiment_planning/experiment_planners/resUNet_planner2.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner2.py similarity index 87% rename from nnunetv2/experiment_planning/experiment_planners/resUNet_planner2.py rename to nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner2.py index 8cffbae77..9806dbdf7 100644 --- a/nnunetv2/experiment_planning/experiment_planners/resUNet_planner2.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner2.py @@ -1,6 +1,6 @@ from typing import Union, List, Tuple -from nnunetv2.experiment_planning.experiment_planners.resUNet_planner import ResUNetPlanner +from nnunetv2.experiment_planning.experiment_planners.residual_unets.ResUNet_planner import ResUNetPlanner class ResUNetPlanner2(ResUNetPlanner): diff --git a/nnunetv2/experiment_planning/experiment_planners/resUNet_planner3.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner3.py similarity index 99% rename from nnunetv2/experiment_planning/experiment_planners/resUNet_planner3.py rename to nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner3.py index 4c70c34ba..335dd5c45 100644 --- a/nnunetv2/experiment_planning/experiment_planners/resUNet_planner3.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner3.py @@ -5,7 +5,7 @@ from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props -from nnunetv2.experiment_planning.experiment_planners.resUNet_planner import ResUNetPlanner +from nnunetv2.experiment_planning.experiment_planners.residual_unets.ResUNet_planner import ResUNetPlanner class ResUNetPlanner3(ResUNetPlanner): diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/__init__.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/__init__.py new file mode 100644 index 000000000..e69de29bb From 18a47e419c4ba58bb78fb105a668f393b12feff7 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Tue, 30 Jan 2024 14:43:53 +0100 Subject: [PATCH 12/24] initial M L XL XLx8 planners --- .../residual_unets/ResEncUNet_planner.py | 3 +++ .../new_nnunet_presets/__init__.py | 0 .../new_nnunet_presets/nnUNetPlannerL.py | 23 +++++++++++++++++ .../new_nnunet_presets/nnUNetPlannerM.py | 23 +++++++++++++++++ .../new_nnunet_presets/nnUNetPlannerXL.py | 23 +++++++++++++++++ .../new_nnunet_presets/nnUNetPlannerXLx8.py | 25 +++++++++++++++++++ 6 files changed, 97 insertions(+) create mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/__init__.py create mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerL.py create mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py create mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py create mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXLx8.py diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py index 29a539333..adc215f5f 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py @@ -216,6 +216,9 @@ def _keygen(patch_size, strides): class ResEncUNetPlanner2(ResEncUNetPlanner): + """ + Same as nnUNetPlannerM (nnUNetPlannerM was built from this) + """ def __init__(self, dataset_name_or_id: Union[str, int], gpu_memory_target_in_gb: float = 8, preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNet2Plans', diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/__init__.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerL.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerL.py new file mode 100644 index 000000000..1e9782bdb --- /dev/null +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerL.py @@ -0,0 +1,23 @@ +from typing import Union, List, Tuple + +from dynamic_network_architectures.architectures.residual_unet import ResidualEncoderUNet + +from nnunetv2.experiment_planning.experiment_planners.residual_unets.ResEncUNet_planner import ResEncUNetPlanner + + +class nnUNetPlannerL(ResEncUNetPlanner): + """ + Target is ~24 GB VRAM max -> RTX 4090, Titan RTX, Quadro 6000 + """ + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 8, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + self.UNet_class = ResidualEncoderUNet + # this is supposed to give the same GPU memory requirement as the default nnU-Net + self.UNet_reference_val_3d = 2100000000 # 1840000000 + self.UNet_reference_val_2d = 403000000 # 352666667 + diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py new file mode 100644 index 000000000..207935539 --- /dev/null +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py @@ -0,0 +1,23 @@ +from typing import Union, List, Tuple + +from dynamic_network_architectures.architectures.residual_unet import ResidualEncoderUNet + +from nnunetv2.experiment_planning.experiment_planners.residual_unets.ResEncUNet_planner import ResEncUNetPlanner + + +class nnUNetPlannerM(ResEncUNetPlanner): + """ + Target is ~9-11 GB VRAM max -> older Titan, RTX 2080ti + """ + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 8, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetMPlans', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + self.UNet_class = ResidualEncoderUNet + # this is supposed to give the same GPU memory requirement as the default nnU-Net + self.UNet_reference_val_3d = 600000000 + self.UNet_reference_val_2d = 115000000 + diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py new file mode 100644 index 000000000..05c9f6742 --- /dev/null +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py @@ -0,0 +1,23 @@ +from typing import Union, List, Tuple + +from dynamic_network_architectures.architectures.residual_unet import ResidualEncoderUNet + +from nnunetv2.experiment_planning.experiment_planners.residual_unets.ResEncUNet_planner import ResEncUNetPlanner + + +class nnUNetPlannerXL(ResEncUNetPlanner): + """ + Target is 40 GB VRAM max -> A100 40GB, RTX 6000 Ada Generation + """ + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 8, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNet2Plans', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + self.UNet_class = ResidualEncoderUNet + # this is supposed to give the same GPU memory requirement as the default nnU-Net + self.UNet_reference_val_3d = 4500000000 + self.UNet_reference_val_2d = 250000000 + diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXLx8.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXLx8.py new file mode 100644 index 000000000..f91e26623 --- /dev/null +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXLx8.py @@ -0,0 +1,25 @@ +from typing import Union, List, Tuple + +from nnunetv2.experiment_planning.experiment_planners.residual_unets.new_nnunet_presets.nnUNetPlannerXL import \ + nnUNetPlannerXL + + +class nnUNetPlannerXLx8(nnUNetPlannerXL): + """ + Target is 8*40 GB VRAM max -> 8xA100 40GB or 4*A100 80GB + """ + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 8, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetXLx8Plans', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + + def plan_experiment(self): + super(nnUNetPlannerXLx8, self).plan_experiment() + for configuration in ['2d', '3d_fullres', '3d_lowres']: + if configuration in self.plans['configurations']: + self.plans['configurations'][configuration]['batch_size'] *= 8 + self.save_plans(self.plans) + return self.plans From 4c08a5f3fe5355a393fc0b7cc54deb30373c7d34 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Tue, 30 Jan 2024 14:58:34 +0100 Subject: [PATCH 13/24] XLx8 default name fix --- .../residual_unets/new_nnunet_presets/nnUNetPlannerXL.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py index 05c9f6742..5761765a0 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py @@ -11,7 +11,7 @@ class nnUNetPlannerXL(ResEncUNetPlanner): """ def __init__(self, dataset_name_or_id: Union[str, int], gpu_memory_target_in_gb: float = 8, - preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNet2Plans', + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetXLPlans', overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, suppress_transpose: bool = False): super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, From c8d50010338eaf18fa82b96a84352fa97a84619b Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Tue, 30 Jan 2024 17:27:51 +0100 Subject: [PATCH 14/24] planners --- .../new_nnunet_presets/nnUNetPlannerL.py | 10 +++++++--- .../new_nnunet_presets/nnUNetPlannerM.py | 7 ++++++- .../new_nnunet_presets/nnUNetPlannerXL.py | 13 +++++++++---- .../new_nnunet_presets/nnUNetPlannerXLx8.py | 5 ++++- 4 files changed, 26 insertions(+), 9 deletions(-) diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerL.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerL.py index 1e9782bdb..2daf5a644 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerL.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerL.py @@ -10,14 +10,18 @@ class nnUNetPlannerL(ResEncUNetPlanner): Target is ~24 GB VRAM max -> RTX 4090, Titan RTX, Quadro 6000 """ def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 8, + gpu_memory_target_in_gb: float = 24, preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans', overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, suppress_transpose: bool = False): + gpu_memory_target_in_gb = 24 super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) self.UNet_class = ResidualEncoderUNet - # this is supposed to give the same GPU memory requirement as the default nnU-Net + + self.UNet_vram_target_GB = gpu_memory_target_in_gb + self.UNet_reference_val_corresp_GB = 24 + self.UNet_reference_val_3d = 2100000000 # 1840000000 - self.UNet_reference_val_2d = 403000000 # 352666667 + self.UNet_reference_val_2d = 380000000 # 352666667 diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py index 207935539..5127b99ef 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py @@ -14,10 +14,15 @@ def __init__(self, dataset_name_or_id: Union[str, int], preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetMPlans', overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, suppress_transpose: bool = False): + gpu_memory_target_in_gb = 8 super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) self.UNet_class = ResidualEncoderUNet + + self.UNet_vram_target_GB = gpu_memory_target_in_gb + self.UNet_reference_val_corresp_GB = 8 + # this is supposed to give the same GPU memory requirement as the default nnU-Net self.UNet_reference_val_3d = 600000000 - self.UNet_reference_val_2d = 115000000 + self.UNet_reference_val_2d = 133000000 diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py index 5761765a0..df5a81ee8 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py @@ -10,14 +10,19 @@ class nnUNetPlannerXL(ResEncUNetPlanner): Target is 40 GB VRAM max -> A100 40GB, RTX 6000 Ada Generation """ def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 8, + gpu_memory_target_in_gb: float = 40, preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetXLPlans', overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, suppress_transpose: bool = False): + gpu_memory_target_in_gb = 40 super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) self.UNet_class = ResidualEncoderUNet - # this is supposed to give the same GPU memory requirement as the default nnU-Net - self.UNet_reference_val_3d = 4500000000 - self.UNet_reference_val_2d = 250000000 + + self.UNet_vram_target_GB = gpu_memory_target_in_gb + self.UNet_reference_val_corresp_GB = 40 + + self.UNet_reference_val_3d = 3600000000 + self.UNet_reference_val_2d = 560000000 + diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXLx8.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXLx8.py index f91e26623..c8b1d7631 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXLx8.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXLx8.py @@ -9,14 +9,17 @@ class nnUNetPlannerXLx8(nnUNetPlannerXL): Target is 8*40 GB VRAM max -> 8xA100 40GB or 4*A100 80GB """ def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 8, + gpu_memory_target_in_gb: float = 40, # this needs to be 40 as we lan for the same size per GPU as XL preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetXLx8Plans', overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, suppress_transpose: bool = False): + gpu_memory_target_in_gb = 40 super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) def plan_experiment(self): + print('DO NOT TRUST ANY PRINTED PLANS AS THE BATCH SIZE WILL NOT YET HAVE BEEN INCREASED! FINAL BATCH SIZE IS ' + '8x OF WHAT YOU SEE') super(nnUNetPlannerXLx8, self).plan_experiment() for configuration in ['2d', '3d_fullres', '3d_lowres']: if configuration in self.plans['configurations']: From 610ee8ce6a7acd16a335e9166dad95fbfaa5e5dd Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Thu, 1 Feb 2024 07:27:00 +0100 Subject: [PATCH 15/24] stuff --- .../generate_lsf_runs_customDecathlon.py | 16 ++++++++-------- .../new_nnunet_presets/nnUNetPlannerM.py | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py b/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py index 4f98521fb..3b8840bf3 100644 --- a/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py +++ b/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py @@ -21,14 +21,14 @@ def merge(dict1, dict2): # after the Nature Methods paper we switch our evaluation to a different (more stable/high quality) set of # datasets for evaluation and future development configurations_all = { - 2: ("3d_fullres", "2d"), - 3: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), - 4: ("2d", "3d_fullres"), + # 2: ("3d_fullres", "2d"), + # 3: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 4: ("2d", "3d_fullres"), 17: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), - 24: ("2d", "3d_fullres"), - 27: ("2d", "3d_fullres"), - 38: ("2d", "3d_fullres"), - 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 24: ("2d", "3d_fullres"), + # 27: ("2d", "3d_fullres"), + # 38: ("2d", "3d_fullres"), + # 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 137: ("2d", "3d_fullres"), 220: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 221: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), @@ -65,7 +65,7 @@ def merge(dict1, dict2): # use_this = merge(use_this, configurations_3d_c_only) use_these_modules = { - 'nnUNetTrainer': ('nnUNetPlans', 'nnUNetResEncUNetPlans', 'nnUNetResEncUNet2Plans', 'nnUNetResBottleneckEncUNetPlans', 'nnUNetResUNetPlans', 'nnUNetResUNet2Plans', 'nnUNetResUNet3Plans', 'nnUNetDeeperResBottleneckEncUNetPlans'), + 'nnUNetTrainer': ('nnUNetResEncUNetMPlans', 'nnUNetResEncUNetLPlans', 'nnUNetResEncUNetXLPlans', 'nnUNetResEncUNetXLx8Plans'), } additional_arguments = f'--disable_checkpointing -num_gpus {num_gpus}' # '' diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py index 5127b99ef..9096f15e3 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py @@ -23,6 +23,6 @@ def __init__(self, dataset_name_or_id: Union[str, int], self.UNet_reference_val_corresp_GB = 8 # this is supposed to give the same GPU memory requirement as the default nnU-Net - self.UNet_reference_val_3d = 600000000 - self.UNet_reference_val_2d = 133000000 + self.UNet_reference_val_3d = 680000000 + self.UNet_reference_val_2d = 135000000 From 9795b0ef2730ac7cde78d5dda0917afd4b7c9aeb Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Fri, 2 Feb 2024 10:30:16 +0100 Subject: [PATCH 16/24] add morefilt variants --- .../default_experiment_planner.py | 2 +- .../ResEncUNetBottleneck_planner.py | 2 +- .../residual_unets/ResEncUNet_planner.py | 2 +- .../residual_unets/ResUNet_planner.py | 2 +- .../residual_unets/ResUNet_planner3.py | 2 +- .../residual_unets_moreFilt/__init__.py | 0 .../nnUNetPlannerLmoreFilt.py | 29 ++++++++++++++++++ .../nnUNetPlannerXLmoreFilt.py | 30 +++++++++++++++++++ .../nnUNetPlannerXLx8moreFilt.py | 28 +++++++++++++++++ 9 files changed, 92 insertions(+), 5 deletions(-) create mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/__init__.py create mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerLmoreFilt.py create mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLmoreFilt.py create mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLx8moreFilt.py diff --git a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py index d3b874d43..f578af6d4 100644 --- a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py @@ -230,7 +230,7 @@ def get_plans_for_configuration(self, approximate_n_voxels_dataset: float, _cache: dict) -> dict: def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: - return tuple([min(max_num_features, self.UNet_reference_com_nfeatures * 2 ** i) for + return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for i in range(num_stages)]) def _keygen(patch_size, strides): diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNetBottleneck_planner.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNetBottleneck_planner.py index 0cfed5f88..cfa36ba5b 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNetBottleneck_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNetBottleneck_planner.py @@ -27,7 +27,7 @@ def get_plans_for_configuration(self, approximate_n_voxels_dataset: float, _cache: dict) -> dict: def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: - return tuple([min(max_num_features, self.UNet_reference_com_nfeatures * 2 ** i) for + return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for i in range(num_stages)]) def _keygen(patch_size, strides): diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py index adc215f5f..c4f7993bd 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py @@ -46,7 +46,7 @@ def get_plans_for_configuration(self, approximate_n_voxels_dataset: float, _cache: dict) -> dict: def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: - return tuple([min(max_num_features, self.UNet_reference_com_nfeatures * 2 ** i) for + return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for i in range(num_stages)]) def _keygen(patch_size, strides): diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner.py index a26bcc70f..981c3eef5 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner.py @@ -45,7 +45,7 @@ def get_plans_for_configuration(self, approximate_n_voxels_dataset: float, _cache: dict) -> dict: def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: - return tuple([min(max_num_features, self.UNet_reference_com_nfeatures * 2 ** i) for + return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for i in range(num_stages)]) def _keygen(patch_size, strides): diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner3.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner3.py index 335dd5c45..1922d2f52 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner3.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner3.py @@ -27,7 +27,7 @@ def get_plans_for_configuration(self, approximate_n_voxels_dataset: float, _cache: dict) -> dict: def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: - return tuple([min(max_num_features, self.UNet_reference_com_nfeatures * 2 ** i) for + return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for i in range(num_stages)]) def _keygen(patch_size, strides): diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/__init__.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerLmoreFilt.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerLmoreFilt.py new file mode 100644 index 000000000..d1c2936ed --- /dev/null +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerLmoreFilt.py @@ -0,0 +1,29 @@ +from typing import Union, List, Tuple + +from dynamic_network_architectures.architectures.residual_unet import ResidualEncoderUNet + +from nnunetv2.experiment_planning.experiment_planners.residual_unets.ResEncUNet_planner import ResEncUNetPlanner + + +class nnUNetPlannerLmoreFilt(ResEncUNetPlanner): + """ + Target is ~24 GB VRAM max -> RTX 4090, Titan RTX, Quadro 6000 + """ + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 24, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLmoreFiltPlans', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + gpu_memory_target_in_gb = 24 + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + self.UNet_class = ResidualEncoderUNet + + self.UNet_vram_target_GB = gpu_memory_target_in_gb + self.UNet_reference_val_corresp_GB = 24 + self.UNet_base_num_features = 48 + self.UNet_max_features_3d = self.UNet_base_num_features * 2 ** 4 + + self.UNet_reference_val_3d = 1900000000 # 1840000000 + self.UNet_reference_val_2d = 370000000 # 352666667 + diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLmoreFilt.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLmoreFilt.py new file mode 100644 index 000000000..551d4c86e --- /dev/null +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLmoreFilt.py @@ -0,0 +1,30 @@ +from typing import Union, List, Tuple + +from dynamic_network_architectures.architectures.residual_unet import ResidualEncoderUNet + +from nnunetv2.experiment_planning.experiment_planners.residual_unets.ResEncUNet_planner import ResEncUNetPlanner + + +class nnUNetPlannerXLmoreFilt(ResEncUNetPlanner): + """ + Target is 40 GB VRAM max -> A100 40GB, RTX 6000 Ada Generation + """ + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 40, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetXLmoreFiltPlans', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + gpu_memory_target_in_gb = 40 + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + self.UNet_class = ResidualEncoderUNet + + self.UNet_vram_target_GB = gpu_memory_target_in_gb + self.UNet_reference_val_corresp_GB = 40 + self.UNet_base_num_features = 64 + self.UNet_max_features_3d = self.UNet_base_num_features * 2 ** 4 + + self.UNet_reference_val_3d = 3200000000 + self.UNet_reference_val_2d = 540000000 + + diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLx8moreFilt.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLx8moreFilt.py new file mode 100644 index 000000000..86c83dc8d --- /dev/null +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLx8moreFilt.py @@ -0,0 +1,28 @@ +from typing import Union, List, Tuple + +from nnunetv2.experiment_planning.experiment_planners.residual_unets_moreFilt.nnUNetPlannerXLmoreFilt import \ + nnUNetPlannerXLmoreFilt + + +class nnUNetPlannerXLx8moreFilt(nnUNetPlannerXLmoreFilt): + """ + Target is 8*40 GB VRAM max -> 8xA100 40GB or 4*A100 80GB + """ + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 40, # this needs to be 40 as we lan for the same size per GPU as XL + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetXLx8moreFiltPlans', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + gpu_memory_target_in_gb = 40 + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + + def plan_experiment(self): + print('DO NOT TRUST ANY PRINTED PLANS AS THE BATCH SIZE WILL NOT YET HAVE BEEN INCREASED! FINAL BATCH SIZE IS ' + '8x OF WHAT YOU SEE') + super(nnUNetPlannerXLmoreFilt, self).plan_experiment() + for configuration in ['2d', '3d_fullres', '3d_lowres']: + if configuration in self.plans['configurations']: + self.plans['configurations'][configuration]['batch_size'] *= 8 + self.save_plans(self.plans) + return self.plans From 2a8d247085993334ef537fb0d64b890d1c55070a Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Mon, 5 Feb 2024 07:15:53 +0100 Subject: [PATCH 17/24] moreFilt --- .../residual_unets_moreFilt/nnUNetPlannerXLmoreFilt.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLmoreFilt.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLmoreFilt.py index 551d4c86e..fbdd0e6d7 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLmoreFilt.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLmoreFilt.py @@ -24,7 +24,6 @@ def __init__(self, dataset_name_or_id: Union[str, int], self.UNet_base_num_features = 64 self.UNet_max_features_3d = self.UNet_base_num_features * 2 ** 4 - self.UNet_reference_val_3d = 3200000000 + self.UNet_reference_val_3d = 3100000000 self.UNet_reference_val_2d = 540000000 - From ac9a239ebbec35fb026c98f1ebb9af57c9f19203 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Wed, 7 Feb 2024 11:03:47 +0100 Subject: [PATCH 18/24] backwards compatibility wip --- .../utilities/plans_handling/plans_handler.py | 60 ++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/nnunetv2/utilities/plans_handling/plans_handler.py b/nnunetv2/utilities/plans_handling/plans_handler.py index 03601817d..a94ea24b7 100644 --- a/nnunetv2/utilities/plans_handling/plans_handler.py +++ b/nnunetv2/utilities/plans_handling/plans_handler.py @@ -1,5 +1,7 @@ from __future__ import annotations +import warnings + import dynamic_network_architectures from copy import deepcopy from functools import lru_cache, partial @@ -16,9 +18,9 @@ from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.label_handling.label_handling import get_labelmanager_class_from_plans - # see https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/ from typing import TYPE_CHECKING +from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm if TYPE_CHECKING: from nnunetv2.utilities.label_handling.label_handling import LabelManager @@ -31,6 +33,62 @@ class ConfigurationManager(object): def __init__(self, configuration_dict: dict): self.configuration = configuration_dict + # backwards compatibility + if 'architecture' not in self.configuration.keys(): + warnings.warn("Detected old nnU-Net plans format. Attempting to reconstruct network architecture " + "parameters. If this fails, rerun nnUNetv2_plan_experiment for your dataset. If you use a " + "custom architecture, please downgrade nnU-Net or update your plans.") + # try to build the architecture information from old plans, modify configuration dict to match new standard + unet_class_name = self.configuration["UNet_class_name"] + if unet_class_name == "PlainConvUNet": + network_class_name = "dynamic_network_architectures.architectures.unet.PlainConvUNet" + elif unet_class_name == 'ResidualEncoderUNet': + network_class_name = "dynamic_network_architectures.architectures.residual_unet.ResidualEncoderUNet" + else: + raise RuntimeError(f'Unknown architecture {unet_class_name}. This conversion only supports ' + f'PlainConvUNet and ResidualEncoderUNet') + + n_stages = len(self.configuration["n_conv_per_stage_encoder"]) + + dim = len(self.configuration["patch_size"]) + conv_op = convert_dim_to_conv_op(dim) + instnorm = get_matching_instancenorm(dimension=dim) + + arch_dict = { + 'network_class_name': network_class_name, + 'arch_kwargs': { + "n_stages": n_stages, + "features_per_stage": [min(self.configuration["UNet_base_num_features"] * 2 ** i, + self.configuration["unet_max_num_features"]) + for i in range(n_stages)], + "conv_op": conv_op.__module__ + '.' + conv_op.__name__, + "kernel_sizes": deepcopy(self.configuration["conv_kernel_sizes"]), + "strides": deepcopy(self.configuration["pool_op_kernel_sizes"]), + "n_conv_per_stage": deepcopy(self.configuration["n_conv_per_stage_encoder"]), + "n_conv_per_stage_decoder": deepcopy(self.configuration["n_conv_per_stage_decoder"]), + "conv_bias": True, + "norm_op": "torch.nn.modules.instancenorm.InstanceNorm3d", + "norm_op_kwargs": { + "eps": 1e-05, + "affine": True + }, + "dropout_op": None, + "dropout_op_kwargs": None, + "nonlin": "torch.nn.LeakyReLU", + "nonlin_kwargs": { + "inplace": True + } + }, + # these need to be imported with locate in order to use them: + # `conv_op = pydoc.locate(architecture_kwargs['conv_op'])` + "_kw_requires_import": [ + "conv_op", + "norm_op", + "dropout_op", + "nonlin" + ] + } + def __repr__(self): return self.configuration.__repr__() From 3d1288220bc3e894eece40e811ab5bbebd0eb4dd Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Wed, 7 Feb 2024 11:08:28 +0100 Subject: [PATCH 19/24] backwards compatibility complete --- nnunetv2/utilities/plans_handling/plans_handler.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nnunetv2/utilities/plans_handling/plans_handler.py b/nnunetv2/utilities/plans_handling/plans_handler.py index a94ea24b7..079de6447 100644 --- a/nnunetv2/utilities/plans_handling/plans_handler.py +++ b/nnunetv2/utilities/plans_handling/plans_handler.py @@ -2,7 +2,6 @@ import warnings -import dynamic_network_architectures from copy import deepcopy from functools import lru_cache, partial from typing import Union, Tuple, List, Type, Callable @@ -67,7 +66,7 @@ def __init__(self, configuration_dict: dict): "n_conv_per_stage": deepcopy(self.configuration["n_conv_per_stage_encoder"]), "n_conv_per_stage_decoder": deepcopy(self.configuration["n_conv_per_stage_decoder"]), "conv_bias": True, - "norm_op": "torch.nn.modules.instancenorm.InstanceNorm3d", + "norm_op": instnorm.__module__ + '.' + instnorm.__name__, "norm_op_kwargs": { "eps": 1e-05, "affine": True @@ -88,6 +87,11 @@ def __init__(self, configuration_dict: dict): "nonlin" ] } + del self.configuration["UNet_class_name"], self.configuration["UNet_base_num_features"], \ + self.configuration["n_conv_per_stage_encoder"], self.configuration["n_conv_per_stage_decoder"], \ + self.configuration["num_pool_per_axis"], self.configuration["pool_op_kernel_sizes"],\ + self.configuration["conv_kernel_sizes"], self.configuration["unet_max_num_features"] + self.configuration["architecture"] = arch_dict def __repr__(self): return self.configuration.__repr__() From 9f8c29ca08f64776e9b5fe0af0911a2ec70cdd61 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Mon, 19 Feb 2024 13:37:17 +0100 Subject: [PATCH 20/24] acdc splits --- .../generate_lsf_runs_customDecathlon.py | 12 ++++---- .../dataset_conversion/Dataset027_ACDC.py | 29 ++++++++++++++++++- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py b/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py index 3b8840bf3..7f9726ef4 100644 --- a/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py +++ b/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py @@ -31,7 +31,7 @@ def merge(dict1, dict2): # 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 137: ("2d", "3d_fullres"), 220: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), - 221: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 221: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 223: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), } @@ -52,23 +52,23 @@ def merge(dict1, dict2): } num_gpus = 1 - exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\" -R \"select[hname!='e230-dgx1-1']\"" + exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\"" resources = "" gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmem=33G" - queue = "-q test.dgx" + queue = "-q gpu" preamble = "-L /bin/bash \"source ~/load_env_mamba_slumber.sh && " train_command = 'nnUNetv2_train' - folds = (0, ) + folds = (1, 2, 3, 4) # use_this = configurations_2d_only use_this = configurations_3d_fr_only # use_this = merge(use_this, configurations_3d_c_only) use_these_modules = { - 'nnUNetTrainer': ('nnUNetResEncUNetMPlans', 'nnUNetResEncUNetLPlans', 'nnUNetResEncUNetXLPlans', 'nnUNetResEncUNetXLx8Plans'), + 'nnUNetTrainer': ('nnUNetPlans', 'nnUNetResEncUNetMPlans', 'nnUNetResEncUNetLPlans', 'nnUNetResEncUNetXLPlans'), } - additional_arguments = f'--disable_checkpointing -num_gpus {num_gpus}' # '' + additional_arguments = f' -num_gpus {num_gpus}' # '' output_file = "/home/isensee/deleteme.txt" with open(output_file, 'w') as f: diff --git a/nnunetv2/dataset_conversion/Dataset027_ACDC.py b/nnunetv2/dataset_conversion/Dataset027_ACDC.py index 569ff6f84..8ebc251cb 100644 --- a/nnunetv2/dataset_conversion/Dataset027_ACDC.py +++ b/nnunetv2/dataset_conversion/Dataset027_ACDC.py @@ -1,9 +1,12 @@ import os import shutil from pathlib import Path +from typing import List +from batchgenerators.utilities.file_and_folder_operations import nifti_files, join, maybe_mkdir_p, save_json from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json -from nnunetv2.paths import nnUNet_raw +from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed +import numpy as np def make_out_dirs(dataset_id: int, task_name="ACDC"): @@ -22,6 +25,22 @@ def make_out_dirs(dataset_id: int, task_name="ACDC"): return out_dir, out_train_dir, out_labels_dir, out_test_dir +def create_ACDC_split(labelsTr_folder: str, seed: int = 1234) -> List[dict[str, List]]: + # labelsTr_folder = '/home/isensee/drives/gpu_data_root/OE0441/isensee/nnUNet_raw/nnUNet_raw_remake/Dataset027_ACDC/labelsTr' + nii_files = nifti_files(labelsTr_folder, join=False) + patients = np.unique([i[:len('patient000')] for i in nii_files]) + rs = np.random.RandomState(seed) + rs.shuffle(patients) + splits = [] + for fold in range(5): + val_patients = patients[fold::5] + train_patients = [i for i in patients if i not in val_patients] + val_cases = [i[:-7] for i in nii_files for j in val_patients if i.startswith(j)] + train_cases = [i[:-7] for i in nii_files for j in train_patients if i.startswith(j)] + splits.append({'train': train_cases, 'val': val_cases}) + return splits + + def copy_files(src_data_folder: Path, train_dir: Path, labels_dir: Path, test_dir: Path): """Copy files from the ACDC dataset to the nnUNet dataset folder. Returns the number of training cases.""" patients_train = sorted([f for f in (src_data_folder / "training").iterdir() if f.is_dir()]) @@ -84,4 +103,12 @@ def convert_acdc(src_data_folder: str, dataset_id=27): args = parser.parse_args() print("Converting...") convert_acdc(args.input_folder, args.dataset_id) + + dataset_name = f"Dataset{args.dataset_id:03d}_{'ACDC'}" + labelsTr = join(nnUNet_raw, dataset_name, 'labelsTr') + preprocessed_folder = join(nnUNet_preprocessed, dataset_name) + maybe_mkdir_p(preprocessed_folder) + split = create_ACDC_split(labelsTr) + save_json(split, join(preprocessed_folder, 'splits_final.json'), sort_keys=False) + print("Done!") From 2b7d4930cd60fb9abdb4d2bf161a5e0c9e31ecef Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Mon, 19 Feb 2024 13:43:19 +0100 Subject: [PATCH 21/24] max_dataset_covered in experiment planning --- .../experiment_planners/default_experiment_planner.py | 4 +++- .../residual_unets/ResEncUNetBottleneck_planner.py | 2 +- .../experiment_planners/residual_unets/ResEncUNet_planner.py | 2 +- .../experiment_planners/residual_unets/ResUNet_planner.py | 2 +- .../experiment_planners/residual_unets/ResUNet_planner3.py | 2 +- .../residual_unets/new_nnunet_presets/nnUNetPlannerL.py | 1 + .../residual_unets/new_nnunet_presets/nnUNetPlannerM.py | 1 + .../residual_unets/new_nnunet_presets/nnUNetPlannerXL.py | 1 + 8 files changed, 10 insertions(+), 5 deletions(-) diff --git a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py index 9d9c0405f..c31f9d86b 100644 --- a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py @@ -65,6 +65,8 @@ def __init__(self, dataset_name_or_id: Union[str, int], self.UNet_min_batch_size = 2 self.UNet_max_features_2d = 512 self.UNet_max_features_3d = 320 + self.max_dataset_covered = 0.05 # we limit the batch size so that no more than 5% of the dataset can be seen + # in a single forward/backward pass self.UNet_vram_target_GB = gpu_memory_target_in_gb @@ -372,7 +374,7 @@ def _keygen(patch_size, strides): # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot # go smaller than self.UNet_min_batch_size though bs_corresponding_to_5_percent = round( - approximate_n_voxels_dataset * 0.05 / np.prod(patch_size, dtype=np.float64)) + approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64)) batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNetBottleneck_planner.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNetBottleneck_planner.py index cfa36ba5b..2911543b3 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNetBottleneck_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNetBottleneck_planner.py @@ -170,7 +170,7 @@ def _keygen(patch_size, strides): # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot # go smaller than self.UNet_min_batch_size though bs_corresponding_to_5_percent = round( - approximate_n_voxels_dataset * 0.05 / np.prod(patch_size, dtype=np.float64)) + approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64)) batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py index c4f7993bd..f89713857 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py @@ -186,7 +186,7 @@ def _keygen(patch_size, strides): # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot # go smaller than self.UNet_min_batch_size though bs_corresponding_to_5_percent = round( - approximate_n_voxels_dataset * 0.05 / np.prod(patch_size, dtype=np.float64)) + approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64)) batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner.py index 981c3eef5..2edaf68e3 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner.py @@ -185,7 +185,7 @@ def _keygen(patch_size, strides): # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot # go smaller than self.UNet_min_batch_size though bs_corresponding_to_5_percent = round( - approximate_n_voxels_dataset * 0.05 / np.prod(patch_size, dtype=np.float64)) + approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64)) batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner3.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner3.py index 1922d2f52..d0d5408b0 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner3.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner3.py @@ -167,7 +167,7 @@ def _keygen(patch_size, strides): # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot # go smaller than self.UNet_min_batch_size though bs_corresponding_to_5_percent = round( - approximate_n_voxels_dataset * 0.05 / np.prod(patch_size, dtype=np.float64)) + approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64)) batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerL.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerL.py index 2daf5a644..2001ed69f 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerL.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerL.py @@ -24,4 +24,5 @@ def __init__(self, dataset_name_or_id: Union[str, int], self.UNet_reference_val_3d = 2100000000 # 1840000000 self.UNet_reference_val_2d = 380000000 # 352666667 + self.max_dataset_covered = 1 diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py index 9096f15e3..d7b4b87af 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py @@ -25,4 +25,5 @@ def __init__(self, dataset_name_or_id: Union[str, int], # this is supposed to give the same GPU memory requirement as the default nnU-Net self.UNet_reference_val_3d = 680000000 self.UNet_reference_val_2d = 135000000 + self.max_dataset_covered = 1 diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py index df5a81ee8..7f59ab81f 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py @@ -24,5 +24,6 @@ def __init__(self, dataset_name_or_id: Union[str, int], self.UNet_reference_val_3d = 3600000000 self.UNet_reference_val_2d = 560000000 + self.max_dataset_covered = 1 From f13b8697029094fc8b6f435f23690d5f938989a5 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Mon, 19 Feb 2024 18:51:16 +0100 Subject: [PATCH 22/24] cleanup --- .../ResEncUNetBottleneck_planner.py | 230 ---------------- .../residual_unets/ResEncUNet_planner.py | 252 ------------------ .../residual_unets/ResUNet_planner.py | 214 --------------- .../residual_unets/ResUNet_planner2.py | 16 -- .../residual_unets/ResUNet_planner3.py | 196 -------------- .../residual_unets/__init__.py | 0 .../new_nnunet_presets/__init__.py | 0 .../new_nnunet_presets/nnUNetPlannerL.py | 28 -- .../new_nnunet_presets/nnUNetPlannerM.py | 29 -- .../new_nnunet_presets/nnUNetPlannerXL.py | 29 -- .../new_nnunet_presets/nnUNetPlannerXLx8.py | 28 -- .../residual_unets_moreFilt/__init__.py | 0 .../nnUNetPlannerLmoreFilt.py | 29 -- .../nnUNetPlannerXLmoreFilt.py | 29 -- .../nnUNetPlannerXLx8moreFilt.py | 28 -- 15 files changed, 1108 deletions(-) delete mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNetBottleneck_planner.py delete mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py delete mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner.py delete mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner2.py delete mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner3.py delete mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets/__init__.py delete mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/__init__.py delete mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerL.py delete mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py delete mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py delete mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXLx8.py delete mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/__init__.py delete mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerLmoreFilt.py delete mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLmoreFilt.py delete mode 100644 nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLx8moreFilt.py diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNetBottleneck_planner.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNetBottleneck_planner.py deleted file mode 100644 index 2911543b3..000000000 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNetBottleneck_planner.py +++ /dev/null @@ -1,230 +0,0 @@ -from copy import deepcopy -from typing import Union, List, Tuple - -import numpy as np -from dynamic_network_architectures.architectures.residual_unet import ResidualEncoderUNet -from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm -from dynamic_network_architectures.building_blocks.residual import BottleneckD -from torch import nn - -from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props -from nnunetv2.experiment_planning.experiment_planners.residual_unets.ResEncUNet_planner import ResEncUNetPlanner - - -class ResEncUNetBottleneckPlanner(ResEncUNetPlanner): - def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 8, - preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResBottleneckEncUNetPlans', - overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, - suppress_transpose: bool = False): - super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, - overwrite_target_spacing, suppress_transpose) - - def get_plans_for_configuration(self, - spacing: Union[np.ndarray, Tuple[float, ...], List[float]], - median_shape: Union[np.ndarray, Tuple[int, ...]], - data_identifier: str, - approximate_n_voxels_dataset: float, - _cache: dict) -> dict: - def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: - return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for - i in range(num_stages)]) - - def _keygen(patch_size, strides): - return str(patch_size) + '_' + str(strides) - - assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}" - num_input_channels = len(self.dataset_json['channel_names'].keys() - if 'channel_names' in self.dataset_json.keys() - else self.dataset_json['modality'].keys()) - max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d - unet_conv_op = convert_dim_to_conv_op(len(spacing)) - - # print(spacing, median_shape, approximate_n_voxels_dataset) - # find an initial patch size - # we first use the spacing to get an aspect ratio - tmp = 1 / np.array(spacing) - - # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same - # volume as a patch of size 256 ** 3) - # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be - # ideal because large initial patch sizes increase computation time because more iterations in the while loop - # further down may be required. - if len(spacing) == 3: - initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)] - elif len(spacing) == 2: - initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)] - else: - raise RuntimeError() - - # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that - # this is different from how nnU-Net v1 does it! - # todo patch size can still get too large because we pad the patch size to a multiple of 2**n - initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])]) - - # use that to get the network topology. Note that this changes the patch_size depending on the number of - # pooling operations (must be divisible by 2**num_pool in each axis) - network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ - shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size, - self.UNet_featuremap_min_edge_length, - 999999) - num_stages = len(pool_op_kernel_sizes) - - norm = get_matching_instancenorm(unet_conv_op) - architecture_kwargs = { - 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, - 'arch_kwargs': { - 'n_stages': num_stages, - 'features_per_stage': _features_per_stage(num_stages, max_num_features), - 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, - 'kernel_sizes': conv_kernel_sizes, - 'strides': pool_op_kernel_sizes, - 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], - 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], - 'conv_bias': True, - 'norm_op': norm.__module__ + '.' + norm.__name__, - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, - 'dropout_op_kwargs': None, - 'nonlin': 'torch.nn.LeakyReLU', - 'nonlin_kwargs': {'inplace': True}, - 'block': BottleneckD.__module__ + '.' + BottleneckD.__name__, - 'bottleneck_channels': [i // 4 for i in _features_per_stage(num_stages, max_num_features)] - }, - '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin', 'block'), - } - - # now estimate vram consumption - if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): - estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] - else: - estimate = self.static_estimate_VRAM_usage(patch_size, - num_input_channels, - len(self.dataset_json['labels'].keys()), - architecture_kwargs['network_class_name'], - architecture_kwargs['arch_kwargs'], - architecture_kwargs['_kw_requires_import'], - ) - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate - - # how large is the reference for us here (batch size etc)? - # adapt for our vram target - reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \ - (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) - - while estimate > reference: - # print(patch_size) - # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the - # aspect ratio the most (that is the largest relative to median shape) - axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1] - - # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this - # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256. - # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size - # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first - # subtract shape_must_be_divisible_by, then recompute it and then subtract the - # recomputed shape_must_be_divisible_by. Annoying. - patch_size = list(patch_size) - tmp = deepcopy(patch_size) - tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] - _, _, _, _, shape_must_be_divisible_by = \ - get_pool_and_conv_props(spacing, tmp, - self.UNet_featuremap_min_edge_length, - 999999) - patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] - - # now recompute topology - network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ - shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size, - self.UNet_featuremap_min_edge_length, - 999999) - - num_stages = len(pool_op_kernel_sizes) - architecture_kwargs['arch_kwargs'].update({ - 'n_stages': num_stages, - 'kernel_sizes': conv_kernel_sizes, - 'strides': pool_op_kernel_sizes, - 'features_per_stage': _features_per_stage(num_stages, max_num_features), - 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], - 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], - 'bottleneck_channels': [i // 4 for i in _features_per_stage(num_stages, max_num_features)] - }) - if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): - estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] - else: - estimate = self.static_estimate_VRAM_usage( - patch_size, - num_input_channels, - len(self.dataset_json['labels'].keys()), - architecture_kwargs['network_class_name'], - architecture_kwargs['arch_kwargs'], - architecture_kwargs['_kw_requires_import'], - ) - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate - - # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was - # executed. If not, additional vram headroom is used to increase batch size - ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d - batch_size = round((reference / estimate) * ref_bs) - - # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot - # go smaller than self.UNet_min_batch_size though - bs_corresponding_to_5_percent = round( - approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64)) - batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) - - resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() - resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn() - - normalization_schemes, mask_is_used_for_norm = \ - self.determine_normalization_scheme_and_whether_mask_is_used_for_norm() - - plan = { - 'data_identifier': data_identifier, - 'preprocessor_name': self.preprocessor_name, - 'batch_size': batch_size, - 'patch_size': patch_size, - 'median_image_size_in_voxels': median_shape, - 'spacing': spacing, - 'normalization_schemes': normalization_schemes, - 'use_mask_for_norm': mask_is_used_for_norm, - 'resampling_fn_data': resampling_data.__name__, - 'resampling_fn_seg': resampling_seg.__name__, - 'resampling_fn_data_kwargs': resampling_data_kwargs, - 'resampling_fn_seg_kwargs': resampling_seg_kwargs, - 'resampling_fn_probabilities': resampling_softmax.__name__, - 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs, - 'architecture': architecture_kwargs - } - return plan - -class ResEncUNetBottleneckDeeperPlanner(ResEncUNetBottleneckPlanner): - def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 8, - preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetDeeperResBottleneckEncUNetPlans', - overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, - suppress_transpose: bool = False): - super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, - overwrite_target_spacing, suppress_transpose) - self.UNet_blocks_per_stage_encoder = (2, 3, 6, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9) - self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) - - -if __name__ == '__main__': - # we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively - net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320), - conv_op=nn.Conv3d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2), - n_blocks_per_stage=(1, 3, 4, 6, 6, 6), num_classes=3, - n_conv_per_stage_decoder=(1, 1, 1, 1, 1), - conv_bias=True, norm_op=nn.InstanceNorm3d, norm_op_kwargs={}, dropout_op=None, - nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) - print(net.compute_conv_feature_map_size((128, 128, 128))) # -> 558319104. The value you see above was finetuned - # from this one to match the regular nnunetplans more closely - - net = ResidualEncoderUNet(input_channels=1, n_stages=7, features_per_stage=(32, 64, 128, 256, 512, 512, 512), - conv_op=nn.Conv2d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2, 2), - n_blocks_per_stage=(1, 3, 4, 6, 6, 6, 6), num_classes=3, - n_conv_per_stage_decoder=(1, 1, 1, 1, 1, 1), - conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={}, dropout_op=None, - nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) - print(net.compute_conv_feature_map_size((512, 512))) # -> 129793792 diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py deleted file mode 100644 index f89713857..000000000 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResEncUNet_planner.py +++ /dev/null @@ -1,252 +0,0 @@ -import numpy as np -from copy import deepcopy -from typing import Union, List, Tuple - -from dynamic_network_architectures.architectures.residual_unet import ResidualEncoderUNet -from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm -from torch import nn - -from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner - -from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props - - -class ResEncUNetPlanner(ExperimentPlanner): - def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 8, - preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetPlans', - overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, - suppress_transpose: bool = False): - super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, - overwrite_target_spacing, suppress_transpose) - self.UNet_class = ResidualEncoderUNet - # the following two numbers are really arbitrary and were set to reproduce default nnU-Net's configurations as - # much as possible - self.UNet_reference_val_3d = 680000000 - self.UNet_reference_val_2d = 135000000 - self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6) - self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) - - def generate_data_identifier(self, configuration_name: str) -> str: - """ - configurations are unique within each plans file but different plans file can have configurations with the - same name. In order to distinguish the associated data we need a data identifier that reflects not just the - config but also the plans it originates from - """ - if configuration_name == '2d' or configuration_name == '3d_fullres': - # we do not deviate from ExperimentPlanner so we can reuse its data - return 'nnUNetPlans' + '_' + configuration_name - else: - return self.plans_identifier + '_' + configuration_name - - def get_plans_for_configuration(self, - spacing: Union[np.ndarray, Tuple[float, ...], List[float]], - median_shape: Union[np.ndarray, Tuple[int, ...]], - data_identifier: str, - approximate_n_voxels_dataset: float, - _cache: dict) -> dict: - def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: - return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for - i in range(num_stages)]) - - def _keygen(patch_size, strides): - return str(patch_size) + '_' + str(strides) - - assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}" - num_input_channels = len(self.dataset_json['channel_names'].keys() - if 'channel_names' in self.dataset_json.keys() - else self.dataset_json['modality'].keys()) - max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d - unet_conv_op = convert_dim_to_conv_op(len(spacing)) - - # print(spacing, median_shape, approximate_n_voxels_dataset) - # find an initial patch size - # we first use the spacing to get an aspect ratio - tmp = 1 / np.array(spacing) - - # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same - # volume as a patch of size 256 ** 3) - # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be - # ideal because large initial patch sizes increase computation time because more iterations in the while loop - # further down may be required. - if len(spacing) == 3: - initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)] - elif len(spacing) == 2: - initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)] - else: - raise RuntimeError() - - # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that - # this is different from how nnU-Net v1 does it! - # todo patch size can still get too large because we pad the patch size to a multiple of 2**n - initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])]) - - # use that to get the network topology. Note that this changes the patch_size depending on the number of - # pooling operations (must be divisible by 2**num_pool in each axis) - network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ - shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size, - self.UNet_featuremap_min_edge_length, - 999999) - num_stages = len(pool_op_kernel_sizes) - - norm = get_matching_instancenorm(unet_conv_op) - architecture_kwargs = { - 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, - 'arch_kwargs': { - 'n_stages': num_stages, - 'features_per_stage': _features_per_stage(num_stages, max_num_features), - 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, - 'kernel_sizes': conv_kernel_sizes, - 'strides': pool_op_kernel_sizes, - 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], - 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], - 'conv_bias': True, - 'norm_op': norm.__module__ + '.' + norm.__name__, - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, - 'dropout_op_kwargs': None, - 'nonlin': 'torch.nn.LeakyReLU', - 'nonlin_kwargs': {'inplace': True}, - }, - '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), - } - - # now estimate vram consumption - if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): - estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] - else: - estimate = self.static_estimate_VRAM_usage(patch_size, - num_input_channels, - len(self.dataset_json['labels'].keys()), - architecture_kwargs['network_class_name'], - architecture_kwargs['arch_kwargs'], - architecture_kwargs['_kw_requires_import'], - ) - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate - - # how large is the reference for us here (batch size etc)? - # adapt for our vram target - reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \ - (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) - - while estimate > reference: - # print(patch_size) - # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the - # aspect ratio the most (that is the largest relative to median shape) - axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1] - - # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this - # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256. - # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size - # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first - # subtract shape_must_be_divisible_by, then recompute it and then subtract the - # recomputed shape_must_be_divisible_by. Annoying. - patch_size = list(patch_size) - tmp = deepcopy(patch_size) - tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] - _, _, _, _, shape_must_be_divisible_by = \ - get_pool_and_conv_props(spacing, tmp, - self.UNet_featuremap_min_edge_length, - 999999) - patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] - - # now recompute topology - network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ - shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size, - self.UNet_featuremap_min_edge_length, - 999999) - - num_stages = len(pool_op_kernel_sizes) - architecture_kwargs['arch_kwargs'].update({ - 'n_stages': num_stages, - 'kernel_sizes': conv_kernel_sizes, - 'strides': pool_op_kernel_sizes, - 'features_per_stage': _features_per_stage(num_stages, max_num_features), - 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], - 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], - }) - if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): - estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] - else: - estimate = self.static_estimate_VRAM_usage( - patch_size, - num_input_channels, - len(self.dataset_json['labels'].keys()), - architecture_kwargs['network_class_name'], - architecture_kwargs['arch_kwargs'], - architecture_kwargs['_kw_requires_import'], - ) - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate - - # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was - # executed. If not, additional vram headroom is used to increase batch size - ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d - batch_size = round((reference / estimate) * ref_bs) - - # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot - # go smaller than self.UNet_min_batch_size though - bs_corresponding_to_5_percent = round( - approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64)) - batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) - - resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() - resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn() - - normalization_schemes, mask_is_used_for_norm = \ - self.determine_normalization_scheme_and_whether_mask_is_used_for_norm() - - plan = { - 'data_identifier': data_identifier, - 'preprocessor_name': self.preprocessor_name, - 'batch_size': batch_size, - 'patch_size': patch_size, - 'median_image_size_in_voxels': median_shape, - 'spacing': spacing, - 'normalization_schemes': normalization_schemes, - 'use_mask_for_norm': mask_is_used_for_norm, - 'resampling_fn_data': resampling_data.__name__, - 'resampling_fn_seg': resampling_seg.__name__, - 'resampling_fn_data_kwargs': resampling_data_kwargs, - 'resampling_fn_seg_kwargs': resampling_seg_kwargs, - 'resampling_fn_probabilities': resampling_softmax.__name__, - 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs, - 'architecture': architecture_kwargs - } - return plan - - -class ResEncUNetPlanner2(ResEncUNetPlanner): - """ - Same as nnUNetPlannerM (nnUNetPlannerM was built from this) - """ - def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 8, - preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNet2Plans', - overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, - suppress_transpose: bool = False): - super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, - overwrite_target_spacing, suppress_transpose) - self.UNet_class = ResidualEncoderUNet - # this is supposed to give the same GPU memory requirement as the default nnU-Net - self.UNet_reference_val_3d = 600000000 - self.UNet_reference_val_2d = 115000000 - - -if __name__ == '__main__': - # we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively - net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320), - conv_op=nn.Conv3d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2), - n_blocks_per_stage=(1, 3, 4, 6, 6, 6), num_classes=3, - n_conv_per_stage_decoder=(1, 1, 1, 1, 1), - conv_bias=True, norm_op=nn.InstanceNorm3d, norm_op_kwargs={}, dropout_op=None, - nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) - print(net.compute_conv_feature_map_size((128, 128, 128))) # -> 558319104. The value you see above was finetuned - # from this one to match the regular nnunetplans more closely - - net = ResidualEncoderUNet(input_channels=1, n_stages=7, features_per_stage=(32, 64, 128, 256, 512, 512, 512), - conv_op=nn.Conv2d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2, 2), - n_blocks_per_stage=(1, 3, 4, 6, 6, 6, 6), num_classes=3, - n_conv_per_stage_decoder=(1, 1, 1, 1, 1, 1), - conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={}, dropout_op=None, - nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) - print(net.compute_conv_feature_map_size((512, 512))) # -> 129793792 diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner.py deleted file mode 100644 index 2edaf68e3..000000000 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner.py +++ /dev/null @@ -1,214 +0,0 @@ -from copy import deepcopy -from typing import Union, List, Tuple - -import numpy as np -from dynamic_network_architectures.architectures.residual_unet import ResidualUNet -from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm - -from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner -from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props - - -class ResUNetPlanner(ExperimentPlanner): - def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 8, - preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResUNetPlans', - overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, - suppress_transpose: bool = False): - super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, - overwrite_target_spacing, suppress_transpose) - - self.UNet_class = ResidualUNet - # the following two numbers are really arbitrary and were set to reproduce default nnU-Net's configurations as - # much as possible - self.UNet_reference_val_3d = 680000000 - self.UNet_reference_val_2d = 135000000 - self.UNet_blocks_per_stage_encoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) - self.UNet_blocks_per_stage_decoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) - - def generate_data_identifier(self, configuration_name: str) -> str: - """ - configurations are unique within each plans file but different plans file can have configurations with the - same name. In order to distinguish the associated data we need a data identifier that reflects not just the - config but also the plans it originates from - """ - if configuration_name == '2d' or configuration_name == '3d_fullres': - # we do not deviate from ExperimentPlanner so we can reuse its data - return 'nnUNetPlans' + '_' + configuration_name - else: - return self.plans_identifier + '_' + configuration_name - - def get_plans_for_configuration(self, - spacing: Union[np.ndarray, Tuple[float, ...], List[float]], - median_shape: Union[np.ndarray, Tuple[int, ...]], - data_identifier: str, - approximate_n_voxels_dataset: float, - _cache: dict) -> dict: - def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: - return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for - i in range(num_stages)]) - - def _keygen(patch_size, strides): - return str(patch_size) + '_' + str(strides) - - assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}" - num_input_channels = len(self.dataset_json['channel_names'].keys() - if 'channel_names' in self.dataset_json.keys() - else self.dataset_json['modality'].keys()) - max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d - unet_conv_op = convert_dim_to_conv_op(len(spacing)) - - # print(spacing, median_shape, approximate_n_voxels_dataset) - # find an initial patch size - # we first use the spacing to get an aspect ratio - tmp = 1 / np.array(spacing) - - # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same - # volume as a patch of size 256 ** 3) - # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be - # ideal because large initial patch sizes increase computation time because more iterations in the while loop - # further down may be required. - if len(spacing) == 3: - initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)] - elif len(spacing) == 2: - initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)] - else: - raise RuntimeError() - - # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that - # this is different from how nnU-Net v1 does it! - # todo patch size can still get too large because we pad the patch size to a multiple of 2**n - initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])]) - - # use that to get the network topology. Note that this changes the patch_size depending on the number of - # pooling operations (must be divisible by 2**num_pool in each axis) - network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ - shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size, - self.UNet_featuremap_min_edge_length, - 999999) - num_stages = len(pool_op_kernel_sizes) - - norm = get_matching_instancenorm(unet_conv_op) - architecture_kwargs = { - 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, - 'arch_kwargs': { - 'n_stages': num_stages, - 'features_per_stage': _features_per_stage(num_stages, max_num_features), - 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, - 'kernel_sizes': conv_kernel_sizes, - 'strides': pool_op_kernel_sizes, - 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], - 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], - 'conv_bias': True, - 'norm_op': norm.__module__ + '.' + norm.__name__, - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, - 'dropout_op_kwargs': None, - 'nonlin': 'torch.nn.LeakyReLU', - 'nonlin_kwargs': {'inplace': True}, - }, - '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), - } - - # now estimate vram consumption - if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): - estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] - else: - estimate = self.static_estimate_VRAM_usage(patch_size, - num_input_channels, - len(self.dataset_json['labels'].keys()), - architecture_kwargs['network_class_name'], - architecture_kwargs['arch_kwargs'], - architecture_kwargs['_kw_requires_import'], - ) - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate - - # how large is the reference for us here (batch size etc)? - # adapt for our vram target - reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \ - (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) - - while estimate > reference: - # print(patch_size) - # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the - # aspect ratio the most (that is the largest relative to median shape) - axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1] - - # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this - # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256. - # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size - # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first - # subtract shape_must_be_divisible_by, then recompute it and then subtract the - # recomputed shape_must_be_divisible_by. Annoying. - patch_size = list(patch_size) - tmp = deepcopy(patch_size) - tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] - _, _, _, _, shape_must_be_divisible_by = \ - get_pool_and_conv_props(spacing, tmp, - self.UNet_featuremap_min_edge_length, - 999999) - patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] - - # now recompute topology - network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ - shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size, - self.UNet_featuremap_min_edge_length, - 999999) - - num_stages = len(pool_op_kernel_sizes) - architecture_kwargs['arch_kwargs'].update({ - 'n_stages': num_stages, - 'kernel_sizes': conv_kernel_sizes, - 'strides': pool_op_kernel_sizes, - 'features_per_stage': _features_per_stage(num_stages, max_num_features), - 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], - 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], - }) - if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): - estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] - else: - estimate = self.static_estimate_VRAM_usage( - patch_size, - num_input_channels, - len(self.dataset_json['labels'].keys()), - architecture_kwargs['network_class_name'], - architecture_kwargs['arch_kwargs'], - architecture_kwargs['_kw_requires_import'], - ) - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate - - # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was - # executed. If not, additional vram headroom is used to increase batch size - ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d - batch_size = round((reference / estimate) * ref_bs) - - # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot - # go smaller than self.UNet_min_batch_size though - bs_corresponding_to_5_percent = round( - approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64)) - batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) - - resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() - resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn() - - normalization_schemes, mask_is_used_for_norm = \ - self.determine_normalization_scheme_and_whether_mask_is_used_for_norm() - - plan = { - 'data_identifier': data_identifier, - 'preprocessor_name': self.preprocessor_name, - 'batch_size': batch_size, - 'patch_size': patch_size, - 'median_image_size_in_voxels': median_shape, - 'spacing': spacing, - 'normalization_schemes': normalization_schemes, - 'use_mask_for_norm': mask_is_used_for_norm, - 'resampling_fn_data': resampling_data.__name__, - 'resampling_fn_seg': resampling_seg.__name__, - 'resampling_fn_data_kwargs': resampling_data_kwargs, - 'resampling_fn_seg_kwargs': resampling_seg_kwargs, - 'resampling_fn_probabilities': resampling_softmax.__name__, - 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs, - 'architecture': architecture_kwargs - } - return plan diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner2.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner2.py deleted file mode 100644 index 9806dbdf7..000000000 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner2.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Union, List, Tuple - -from nnunetv2.experiment_planning.experiment_planners.residual_unets.ResUNet_planner import ResUNetPlanner - - -class ResUNetPlanner2(ResUNetPlanner): - def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 8, - preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResUNet2Plans', - overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, - suppress_transpose: bool = False): - super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, - overwrite_target_spacing, suppress_transpose) - - self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6) - self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner3.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner3.py deleted file mode 100644 index d0d5408b0..000000000 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/ResUNet_planner3.py +++ /dev/null @@ -1,196 +0,0 @@ -from copy import deepcopy -from typing import Union, List, Tuple - -import numpy as np -from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm - -from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props -from nnunetv2.experiment_planning.experiment_planners.residual_unets.ResUNet_planner import ResUNetPlanner - - -class ResUNetPlanner3(ResUNetPlanner): - def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 8, - preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResUNet3Plans', - overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, - suppress_transpose: bool = False): - super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, - overwrite_target_spacing, suppress_transpose) - - self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6) - self.UNet_blocks_per_stage_decoder = None - - def get_plans_for_configuration(self, - spacing: Union[np.ndarray, Tuple[float, ...], List[float]], - median_shape: Union[np.ndarray, Tuple[int, ...]], - data_identifier: str, - approximate_n_voxels_dataset: float, - _cache: dict) -> dict: - def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: - return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for - i in range(num_stages)]) - - def _keygen(patch_size, strides): - return str(patch_size) + '_' + str(strides) - - assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}" - num_input_channels = len(self.dataset_json['channel_names'].keys() - if 'channel_names' in self.dataset_json.keys() - else self.dataset_json['modality'].keys()) - max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d - unet_conv_op = convert_dim_to_conv_op(len(spacing)) - - # print(spacing, median_shape, approximate_n_voxels_dataset) - # find an initial patch size - # we first use the spacing to get an aspect ratio - tmp = 1 / np.array(spacing) - - # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same - # volume as a patch of size 256 ** 3) - # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be - # ideal because large initial patch sizes increase computation time because more iterations in the while loop - # further down may be required. - if len(spacing) == 3: - initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)] - elif len(spacing) == 2: - initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)] - else: - raise RuntimeError() - - # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that - # this is different from how nnU-Net v1 does it! - # todo patch size can still get too large because we pad the patch size to a multiple of 2**n - initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])]) - - # use that to get the network topology. Note that this changes the patch_size depending on the number of - # pooling operations (must be divisible by 2**num_pool in each axis) - network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ - shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size, - self.UNet_featuremap_min_edge_length, - 999999) - num_stages = len(pool_op_kernel_sizes) - - norm = get_matching_instancenorm(unet_conv_op) - architecture_kwargs = { - 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, - 'arch_kwargs': { - 'n_stages': num_stages, - 'features_per_stage': _features_per_stage(num_stages, max_num_features), - 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, - 'kernel_sizes': conv_kernel_sizes, - 'strides': pool_op_kernel_sizes, - 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], - 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_encoder[:num_stages - 1][::-1], - 'conv_bias': True, - 'norm_op': norm.__module__ + '.' + norm.__name__, - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, - 'dropout_op_kwargs': None, - 'nonlin': 'torch.nn.LeakyReLU', - 'nonlin_kwargs': {'inplace': True}, - }, - '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), - } - - # now estimate vram consumption - if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): - estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] - else: - estimate = self.static_estimate_VRAM_usage(patch_size, - num_input_channels, - len(self.dataset_json['labels'].keys()), - architecture_kwargs['network_class_name'], - architecture_kwargs['arch_kwargs'], - architecture_kwargs['_kw_requires_import'], - ) - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate - - # how large is the reference for us here (batch size etc)? - # adapt for our vram target - reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \ - (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) - - while estimate > reference: - # print(patch_size) - # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the - # aspect ratio the most (that is the largest relative to median shape) - axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1] - - # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this - # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256. - # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size - # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first - # subtract shape_must_be_divisible_by, then recompute it and then subtract the - # recomputed shape_must_be_divisible_by. Annoying. - patch_size = list(patch_size) - tmp = deepcopy(patch_size) - tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] - _, _, _, _, shape_must_be_divisible_by = \ - get_pool_and_conv_props(spacing, tmp, - self.UNet_featuremap_min_edge_length, - 999999) - patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] - - # now recompute topology - network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ - shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size, - self.UNet_featuremap_min_edge_length, - 999999) - - num_stages = len(pool_op_kernel_sizes) - architecture_kwargs['arch_kwargs'].update({ - 'n_stages': num_stages, - 'kernel_sizes': conv_kernel_sizes, - 'strides': pool_op_kernel_sizes, - 'features_per_stage': _features_per_stage(num_stages, max_num_features), - 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], - 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_encoder[:num_stages - 1][::-1], - }) - if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): - estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] - else: - estimate = self.static_estimate_VRAM_usage( - patch_size, - num_input_channels, - len(self.dataset_json['labels'].keys()), - architecture_kwargs['network_class_name'], - architecture_kwargs['arch_kwargs'], - architecture_kwargs['_kw_requires_import'], - ) - _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate - - # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was - # executed. If not, additional vram headroom is used to increase batch size - ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d - batch_size = round((reference / estimate) * ref_bs) - - # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot - # go smaller than self.UNet_min_batch_size though - bs_corresponding_to_5_percent = round( - approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64)) - batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) - - resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() - resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn() - - normalization_schemes, mask_is_used_for_norm = \ - self.determine_normalization_scheme_and_whether_mask_is_used_for_norm() - - plan = { - 'data_identifier': data_identifier, - 'preprocessor_name': self.preprocessor_name, - 'batch_size': batch_size, - 'patch_size': patch_size, - 'median_image_size_in_voxels': median_shape, - 'spacing': spacing, - 'normalization_schemes': normalization_schemes, - 'use_mask_for_norm': mask_is_used_for_norm, - 'resampling_fn_data': resampling_data.__name__, - 'resampling_fn_seg': resampling_seg.__name__, - 'resampling_fn_data_kwargs': resampling_data_kwargs, - 'resampling_fn_seg_kwargs': resampling_seg_kwargs, - 'resampling_fn_probabilities': resampling_softmax.__name__, - 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs, - 'architecture': architecture_kwargs - } - return plan diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/__init__.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/__init__.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerL.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerL.py deleted file mode 100644 index 2001ed69f..000000000 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerL.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Union, List, Tuple - -from dynamic_network_architectures.architectures.residual_unet import ResidualEncoderUNet - -from nnunetv2.experiment_planning.experiment_planners.residual_unets.ResEncUNet_planner import ResEncUNetPlanner - - -class nnUNetPlannerL(ResEncUNetPlanner): - """ - Target is ~24 GB VRAM max -> RTX 4090, Titan RTX, Quadro 6000 - """ - def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 24, - preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans', - overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, - suppress_transpose: bool = False): - gpu_memory_target_in_gb = 24 - super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, - overwrite_target_spacing, suppress_transpose) - self.UNet_class = ResidualEncoderUNet - - self.UNet_vram_target_GB = gpu_memory_target_in_gb - self.UNet_reference_val_corresp_GB = 24 - - self.UNet_reference_val_3d = 2100000000 # 1840000000 - self.UNet_reference_val_2d = 380000000 # 352666667 - self.max_dataset_covered = 1 - diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py deleted file mode 100644 index d7b4b87af..000000000 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerM.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Union, List, Tuple - -from dynamic_network_architectures.architectures.residual_unet import ResidualEncoderUNet - -from nnunetv2.experiment_planning.experiment_planners.residual_unets.ResEncUNet_planner import ResEncUNetPlanner - - -class nnUNetPlannerM(ResEncUNetPlanner): - """ - Target is ~9-11 GB VRAM max -> older Titan, RTX 2080ti - """ - def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 8, - preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetMPlans', - overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, - suppress_transpose: bool = False): - gpu_memory_target_in_gb = 8 - super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, - overwrite_target_spacing, suppress_transpose) - self.UNet_class = ResidualEncoderUNet - - self.UNet_vram_target_GB = gpu_memory_target_in_gb - self.UNet_reference_val_corresp_GB = 8 - - # this is supposed to give the same GPU memory requirement as the default nnU-Net - self.UNet_reference_val_3d = 680000000 - self.UNet_reference_val_2d = 135000000 - self.max_dataset_covered = 1 - diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py deleted file mode 100644 index 7f59ab81f..000000000 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXL.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Union, List, Tuple - -from dynamic_network_architectures.architectures.residual_unet import ResidualEncoderUNet - -from nnunetv2.experiment_planning.experiment_planners.residual_unets.ResEncUNet_planner import ResEncUNetPlanner - - -class nnUNetPlannerXL(ResEncUNetPlanner): - """ - Target is 40 GB VRAM max -> A100 40GB, RTX 6000 Ada Generation - """ - def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 40, - preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetXLPlans', - overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, - suppress_transpose: bool = False): - gpu_memory_target_in_gb = 40 - super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, - overwrite_target_spacing, suppress_transpose) - self.UNet_class = ResidualEncoderUNet - - self.UNet_vram_target_GB = gpu_memory_target_in_gb - self.UNet_reference_val_corresp_GB = 40 - - self.UNet_reference_val_3d = 3600000000 - self.UNet_reference_val_2d = 560000000 - self.max_dataset_covered = 1 - - diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXLx8.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXLx8.py deleted file mode 100644 index c8b1d7631..000000000 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/new_nnunet_presets/nnUNetPlannerXLx8.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Union, List, Tuple - -from nnunetv2.experiment_planning.experiment_planners.residual_unets.new_nnunet_presets.nnUNetPlannerXL import \ - nnUNetPlannerXL - - -class nnUNetPlannerXLx8(nnUNetPlannerXL): - """ - Target is 8*40 GB VRAM max -> 8xA100 40GB or 4*A100 80GB - """ - def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 40, # this needs to be 40 as we lan for the same size per GPU as XL - preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetXLx8Plans', - overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, - suppress_transpose: bool = False): - gpu_memory_target_in_gb = 40 - super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, - overwrite_target_spacing, suppress_transpose) - - def plan_experiment(self): - print('DO NOT TRUST ANY PRINTED PLANS AS THE BATCH SIZE WILL NOT YET HAVE BEEN INCREASED! FINAL BATCH SIZE IS ' - '8x OF WHAT YOU SEE') - super(nnUNetPlannerXLx8, self).plan_experiment() - for configuration in ['2d', '3d_fullres', '3d_lowres']: - if configuration in self.plans['configurations']: - self.plans['configurations'][configuration]['batch_size'] *= 8 - self.save_plans(self.plans) - return self.plans diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/__init__.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerLmoreFilt.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerLmoreFilt.py deleted file mode 100644 index d1c2936ed..000000000 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerLmoreFilt.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Union, List, Tuple - -from dynamic_network_architectures.architectures.residual_unet import ResidualEncoderUNet - -from nnunetv2.experiment_planning.experiment_planners.residual_unets.ResEncUNet_planner import ResEncUNetPlanner - - -class nnUNetPlannerLmoreFilt(ResEncUNetPlanner): - """ - Target is ~24 GB VRAM max -> RTX 4090, Titan RTX, Quadro 6000 - """ - def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 24, - preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLmoreFiltPlans', - overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, - suppress_transpose: bool = False): - gpu_memory_target_in_gb = 24 - super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, - overwrite_target_spacing, suppress_transpose) - self.UNet_class = ResidualEncoderUNet - - self.UNet_vram_target_GB = gpu_memory_target_in_gb - self.UNet_reference_val_corresp_GB = 24 - self.UNet_base_num_features = 48 - self.UNet_max_features_3d = self.UNet_base_num_features * 2 ** 4 - - self.UNet_reference_val_3d = 1900000000 # 1840000000 - self.UNet_reference_val_2d = 370000000 # 352666667 - diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLmoreFilt.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLmoreFilt.py deleted file mode 100644 index fbdd0e6d7..000000000 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLmoreFilt.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Union, List, Tuple - -from dynamic_network_architectures.architectures.residual_unet import ResidualEncoderUNet - -from nnunetv2.experiment_planning.experiment_planners.residual_unets.ResEncUNet_planner import ResEncUNetPlanner - - -class nnUNetPlannerXLmoreFilt(ResEncUNetPlanner): - """ - Target is 40 GB VRAM max -> A100 40GB, RTX 6000 Ada Generation - """ - def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 40, - preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetXLmoreFiltPlans', - overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, - suppress_transpose: bool = False): - gpu_memory_target_in_gb = 40 - super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, - overwrite_target_spacing, suppress_transpose) - self.UNet_class = ResidualEncoderUNet - - self.UNet_vram_target_GB = gpu_memory_target_in_gb - self.UNet_reference_val_corresp_GB = 40 - self.UNet_base_num_features = 64 - self.UNet_max_features_3d = self.UNet_base_num_features * 2 ** 4 - - self.UNet_reference_val_3d = 3100000000 - self.UNet_reference_val_2d = 540000000 - diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLx8moreFilt.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLx8moreFilt.py deleted file mode 100644 index 86c83dc8d..000000000 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets_moreFilt/nnUNetPlannerXLx8moreFilt.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Union, List, Tuple - -from nnunetv2.experiment_planning.experiment_planners.residual_unets_moreFilt.nnUNetPlannerXLmoreFilt import \ - nnUNetPlannerXLmoreFilt - - -class nnUNetPlannerXLx8moreFilt(nnUNetPlannerXLmoreFilt): - """ - Target is 8*40 GB VRAM max -> 8xA100 40GB or 4*A100 80GB - """ - def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 40, # this needs to be 40 as we lan for the same size per GPU as XL - preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetXLx8moreFiltPlans', - overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, - suppress_transpose: bool = False): - gpu_memory_target_in_gb = 40 - super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, - overwrite_target_spacing, suppress_transpose) - - def plan_experiment(self): - print('DO NOT TRUST ANY PRINTED PLANS AS THE BATCH SIZE WILL NOT YET HAVE BEEN INCREASED! FINAL BATCH SIZE IS ' - '8x OF WHAT YOU SEE') - super(nnUNetPlannerXLmoreFilt, self).plan_experiment() - for configuration in ['2d', '3d_fullres', '3d_lowres']: - if configuration in self.plans['configurations']: - self.plans['configurations'][configuration]['batch_size'] *= 8 - self.save_plans(self.plans) - return self.plans From 900e0b1c49fa8b3cbd9fb7b66b3a5c38b0233a28 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Mon, 19 Feb 2024 18:52:21 +0100 Subject: [PATCH 23/24] fix dynamic network arch imports --- .../experiment_planners/default_experiment_planner.py | 1 - .../experiment_planners/resencUNet_planner.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py index c31f9d86b..8d512c4f7 100644 --- a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py @@ -6,7 +6,6 @@ import torch from batchgenerators.utilities.file_and_folder_operations import load_json, join, save_json, isfile, maybe_mkdir_p -from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet from dynamic_network_architectures.architectures.unet import PlainConvUNet from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm diff --git a/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py index f06cc4245..50993b0b2 100644 --- a/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py @@ -1,6 +1,6 @@ from typing import Union, List, Tuple -from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet +from dynamic_network_architectures.architectures.residual_unet import ResidualEncoderUNet from torch import nn from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner From e0137edc0a4145efaa5591b17154cb1c93a0d79e Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Mon, 19 Feb 2024 19:09:35 +0100 Subject: [PATCH 24/24] fix: patch size for 2d was unintentionally capped at 512 with the default configuration; feature: improved network architecture definition in plans. Backwards compatibility added in ConfigurationManager --- .../default_experiment_planner.py | 8 +- .../experiment_planners/resencUNet_planner.py | 205 +++++++++++++++++- .../utilities/plans_handling/plans_handler.py | 3 +- 3 files changed, 200 insertions(+), 16 deletions(-) diff --git a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py index 8d512c4f7..798a12f55 100644 --- a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py @@ -316,8 +316,11 @@ def _keygen(patch_size, strides): reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \ (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) - while estimate > reference: - # print(patch_size) + ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d + # we enforce a batch size of at least two, reference values may have been computed for different batch sizes. + # Correct for that in the while loop if statement + while (estimate / ref_bs * 2) > reference: + # print(patch_size, estimate, reference) # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the # aspect ratio the most (that is the largest relative to median shape) axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1] @@ -367,7 +370,6 @@ def _keygen(patch_size, strides): # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was # executed. If not, additional vram headroom is used to increase batch size - ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d batch_size = round((reference / estimate) * ref_bs) # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot diff --git a/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py b/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py index 50993b0b2..a7b2d2305 100644 --- a/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py @@ -1,10 +1,15 @@ +import numpy as np +from copy import deepcopy from typing import Union, List, Tuple from dynamic_network_architectures.architectures.residual_unet import ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm from torch import nn from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner +from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props + class ResEncUNetPlanner(ExperimentPlanner): def __init__(self, dataset_name_or_id: Union[str, int], @@ -14,23 +19,200 @@ def __init__(self, dataset_name_or_id: Union[str, int], suppress_transpose: bool = False): super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) - - self.UNet_base_num_features = 32 self.UNet_class = ResidualEncoderUNet # the following two numbers are really arbitrary and were set to reproduce default nnU-Net's configurations as # much as possible self.UNet_reference_val_3d = 680000000 self.UNet_reference_val_2d = 135000000 - self.UNet_reference_com_nfeatures = 32 - self.UNet_reference_val_corresp_GB = 8 - self.UNet_reference_val_corresp_bs_2d = 12 - self.UNet_reference_val_corresp_bs_3d = 2 - self.UNet_featuremap_min_edge_length = 4 self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6) self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) - self.UNet_min_batch_size = 2 - self.UNet_max_features_2d = 512 - self.UNet_max_features_3d = 320 + + def generate_data_identifier(self, configuration_name: str) -> str: + """ + configurations are unique within each plans file but different plans file can have configurations with the + same name. In order to distinguish the associated data we need a data identifier that reflects not just the + config but also the plans it originates from + """ + if configuration_name == '2d' or configuration_name == '3d_fullres': + # we do not deviate from ExperimentPlanner so we can reuse its data + return 'nnUNetPlans' + '_' + configuration_name + else: + return self.plans_identifier + '_' + configuration_name + + def get_plans_for_configuration(self, + spacing: Union[np.ndarray, Tuple[float, ...], List[float]], + median_shape: Union[np.ndarray, Tuple[int, ...]], + data_identifier: str, + approximate_n_voxels_dataset: float, + _cache: dict) -> dict: + def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: + return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for + i in range(num_stages)]) + + def _keygen(patch_size, strides): + return str(patch_size) + '_' + str(strides) + + assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}" + num_input_channels = len(self.dataset_json['channel_names'].keys() + if 'channel_names' in self.dataset_json.keys() + else self.dataset_json['modality'].keys()) + max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d + unet_conv_op = convert_dim_to_conv_op(len(spacing)) + + # print(spacing, median_shape, approximate_n_voxels_dataset) + # find an initial patch size + # we first use the spacing to get an aspect ratio + tmp = 1 / np.array(spacing) + + # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same + # volume as a patch of size 256 ** 3) + # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be + # ideal because large initial patch sizes increase computation time because more iterations in the while loop + # further down may be required. + if len(spacing) == 3: + initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)] + elif len(spacing) == 2: + initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)] + else: + raise RuntimeError() + + # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that + # this is different from how nnU-Net v1 does it! + # todo patch size can still get too large because we pad the patch size to a multiple of 2**n + initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])]) + + # use that to get the network topology. Note that this changes the patch_size depending on the number of + # pooling operations (must be divisible by 2**num_pool in each axis) + network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ + shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size, + self.UNet_featuremap_min_edge_length, + 999999) + num_stages = len(pool_op_kernel_sizes) + + norm = get_matching_instancenorm(unet_conv_op) + architecture_kwargs = { + 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, + 'arch_kwargs': { + 'n_stages': num_stages, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + 'conv_bias': True, + 'norm_op': norm.__module__ + '.' + norm.__name__, + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, + 'dropout_op_kwargs': None, + 'nonlin': 'torch.nn.LeakyReLU', + 'nonlin_kwargs': {'inplace': True}, + }, + '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), + } + + # now estimate vram consumption + if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): + estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] + else: + estimate = self.static_estimate_VRAM_usage(patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + + # how large is the reference for us here (batch size etc)? + # adapt for our vram target + reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \ + (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) + + while estimate > reference: + # print(patch_size) + # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the + # aspect ratio the most (that is the largest relative to median shape) + axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1] + + # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this + # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256. + # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size + # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first + # subtract shape_must_be_divisible_by, then recompute it and then subtract the + # recomputed shape_must_be_divisible_by. Annoying. + patch_size = list(patch_size) + tmp = deepcopy(patch_size) + tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] + _, _, _, _, shape_must_be_divisible_by = \ + get_pool_and_conv_props(spacing, tmp, + self.UNet_featuremap_min_edge_length, + 999999) + patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] + + # now recompute topology + network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ + shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size, + self.UNet_featuremap_min_edge_length, + 999999) + + num_stages = len(pool_op_kernel_sizes) + architecture_kwargs['arch_kwargs'].update({ + 'n_stages': num_stages, + 'kernel_sizes': conv_kernel_sizes, + 'strides': pool_op_kernel_sizes, + 'features_per_stage': _features_per_stage(num_stages, max_num_features), + 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + }) + if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): + estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] + else: + estimate = self.static_estimate_VRAM_usage( + patch_size, + num_input_channels, + len(self.dataset_json['labels'].keys()), + architecture_kwargs['network_class_name'], + architecture_kwargs['arch_kwargs'], + architecture_kwargs['_kw_requires_import'], + ) + _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate + + # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was + # executed. If not, additional vram headroom is used to increase batch size + ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d + batch_size = round((reference / estimate) * ref_bs) + + # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot + # go smaller than self.UNet_min_batch_size though + bs_corresponding_to_5_percent = round( + approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64)) + batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) + + resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() + resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn() + + normalization_schemes, mask_is_used_for_norm = \ + self.determine_normalization_scheme_and_whether_mask_is_used_for_norm() + + plan = { + 'data_identifier': data_identifier, + 'preprocessor_name': self.preprocessor_name, + 'batch_size': batch_size, + 'patch_size': patch_size, + 'median_image_size_in_voxels': median_shape, + 'spacing': spacing, + 'normalization_schemes': normalization_schemes, + 'use_mask_for_norm': mask_is_used_for_norm, + 'resampling_fn_data': resampling_data.__name__, + 'resampling_fn_seg': resampling_seg.__name__, + 'resampling_fn_data_kwargs': resampling_data_kwargs, + 'resampling_fn_seg_kwargs': resampling_seg_kwargs, + 'resampling_fn_probabilities': resampling_softmax.__name__, + 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs, + 'architecture': architecture_kwargs + } + return plan if __name__ == '__main__': @@ -50,5 +232,4 @@ def __init__(self, dataset_name_or_id: Union[str, int], n_conv_per_stage_decoder=(1, 1, 1, 1, 1, 1), conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={}, dropout_op=None, nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) - print(net.compute_conv_feature_map_size((512, 512))) # -> 129793792 - + print(net.compute_conv_feature_map_size((512, 512))) # -> 129793792 \ No newline at end of file diff --git a/nnunetv2/utilities/plans_handling/plans_handler.py b/nnunetv2/utilities/plans_handling/plans_handler.py index 079de6447..518a462f9 100644 --- a/nnunetv2/utilities/plans_handling/plans_handler.py +++ b/nnunetv2/utilities/plans_handling/plans_handler.py @@ -36,7 +36,8 @@ def __init__(self, configuration_dict: dict): if 'architecture' not in self.configuration.keys(): warnings.warn("Detected old nnU-Net plans format. Attempting to reconstruct network architecture " "parameters. If this fails, rerun nnUNetv2_plan_experiment for your dataset. If you use a " - "custom architecture, please downgrade nnU-Net or update your plans.") + "custom architecture, please downgrade nnU-Net yo v2.3 " + "(https://github.com/MIC-DKFZ/nnUNet/releases/tag/v2.3) or update your plans.") # try to build the architecture information from old plans, modify configuration dict to match new standard unet_class_name = self.configuration["UNet_class_name"] if unet_class_name == "PlainConvUNet":