diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 064378778..06d2dc607 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -573,11 +573,14 @@ def _internal_predict_sliding_window_return_logits(self, predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), dtype=torch.half, device=results_device) + n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) + if self.use_gaussian: gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, value_scaling_factor=10, device=results_device) - n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) + else: + gaussian = 1 if not self.allow_tqdm and self.verbose: print(f'running prediction: {len(slicers)} steps') @@ -589,11 +592,10 @@ def _internal_predict_sliding_window_return_logits(self, if self.use_gaussian: prediction *= gaussian - n_predictions[sl[1:]] += gaussian predicted_logits[sl] += prediction + n_predictions[sl[1:]] += gaussian - if self.use_gaussian: - predicted_logits /= n_predictions + predicted_logits /= n_predictions # check for infs if torch.any(torch.isinf(predicted_logits)): raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, ' diff --git a/nnunetv2/inference/sliding_window_prediction.py b/nnunetv2/inference/sliding_window_prediction.py index 64e6b1a5e..93ed7b1e5 100644 --- a/nnunetv2/inference/sliding_window_prediction.py +++ b/nnunetv2/inference/sliding_window_prediction.py @@ -7,7 +7,7 @@ from scipy.ndimage import gaussian_filter -@lru_cache(maxsize=None) +@lru_cache(maxsize=2) def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale: float = 1. / 8, value_scaling_factor: float = 1, dtype=torch.float16, device=torch.device('cuda', 0)) \ -> torch.Tensor: