Skip to content

Commit

Permalink
Implement find_objects for Dask Array
Browse files Browse the repository at this point in the history
Provides an implementation of `find_objects` that determines bounding
boxes for each label in the image. Currently requires the user to
specify the number of labels they would like to inspect. Raises a
`NotImplementedError` if the user wishes to collect all bounding boxes
for labels. Works by selecting the 1-D positions that correspond to the
label while ignoring all other points. Assumes that these positions
along with an intermediate array of the same size comfortably fit in
memory.

Within a utility function, determines whether any positions were found
for the corresponding label. If not, simply returns `None`. If positions
were found, it manually unravels the positions and finds the maximum and
minimum positions along each dimension. These are stored into `slice`s,
which are packed into a `tuple` and returned. Makes sure to use in-place
NumPy operations to avoid using additional memory.
  • Loading branch information
jakirkham committed Feb 4, 2019
1 parent edf5bab commit 6786d6f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
40 changes: 40 additions & 0 deletions dask_image/ndmeasure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,46 @@ def extrema(input, labels=None, index=None):
return result


def find_objects(input, max_label=0):
"""
Find the center of mass over an image at specified subregions.
Parameters
----------
input : ndarray
Image features noted by integers.
max_label : int, optional
Maximum label to look for in ``input``. If 0, look for all labels.
Returns
-------
object_slices : ``list`` of ``tuple``s
A ``list`` of ``tuple``s specifying the bounding boxes of each label.
"""

# Normalize arguments
input = dask.array.asarray(input)
max_label = int(max_label)

# Catch unsupported case
if max_label == 0:
raise NotImplementedError(
"Getting all labels is currently not supported."
)

positions = _utils._ravel_shape_indices(
input.shape, chunks=input.chunks
)

object_slices = []
for i in range(1, max_label + 1):
object_slices.append(_utils._find_object(
positions[input == i], input.shape
))

return object_slices


def histogram(input,
min,
max,
Expand Down
22 changes: 22 additions & 0 deletions dask_image/ndmeasure/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,28 @@ def _extrema(a, positions, shape, dtype):
return result[0]


@dask.delayed
def _find_object(pos, shape):
"""
Find bounding box for the given positions.
"""

if pos.size == 0:
return None

pos = numpy.require(pos, requirements='OW')

result = []
pos_i = numpy.empty_like(pos)
for s in reversed(shape):
numpy.mod(pos, s, out=pos_i)
numpy.subtract(pos, pos_i, out=pos)
result.insert(0, slice(pos_i.min(), pos_i.max() + 1))
result = tuple(result)

return result


def _histogram(input,
min,
max,
Expand Down

0 comments on commit 6786d6f

Please sign in to comment.