Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving nnUNet inference speed #2048

Merged
merged 8 commits into from
Apr 11, 2024
13 changes: 4 additions & 9 deletions nnunetv2/inference/data_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def preprocess_fromfiles_save_to_queue(list_of_lists: List[List[str]],
seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype)
data = np.vstack((data, seg_onehot))

data = torch.from_numpy(data).contiguous().float()
data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format)

item = {'data': data, 'data_properties': data_properties,
'ofile': output_filenames_truncated[idx] if output_filenames_truncated is not None else None}
Expand Down Expand Up @@ -146,9 +146,7 @@ def __init__(self, list_of_lists: List[List[str]],

def generate_train_batch(self):
idx = self.get_indices()[0]
files = self._data[idx][0]
seg_prev_stage = self._data[idx][1]
ofile = self._data[idx][2]
files, seg_prev_stage, ofile = self._data[idx][0]
# if we have a segmentation from the previous stage we have to process it together with the images so that we
# can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after
# preprocessing and then there might be misalignments
Expand Down Expand Up @@ -192,10 +190,7 @@ def __init__(self, list_of_images: List[np.ndarray],

def generate_train_batch(self):
idx = self.get_indices()[0]
image = self._data[idx][0]
seg_prev_stage = self._data[idx][1]
props = self._data[idx][2]
ofname = self._data[idx][3]
image, seg_prev_stage, props, ofname = self._data[idx][0]
# if we have a segmentation from the previous stage we have to process it together with the images so that we
# can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after
# preprocessing and then there might be misalignments
Expand Down Expand Up @@ -238,7 +233,7 @@ def preprocess_fromnpy_save_to_queue(list_of_images: List[np.ndarray],
seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype)
data = np.vstack((data, seg_onehot))

data = torch.from_numpy(data).contiguous().float()
data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format)

item = {'data': data, 'data_properties': list_of_image_properties[idx],
'ofile': truncated_ofnames[idx] if truncated_ofnames is not None else None}
Expand Down
122 changes: 62 additions & 60 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ def __init__(self,
self.use_gaussian = use_gaussian
self.use_mirroring = use_mirroring
if device.type == 'cuda':
# device = torch.device(type='cuda', index=0) # set the desired GPU with CUDA_VISIBLE_DEVICES!
pass
if device.type != 'cuda':
torch.backends.cudnn.benchmark = True
else:
print(f'perform_everything_on_device=True is only supported for cuda devices! Setting this to False')
perform_everything_on_device = False
self.device = device
Expand Down Expand Up @@ -464,6 +463,7 @@ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: di
else:
return ret

@torch.inference_mode()
def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor:
"""
IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON
Expand All @@ -474,30 +474,28 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten
"""
n_threads = torch.get_num_threads()
torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads)
with torch.no_grad():
prediction = None
prediction = None

for params in self.list_of_parameters:
for params in self.list_of_parameters:

# messing with state dict names...
if not isinstance(self.network, OptimizedModule):
self.network.load_state_dict(params)
else:
self.network._orig_mod.load_state_dict(params)
# messing with state dict names...
if not isinstance(self.network, OptimizedModule):
self.network.load_state_dict(params)
else:
self.network._orig_mod.load_state_dict(params)

# why not leave prediction on device if perform_everything_on_device? Because this may cause the
# second iteration to crash due to OOM. Grabbing that with try except cause way more bloated code than
# this actually saves computation time
if prediction is None:
prediction = self.predict_sliding_window_return_logits(data).to('cpu')
else:
prediction += self.predict_sliding_window_return_logits(data).to('cpu')
# why not leave prediction on device if perform_everything_on_device? Because this may cause the
# second iteration to crash due to OOM. Grabbing that with try except cause way more bloated code than
# this actually saves computation time
if prediction is None:
prediction = self.predict_sliding_window_return_logits(data).to('cpu')
else:
prediction += self.predict_sliding_window_return_logits(data).to('cpu')

if len(self.list_of_parameters) > 1:
prediction /= len(self.list_of_parameters)
if len(self.list_of_parameters) > 1:
prediction /= len(self.list_of_parameters)

if self.verbose: print('Prediction done')
prediction = prediction.to('cpu')
if self.verbose: print('Prediction done')
torch.set_num_threads(n_threads)
return prediction

Expand Down Expand Up @@ -544,11 +542,12 @@ def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor:
# x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3
assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!'

mirror_axes = [m + 2 for m in mirror_axes]
axes_combinations = [
c for i in range(len(mirror_axes)) for c in itertools.combinations([m + 2 for m in mirror_axes], i + 1)
c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1)
]
for axes in axes_combinations:
prediction += torch.flip(self.network(torch.flip(x, (*axes,))), (*axes,))
prediction += torch.flip(self.network(torch.flip(x, axes)), axes)
prediction /= (len(axes_combinations) + 1)
return prediction

