Skip to content

Commit

Permalink
Fixed after review
Browse files Browse the repository at this point in the history
* Added n_predictions back to replicate the previous behavior
* Setting gaussian to 1 if not using the gaussian
* setting lru cache size back to 2 to prevent OOM for unintended usage
  • Loading branch information
ancestor-mithril committed Apr 10, 2024
1 parent 2d951e3 commit 809277b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,11 +573,14 @@ def _internal_predict_sliding_window_return_logits(self,
predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
dtype=torch.half,
device=results_device)
n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)

if self.use_gaussian:
gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
value_scaling_factor=10,
device=results_device)
n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)
else:
gaussian = 1

if not self.allow_tqdm and self.verbose:
print(f'running prediction: {len(slicers)} steps')
Expand All @@ -589,11 +592,10 @@ def _internal_predict_sliding_window_return_logits(self,

if self.use_gaussian:
prediction *= gaussian
n_predictions[sl[1:]] += gaussian
predicted_logits[sl] += prediction
n_predictions[sl[1:]] += gaussian

if self.use_gaussian:
predicted_logits /= n_predictions
predicted_logits /= n_predictions
# check for infs
if torch.any(torch.isinf(predicted_logits)):
raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, '
Expand Down
2 changes: 1 addition & 1 deletion nnunetv2/inference/sliding_window_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from scipy.ndimage import gaussian_filter


@lru_cache(maxsize=None)
@lru_cache(maxsize=2)
def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale: float = 1. / 8,
value_scaling_factor: float = 1, dtype=torch.float16, device=torch.device('cuda', 0)) \
-> torch.Tensor:
Expand Down

0 comments on commit 809277b

Please sign in to comment.