Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TTA Batch Processing to Improve Inference Speed #2153

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nnunetv2/inference/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
tile_step_size=0.5,
use_gaussian=True,
use_mirroring=True,
use_batch_tta=True,
perform_everything_on_device=True,
device=torch.device('cuda', 0),
verbose=False,
Expand Down
56 changes: 46 additions & 10 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self,
tile_step_size: float = 0.5,
use_gaussian: bool = True,
use_mirroring: bool = True,
use_batch_tta: bool = True,
perform_everything_on_device: bool = True,
device: torch.device = torch.device('cuda'),
verbose: bool = False,
Expand All @@ -54,6 +55,7 @@ def __init__(self,
self.tile_step_size = tile_step_size
self.use_gaussian = use_gaussian
self.use_mirroring = use_mirroring
self.use_batch_tta = use_batch_tta
if device.type == 'cuda':
torch.backends.cudnn.benchmark = True
else:
Expand Down Expand Up @@ -536,20 +538,43 @@ def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]):

def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor:
mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None
prediction = self.network(x)

if mirror_axes is not None:
# check for invalid numbers in mirror_axes
# x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3
assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!'
if mirror_axes is None:
return self.network(x)

mirror_axes = [m + 2 for m in mirror_axes]
axes_combinations = [
c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1)
]
# check for invalid numbers in mirror_axes
# x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3
assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!'

mirror_axes = [m + 2 for m in mirror_axes]
axes_combinations = [
c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1)
]

if self.use_batch_tta:
tta_batch_size = 4 if len(mirror_axes) == 3 else 2

assert (len(axes_combinations) + 1) % tta_batch_size == 0, '(len(axes_combinations) + 1) must be divisible by tta_batch_size'

x_combinations = [torch.flip(x, axes) for axes in axes_combinations]
x_combinations.insert(0, x)

prediction = 0
for i in range(0, len(x_combinations), tta_batch_size):
batch_x = torch.cat(x_combinations[i:i+tta_batch_size], dim=0)
batch_prediction = self.network(batch_x)

for j in range(batch_prediction.shape[0]):
original_idx = i + j
axes_to_flip_back = axes_combinations[original_idx - 1] if original_idx != 0 else []
prediction += torch.flip(batch_prediction[j:j+1], axes_to_flip_back)

else:
prediction = self.network(x)
for axes in axes_combinations:
prediction += torch.flip(self.network(torch.flip(x, axes)), axes)
prediction /= (len(axes_combinations) + 1)

prediction /= (len(axes_combinations) + 1)
return prediction

def _internal_predict_sliding_window_return_logits(self,
Expand Down Expand Up @@ -683,6 +708,9 @@ def predict_entry_point_modelfolder():
parser.add_argument('--disable_tta', action='store_true', required=False, default=False,
help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, '
'but less accurate inference. Not recommended.')
parser.add_argument('--disable_batch_tta', action='store_true', required=False, default=False,
help='Set this flag to disable batched test time data augmentation. This will slow down inference, '
'but may help with out-of-VRAM issues.')
parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have "
"to be a good listener/reader.")
parser.add_argument('--save_probabilities', action='store_true',
Expand Down Expand Up @@ -739,6 +767,7 @@ def predict_entry_point_modelfolder():
predictor = nnUNetPredictor(tile_step_size=args.step_size,
use_gaussian=True,
use_mirroring=not args.disable_tta,
use_batch_tta=not args.disable_batch_tta,
perform_everything_on_device=True,
device=device,
verbose=args.verbose,
Expand Down Expand Up @@ -784,6 +813,9 @@ def predict_entry_point():
parser.add_argument('--disable_tta', action='store_true', required=False, default=False,
help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, '
'but less accurate inference. Not recommended.')
parser.add_argument('--disable_batch_tta', action='store_true', required=False, default=False,
help='Set this flag to disable batched test time data augmentation. This will slow down inference, '
'but may help with out-of-VRAM issues.')
parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have "
"to be a good listener/reader.")
parser.add_argument('--save_probabilities', action='store_true',
Expand Down Expand Up @@ -853,6 +885,7 @@ def predict_entry_point():
predictor = nnUNetPredictor(tile_step_size=args.step_size,
use_gaussian=True,
use_mirroring=not args.disable_tta,
use_batch_tta=not args.disable_batch_tta,
perform_everything_on_device=True,
device=device,
verbose=args.verbose,
Expand All @@ -877,6 +910,7 @@ def predict_entry_point():
# args.step_size,
# use_gaussian=True,
# use_mirroring=not args.disable_tta,
# use_batch_tta=not args.disable_batch_tta,
# perform_everything_on_device=True,
# verbose=args.verbose,
# save_probabilities=args.save_probabilities,
Expand All @@ -898,6 +932,7 @@ def predict_entry_point():
tile_step_size=0.5,
use_gaussian=True,
use_mirroring=True,
use_batch_tta=True,
perform_everything_on_device=True,
device=torch.device('cuda', 0),
verbose=False,
Expand Down Expand Up @@ -928,6 +963,7 @@ def predict_entry_point():
# tile_step_size=0.5,
# use_gaussian=True,
# use_mirroring=True,
# use_batch_tta=True,
# perform_everything_on_device=True,
# device=torch.device('cuda', 0),
# verbose=False,
Expand Down
1 change: 1 addition & 0 deletions nnunetv2/inference/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Example:
tile_step_size=0.5,
use_gaussian=True,
use_mirroring=True,
use_batch_tta=True,
perform_everything_on_device=True,
device=torch.device('cuda', 0),
verbose=False,
Expand Down
2 changes: 1 addition & 1 deletion nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,7 +1228,7 @@ def perform_actual_validation(self, save_probabilities: bool = False):
"forward pass (where compile is triggered) already has deep supervision disabled. "
"This is exactly what we need in perform_actual_validation")

predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True,
predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True, use_batch_tta=True,
perform_everything_on_device=True, device=self.device, verbose=False,
verbose_preprocessing=False, allow_tqdm=False)
predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None,
Expand Down