From d87fa5b84e138c536d1f5c2dba7be92856893554 Mon Sep 17 00:00:00 2001 From: ancestor-mithril <58839912+ancestor-mithril@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:18:23 +0200 Subject: [PATCH 1/8] Improved data iterators --- nnunetv2/inference/data_iterators.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/nnunetv2/inference/data_iterators.py b/nnunetv2/inference/data_iterators.py index 1777fb934..068905d9f 100644 --- a/nnunetv2/inference/data_iterators.py +++ b/nnunetv2/inference/data_iterators.py @@ -38,7 +38,7 @@ def preprocess_fromfiles_save_to_queue(list_of_lists: List[List[str]], seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype) data = np.vstack((data, seg_onehot)) - data = torch.from_numpy(data).contiguous().float() + data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format) item = {'data': data, 'data_properties': data_properties, 'ofile': output_filenames_truncated[idx] if output_filenames_truncated is not None else None} @@ -146,9 +146,7 @@ def __init__(self, list_of_lists: List[List[str]], def generate_train_batch(self): idx = self.get_indices()[0] - files = self._data[idx][0] - seg_prev_stage = self._data[idx][1] - ofile = self._data[idx][2] + files, seg_prev_stage, ofile = self._data[idx][0] # if we have a segmentation from the previous stage we have to process it together with the images so that we # can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after # preprocessing and then there might be misalignments @@ -192,10 +190,7 @@ def __init__(self, list_of_images: List[np.ndarray], def generate_train_batch(self): idx = self.get_indices()[0] - image = self._data[idx][0] - seg_prev_stage = self._data[idx][1] - props = self._data[idx][2] - ofname = self._data[idx][3] + image, seg_prev_stage, props, ofname = self._data[idx][0] # if we have a segmentation from the previous stage we have to process it together with the images so that we # can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after # preprocessing and then there might be misalignments @@ -238,7 +233,7 @@ def preprocess_fromnpy_save_to_queue(list_of_images: List[np.ndarray], seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype) data = np.vstack((data, seg_onehot)) - data = torch.from_numpy(data).contiguous().float() + data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format) item = {'data': data, 'data_properties': list_of_image_properties[idx], 'ofile': truncated_ofnames[idx] if truncated_ofnames is not None else None} From ea88fb7be36efa90f095b46cde6eb59d85cedff7 Mon Sep 17 00:00:00 2001 From: ancestor-mithril <58839912+ancestor-mithril@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:21:16 +0200 Subject: [PATCH 2/8] Using backends.cudnn.benchmark during prediction for faster inference --- nnunetv2/inference/predict_from_raw_data.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 2ac6600c2..3c2b3b25c 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -55,9 +55,8 @@ def __init__(self, self.use_gaussian = use_gaussian self.use_mirroring = use_mirroring if device.type == 'cuda': - # device = torch.device(type='cuda', index=0) # set the desired GPU with CUDA_VISIBLE_DEVICES! - pass - if device.type != 'cuda': + torch.backends.cudnn.benchmark = True + else: print(f'perform_everything_on_device=True is only supported for cuda devices! Setting this to False') perform_everything_on_device = False self.device = device From db7bec74ab70442a6280a5c0e8572bc2de29b9d3 Mon Sep 17 00:00:00 2001 From: ancestor-mithril <58839912+ancestor-mithril@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:25:20 +0200 Subject: [PATCH 3/8] Using torch.inference_mode for prediction --- nnunetv2/inference/predict_from_raw_data.py | 95 ++++++++++----------- 1 file changed, 47 insertions(+), 48 deletions(-) diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 3c2b3b25c..c285c50cb 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -463,6 +463,7 @@ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: di else: return ret + @torch.inference_mode() def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor: """ IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON @@ -473,30 +474,28 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten """ n_threads = torch.get_num_threads() torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads) - with torch.no_grad(): - prediction = None + prediction = None - for params in self.list_of_parameters: + for params in self.list_of_parameters: - # messing with state dict names... - if not isinstance(self.network, OptimizedModule): - self.network.load_state_dict(params) - else: - self.network._orig_mod.load_state_dict(params) + # messing with state dict names... + if not isinstance(self.network, OptimizedModule): + self.network.load_state_dict(params) + else: + self.network._orig_mod.load_state_dict(params) - # why not leave prediction on device if perform_everything_on_device? Because this may cause the - # second iteration to crash due to OOM. Grabbing that with try except cause way more bloated code than - # this actually saves computation time - if prediction is None: - prediction = self.predict_sliding_window_return_logits(data).to('cpu') - else: - prediction += self.predict_sliding_window_return_logits(data).to('cpu') + # why not leave prediction on device if perform_everything_on_device? Because this may cause the + # second iteration to crash due to OOM. Grabbing that with try except cause way more bloated code than + # this actually saves computation time + if prediction is None: + prediction = self.predict_sliding_window_return_logits(data).to('cpu') + else: + prediction += self.predict_sliding_window_return_logits(data).to('cpu') - if len(self.list_of_parameters) > 1: - prediction /= len(self.list_of_parameters) + if len(self.list_of_parameters) > 1: + prediction /= len(self.list_of_parameters) - if self.verbose: print('Prediction done') - prediction = prediction.to('cpu') + if self.verbose: print('Prediction done') torch.set_num_threads(n_threads) return prediction @@ -617,38 +616,38 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \ # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False # is set. Whyyyyyyy. (this is why we don't make use of enabled=False) # So autocast will only be active if we have a cuda device. - with torch.no_grad(): - with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): - assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)' - - if self.verbose: print(f'Input shape: {input_image.shape}') - if self.verbose: print("step_size:", self.tile_step_size) - if self.verbose: print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None) - - # if input_image is smaller than tile_size we need to pad it to tile_size. - data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size, - 'constant', {'value': 0}, True, - None) - - slicers = self._internal_get_sliding_window_slicers(data.shape[1:]) - - if self.perform_everything_on_device and self.device != 'cpu': - # we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device - try: - predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, - self.perform_everything_on_device) - except RuntimeError: - print( - 'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU') - empty_cache(self.device) - predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False) - else: + with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): + assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)' + + if self.verbose: + print(f'Input shape: {input_image.shape}') + print("step_size:", self.tile_step_size) + print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None) + + # if input_image is smaller than tile_size we need to pad it to tile_size. + data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size, + 'constant', {'value': 0}, True, + None) + + slicers = self._internal_get_sliding_window_slicers(data.shape[1:]) + + if self.perform_everything_on_device and self.device != 'cpu': + # we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device + try: predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, self.perform_everything_on_device) + except RuntimeError: + print( + 'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU') + empty_cache(self.device) + predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False) + else: + predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, + self.perform_everything_on_device) - empty_cache(self.device) - # revert padding - predicted_logits = predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])] + empty_cache(self.device) + # revert padding + predicted_logits = predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])] return predicted_logits From a3e117023a0b8679a39bc24c3fd4a2fbcbe29bcc Mon Sep 17 00:00:00 2001 From: ancestor-mithril <58839912+ancestor-mithril@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:28:37 +0200 Subject: [PATCH 4/8] Removed tuple unpacking for mirroring and predicting --- nnunetv2/inference/predict_from_raw_data.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index c285c50cb..45b150661 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -542,11 +542,12 @@ def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor: # x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3 assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!' + mirror_axes = [m + 2 for m in mirror_axes] axes_combinations = [ - c for i in range(len(mirror_axes)) for c in itertools.combinations([m + 2 for m in mirror_axes], i + 1) + c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1) ] for axes in axes_combinations: - prediction += torch.flip(self.network(torch.flip(x, (*axes,))), (*axes,)) + prediction += torch.flip(self.network(torch.flip(x, axes)), axes) prediction /= (len(axes_combinations) + 1) return prediction From cf7696e333b54bcea018dfe44e795cf158dbdfe0 Mon Sep 17 00:00:00 2001 From: ancestor-mithril <58839912+ancestor-mithril@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:35:47 +0200 Subject: [PATCH 5/8] Using gaussian only when necessary + applied inplace addition --- nnunetv2/inference/predict_from_raw_data.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 45b150661..064378778 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -573,24 +573,27 @@ def _internal_predict_sliding_window_return_logits(self, predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), dtype=torch.half, device=results_device) - n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) if self.use_gaussian: gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, value_scaling_factor=10, device=results_device) + n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) - if self.verbose: print('running prediction') - if not self.allow_tqdm and self.verbose: print(f'{len(slicers)} steps') + if not self.allow_tqdm and self.verbose: + print(f'running prediction: {len(slicers)} steps') for sl in tqdm(slicers, disable=not self.allow_tqdm): workon = data[sl][None] - workon = workon.to(self.device, non_blocking=False) + workon = workon.to(self.device) prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device) - predicted_logits[sl] += (prediction * gaussian if self.use_gaussian else prediction) - n_predictions[sl[1:]] += (gaussian if self.use_gaussian else 1) + if self.use_gaussian: + prediction *= gaussian + n_predictions[sl[1:]] += gaussian + predicted_logits[sl] += prediction - predicted_logits /= n_predictions + if self.use_gaussian: + predicted_logits /= n_predictions # check for infs if torch.any(torch.isinf(predicted_logits)): raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, ' @@ -648,7 +651,7 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \ empty_cache(self.device) # revert padding - predicted_logits = predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])] + predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])] return predicted_logits From c996ba2b3c0120f9230ac92b6d105d2c08bdfd6e Mon Sep 17 00:00:00 2001 From: ancestor-mithril <58839912+ancestor-mithril@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:38:43 +0200 Subject: [PATCH 6/8] Compute gaussian * changing `lru_cache`'s `maxize=2` to `maxsize=None` for faster access and because cache will not be filled with more than 2 values --- nnunetv2/inference/sliding_window_prediction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnunetv2/inference/sliding_window_prediction.py b/nnunetv2/inference/sliding_window_prediction.py index a6f8ebbae..cf03dcb93 100644 --- a/nnunetv2/inference/sliding_window_prediction.py +++ b/nnunetv2/inference/sliding_window_prediction.py @@ -7,7 +7,7 @@ from scipy.ndimage import gaussian_filter -@lru_cache(maxsize=2) +@lru_cache(maxsize=None) def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale: float = 1. / 8, value_scaling_factor: float = 1, dtype=torch.float16, device=torch.device('cuda', 0)) \ -> torch.Tensor: From 2d951e3df4e44761b2b66e967b5a4616cb7167d6 Mon Sep 17 00:00:00 2001 From: ancestor-mithril <58839912+ancestor-mithril@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:44:15 +0200 Subject: [PATCH 7/8] Faster compute_gaussian --- nnunetv2/inference/sliding_window_prediction.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/nnunetv2/inference/sliding_window_prediction.py b/nnunetv2/inference/sliding_window_prediction.py index cf03dcb93..64e6b1a5e 100644 --- a/nnunetv2/inference/sliding_window_prediction.py +++ b/nnunetv2/inference/sliding_window_prediction.py @@ -18,14 +18,12 @@ def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale: gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0) gaussian_importance_map = torch.from_numpy(gaussian_importance_map) + gaussian_importance_map = gaussian_importance_map.to(device=device, dtype=dtype) - gaussian_importance_map = gaussian_importance_map / torch.max(gaussian_importance_map) * value_scaling_factor - gaussian_importance_map = gaussian_importance_map.type(dtype).to(device) - + gaussian_importance_map /= (torch.max(gaussian_importance_map) / value_scaling_factor) # gaussian_importance_map cannot be 0, otherwise we may end up with nans! - gaussian_importance_map[gaussian_importance_map == 0] = torch.min( - gaussian_importance_map[gaussian_importance_map != 0]) - + mask = gaussian_importance_map == 0 + gaussian_importance_map[mask] = torch.min(gaussian_importance_map[~mask]) return gaussian_importance_map From 809277b11ad2e9ed367472246882d10578612ebc Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Wed, 10 Apr 2024 12:45:47 +0300 Subject: [PATCH 8/8] Fixed after review * Added n_predictions back to replicate the previous behavior * Setting gaussian to 1 if not using the gaussian * setting lru cache size back to 2 to prevent OOM for unintended usage --- nnunetv2/inference/predict_from_raw_data.py | 10 ++++++---- nnunetv2/inference/sliding_window_prediction.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 064378778..06d2dc607 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -573,11 +573,14 @@ def _internal_predict_sliding_window_return_logits(self, predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), dtype=torch.half, device=results_device) + n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) + if self.use_gaussian: gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, value_scaling_factor=10, device=results_device) - n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) + else: + gaussian = 1 if not self.allow_tqdm and self.verbose: print(f'running prediction: {len(slicers)} steps') @@ -589,11 +592,10 @@ def _internal_predict_sliding_window_return_logits(self, if self.use_gaussian: prediction *= gaussian - n_predictions[sl[1:]] += gaussian predicted_logits[sl] += prediction + n_predictions[sl[1:]] += gaussian - if self.use_gaussian: - predicted_logits /= n_predictions + predicted_logits /= n_predictions # check for infs if torch.any(torch.isinf(predicted_logits)): raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, ' diff --git a/nnunetv2/inference/sliding_window_prediction.py b/nnunetv2/inference/sliding_window_prediction.py index 64e6b1a5e..93ed7b1e5 100644 --- a/nnunetv2/inference/sliding_window_prediction.py +++ b/nnunetv2/inference/sliding_window_prediction.py @@ -7,7 +7,7 @@ from scipy.ndimage import gaussian_filter -@lru_cache(maxsize=None) +@lru_cache(maxsize=2) def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale: float = 1. / 8, value_scaling_factor: float = 1, dtype=torch.float16, device=torch.device('cuda', 0)) \ -> torch.Tensor: