diff --git a/dask_image/ndmeasure/__init__.py b/dask_image/ndmeasure/__init__.py index a235955d..20bcb4c7 100644 --- a/dask_image/ndmeasure/__init__.py +++ b/dask_image/ndmeasure/__init__.py @@ -6,6 +6,8 @@ import collections import functools +import itertools +import operator from warnings import warn import numpy @@ -127,6 +129,50 @@ def extrema(input, labels=None, index=None): return result +def find_object_chunk(x_chunk, offset, max_label): + """Wrapper around scipy.ndimage's find_object""" + + result = numpy.array(scipy.ndimage.find_objects(x_chunk, max_label), + dtype=object) + for i in range(len(result)): + box = result[i] + if box is not None: + r_sl = [] + for sl, o in zip(box, offset): + r_sl.append(slice(sl.start + o, sl.stop + o)) + r_sl = tuple(r_sl) + result[i] = r_sl + + return result + + +def find_object_chunk_delayed(x_chunk, offset, max_label): + """Delayed wrapper around ``find_object_chunk``""" + r = dask.delayed(find_object_chunk)(x_chunk, offset, max_label) + r = dask.array.from_delayed(r, shape=(numpy.nan, 2), dtype=object) + return r + + +def find_object_aggregate(multichunk_objs, axis, keepdims): + """Wrapper around scipy.ndimage's find_object""" + + r = [] + for obj in itertools.zip_longest(*multichunk_objs, fillvalue=None): + obj = list(itertools.filterfalse(lambda v: v is None, obj)) + if obj: + r_sl = [] + for slices in zip(*obj): + start_values = map(operator.attrgetter("start"), slices) + stop_values = map(operator.attrgetter("stop"), slices) + r_sl.append(slice(min(start_values), max(stop_values))) + r_sl = tuple(r_sl) + r.append(r_sl) + else: + r.append(None) + + return r + + def find_objects(input, max_label=0): """ Find the center of mass over an image at specified subregions. @@ -147,22 +193,24 @@ def find_objects(input, max_label=0): input = dask.array.asarray(input) max_label = int(max_label) - if max_label == 0: - raise NotImplementedError( - "Getting all labels is currently not supported." - ) - - positions = _utils._ravel_shape_indices( - input.shape, chunks=input.chunks - ) - - object_slices = [] - for i in range(1, max_label + 1): - object_slices.append(_utils._find_object( - positions[input == i], input.shape - )) - - return object_slices + mapped_blocks = numpy.empty(input.numblocks, dtype=object) + for index, cslice in zip(numpy.ndindex(*input.numblocks), + dask.array.core.slices_from_chunks(input.chunks)): + offset = tuple(map(operator.attrgetter("start"), cslice)) + input_block = input[cslice] + mapped_blocks[index] = find_object_chunk_delayed(input_block, + offset, + max_label) + mapped_blocks = dask.array.block(mapped_blocks.tolist(), + allow_unknown_chunksizes=True) + + r = dask.array.reductions.reduction(mapped_blocks, + lambda x_chunk, axis, keepdims: x_chunk, + find_object_aggregate, + dtype=object, + concatenate=False) + + return r def histogram(input,