Skip to content

Commit

Permalink
Merge remote-tracking branch 'ancestor-mithril/dev2' into release/sta…
Browse files Browse the repository at this point in the history
…ge_pull_requests
  • Loading branch information
FabianIsensee committed Apr 11, 2024
2 parents 31de31c + 809277b commit 3a2d870
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 73 deletions.
13 changes: 4 additions & 9 deletions nnunetv2/inference/data_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def preprocess_fromfiles_save_to_queue(list_of_lists: List[List[str]],
seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype)
data = np.vstack((data, seg_onehot))

data = torch.from_numpy(data).contiguous().float()
data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format)

item = {'data': data, 'data_properties': data_properties,
'ofile': output_filenames_truncated[idx] if output_filenames_truncated is not None else None}
Expand Down Expand Up @@ -146,9 +146,7 @@ def __init__(self, list_of_lists: List[List[str]],

def generate_train_batch(self):
idx = self.get_indices()[0]
files = self._data[idx][0]
seg_prev_stage = self._data[idx][1]
ofile = self._data[idx][2]
files, seg_prev_stage, ofile = self._data[idx][0]
# if we have a segmentation from the previous stage we have to process it together with the images so that we
# can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after
# preprocessing and then there might be misalignments
Expand Down Expand Up @@ -192,10 +190,7 @@ def __init__(self, list_of_images: List[np.ndarray],

def generate_train_batch(self):
idx = self.get_indices()[0]
image = self._data[idx][0]
seg_prev_stage = self._data[idx][1]
props = self._data[idx][2]
ofname = self._data[idx][3]
image, seg_prev_stage, props, ofname = self._data[idx][0]
# if we have a segmentation from the previous stage we have to process it together with the images so that we
# can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after
# preprocessing and then there might be misalignments
Expand Down Expand Up @@ -238,7 +233,7 @@ def preprocess_fromnpy_save_to_queue(list_of_images: List[np.ndarray],
seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype)
data = np.vstack((data, seg_onehot))

data = torch.from_numpy(data).contiguous().float()
data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format)

item = {'data': data, 'data_properties': list_of_image_properties[idx],
'ofile': truncated_ofnames[idx] if truncated_ofnames is not None else None}
Expand Down
120 changes: 62 additions & 58 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ def __init__(self,
self.use_gaussian = use_gaussian
self.use_mirroring = use_mirroring
if device.type == 'cuda':
# device = torch.device(type='cuda', index=0) # set the desired GPU with CUDA_VISIBLE_DEVICES!
pass
if device.type != 'cuda':
torch.backends.cudnn.benchmark = True
else:
print(f'perform_everything_on_device=True is only supported for cuda devices! Setting this to False')
perform_everything_on_device = False
self.device = device
Expand Down Expand Up @@ -464,6 +463,7 @@ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: di
else:
return ret

@torch.inference_mode()
def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor:
"""
IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON
Expand All @@ -474,30 +474,28 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten
"""
n_threads = torch.get_num_threads()
torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads)
with torch.no_grad():
prediction = None
prediction = None

for params in self.list_of_parameters:
for params in self.list_of_parameters:

# messing with state dict names...
if not isinstance(self.network, OptimizedModule):
self.network.load_state_dict(params)
else:
self.network._orig_mod.load_state_dict(params)
# messing with state dict names...
if not isinstance(self.network, OptimizedModule):
self.network.load_state_dict(params)
else:
self.network._orig_mod.load_state_dict(params)

# why not leave prediction on device if perform_everything_on_device? Because this may cause the
# second iteration to crash due to OOM. Grabbing that with try except cause way more bloated code than
# this actually saves computation time
if prediction is None:
prediction = self.predict_sliding_window_return_logits(data).to('cpu')
else:
prediction += self.predict_sliding_window_return_logits(data).to('cpu')
# why not leave prediction on device if perform_everything_on_device? Because this may cause the
# second iteration to crash due to OOM. Grabbing that with try except cause way more bloated code than
# this actually saves computation time
if prediction is None:
prediction = self.predict_sliding_window_return_logits(data).to('cpu')
else:
prediction += self.predict_sliding_window_return_logits(data).to('cpu')

if len(self.list_of_parameters) > 1:
prediction /= len(self.list_of_parameters)
if len(self.list_of_parameters) > 1:
prediction /= len(self.list_of_parameters)

if self.verbose: print('Prediction done')
prediction = prediction.to('cpu')
if self.verbose: print('Prediction done')
torch.set_num_threads(n_threads)
return prediction

Expand Down Expand Up @@ -544,11 +542,12 @@ def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor:
# x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3
assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!'

mirror_axes = [m + 2 for m in mirror_axes]
axes_combinations = [
c for i in range(len(mirror_axes)) for c in itertools.combinations([m + 2 for m in mirror_axes], i + 1)
c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1)
]
for axes in axes_combinations:
prediction += torch.flip(self.network(torch.flip(x, (*axes,))), (*axes,))
prediction += torch.flip(self.network(torch.flip(x, axes)), axes)
prediction /= (len(axes_combinations) + 1)
return prediction

