Skip to content

Commit

Permalink
WIP: Implement find_objects with a reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham committed Feb 10, 2019
1 parent 52847e6 commit d434f70
Showing 1 changed file with 65 additions and 16 deletions.
81 changes: 65 additions & 16 deletions dask_image/ndmeasure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

import collections
import functools
import itertools
import operator

import numpy
import scipy

import dask.array

Expand Down Expand Up @@ -126,6 +129,50 @@ def extrema(input, labels=None, index=None):
return result


def find_object_chunk(x_chunk, offset, max_label):
"""Wrapper around scipy.ndimage's find_object"""

result = numpy.array(scipy.ndimage.find_objects(x_chunk, max_label),
dtype=object)
for i in range(len(result)):
box = result[i]
if box is not None:
r_sl = []
for sl, o in zip(box, offset):
r_sl.append(slice(sl.start + o, sl.stop + o))
r_sl = tuple(r_sl)
result[i] = r_sl

return result


def find_object_chunk_delayed(x_chunk, offset, max_label):
"""Delayed wrapper around ``find_object_chunk``"""
r = dask.delayed(find_object_chunk)(x_chunk, offset, max_label)
r = dask.array.from_delayed(r, shape=(numpy.nan, 2), dtype=object)
return r


def find_object_aggregate(multichunk_objs, axis, keepdims):
"""Wrapper around scipy.ndimage's find_object"""

r = []
for obj in itertools.zip_longest(*multichunk_objs, fillvalue=None):
obj = list(itertools.filterfalse(lambda v: v is None, obj))
if obj:
r_sl = []
for slices in zip(*obj):
start_values = map(operator.attrgetter("start"), slices)
stop_values = map(operator.attrgetter("stop"), slices)
r_sl.append(slice(min(start_values), max(stop_values)))
r_sl = tuple(r_sl)
r.append(r_sl)
else:
r.append(None)

return r


def find_objects(input, max_label=0):
"""
Find the center of mass over an image at specified subregions.
Expand All @@ -146,22 +193,24 @@ def find_objects(input, max_label=0):
input = dask.array.asarray(input)
max_label = int(max_label)

if max_label == 0:
raise NotImplementedError(
"Getting all labels is currently not supported."
)

positions = _utils._ravel_shape_indices(
input.shape, chunks=input.chunks
)

object_slices = []
for i in range(1, max_label + 1):
object_slices.append(_utils._find_object(
positions[input == i], input.shape
))

return object_slices
mapped_blocks = numpy.empty(input.numblocks, dtype=object)
for index, cslice in zip(numpy.ndindex(*input.numblocks),
dask.array.core.slices_from_chunks(input.chunks)):
offset = tuple(map(operator.attrgetter("start"), cslice))
input_block = input[cslice]
mapped_blocks[index] = find_object_chunk_delayed(input_block,
offset,
max_label)
mapped_blocks = dask.array.block(mapped_blocks.tolist(),
allow_unknown_chunksizes=True)

r = dask.array.reductions.reduction(mapped_blocks,
lambda x_chunk, axis, keepdims: x_chunk,
find_object_aggregate,
dtype=object,
concatenate=False)

return r


def histogram(input,
Expand Down

0 comments on commit d434f70

Please sign in to comment.