Skip to content

Commit

Permalink
replace inference mode with torch.no_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Jul 26, 2024
1 parent fee8c2d commit 96253e9
Showing 1 changed file with 40 additions and 40 deletions.
80 changes: 40 additions & 40 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,53 +609,53 @@ def _internal_predict_sliding_window_return_logits(self,
raise e
return predicted_logits

@torch.inference_mode()
def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \
-> Union[np.ndarray, torch.Tensor]:
assert isinstance(input_image, torch.Tensor)
self.network = self.network.to(self.device)
self.network.eval()
with torch.no_grad():
assert isinstance(input_image, torch.Tensor)
self.network = self.network.to(self.device)
self.network.eval()

empty_cache(self.device)
empty_cache(self.device)

# Autocast can be annoying
# If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection)
# and needs to be disabled.
# If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False
# is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
# So autocast will only be active if we have a cuda device.
with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)'

if self.verbose:
print(f'Input shape: {input_image.shape}')
print("step_size:", self.tile_step_size)
print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None)

# if input_image is smaller than tile_size we need to pad it to tile_size.
data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size,
'constant', {'value': 0}, True,
None)

slicers = self._internal_get_sliding_window_slicers(data.shape[1:])

if self.perform_everything_on_device and self.device != 'cpu':
# we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device
try:
# Autocast can be annoying
# If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection)
# and needs to be disabled.
# If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False
# is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
# So autocast will only be active if we have a cuda device.
with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)'

if self.verbose:
print(f'Input shape: {input_image.shape}')
print("step_size:", self.tile_step_size)
print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None)

# if input_image is smaller than tile_size we need to pad it to tile_size.
data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size,
'constant', {'value': 0}, True,
None)

slicers = self._internal_get_sliding_window_slicers(data.shape[1:])

if self.perform_everything_on_device and self.device != 'cpu':
# we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device
try:
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,
self.perform_everything_on_device)
except RuntimeError:
print(
'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU')
empty_cache(self.device)
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False)
else:
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,
self.perform_everything_on_device)
except RuntimeError:
print(
'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU')
empty_cache(self.device)
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False)
else:
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,
self.perform_everything_on_device)

empty_cache(self.device)
# revert padding
predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])]
empty_cache(self.device)
# revert padding
predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])]
return predicted_logits


Expand Down

0 comments on commit 96253e9

Please sign in to comment.