From 39650786450df7acc136d4e644b0c16039dafcb1 Mon Sep 17 00:00:00 2001 From: Pengcheng Shi Date: Fri, 29 Mar 2024 23:04:48 +0800 Subject: [PATCH] Delete nnUNetTrainer/nnUNetTrainer_NexToU_Res.py --- nnUNetTrainer/nnUNetTrainer_NexToU_Res.py | 92 ----------------------- 1 file changed, 92 deletions(-) delete mode 100644 nnUNetTrainer/nnUNetTrainer_NexToU_Res.py diff --git a/nnUNetTrainer/nnUNetTrainer_NexToU_Res.py b/nnUNetTrainer/nnUNetTrainer_NexToU_Res.py deleted file mode 100644 index a6e69de..0000000 --- a/nnUNetTrainer/nnUNetTrainer_NexToU_Res.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -import torch -from torch import autocast, nn -from typing import Union, Tuple, List -from torch import distributed as dist -from torch.nn.parallel import DistributedDataParallel as DDP - -from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet, PlainConvUNet -from nnunetv2.training.nnUNetTrainer.variants.network_architecture.NexToU_Res import NexToU_Res -from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_batchnorm -from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0, InitWeights_He -from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer -from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager -from nnunetv2.utilities.get_network_from_plans import get_network_from_plans -from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels - -class nnUNetTrainer_NexToU_Res(nnUNetTrainer): - @staticmethod - def build_network_architecture(plans_manager: PlansManager, - dataset_json, - configuration_manager: ConfigurationManager, - num_input_channels, - enable_deep_supervision: bool = True) -> nn.Module: - num_stages = len(configuration_manager.conv_kernel_sizes) - - dim = len(configuration_manager.conv_kernel_sizes[0]) - conv_op = convert_dim_to_conv_op(dim) - - label_manager = plans_manager.get_label_manager(dataset_json) - - segmentation_network_class_name = 'NexToU_Res' #configuration_manager.UNet_class_name - mapping = { - 'PlainConvUNet': PlainConvUNet, - 'ResidualEncoderUNet': ResidualEncoderUNet, - 'NexToU_Res': NexToU_Res - } - kwargs = { - 'PlainConvUNet': { - 'conv_bias': True, - 'norm_op': get_matching_batchnorm(conv_op), - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, 'dropout_op_kwargs': None, - 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, - }, - 'ResidualEncoderUNet': { - 'conv_bias': True, - 'norm_op': get_matching_batchnorm(conv_op), - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, 'dropout_op_kwargs': None, - 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, - }, - 'NexToU_Res': { - 'conv_bias': True, - 'norm_op': get_matching_batchnorm(conv_op), - 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, - 'dropout_op': None, 'dropout_op_kwargs': None, - 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, - } - } - assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ - 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ - 'into either this ' \ - 'function (get_network_from_plans) or ' \ - 'the init of your nnUNetModule to accomodate that.' - network_class = mapping[segmentation_network_class_name] - - conv_or_blocks_per_stage = { - 'n_blocks_per_stage' - if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': configuration_manager.n_conv_per_stage_encoder, - 'n_blocks_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder - } - - # network class name!! - model = network_class( - input_channels=num_input_channels, - patch_size=configuration_manager.patch_size, - n_stages=num_stages, - features_per_stage=[min(configuration_manager.UNet_base_num_features * 2 ** i, - configuration_manager.unet_max_num_features) for i in range(num_stages)], - conv_op=conv_op, - kernel_sizes=configuration_manager.conv_kernel_sizes, - strides=configuration_manager.pool_op_kernel_sizes, - num_classes=label_manager.num_segmentation_heads, - deep_supervision=enable_deep_supervision, - **conv_or_blocks_per_stage, - **kwargs[segmentation_network_class_name] - ) - model.apply(InitWeights_He(1e-2)) - if network_class == ResidualEncoderUNet: - model.apply(init_last_bn_before_add_to_0) - return model -