Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved crop_to_nonzero #2049

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 10 additions & 18 deletions nnunetv2/preprocessing/cropping/cropping.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np

from scipy.ndimage import binary_fill_holes

# Hello! crop_to_nonzero is the function you are looking for. Ignore the rest.
from acvl_utils.cropping_and_padding.bounding_boxes import get_bbox_from_mask, crop_to_bbox, bounding_box_to_slice
Expand All @@ -11,14 +11,11 @@ def create_nonzero_mask(data):
:param data:
:return: the mask is True where the data is nonzero
"""
from scipy.ndimage import binary_fill_holes
assert data.ndim in (3, 4), "data must have shape (C, X, Y, Z) or shape (C, X, Y)"
nonzero_mask = np.zeros(data.shape[1:], dtype=bool)
for c in range(data.shape[0]):
this_mask = data[c] != 0
nonzero_mask = nonzero_mask | this_mask
nonzero_mask = binary_fill_holes(nonzero_mask)
return nonzero_mask
nonzero_mask = data[0] != 0
for c in range(1, data.shape[0]):
nonzero_mask |= data[c] != 0
return binary_fill_holes(nonzero_mask)


def crop_to_nonzero(data, seg=None, nonzero_label=-1):
Expand All @@ -31,21 +28,16 @@ def crop_to_nonzero(data, seg=None, nonzero_label=-1):
"""
nonzero_mask = create_nonzero_mask(data)
bbox = get_bbox_from_mask(nonzero_mask)

slicer = bounding_box_to_slice(bbox)
data = data[tuple([slice(None), *slicer])]

if seg is not None:
seg = seg[tuple([slice(None), *slicer])]

nonzero_mask = nonzero_mask[slicer][None]

slicer = (slice(None), ) + slicer
data = data[slicer]
if seg is not None:
seg = seg[slicer]
seg[(seg == 0) & (~nonzero_mask)] = nonzero_label
else:
nonzero_mask = nonzero_mask.astype(np.int8)
nonzero_mask[nonzero_mask == 0] = nonzero_label
nonzero_mask[nonzero_mask > 0] = 0
seg = nonzero_mask
seg = np.where(nonzero_mask, np.int8(0), np.int8(nonzero_label))
return data, seg, bbox


Loading