From ab34ef4952adc94d41ee654336c18cd3f971875b Mon Sep 17 00:00:00 2001 From: ancestor-mithril <58839912+ancestor-mithril@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:56:46 +0200 Subject: [PATCH] Improved crop_to_nonzero The code path used during inference avoids multiple memory allocations which are costly for large volumes. --- nnunetv2/preprocessing/cropping/cropping.py | 28 ++++++++------------- 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/nnunetv2/preprocessing/cropping/cropping.py b/nnunetv2/preprocessing/cropping/cropping.py index 96fe7b7db..354f3d3bf 100644 --- a/nnunetv2/preprocessing/cropping/cropping.py +++ b/nnunetv2/preprocessing/cropping/cropping.py @@ -1,5 +1,5 @@ import numpy as np - +from scipy.ndimage import binary_fill_holes # Hello! crop_to_nonzero is the function you are looking for. Ignore the rest. from acvl_utils.cropping_and_padding.bounding_boxes import get_bbox_from_mask, crop_to_bbox, bounding_box_to_slice @@ -11,14 +11,11 @@ def create_nonzero_mask(data): :param data: :return: the mask is True where the data is nonzero """ - from scipy.ndimage import binary_fill_holes assert data.ndim in (3, 4), "data must have shape (C, X, Y, Z) or shape (C, X, Y)" - nonzero_mask = np.zeros(data.shape[1:], dtype=bool) - for c in range(data.shape[0]): - this_mask = data[c] != 0 - nonzero_mask = nonzero_mask | this_mask - nonzero_mask = binary_fill_holes(nonzero_mask) - return nonzero_mask + nonzero_mask = data[0] != 0 + for c in range(1, data.shape[0]): + nonzero_mask |= data[c] != 0 + return binary_fill_holes(nonzero_mask) def crop_to_nonzero(data, seg=None, nonzero_label=-1): @@ -31,21 +28,16 @@ def crop_to_nonzero(data, seg=None, nonzero_label=-1): """ nonzero_mask = create_nonzero_mask(data) bbox = get_bbox_from_mask(nonzero_mask) - slicer = bounding_box_to_slice(bbox) - data = data[tuple([slice(None), *slicer])] - - if seg is not None: - seg = seg[tuple([slice(None), *slicer])] - nonzero_mask = nonzero_mask[slicer][None] + + slicer = (slice(None), ) + slicer + data = data[slicer] if seg is not None: + seg = seg[slicer] seg[(seg == 0) & (~nonzero_mask)] = nonzero_label else: - nonzero_mask = nonzero_mask.astype(np.int8) - nonzero_mask[nonzero_mask == 0] = nonzero_label - nonzero_mask[nonzero_mask > 0] = 0 - seg = nonzero_mask + seg = np.where(nonzero_mask, np.int8(0), np.int8(nonzero_label)) return data, seg, bbox