Skip to content

Commit

Permalink
cleanup inference code, better verbose messages
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Dec 12, 2023
1 parent cd8927d commit 1315d2f
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 112 deletions.
213 changes: 104 additions & 109 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self,
tile_step_size: float = 0.5,
use_gaussian: bool = True,
use_mirroring: bool = True,
perform_everything_on_gpu: bool = True,
perform_everything_on_device: bool = True,
device: torch.device = torch.device('cuda'),
verbose: bool = False,
verbose_preprocessing: bool = False,
Expand All @@ -60,10 +60,10 @@ def __init__(self,
# why would I ever want to do that. Stupid dobby. This kills DDP inference...
pass
if device.type != 'cuda':
print(f'perform_everything_on_gpu=True is only supported for cuda devices! Setting this to False')
perform_everything_on_gpu = False
print(f'perform_everything_on_device=True is only supported for cuda devices! Setting this to False')
perform_everything_on_device = False
self.device = device
self.perform_everything_on_gpu = perform_everything_on_gpu
self.perform_everything_on_device = perform_everything_on_device

def initialize_from_trained_model_folder(self, model_training_output_dir: str,
use_folds: Union[Tuple[Union[int, str]], None],
Expand Down Expand Up @@ -111,7 +111,7 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str,
self.label_manager = plans_manager.get_label_manager(dataset_json)
if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \
and not isinstance(self.network, OptimizedModule):
print('compiling network')
print('Using torch.compile')
self.network = torch.compile(self.network)

def manual_initialization(self, network: nn.Module, plans_manager: PlansManager,
Expand All @@ -135,7 +135,7 @@ def manual_initialization(self, network: nn.Module, plans_manager: PlansManager,
if isinstance(self.network, DistributedDataParallel):
allow_compile = allow_compile and isinstance(self.network.module, OptimizedModule)
if allow_compile:
print('compiling network')
print('Using torch.compile')
self.network = torch.compile(self.network)

@staticmethod
Expand Down Expand Up @@ -353,7 +353,7 @@ def predict_from_data_iterator(self,
else:
print(f'\nPredicting image of shape {data.shape}:')

print(f'perform_everything_on_gpu: {self.perform_everything_on_gpu}')
print(f'perform_everything_on_device: {self.perform_everything_on_device}')

properties = preprocessed['data_properties']

Expand Down Expand Up @@ -454,56 +454,33 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten
RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE.
SEE convert_predicted_logits_to_segmentation_with_correct_shape
"""
# we have some code duplication here but this allows us to run with perform_everything_on_gpu=True as
# default and not have the entire program crash in case of GPU out of memory. Neat. That should make
# things a lot faster for some datasets.
original_perform_everything_on_gpu = self.perform_everything_on_gpu
n_threads = torch.get_num_threads()
torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads)
with torch.no_grad():
prediction = None
if self.perform_everything_on_gpu:
try:
for params in self.list_of_parameters:

# messing with state dict names...
if not isinstance(self.network, OptimizedModule):
self.network.load_state_dict(params)
else:
self.network._orig_mod.load_state_dict(params)

if prediction is None:
prediction = self.predict_sliding_window_return_logits(data)
else:
prediction += self.predict_sliding_window_return_logits(data)

if len(self.list_of_parameters) > 1:
prediction /= len(self.list_of_parameters)

except RuntimeError:
print('Prediction with perform_everything_on_gpu=True failed due to insufficient GPU memory. '
'Falling back to perform_everything_on_gpu=False. Not a big deal, just slower...')
print('Error:')
traceback.print_exc()
prediction = None
self.perform_everything_on_gpu = False

if prediction is None:
for params in self.list_of_parameters:
# messing with state dict names...
if not isinstance(self.network, OptimizedModule):
self.network.load_state_dict(params)
else:
self.network._orig_mod.load_state_dict(params)

if prediction is None:
prediction = self.predict_sliding_window_return_logits(data)
else:
prediction += self.predict_sliding_window_return_logits(data)
if len(self.list_of_parameters) > 1:
prediction /= len(self.list_of_parameters)

print('Prediction done, transferring to CPU if needed')

for params in self.list_of_parameters:

# messing with state dict names...
if not isinstance(self.network, OptimizedModule):
self.network.load_state_dict(params)
else:
self.network._orig_mod.load_state_dict(params)

# why not leave prediction on device if perform_everything_on_device? Because this may cause the
# second iteration to crash due to OOM. Grabbing tha twith try except cause way more bloated code than

Check failure on line 471 in nnunetv2/inference/predict_from_raw_data.py

View workflow job for this annotation

GitHub Actions / Check for spelling errors

tha ==> than, that, the
# this actually saves computation time
if prediction is None:
prediction = self.predict_sliding_window_return_logits(data).to('cpu')
else:
prediction += self.predict_sliding_window_return_logits(data).to('cpu')

if len(self.list_of_parameters) > 1:
prediction /= len(self.list_of_parameters)

if self.verbose: print('Prediction done')
prediction = prediction.to('cpu')
self.perform_everything_on_gpu = original_perform_everything_on_gpu
torch.set_num_threads(n_threads)
return prediction

def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]):
Expand Down Expand Up @@ -557,6 +534,48 @@ def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor:
prediction /= (len(axes_combinations) + 1)
return prediction

def _internal_predict_sliding_window_return_logits(self,
data: torch.Tensor,
slicers,
do_on_device: bool = True,
):
results_device = self.device if do_on_device else torch.device('cpu')

# move data to device
if self.verbose: print(f'move image to device {results_device}')
data = data.to(self.device)

# preallocate arrays
if self.verbose: print(f'preallocating results arrays on device {results_device}')
predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
dtype=torch.half,
device=results_device)
n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)
if self.use_gaussian:
gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
value_scaling_factor=10,
device=results_device)
empty_cache(self.device)

if self.verbose: print('running prediction')
if not self.allow_tqdm and self.verbose: print(f'{len(slicers)} steps')
for sl in tqdm(slicers, disable=not self.allow_tqdm):
workon = data[sl][None]
workon = workon.to(self.device, non_blocking=False)

prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)

predicted_logits[sl] += (prediction * gaussian if self.use_gaussian else prediction)
n_predictions[sl[1:]] += (gaussian if self.use_gaussian else 1)

predicted_logits /= n_predictions
# check for infs
if torch.any(torch.isinf(predicted_logits)):
raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, '
'reduce value_scaling_factor in compute_gaussian or increase the dtype of '
'predicted_logits to fp32')
return predicted_logits

def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \
-> Union[np.ndarray, torch.Tensor]:
assert isinstance(input_image, torch.Tensor)
Expand Down Expand Up @@ -586,54 +605,21 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \

slicers = self._internal_get_sliding_window_slicers(data.shape[1:])

# preallocate results and num_predictions
results_device = self.device if self.perform_everything_on_gpu else torch.device('cpu')
if self.verbose: print('preallocating arrays')
try:
data = data.to(self.device)
predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
dtype=torch.half,
device=results_device)
n_predictions = torch.zeros(data.shape[1:], dtype=torch.half,
device=results_device)
if self.use_gaussian:
gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
value_scaling_factor=10,
device=results_device)
except RuntimeError:
# sometimes the stuff is too large for GPUs. In that case fall back to CPU
results_device = torch.device('cpu')
data = data.to(results_device)
predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
dtype=torch.half,
device=results_device)
n_predictions = torch.zeros(data.shape[1:], dtype=torch.half,
device=results_device)
if self.use_gaussian:
gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
value_scaling_factor=10,
device=results_device)
finally:
empty_cache(self.device)

if self.verbose: print('running prediction')
for sl in tqdm(slicers, disable=not self.allow_tqdm):
workon = data[sl][None]
workon = workon.to(self.device, non_blocking=False)

prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)

predicted_logits[sl] += (prediction * gaussian if self.use_gaussian else prediction)
n_predictions[sl[1:]] += (gaussian if self.use_gaussian else 1)

predicted_logits /= n_predictions
# check for infs
if torch.any(torch.isinf(predicted_logits)):
raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, '
'reduce value_scaling_factor in compute_gaussian or increase the dtype of '
'predicted_logits to fp32')
empty_cache(self.device)
return predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])]
if self.perform_everything_on_device and self.device != 'cpu':
# we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device
try:
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, self.perform_everything_on_device)
except RuntimeError:
print('Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU')
empty_cache(self.device)
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False)
else:
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, self.perform_everything_on_device)

empty_cache(self.device)
# revert padding
predicted_logits = predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])]
return predicted_logits


def predict_entry_point_modelfolder():
Expand Down Expand Up @@ -681,6 +667,10 @@ def predict_entry_point_modelfolder():
help="Use this to set the device the inference should run with. Available options are 'cuda' "
"(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! "
"Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!")
parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False,
help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive '
'jobs)')


print(
"\n#######################################################################\nPlease cite the following paper "
Expand Down Expand Up @@ -713,9 +703,10 @@ def predict_entry_point_modelfolder():
predictor = nnUNetPredictor(tile_step_size=args.step_size,
use_gaussian=True,
use_mirroring=not args.disable_tta,
perform_everything_on_gpu=True,
perform_everything_on_device=True,
device=device,
verbose=args.verbose)
verbose=args.verbose,
allow_tqdm=not args.disable_progress_bar)
predictor.initialize_from_trained_model_folder(args.m, args.f, args.chk)
predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities,
overwrite=not args.continue_prediction,
Expand Down Expand Up @@ -785,6 +776,9 @@ def predict_entry_point():
help="Use this to set the device the inference should run with. Available options are 'cuda' "
"(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! "
"Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!")
parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False,
help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive '
'jobs)')

print(
"\n#######################################################################\nPlease cite the following paper "
Expand Down Expand Up @@ -822,10 +816,11 @@ def predict_entry_point():
predictor = nnUNetPredictor(tile_step_size=args.step_size,
use_gaussian=True,
use_mirroring=not args.disable_tta,
perform_everything_on_gpu=True,
perform_everything_on_device=True,
device=device,
verbose=args.verbose,
verbose_preprocessing=False)
verbose_preprocessing=False,
allow_tqdm=not args.disable_progress_bar)
predictor.initialize_from_trained_model_folder(
model_folder,
args.f,
Expand All @@ -845,7 +840,7 @@ def predict_entry_point():
# args.step_size,
# use_gaussian=True,
# use_mirroring=not args.disable_tta,
# perform_everything_on_gpu=True,
# perform_everything_on_device=True,
# verbose=args.verbose,
# save_probabilities=args.save_probabilities,
# overwrite=not args.continue_prediction,
Expand All @@ -865,7 +860,7 @@ def predict_entry_point():
tile_step_size=0.5,
use_gaussian=True,
use_mirroring=True,
perform_everything_on_gpu=True,
perform_everything_on_device=True,
device=torch.device('cuda', 0),
verbose=False,
verbose_preprocessing=False,
Expand Down Expand Up @@ -895,7 +890,7 @@ def predict_entry_point():
# tile_step_size=0.5,
# use_gaussian=True,
# use_mirroring=True,
# perform_everything_on_gpu=True,
# perform_everything_on_device=True,
# device=torch.device('cuda', 0),
# verbose=False,
# allow_tqdm=True
Expand Down
6 changes: 3 additions & 3 deletions nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def initialize(self):
).to(self.device)
# compile network for free speedup
if self._do_i_compile():
self.print_to_log_file('Compiling network...')
self.print_to_log_file('Using torch.compile...')
self.network = torch.compile(self.network)

self.optimizer, self.lr_scheduler = self.configure_optimizers()
Expand Down Expand Up @@ -1174,9 +1174,9 @@ def perform_actual_validation(self, save_probabilities: bool = False):
try:
prediction = predictor.predict_sliding_window_return_logits(data)
except RuntimeError:
predictor.perform_everything_on_gpu = False
predictor.perform_everything_on_device = False
prediction = predictor.predict_sliding_window_return_logits(data)
predictor.perform_everything_on_gpu = True
predictor.perform_everything_on_device = True

prediction = prediction.cpu()

Expand Down

0 comments on commit 1315d2f

Please sign in to comment.