Expand All @@ -575,21 +574,26 @@ def _internal_predict_sliding_window_return_logits(self,
dtype=torch.half,
device=results_device)
n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)

if self.use_gaussian:
gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
value_scaling_factor=10,
device=results_device)
else:
gaussian = 1

if self.verbose: print('running prediction')
if not self.allow_tqdm and self.verbose: print(f'{len(slicers)} steps')
if not self.allow_tqdm and self.verbose:
print(f'running prediction: {len(slicers)} steps')
for sl in tqdm(slicers, disable=not self.allow_tqdm):
workon = data[sl][None]
workon = workon.to(self.device, non_blocking=False)
workon = workon.to(self.device)

prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)

predicted_logits[sl] += (prediction * gaussian if self.use_gaussian else prediction)
n_predictions[sl[1:]] += (gaussian if self.use_gaussian else 1)
if self.use_gaussian:
prediction *= gaussian
predicted_logits[sl] += prediction
n_predictions[sl[1:]] += gaussian

predicted_logits /= n_predictions
# check for infs
Expand Down Expand Up @@ -618,38 +622,38 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \
# If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False
# is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
# So autocast will only be active if we have a cuda device.
with torch.no_grad():
with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)'

if self.verbose: print(f'Input shape: {input_image.shape}')
if self.verbose: print("step_size:", self.tile_step_size)
if self.verbose: print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None)

# if input_image is smaller than tile_size we need to pad it to tile_size.
data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size,
'constant', {'value': 0}, True,
None)

slicers = self._internal_get_sliding_window_slicers(data.shape[1:])

if self.perform_everything_on_device and self.device != 'cpu':
# we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device
try:
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,
self.perform_everything_on_device)
except RuntimeError:
print(
'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU')
empty_cache(self.device)
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False)
else:
with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)'

if self.verbose:
print(f'Input shape: {input_image.shape}')
print("step_size:", self.tile_step_size)
print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None)

# if input_image is smaller than tile_size we need to pad it to tile_size.
data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size,
'constant', {'value': 0}, True,
None)

slicers = self._internal_get_sliding_window_slicers(data.shape[1:])

if self.perform_everything_on_device and self.device != 'cpu':
# we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device
try:
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,
self.perform_everything_on_device)
except RuntimeError:
print(
'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU')
empty_cache(self.device)
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False)
else:
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,
self.perform_everything_on_device)

empty_cache(self.device)
# revert padding
predicted_logits = predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])]
empty_cache(self.device)
# revert padding
predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])]
return predicted_logits


Expand Down
10 changes: 4 additions & 6 deletions nnunetv2/inference/sliding_window_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@ def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale:
gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)

gaussian_importance_map = torch.from_numpy(gaussian_importance_map)
gaussian_importance_map = gaussian_importance_map.to(device=device, dtype=dtype)

gaussian_importance_map = gaussian_importance_map / torch.max(gaussian_importance_map) * value_scaling_factor
gaussian_importance_map = gaussian_importance_map.type(dtype).to(device)

gaussian_importance_map /= (torch.max(gaussian_importance_map) / value_scaling_factor)
# gaussian_importance_map cannot be 0, otherwise we may end up with nans!
gaussian_importance_map[gaussian_importance_map == 0] = torch.min(
gaussian_importance_map[gaussian_importance_map != 0])

mask = gaussian_importance_map == 0
gaussian_importance_map[mask] = torch.min(gaussian_importance_map[~mask])
return gaussian_importance_map


Expand Down

0 comments on commit 3a2d870

Please sign in to comment.