Skip to content

Commit

Permalink
timeseries 2d (#284)
Browse files Browse the repository at this point in the history
* timeseries 2d

* tests for OptimizedStackedTimepointsData2D

* sample_patches with one test

* config with mypy

* mypy for all datas

* model_2d works with tests by ignored by mypy

* adaptations as model_2d was added

* more mypy

* model_2d mypy

* tests with tensorflow and reduced coverage

* test for model_3d

* test_model_3d mypy

* remove no cover from data_3d

* coverage =85%
  • Loading branch information
gatoniel authored Mar 7, 2023
1 parent cd6c94c commit 126e170
Show file tree
Hide file tree
Showing 19 changed files with 2,329 additions and 52 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[flake8]
select = B,B9,C,D,DAR,E,F,N,RST,S,W
ignore = E203,E501,RST201,RST203,RST301,W503,B905
ignore = E203,E501,RST201,RST203,RST301,W503,B905,S101
max-line-length = 80
max-complexity = 17
docstring-convention = google
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def mypy(session: Session) -> None:
def tests(session: Session) -> None:
"""Run the test suite."""
session.install(".")
session.install("coverage[toml]", "pytest", "pygments", "pytest-mock")
session.install("coverage[toml]", "pytest", "pygments", "pytest-mock", "tensorflow")
try:
session.run("coverage", "run", "--parallel", "-m", "pytest", *session.posargs)
finally:
Expand Down
736 changes: 707 additions & 29 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Pygments = "^2.13.0"
pyupgrade = "^3.3.0"
furo = ">=2021.11.12"
pytest-mock = "^3.10.0"
tensorflow = "^2.11.0"

[tool.coverage.paths]
source = ["src", "*/site-packages"]
Expand All @@ -56,7 +57,7 @@ source = ["merge_stardist_masks", "tests"]

[tool.coverage.report]
show_missing = true
fail_under = 95
fail_under = 85

[tool.mypy]
strict = true
Expand All @@ -66,6 +67,7 @@ show_column_numbers = true
show_error_codes = true
show_error_context = true
plugins = "numpy.typing.mypy_plugin"
exclude = "src/merge_stardist_masks/model_2d.py"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
126 changes: 126 additions & 0 deletions src/merge_stardist_masks/config_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Configuration for 2D stacked time frames modified directly from StarDist."""
from __future__ import annotations

from typing import Optional
from typing import Tuple

from csbdeep.models import BaseConfig # type: ignore [import]
from csbdeep.utils import _raise # type: ignore [import]
from csbdeep.utils import backend_channels_last
from csbdeep.utils.tf import keras_import # type: ignore [import]
from distutils.version import LooseVersion
from stardist.utils import _normalize_grid # type: ignore [import]


keras = keras_import()


class StackedTimepointsConfig2D(BaseConfig): # type: ignore [misc]
"""Configuration for a 2D StarDist model based on stacked timepoints."""

def __init__(
self,
axes: str = "YX",
n_rays: int = 32,
len_t: int = 3,
n_channel_in: int = 1,
grid: Tuple[int, ...] = (1, 1),
n_classes: Optional[int] = None,
backbone: str = "unet",
train_patch_size: Tuple[int, ...] = (256, 256),
**kwargs: int,
) -> None:
"""Initialize with fixed length in time direction."""
super().__init__(
axes=axes,
n_channel_in=n_channel_in,
n_channel_out=(1 + n_rays) * len_t,
)

n_classes is None or _raise(NotImplementedError("n_classes not implemented."))

# directly set by parameters
self.len_t = len_t
self.n_rays = int(n_rays)
self.grid = _normalize_grid(grid, 2)
self.backbone = str(backbone).lower()
self.n_classes = None if n_classes is None else int(n_classes)
self.train_patch_size = train_patch_size

# default config (can be overwritten by kwargs below)
if self.backbone == "unet":
self.unet_n_depth = 3
self.unet_kernel_size = 3, 3
self.unet_n_filter_base = 32
self.unet_n_conv_per_depth = 2
self.unet_pool = 2, 2
self.unet_activation = "relu"
self.unet_last_activation = "relu"
self.unet_batch_norm = False
self.unet_dropout = 0.0
self.unet_prefix = ""
self.net_conv_after_unet = 128
else:
# TODO: resnet backbone for 2D model?
raise ValueError("backbone '%s' not supported." % self.backbone)

# net_mask_shape not needed but kept for legacy reasons
if backend_channels_last():
self.net_input_shape = None, None, self.n_channel_in * self.len_t
# self.net_mask_shape = None, None, 1
else:
self.net_input_shape = self.n_channel_in * self.len_t, None, None
# self.net_mask_shape = 1, None, None

self.train_shape_completion = False
self.train_completion_crop = 32
self.train_background_reg = 1e-4
self.train_foreground_only = 0.9
self.train_sample_cache = True

self.train_dist_loss = "mae"
self.train_loss_weights = (1, 0.2) if self.n_classes is None else (1, 0.2, 1)
self.train_class_weights = (
(1, 1) if self.n_classes is None else (1,) * (self.n_classes + 1)
)
self.train_epochs = 400
self.train_steps_per_epoch = 100
self.train_learning_rate = 0.0003
self.train_batch_size = 4
self.train_n_val_patches = None
self.train_tensorboard = True
# the parameter 'min_delta' was called 'epsilon' for keras<=2.1.5
min_delta_key = (
"epsilon"
if LooseVersion(keras.__version__) <= LooseVersion("2.1.5")
else "min_delta"
)
self.train_reduce_lr = {"factor": 0.5, "patience": 40, min_delta_key: 0}

self.use_gpu = False

# remove derived attributes that shouldn't be overwritten
for k in ("n_dim", "n_channel_out"):
try:
del kwargs[k]
except KeyError:
pass

self.update_parameters(False, **kwargs)

# FIXME: put into is_valid()
if not len(self.train_loss_weights) == (2 if self.n_classes is None else 3):
raise ValueError(
f"train_loss_weights {self.train_loss_weights} not compatible "
f"with n_classes ({self.n_classes}): must be 3 weights if "
"n_classes is not None, otherwise 2"
)

if not len(self.train_class_weights) == (
2 if self.n_classes is None else self.n_classes + 1
):
raise ValueError(
f"train_class_weights {self.train_class_weights} not compatible "
f"with n_classes ({self.n_classes}): must be 'n_classes + 1' weights "
"if n_classes is not None, otherwise 2"
)
146 changes: 146 additions & 0 deletions src/merge_stardist_masks/data_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""Data generator for 2d time stacks based on stardist's data generators."""
from __future__ import annotations

from typing import List
from typing import Optional
from typing import Tuple
from typing import TypeVar

import numpy as np
import numpy.typing as npt

from .data_base import AugmenterSignature
from .data_base import StackedTimepointsDataBase
from .sample_patches import sample_patches
from .timeseries_2d import bordering_gaussian_weights_timeseries
from .timeseries_2d import edt_prob_timeseries
from .timeseries_2d import star_dist_timeseries
from .timeseries_2d import touching_pixels_2d_timeseries

T = TypeVar("T", bound=np.generic)


class OptimizedStackedTimepointsData2D(StackedTimepointsDataBase):
"""Uses better weights and stacked timepoints."""

def __init__(
self,
xs: List[npt.NDArray[T]],
ys: List[npt.NDArray[T]],
batch_size: int,
n_rays: int,
length: int,
n_classes: Optional[int] = None,
classes: Optional[List[npt.NDArray[T]]] = None,
use_gpu: bool = False,
patch_size: Tuple[int, ...] = (256, 256),
b: int = 32,
grid: Tuple[int, ...] = (1, 1),
shape_completion: bool = False,
augmenter: Optional[AugmenterSignature[T]] = None,
foreground_prob: int = 0,
maxfilter_patch_size: Optional[int] = None,
sample_ind_cache: bool = True,
) -> None:
"""Initialize with arrays of shape (size, T, Y, X, channels)."""
super().__init__(
xs=xs,
ys=ys,
n_rays=n_rays,
grid=grid,
n_classes=n_classes,
classes=classes,
batch_size=batch_size,
patch_size=patch_size,
length=length,
augmenter=augmenter,
foreground_prob=foreground_prob,
maxfilter_patch_size=maxfilter_patch_size,
use_gpu=use_gpu,
sample_ind_cache=sample_ind_cache,
)

self.shape_completion = bool(shape_completion)
if self.shape_completion and b > 0:
self.b = slice(None), slice(b, -b), slice(b, -b)
else:
self.b = slice(None), slice(None), slice(None)

self.sd_mode = "opencl" if self.use_gpu else "cpp"

def __getitem__(
self, i: int
) -> Tuple[List[npt.NDArray[np.double]], List[npt.NDArray[np.double]],]:
"""Return batch i as numpy array."""
idx = self.batch(i)
arrays = [
sample_patches(
(self.ys[k],) + self.channels_as_tuple(self.xs[k]),
patch_size=self.patch_size,
n_samples=1,
valid_inds=self.get_valid_inds(k),
)
for k in idx
]

if self.n_channel is None:
xs, ys = list(zip(*[(x[0][self.b], y[0]) for y, x in arrays]))
else:
xs, ys = list(
zip(
*[
(np.stack([_x[0] for _x in x], axis=-1)[self.b], y[0])
for y, *x in arrays
]
)
)

xs, ys = tuple(zip(*tuple(self.augmenter(_x, _y) for _x, _y in zip(xs, ys))))

prob_ = np.stack([edt_prob_timeseries(lbl, self.b, self.ss_grid) for lbl in ys])
touching = np.stack(
[touching_pixels_2d_timeseries(lbl, self.b, self.ss_grid) for lbl in ys]
)
touching_edt = np.stack(
[
bordering_gaussian_weights_timeseries(
mask, lbl, sigma=2, b=self.b, ss_grid=self.ss_grid
)
for mask, lbl in zip(touching, ys)
]
)
prob = np.clip(prob_ - touching, 0, 1)
dist_mask: npt.NDArray[np.double] = prob_ + touching_edt

dists = np.stack(
[
star_dist_timeseries(
lbl, self.n_rays, mode=self.sd_mode, grid=self.grid
)
for lbl in ys
]
)

if xs[0].ndim == 3:
xs = [
np.expand_dims(x, axis=-1) for x in xs # type: ignore [no-untyped-call]
]
xs = np.stack(
[
np.concatenate( # type: ignore [no-untyped-call]
[x[i] for i in range(self.len_t)], axis=-1
)
for x in xs
]
)

# append dist_mask to dist as additional channel
# dist_and_mask = np.concatenate([dist,dist_mask],axis=-1)
# faster than concatenate
dist_and_mask = np.empty(
dists.shape[:-1] + (self.len_t * (self.n_rays + 1),), np.float32
)
dist_and_mask[..., : -self.len_t] = dists
dist_and_mask[..., -self.len_t :] = dist_mask

return [xs], [prob, dist_and_mask]
32 changes: 16 additions & 16 deletions src/merge_stardist_masks/data_3d.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
"""Stardist 3D data generator for new weights and probability maps."""
from typing import List # pragma: no cover
from typing import Tuple # pragma: no cover
from typing import TypeVar # pragma: no cover

import numpy as np # pragma: no cover
import numpy.typing as npt # pragma: no cover
from scipy.ndimage import zoom # type: ignore [import] # pragma: no cover
from stardist.geometry import star_dist3D # type: ignore [import] # pragma: no cover
from stardist.models.model3d import ( # type: ignore [import] # pragma: no cover
from typing import List
from typing import Tuple
from typing import TypeVar

import numpy as np
import numpy.typing as npt
from scipy.ndimage import zoom # type: ignore [import]
from stardist.geometry import star_dist3D # type: ignore [import]
from stardist.models.model3d import ( # type: ignore [import]
StarDistData3D,
)
from stardist.sample_patches import ( # type: ignore [import] # pragma: no cover
from stardist.sample_patches import ( # type: ignore [import]
sample_patches,
)
from stardist.utils import edt_prob # type: ignore [import] # pragma: no cover
from stardist.utils import mask_to_categorical # pragma: no cover
from stardist.utils import edt_prob # type: ignore [import]
from stardist.utils import mask_to_categorical

from .touching_pixels import bordering_gaussian_weights # pragma: no cover
from .touching_pixels import touching_pixels_3d # pragma: no cover
from .touching_pixels import bordering_gaussian_weights
from .touching_pixels import touching_pixels_3d


T = TypeVar("T", bound=np.generic) # pragma: no cover
T = TypeVar("T", bound=np.generic)


class OptimizedStarDistData3D(StarDistData3D): # type: ignore [misc] # pragma: no cover
class OptimizedStarDistData3D(StarDistData3D): # type: ignore [misc]
"""Overwrite __getitem__ function to use different prob and weights."""

def __getitem__(self, i: int) -> Tuple[List[npt.NDArray[T]], List[npt.NDArray[T]]]:
Expand Down
Loading

0 comments on commit 126e170

Please sign in to comment.