From d675153e041eadbd02e0eaaca603a61044939167 Mon Sep 17 00:00:00 2001 From: Pengcheng Shi Date: Thu, 2 May 2024 21:44:37 +0800 Subject: [PATCH 1/2] Implement TTA batch processing for improved inference efficiency --- nnunetv2/inference/predict_from_raw_data.py | 44 ++++++++++++++++----- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 8d4096482..eb6c5ac74 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -536,20 +536,44 @@ def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]): def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor: mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None - prediction = self.network(x) - if mirror_axes is not None: - # check for invalid numbers in mirror_axes - # x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3 - assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!' + if mirror_axes is None: + return self.network(x) - mirror_axes = [m + 2 for m in mirror_axes] - axes_combinations = [ - c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1) - ] + # check for invalid numbers in mirror_axes + # x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3 + assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!' + + mirror_axes = [m + 2 for m in mirror_axes] + axes_combinations = [ + c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1) + ] + + self.tta_batch = True + if self.tta_batch: + self.tta_batch_size = 4 if len(mirror_axes) == 3 else 2 + + assert (len(axes_combinations) + 1) % self.tta_batch_size == 0, '(len(axes_combinations) + 1) must be divisible by self.tta_batch_size' + + x_combinations = [torch.flip(x, axes) for axes in axes_combinations] + x_combinations.insert(0, x) + + prediction = 0 + for i in range(0, len(x_combinations), self.tta_batch_size): + batch_x = torch.cat(x_combinations[i:i+self.tta_batch_size], dim=0) + batch_prediction = self.network(batch_x) + + for j in range(batch_prediction.shape[0]): + original_idx = i + j + axes_to_flip_back = axes_combinations[original_idx - 1] if original_idx != 0 else [] + prediction += torch.flip(batch_prediction[j:j+1], axes_to_flip_back) + + else: + prediction = self.network(x) for axes in axes_combinations: prediction += torch.flip(self.network(torch.flip(x, axes)), axes) - prediction /= (len(axes_combinations) + 1) + + prediction /= (len(axes_combinations) + 1) return prediction def _internal_predict_sliding_window_return_logits(self, From 9aee47e3f29123c8f30846fa1f692b72e030345e Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 3 May 2024 22:26:23 +0800 Subject: [PATCH 2/2] Implement TTA batch processing for improved inference efficiency --- nnunetv2/inference/examples.py | 1 + nnunetv2/inference/predict_from_raw_data.py | 24 ++++++++++++++----- nnunetv2/inference/readme.md | 1 + .../training/nnUNetTrainer/nnUNetTrainer.py | 2 +- 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/nnunetv2/inference/examples.py b/nnunetv2/inference/examples.py index a66d98f8b..44b243f7f 100644 --- a/nnunetv2/inference/examples.py +++ b/nnunetv2/inference/examples.py @@ -12,6 +12,7 @@ tile_step_size=0.5, use_gaussian=True, use_mirroring=True, + use_batch_tta=True, perform_everything_on_device=True, device=torch.device('cuda', 0), verbose=False, diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index eb6c5ac74..3a3761861 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -39,6 +39,7 @@ def __init__(self, tile_step_size: float = 0.5, use_gaussian: bool = True, use_mirroring: bool = True, + use_batch_tta: bool = True, perform_everything_on_device: bool = True, device: torch.device = torch.device('cuda'), verbose: bool = False, @@ -54,6 +55,7 @@ def __init__(self, self.tile_step_size = tile_step_size self.use_gaussian = use_gaussian self.use_mirroring = use_mirroring + self.use_batch_tta = use_batch_tta if device.type == 'cuda': torch.backends.cudnn.benchmark = True else: @@ -549,18 +551,17 @@ def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor: c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1) ] - self.tta_batch = True - if self.tta_batch: - self.tta_batch_size = 4 if len(mirror_axes) == 3 else 2 + if self.use_batch_tta: + tta_batch_size = 4 if len(mirror_axes) == 3 else 2 - assert (len(axes_combinations) + 1) % self.tta_batch_size == 0, '(len(axes_combinations) + 1) must be divisible by self.tta_batch_size' + assert (len(axes_combinations) + 1) % tta_batch_size == 0, '(len(axes_combinations) + 1) must be divisible by tta_batch_size' x_combinations = [torch.flip(x, axes) for axes in axes_combinations] x_combinations.insert(0, x) prediction = 0 - for i in range(0, len(x_combinations), self.tta_batch_size): - batch_x = torch.cat(x_combinations[i:i+self.tta_batch_size], dim=0) + for i in range(0, len(x_combinations), tta_batch_size): + batch_x = torch.cat(x_combinations[i:i+tta_batch_size], dim=0) batch_prediction = self.network(batch_x) for j in range(batch_prediction.shape[0]): @@ -707,6 +708,9 @@ def predict_entry_point_modelfolder(): parser.add_argument('--disable_tta', action='store_true', required=False, default=False, help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, ' 'but less accurate inference. Not recommended.') + parser.add_argument('--disable_batch_tta', action='store_true', required=False, default=False, + help='Set this flag to disable batched test time data augmentation. This will slow down inference, ' + 'but may help with out-of-VRAM issues.') parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have " "to be a good listener/reader.") parser.add_argument('--save_probabilities', action='store_true', @@ -763,6 +767,7 @@ def predict_entry_point_modelfolder(): predictor = nnUNetPredictor(tile_step_size=args.step_size, use_gaussian=True, use_mirroring=not args.disable_tta, + use_batch_tta=not args.disable_batch_tta, perform_everything_on_device=True, device=device, verbose=args.verbose, @@ -808,6 +813,9 @@ def predict_entry_point(): parser.add_argument('--disable_tta', action='store_true', required=False, default=False, help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, ' 'but less accurate inference. Not recommended.') + parser.add_argument('--disable_batch_tta', action='store_true', required=False, default=False, + help='Set this flag to disable batched test time data augmentation. This will slow down inference, ' + 'but may help with out-of-VRAM issues.') parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have " "to be a good listener/reader.") parser.add_argument('--save_probabilities', action='store_true', @@ -877,6 +885,7 @@ def predict_entry_point(): predictor = nnUNetPredictor(tile_step_size=args.step_size, use_gaussian=True, use_mirroring=not args.disable_tta, + use_batch_tta=not args.disable_batch_tta, perform_everything_on_device=True, device=device, verbose=args.verbose, @@ -901,6 +910,7 @@ def predict_entry_point(): # args.step_size, # use_gaussian=True, # use_mirroring=not args.disable_tta, + # use_batch_tta=not args.disable_batch_tta, # perform_everything_on_device=True, # verbose=args.verbose, # save_probabilities=args.save_probabilities, @@ -922,6 +932,7 @@ def predict_entry_point(): tile_step_size=0.5, use_gaussian=True, use_mirroring=True, + use_batch_tta=True, perform_everything_on_device=True, device=torch.device('cuda', 0), verbose=False, @@ -952,6 +963,7 @@ def predict_entry_point(): # tile_step_size=0.5, # use_gaussian=True, # use_mirroring=True, + # use_batch_tta=True, # perform_everything_on_device=True, # device=torch.device('cuda', 0), # verbose=False, diff --git a/nnunetv2/inference/readme.md b/nnunetv2/inference/readme.md index d0acc6b2a..3a62c23ba 100644 --- a/nnunetv2/inference/readme.md +++ b/nnunetv2/inference/readme.md @@ -57,6 +57,7 @@ Example: tile_step_size=0.5, use_gaussian=True, use_mirroring=True, + use_batch_tta=True, perform_everything_on_device=True, device=torch.device('cuda', 0), verbose=False, diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index c709b6de6..f948c2513 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -1204,7 +1204,7 @@ def perform_actual_validation(self, save_probabilities: bool = False): "forward pass (where compile is triggered) already has deep supervision disabled. " "This is exactly what we need in perform_actual_validation") - predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True, + predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True, use_batch_tta=True, perform_everything_on_device=True, device=self.device, verbose=False, verbose_preprocessing=False, allow_tqdm=False) predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None,