diff --git a/nnunetv2/inference/examples.py b/nnunetv2/inference/examples.py index a66d98f8b..44b243f7f 100644 --- a/nnunetv2/inference/examples.py +++ b/nnunetv2/inference/examples.py @@ -12,6 +12,7 @@ tile_step_size=0.5, use_gaussian=True, use_mirroring=True, + use_batch_tta=True, perform_everything_on_device=True, device=torch.device('cuda', 0), verbose=False, diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 8d4096482..3a3761861 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -39,6 +39,7 @@ def __init__(self, tile_step_size: float = 0.5, use_gaussian: bool = True, use_mirroring: bool = True, + use_batch_tta: bool = True, perform_everything_on_device: bool = True, device: torch.device = torch.device('cuda'), verbose: bool = False, @@ -54,6 +55,7 @@ def __init__(self, self.tile_step_size = tile_step_size self.use_gaussian = use_gaussian self.use_mirroring = use_mirroring + self.use_batch_tta = use_batch_tta if device.type == 'cuda': torch.backends.cudnn.benchmark = True else: @@ -536,20 +538,43 @@ def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]): def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor: mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None - prediction = self.network(x) - if mirror_axes is not None: - # check for invalid numbers in mirror_axes - # x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3 - assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!' + if mirror_axes is None: + return self.network(x) - mirror_axes = [m + 2 for m in mirror_axes] - axes_combinations = [ - c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1) - ] + # check for invalid numbers in mirror_axes + # x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3 + assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!' + + mirror_axes = [m + 2 for m in mirror_axes] + axes_combinations = [ + c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1) + ] + + if self.use_batch_tta: + tta_batch_size = 4 if len(mirror_axes) == 3 else 2 + + assert (len(axes_combinations) + 1) % tta_batch_size == 0, '(len(axes_combinations) + 1) must be divisible by tta_batch_size' + + x_combinations = [torch.flip(x, axes) for axes in axes_combinations] + x_combinations.insert(0, x) + + prediction = 0 + for i in range(0, len(x_combinations), tta_batch_size): + batch_x = torch.cat(x_combinations[i:i+tta_batch_size], dim=0) + batch_prediction = self.network(batch_x) + + for j in range(batch_prediction.shape[0]): + original_idx = i + j + axes_to_flip_back = axes_combinations[original_idx - 1] if original_idx != 0 else [] + prediction += torch.flip(batch_prediction[j:j+1], axes_to_flip_back) + + else: + prediction = self.network(x) for axes in axes_combinations: prediction += torch.flip(self.network(torch.flip(x, axes)), axes) - prediction /= (len(axes_combinations) + 1) + + prediction /= (len(axes_combinations) + 1) return prediction def _internal_predict_sliding_window_return_logits(self, @@ -683,6 +708,9 @@ def predict_entry_point_modelfolder(): parser.add_argument('--disable_tta', action='store_true', required=False, default=False, help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, ' 'but less accurate inference. Not recommended.') + parser.add_argument('--disable_batch_tta', action='store_true', required=False, default=False, + help='Set this flag to disable batched test time data augmentation. This will slow down inference, ' + 'but may help with out-of-VRAM issues.') parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have " "to be a good listener/reader.") parser.add_argument('--save_probabilities', action='store_true', @@ -739,6 +767,7 @@ def predict_entry_point_modelfolder(): predictor = nnUNetPredictor(tile_step_size=args.step_size, use_gaussian=True, use_mirroring=not args.disable_tta, + use_batch_tta=not args.disable_batch_tta, perform_everything_on_device=True, device=device, verbose=args.verbose, @@ -784,6 +813,9 @@ def predict_entry_point(): parser.add_argument('--disable_tta', action='store_true', required=False, default=False, help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, ' 'but less accurate inference. Not recommended.') + parser.add_argument('--disable_batch_tta', action='store_true', required=False, default=False, + help='Set this flag to disable batched test time data augmentation. This will slow down inference, ' + 'but may help with out-of-VRAM issues.') parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have " "to be a good listener/reader.") parser.add_argument('--save_probabilities', action='store_true', @@ -853,6 +885,7 @@ def predict_entry_point(): predictor = nnUNetPredictor(tile_step_size=args.step_size, use_gaussian=True, use_mirroring=not args.disable_tta, + use_batch_tta=not args.disable_batch_tta, perform_everything_on_device=True, device=device, verbose=args.verbose, @@ -877,6 +910,7 @@ def predict_entry_point(): # args.step_size, # use_gaussian=True, # use_mirroring=not args.disable_tta, + # use_batch_tta=not args.disable_batch_tta, # perform_everything_on_device=True, # verbose=args.verbose, # save_probabilities=args.save_probabilities, @@ -898,6 +932,7 @@ def predict_entry_point(): tile_step_size=0.5, use_gaussian=True, use_mirroring=True, + use_batch_tta=True, perform_everything_on_device=True, device=torch.device('cuda', 0), verbose=False, @@ -928,6 +963,7 @@ def predict_entry_point(): # tile_step_size=0.5, # use_gaussian=True, # use_mirroring=True, + # use_batch_tta=True, # perform_everything_on_device=True, # device=torch.device('cuda', 0), # verbose=False, diff --git a/nnunetv2/inference/readme.md b/nnunetv2/inference/readme.md index d0acc6b2a..3a62c23ba 100644 --- a/nnunetv2/inference/readme.md +++ b/nnunetv2/inference/readme.md @@ -57,6 +57,7 @@ Example: tile_step_size=0.5, use_gaussian=True, use_mirroring=True, + use_batch_tta=True, perform_everything_on_device=True, device=torch.device('cuda', 0), verbose=False, diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index b23847cb2..db55a93fc 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -1228,7 +1228,7 @@ def perform_actual_validation(self, save_probabilities: bool = False): "forward pass (where compile is triggered) already has deep supervision disabled. " "This is exactly what we need in perform_actual_validation") - predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True, + predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True, use_batch_tta=True, perform_everything_on_device=True, device=self.device, verbose=False, verbose_preprocessing=False, allow_tqdm=False) predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None,