Expand All @@ -574,24 +573,27 @@ def _internal_predict_sliding_window_return_logits(self,
predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
dtype=torch.half,
device=results_device)
n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)
if self.use_gaussian:
gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
value_scaling_factor=10,
device=results_device)
n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)
ancestor-mithril marked this conversation as resolved.
Show resolved Hide resolved

if self.verbose: print('running prediction')
if not self.allow_tqdm and self.verbose: print(f'{len(slicers)} steps')
if not self.allow_tqdm and self.verbose:
print(f'running prediction: {len(slicers)} steps')
for sl in tqdm(slicers, disable=not self.allow_tqdm):
workon = data[sl][None]
workon = workon.to(self.device, non_blocking=False)
workon = workon.to(self.device)

prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)

predicted_logits[sl] += (prediction * gaussian if self.use_gaussian else prediction)
n_predictions[sl[1:]] += (gaussian if self.use_gaussian else 1)
if self.use_gaussian:
prediction *= gaussian
n_predictions[sl[1:]] += gaussian
ancestor-mithril marked this conversation as resolved.
Show resolved Hide resolved
predicted_logits[sl] += prediction

predicted_logits /= n_predictions
if self.use_gaussian:
ancestor-mithril marked this conversation as resolved.
Show resolved Hide resolved
predicted_logits /= n_predictions
# check for infs
if torch.any(torch.isinf(predicted_logits)):
raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, '
Expand All @@ -618,38 +620,38 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \
# If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False
# is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
# So autocast will only be active if we have a cuda device.
with torch.no_grad():
with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)'

if self.verbose: print(f'Input shape: {input_image.shape}')
if self.verbose: print("step_size:", self.tile_step_size)
if self.verbose: print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None)

# if input_image is smaller than tile_size we need to pad it to tile_size.
data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size,
'constant', {'value': 0}, True,
None)

slicers = self._internal_get_sliding_window_slicers(data.shape[1:])

if self.perform_everything_on_device and self.device != 'cpu':
# we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device
try:
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,
self.perform_everything_on_device)
except RuntimeError:
print(
'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU')
empty_cache(self.device)
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False)
else:
with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)'

if self.verbose:
print(f'Input shape: {input_image.shape}')
print("step_size:", self.tile_step_size)
print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None)

# if input_image is smaller than tile_size we need to pad it to tile_size.
data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size,
'constant', {'value': 0}, True,
None)

slicers = self._internal_get_sliding_window_slicers(data.shape[1:])

if self.perform_everything_on_device and self.device != 'cpu':
# we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device
try:
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,
self.perform_everything_on_device)
except RuntimeError:
print(
'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU')
empty_cache(self.device)
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False)
else:
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,
self.perform_everything_on_device)

empty_cache(self.device)
# revert padding
predicted_logits = predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])]
empty_cache(self.device)
# revert padding
predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])]
return predicted_logits


Expand Down
12 changes: 5 additions & 7 deletions nnunetv2/inference/sliding_window_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from scipy.ndimage import gaussian_filter


@lru_cache(maxsize=2)
@lru_cache(maxsize=None)
ancestor-mithril marked this conversation as resolved.
Show resolved Hide resolved
def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale: float = 1. / 8,
value_scaling_factor: float = 1, dtype=torch.float16, device=torch.device('cuda', 0)) \
-> torch.Tensor:
Expand All @@ -18,14 +18,12 @@ def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale:
gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)

gaussian_importance_map = torch.from_numpy(gaussian_importance_map)
gaussian_importance_map = gaussian_importance_map.to(device=device, dtype=dtype)

gaussian_importance_map = gaussian_importance_map / torch.max(gaussian_importance_map) * value_scaling_factor
gaussian_importance_map = gaussian_importance_map.type(dtype).to(device)

gaussian_importance_map /= (torch.max(gaussian_importance_map) / value_scaling_factor)
# gaussian_importance_map cannot be 0, otherwise we may end up with nans!
gaussian_importance_map[gaussian_importance_map == 0] = torch.min(
gaussian_importance_map[gaussian_importance_map != 0])

mask = gaussian_importance_map == 0
gaussian_importance_map[mask] = torch.min(gaussian_importance_map[~mask])
return gaussian_importance_map


Expand Down
Loading