From 1315d2f05a2797438b19be6ceae71de48f2b8452 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Tue, 12 Dec 2023 11:48:54 +0100 Subject: [PATCH] cleanup inference code, better verbose messages --- nnunetv2/inference/predict_from_raw_data.py | 213 +++++++++--------- .../training/nnUNetTrainer/nnUNetTrainer.py | 6 +- 2 files changed, 107 insertions(+), 112 deletions(-) diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 713a02954..e57156467 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -40,7 +40,7 @@ def __init__(self, tile_step_size: float = 0.5, use_gaussian: bool = True, use_mirroring: bool = True, - perform_everything_on_gpu: bool = True, + perform_everything_on_device: bool = True, device: torch.device = torch.device('cuda'), verbose: bool = False, verbose_preprocessing: bool = False, @@ -60,10 +60,10 @@ def __init__(self, # why would I ever want to do that. Stupid dobby. This kills DDP inference... pass if device.type != 'cuda': - print(f'perform_everything_on_gpu=True is only supported for cuda devices! Setting this to False') - perform_everything_on_gpu = False + print(f'perform_everything_on_device=True is only supported for cuda devices! Setting this to False') + perform_everything_on_device = False self.device = device - self.perform_everything_on_gpu = perform_everything_on_gpu + self.perform_everything_on_device = perform_everything_on_device def initialize_from_trained_model_folder(self, model_training_output_dir: str, use_folds: Union[Tuple[Union[int, str]], None], @@ -111,7 +111,7 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str, self.label_manager = plans_manager.get_label_manager(dataset_json) if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \ and not isinstance(self.network, OptimizedModule): - print('compiling network') + print('Using torch.compile') self.network = torch.compile(self.network) def manual_initialization(self, network: nn.Module, plans_manager: PlansManager, @@ -135,7 +135,7 @@ def manual_initialization(self, network: nn.Module, plans_manager: PlansManager, if isinstance(self.network, DistributedDataParallel): allow_compile = allow_compile and isinstance(self.network.module, OptimizedModule) if allow_compile: - print('compiling network') + print('Using torch.compile') self.network = torch.compile(self.network) @staticmethod @@ -353,7 +353,7 @@ def predict_from_data_iterator(self, else: print(f'\nPredicting image of shape {data.shape}:') - print(f'perform_everything_on_gpu: {self.perform_everything_on_gpu}') + print(f'perform_everything_on_device: {self.perform_everything_on_device}') properties = preprocessed['data_properties'] @@ -454,56 +454,33 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE. SEE convert_predicted_logits_to_segmentation_with_correct_shape """ - # we have some code duplication here but this allows us to run with perform_everything_on_gpu=True as - # default and not have the entire program crash in case of GPU out of memory. Neat. That should make - # things a lot faster for some datasets. - original_perform_everything_on_gpu = self.perform_everything_on_gpu + n_threads = torch.get_num_threads() + torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads) with torch.no_grad(): prediction = None - if self.perform_everything_on_gpu: - try: - for params in self.list_of_parameters: - - # messing with state dict names... - if not isinstance(self.network, OptimizedModule): - self.network.load_state_dict(params) - else: - self.network._orig_mod.load_state_dict(params) - - if prediction is None: - prediction = self.predict_sliding_window_return_logits(data) - else: - prediction += self.predict_sliding_window_return_logits(data) - - if len(self.list_of_parameters) > 1: - prediction /= len(self.list_of_parameters) - - except RuntimeError: - print('Prediction with perform_everything_on_gpu=True failed due to insufficient GPU memory. ' - 'Falling back to perform_everything_on_gpu=False. Not a big deal, just slower...') - print('Error:') - traceback.print_exc() - prediction = None - self.perform_everything_on_gpu = False - - if prediction is None: - for params in self.list_of_parameters: - # messing with state dict names... - if not isinstance(self.network, OptimizedModule): - self.network.load_state_dict(params) - else: - self.network._orig_mod.load_state_dict(params) - - if prediction is None: - prediction = self.predict_sliding_window_return_logits(data) - else: - prediction += self.predict_sliding_window_return_logits(data) - if len(self.list_of_parameters) > 1: - prediction /= len(self.list_of_parameters) - - print('Prediction done, transferring to CPU if needed') + + for params in self.list_of_parameters: + + # messing with state dict names... + if not isinstance(self.network, OptimizedModule): + self.network.load_state_dict(params) + else: + self.network._orig_mod.load_state_dict(params) + + # why not leave prediction on device if perform_everything_on_device? Because this may cause the + # second iteration to crash due to OOM. Grabbing tha twith try except cause way more bloated code than + # this actually saves computation time + if prediction is None: + prediction = self.predict_sliding_window_return_logits(data).to('cpu') + else: + prediction += self.predict_sliding_window_return_logits(data).to('cpu') + + if len(self.list_of_parameters) > 1: + prediction /= len(self.list_of_parameters) + + if self.verbose: print('Prediction done') prediction = prediction.to('cpu') - self.perform_everything_on_gpu = original_perform_everything_on_gpu + torch.set_num_threads(n_threads) return prediction def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]): @@ -557,6 +534,48 @@ def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor: prediction /= (len(axes_combinations) + 1) return prediction + def _internal_predict_sliding_window_return_logits(self, + data: torch.Tensor, + slicers, + do_on_device: bool = True, + ): + results_device = self.device if do_on_device else torch.device('cpu') + + # move data to device + if self.verbose: print(f'move image to device {results_device}') + data = data.to(self.device) + + # preallocate arrays + if self.verbose: print(f'preallocating results arrays on device {results_device}') + predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), + dtype=torch.half, + device=results_device) + n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) + if self.use_gaussian: + gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, + value_scaling_factor=10, + device=results_device) + empty_cache(self.device) + + if self.verbose: print('running prediction') + if not self.allow_tqdm and self.verbose: print(f'{len(slicers)} steps') + for sl in tqdm(slicers, disable=not self.allow_tqdm): + workon = data[sl][None] + workon = workon.to(self.device, non_blocking=False) + + prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device) + + predicted_logits[sl] += (prediction * gaussian if self.use_gaussian else prediction) + n_predictions[sl[1:]] += (gaussian if self.use_gaussian else 1) + + predicted_logits /= n_predictions + # check for infs + if torch.any(torch.isinf(predicted_logits)): + raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, ' + 'reduce value_scaling_factor in compute_gaussian or increase the dtype of ' + 'predicted_logits to fp32') + return predicted_logits + def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \ -> Union[np.ndarray, torch.Tensor]: assert isinstance(input_image, torch.Tensor) @@ -586,54 +605,21 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \ slicers = self._internal_get_sliding_window_slicers(data.shape[1:]) - # preallocate results and num_predictions - results_device = self.device if self.perform_everything_on_gpu else torch.device('cpu') - if self.verbose: print('preallocating arrays') - try: - data = data.to(self.device) - predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), - dtype=torch.half, - device=results_device) - n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, - device=results_device) - if self.use_gaussian: - gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, - value_scaling_factor=10, - device=results_device) - except RuntimeError: - # sometimes the stuff is too large for GPUs. In that case fall back to CPU - results_device = torch.device('cpu') - data = data.to(results_device) - predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), - dtype=torch.half, - device=results_device) - n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, - device=results_device) - if self.use_gaussian: - gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, - value_scaling_factor=10, - device=results_device) - finally: - empty_cache(self.device) - - if self.verbose: print('running prediction') - for sl in tqdm(slicers, disable=not self.allow_tqdm): - workon = data[sl][None] - workon = workon.to(self.device, non_blocking=False) - - prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device) - - predicted_logits[sl] += (prediction * gaussian if self.use_gaussian else prediction) - n_predictions[sl[1:]] += (gaussian if self.use_gaussian else 1) - - predicted_logits /= n_predictions - # check for infs - if torch.any(torch.isinf(predicted_logits)): - raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, ' - 'reduce value_scaling_factor in compute_gaussian or increase the dtype of ' - 'predicted_logits to fp32') - empty_cache(self.device) - return predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])] + if self.perform_everything_on_device and self.device != 'cpu': + # we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device + try: + predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, self.perform_everything_on_device) + except RuntimeError: + print('Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU') + empty_cache(self.device) + predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False) + else: + predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, self.perform_everything_on_device) + + empty_cache(self.device) + # revert padding + predicted_logits = predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])] + return predicted_logits def predict_entry_point_modelfolder(): @@ -681,6 +667,10 @@ def predict_entry_point_modelfolder(): help="Use this to set the device the inference should run with. Available options are 'cuda' " "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!") + parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False, + help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive ' + 'jobs)') + print( "\n#######################################################################\nPlease cite the following paper " @@ -713,9 +703,10 @@ def predict_entry_point_modelfolder(): predictor = nnUNetPredictor(tile_step_size=args.step_size, use_gaussian=True, use_mirroring=not args.disable_tta, - perform_everything_on_gpu=True, + perform_everything_on_device=True, device=device, - verbose=args.verbose) + verbose=args.verbose, + allow_tqdm=not args.disable_progress_bar) predictor.initialize_from_trained_model_folder(args.m, args.f, args.chk) predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities, overwrite=not args.continue_prediction, @@ -785,6 +776,9 @@ def predict_entry_point(): help="Use this to set the device the inference should run with. Available options are 'cuda' " "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!") + parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False, + help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive ' + 'jobs)') print( "\n#######################################################################\nPlease cite the following paper " @@ -822,10 +816,11 @@ def predict_entry_point(): predictor = nnUNetPredictor(tile_step_size=args.step_size, use_gaussian=True, use_mirroring=not args.disable_tta, - perform_everything_on_gpu=True, + perform_everything_on_device=True, device=device, verbose=args.verbose, - verbose_preprocessing=False) + verbose_preprocessing=False, + allow_tqdm=not args.disable_progress_bar) predictor.initialize_from_trained_model_folder( model_folder, args.f, @@ -845,7 +840,7 @@ def predict_entry_point(): # args.step_size, # use_gaussian=True, # use_mirroring=not args.disable_tta, - # perform_everything_on_gpu=True, + # perform_everything_on_device=True, # verbose=args.verbose, # save_probabilities=args.save_probabilities, # overwrite=not args.continue_prediction, @@ -865,7 +860,7 @@ def predict_entry_point(): tile_step_size=0.5, use_gaussian=True, use_mirroring=True, - perform_everything_on_gpu=True, + perform_everything_on_device=True, device=torch.device('cuda', 0), verbose=False, verbose_preprocessing=False, @@ -895,7 +890,7 @@ def predict_entry_point(): # tile_step_size=0.5, # use_gaussian=True, # use_mirroring=True, - # perform_everything_on_gpu=True, + # perform_everything_on_device=True, # device=torch.device('cuda', 0), # verbose=False, # allow_tqdm=True diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index 8bb0efaa2..27439b170 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -211,7 +211,7 @@ def initialize(self): ).to(self.device) # compile network for free speedup if self._do_i_compile(): - self.print_to_log_file('Compiling network...') + self.print_to_log_file('Using torch.compile...') self.network = torch.compile(self.network) self.optimizer, self.lr_scheduler = self.configure_optimizers() @@ -1174,9 +1174,9 @@ def perform_actual_validation(self, save_probabilities: bool = False): try: prediction = predictor.predict_sliding_window_return_logits(data) except RuntimeError: - predictor.perform_everything_on_gpu = False + predictor.perform_everything_on_device = False prediction = predictor.predict_sliding_window_return_logits(data) - predictor.perform_everything_on_gpu = True + predictor.perform_everything_on_device = True prediction = prediction.cpu()