From a7377539268f744881788aa15d7030f11da320fe Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Fri, 7 Jun 2024 13:07:39 +0200 Subject: [PATCH] allow resampling with torch --- .../residual_encoder_unet_planners.py | 57 ------------------- .../preprocessors/default_preprocessor.py | 1 - 2 files changed, 58 deletions(-) diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/residual_encoder_unet_planners.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/residual_encoder_unet_planners.py index f7026e311..012950b82 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/residual_encoder_unet_planners.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/residual_encoder_unet_planners.py @@ -294,63 +294,6 @@ def __init__(self, dataset_name_or_id: Union[str, int], self.max_dataset_covered = 1 -class nnUNetPlannerResEncL_torchres(nnUNetPlannerResEncL): - def __init__(self, dataset_name_or_id: Union[str, int], - gpu_memory_target_in_gb: float = 24, - preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres', - overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, - suppress_transpose: bool = False): - super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, - overwrite_target_spacing, suppress_transpose) - - def generate_data_identifier(self, configuration_name: str) -> str: - """ - configurations are unique within each plans file but different plans file can have configurations with the - same name. In order to distinguish the associated data we need a data identifier that reflects not just the - config but also the plans it originates from - """ - return self.plans_identifier + '_' + configuration_name - - def determine_resampling(self, *args, **kwargs): - """ - returns what functions to use for resampling data and seg, respectively. Also returns kwargs - resampling function must be callable(data, current_spacing, new_spacing, **kwargs) - - determine_resampling is called within get_plans_for_configuration to allow for different functions for each - configuration - """ - resampling_data = resample_torch_fornnunet - resampling_data_kwargs = { - "is_seg": False, - 'force_separate_z': False, - 'memefficient_seg_resampling': False - } - resampling_seg = resample_torch_fornnunet - resampling_seg_kwargs = { - "is_seg": True, - 'force_separate_z': False, - 'memefficient_seg_resampling': False - } - return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs - - def determine_segmentation_softmax_export_fn(self, *args, **kwargs): - """ - function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be - used as target. current_spacing and new_spacing are merely there in case we want to use it somehow - - determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different - functions for each configuration - - """ - resampling_fn = resample_torch_fornnunet - resampling_fn_kwargs = { - "is_seg": False, - 'force_separate_z': False, - 'memefficient_seg_resampling': False - } - return resampling_fn, resampling_fn_kwargs - - if __name__ == '__main__': # we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320), diff --git a/nnunetv2/preprocessing/preprocessors/default_preprocessor.py b/nnunetv2/preprocessing/preprocessors/default_preprocessor.py index 400d81acd..8b1abf7b2 100644 --- a/nnunetv2/preprocessing/preprocessors/default_preprocessor.py +++ b/nnunetv2/preprocessing/preprocessors/default_preprocessor.py @@ -260,7 +260,6 @@ def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plan pbar.update() remaining = [i for i in remaining if i not in done] sleep(0.1) - _ = [i.get() for i in r] def modify_seg_fn(self, seg: np.ndarray, plans_manager: PlansManager, dataset_json: dict, configuration_manager: ConfigurationManager) -> np.ndarray: