Skip to content

Commit

Permalink
Overwriting (#171)
Browse files Browse the repository at this point in the history
* lbl is not needed in isotropic version

* create lbls with probs

* works

* count full overlaps from zero again when new object was found
  • Loading branch information
gatoniel authored Aug 31, 2022
1 parent 04067bb commit 9342fa5
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
select = B,B9,C,D,DAR,E,F,N,RST,S,W
ignore = E203,E501,RST201,RST203,RST301,W503
max-line-length = 80
max-complexity = 12
max-complexity = 14
docstring-convention = google
per-file-ignores = tests/*:S101
rst-roles = class,const,func,meth,mod,ref
Expand Down
108 changes: 96 additions & 12 deletions src/merge_stardist_masks/naive_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,28 @@ def my_polygons_list_to_label(
]


def poly_list_with_probs(
dists_: ArrayLike,
points_: ArrayLike,
probs_: ArrayLike,
shape: Tuple[int, ...],
poly_list_func: PolyToLabelSignature,
) -> Tuple[npt.NDArray[np.int_], npt.NDArray[np.single]]:
"""Return labels and according probabilities."""
inds = np.argsort(probs_)
probs = np.array(probs_)[inds]
dists = np.array(dists_)[inds]
points = np.array(points_)[inds]

lbl: npt.NDArray[np.int_] = poly_list_func(dists, points, shape)

prob_array = np.zeros_like(lbl, dtype=float)
for i in range(1, len(probs) + 1):
prob_array[lbl == i] = probs[i - 1]

return (lbl, prob_array)


def get_poly_to_label(
shape: Tuple[int, ...], rays: Optional[Rays_Base]
) -> PolyToLabelSignature:
Expand Down Expand Up @@ -177,6 +199,29 @@ def paint_in_without_overlaps(
return paint_in


def paint_in_without_overlaps_check_probs(
paint_in: npt.NDArray[T],
shape: npt.NDArray[np.bool_],
old_probs: npt.NDArray[np.single],
new_probs: npt.NDArray[np.single],
paint_id: int,
) -> Tuple[npt.NDArray[T], npt.NDArray[np.single]]:
"""Set and overwrite entries of array to paint_id respecting their probabilities."""
to_be_painted = paint_in[shape]
to_be_painted_old_probs = old_probs[shape]
to_be_painted_new_probs = new_probs[shape]

overwrite_inds = to_be_painted_new_probs > to_be_painted_old_probs

to_be_painted[overwrite_inds] = paint_id
to_be_painted_old_probs[overwrite_inds] = to_be_painted_new_probs[overwrite_inds]

paint_in[shape] = to_be_painted
old_probs[shape] = to_be_painted_old_probs

return (paint_in, old_probs)


def paint_in_with_overlaps(
paint_in: npt.NDArray[T], shape: npt.NDArray[np.bool_], paint_id: int
) -> npt.NDArray[T]:
Expand All @@ -202,6 +247,7 @@ def naive_fusion(
max_full_overlaps: int = 2,
erase_probs_at_full_overlap: bool = False,
show_overlaps: bool = False,
respect_probs: bool = False,
) -> Union[npt.NDArray[np.uint16], npt.NDArray[np.intc]]:
"""Merge overlapping masks given by dists, probs, rays.
Expand Down Expand Up @@ -231,6 +277,8 @@ def naive_fusion(
erase_probs_at_full_overlap: If set to ``True`` probs are set to -1 whenever
a full overlap is detected.
show_overlaps: If set to true, overlaps are set to ``-1``.
respect_probs: If set to true, overlapping elements are overwritten by
considering their probabilities. Only works with uniform grid.
Returns:
The label image with uint16 labels. For 2D, the shape is
Expand All @@ -240,6 +288,7 @@ def naive_fusion(
Raises:
ValueError: If `rays` is ``None`` and 3D inputs are given or when
``probs.ndim != len(grid)``. # noqa: DAR402 ValueError
NotImplementedError: If grid is anisotropic and respect_probs is set to true.
Example:
>>> from merge_stardist_masks.naive_fusion import naive_fusion
Expand All @@ -259,8 +308,13 @@ def naive_fusion(
max_full_overlaps,
erase_probs_at_full_overlap=erase_probs_at_full_overlap,
show_overlaps=show_overlaps,
respect_probs=respect_probs,
)
else:
if respect_probs:
raise NotImplementedError(
"respect_probs=True is only available for isotropic grid."
)
return naive_fusion_anisotropic_grid(
dists,
probs,
Expand All @@ -284,6 +338,7 @@ def naive_fusion_isotropic_grid(
max_full_overlaps: int = 2,
erase_probs_at_full_overlap: bool = False,
show_overlaps: bool = False,
respect_probs: bool = False,
) -> Union[npt.NDArray[np.uint16], npt.NDArray[np.intc]]:
"""Merge overlapping masks given by dists, probs, rays.
Expand Down Expand Up @@ -312,6 +367,8 @@ def naive_fusion_isotropic_grid(
erase_probs_at_full_overlap: If set to ``True`` probs are set to -1 whenever
a full overlap is detected.
show_overlaps: If set to true, overlaps are set to ``-1``.
respect_probs: If set to true, overlapping elements are overwritten by
considering their probabilities.
Returns:
The label image with uint16 labels. For 2D, the shape is
Expand Down Expand Up @@ -356,12 +413,14 @@ def naive_fusion_isotropic_grid(

if show_overlaps:
paint_in = paint_in_with_overlaps
lbl = np.zeros(shape, dtype=np.intc)
# lbl = np.zeros(shape, dtype=np.intc)
big_lbl = np.zeros(big_shape, dtype=np.intc)
else:
paint_in = paint_in_without_overlaps
lbl = np.zeros(shape, dtype=np.uint16)
big_lbl = np.zeros(big_shape, dtype=np.uint16)
if respect_probs:
old_probs = np.zeros_like(big_lbl, dtype=np.single)
paint_in = paint_in_without_overlaps
# lbl = np.zeros(shape, dtype=np.uint16)

sorted_probs_j = 0
current_id = 1
Expand All @@ -379,12 +438,13 @@ def naive_fusion_isotropic_grid(
break

max_ind = np.unravel_index(max_ind, new_probs.shape)
big_new_shape_prob = [float(new_probs[tuple(max_ind)])]
new_probs[tuple(max_ind)] = -1

ind = tuple(max_ind) + (slice(None),)

slices, point = this_slice_point(points[ind], max_dist)
shape_paint = lbl[slices].shape
shape_paint = new_probs[slices].shape

new_shape = (
poly_to_label(
Expand Down Expand Up @@ -421,6 +481,7 @@ def naive_fusion_isotropic_grid(
break

max_ind_within = np.argmax(probs_within)
this_prob = float(probs_within[max_ind_within])
probs_within[max_ind_within] = -1

current_probs[new_shape] = probs_within
Expand Down Expand Up @@ -448,27 +509,48 @@ def naive_fusion_isotropic_grid(
if erase_probs_at_full_overlap:
current_probs[additional_shape] = -1
else:
full_overlaps = 0
big_new_shape_dists.append(this_dist)
big_new_shape_points.append(
big_point + (this_point - points[ind]) * grid
)
big_new_shape_prob.append(this_prob)

big_new_shape: npt.NDArray[np.bool_] = (
poly_list_to_label(
if respect_probs:
big_new_shape_, shape_probs = poly_list_with_probs(
big_new_shape_dists,
big_new_shape_points,
big_new_shape_prob,
big_shape_paint,
poly_list_to_label,
)
big_new_shape1: npt.NDArray[np.bool_] = big_new_shape_ > 0
(
big_lbl[big_slices],
old_probs[big_slices],
) = paint_in_without_overlaps_check_probs(
big_lbl[big_slices],
big_new_shape1,
old_probs[big_slices],
shape_probs,
current_id,
)
else:
big_new_shape: npt.NDArray[np.bool_] = (
poly_list_to_label(
big_new_shape_dists,
big_new_shape_points,
big_shape_paint,
)
> 0
)
big_lbl[big_slices] = paint_in(
big_lbl[big_slices], big_new_shape, current_id
)
> 0
)

current_probs[new_shape] = -1
new_probs[slices] = current_probs

lbl[slices] = paint_in(lbl[slices], new_shape, current_id)

big_lbl[big_slices] = paint_in(big_lbl[big_slices], big_new_shape, current_id)

current_id += 1

return big_lbl
Expand Down Expand Up @@ -644,6 +726,8 @@ def naive_fusion_anisotropic_grid(
full_overlaps += 1
if erase_probs_at_full_overlap:
current_probs[additional_shape] = -1
else:
full_overlaps = 0

current_probs[new_shape] = -1
new_probs[slices] = current_probs
Expand Down
Loading

0 comments on commit 9342fa5

Please sign in to comment.