diff --git a/nnUNetTrainer/nnUNetTrainer_NexToU.py b/nnUNetTrainer/nnUNetTrainer_NexToU.py index 0fc6159..21b6fd6 100644 --- a/nnUNetTrainer/nnUNetTrainer_NexToU.py +++ b/nnUNetTrainer/nnUNetTrainer_NexToU.py @@ -65,9 +65,9 @@ def build_network_architecture(plans_manager: PlansManager, network_class = mapping[segmentation_network_class_name] conv_or_blocks_per_stage = { - 'n_conv_per_stage' + 'n_blocks_per_stage' if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': configuration_manager.n_conv_per_stage_encoder, - 'n_conv_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder + 'n_blocks_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder } # network class name!!