-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* timeseries 2d * tests for OptimizedStackedTimepointsData2D * sample_patches with one test * config with mypy * mypy for all datas * model_2d works with tests by ignored by mypy * adaptations as model_2d was added * more mypy * model_2d mypy * tests with tensorflow and reduced coverage * test for model_3d * test_model_3d mypy * remove no cover from data_3d * coverage =85%
- Loading branch information
Showing
19 changed files
with
2,329 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
"""Configuration for 2D stacked time frames modified directly from StarDist.""" | ||
from __future__ import annotations | ||
|
||
from typing import Optional | ||
from typing import Tuple | ||
|
||
from csbdeep.models import BaseConfig # type: ignore [import] | ||
from csbdeep.utils import _raise # type: ignore [import] | ||
from csbdeep.utils import backend_channels_last | ||
from csbdeep.utils.tf import keras_import # type: ignore [import] | ||
from distutils.version import LooseVersion | ||
from stardist.utils import _normalize_grid # type: ignore [import] | ||
|
||
|
||
keras = keras_import() | ||
|
||
|
||
class StackedTimepointsConfig2D(BaseConfig): # type: ignore [misc] | ||
"""Configuration for a 2D StarDist model based on stacked timepoints.""" | ||
|
||
def __init__( | ||
self, | ||
axes: str = "YX", | ||
n_rays: int = 32, | ||
len_t: int = 3, | ||
n_channel_in: int = 1, | ||
grid: Tuple[int, ...] = (1, 1), | ||
n_classes: Optional[int] = None, | ||
backbone: str = "unet", | ||
train_patch_size: Tuple[int, ...] = (256, 256), | ||
**kwargs: int, | ||
) -> None: | ||
"""Initialize with fixed length in time direction.""" | ||
super().__init__( | ||
axes=axes, | ||
n_channel_in=n_channel_in, | ||
n_channel_out=(1 + n_rays) * len_t, | ||
) | ||
|
||
n_classes is None or _raise(NotImplementedError("n_classes not implemented.")) | ||
|
||
# directly set by parameters | ||
self.len_t = len_t | ||
self.n_rays = int(n_rays) | ||
self.grid = _normalize_grid(grid, 2) | ||
self.backbone = str(backbone).lower() | ||
self.n_classes = None if n_classes is None else int(n_classes) | ||
self.train_patch_size = train_patch_size | ||
|
||
# default config (can be overwritten by kwargs below) | ||
if self.backbone == "unet": | ||
self.unet_n_depth = 3 | ||
self.unet_kernel_size = 3, 3 | ||
self.unet_n_filter_base = 32 | ||
self.unet_n_conv_per_depth = 2 | ||
self.unet_pool = 2, 2 | ||
self.unet_activation = "relu" | ||
self.unet_last_activation = "relu" | ||
self.unet_batch_norm = False | ||
self.unet_dropout = 0.0 | ||
self.unet_prefix = "" | ||
self.net_conv_after_unet = 128 | ||
else: | ||
# TODO: resnet backbone for 2D model? | ||
raise ValueError("backbone '%s' not supported." % self.backbone) | ||
|
||
# net_mask_shape not needed but kept for legacy reasons | ||
if backend_channels_last(): | ||
self.net_input_shape = None, None, self.n_channel_in * self.len_t | ||
# self.net_mask_shape = None, None, 1 | ||
else: | ||
self.net_input_shape = self.n_channel_in * self.len_t, None, None | ||
# self.net_mask_shape = 1, None, None | ||
|
||
self.train_shape_completion = False | ||
self.train_completion_crop = 32 | ||
self.train_background_reg = 1e-4 | ||
self.train_foreground_only = 0.9 | ||
self.train_sample_cache = True | ||
|
||
self.train_dist_loss = "mae" | ||
self.train_loss_weights = (1, 0.2) if self.n_classes is None else (1, 0.2, 1) | ||
self.train_class_weights = ( | ||
(1, 1) if self.n_classes is None else (1,) * (self.n_classes + 1) | ||
) | ||
self.train_epochs = 400 | ||
self.train_steps_per_epoch = 100 | ||
self.train_learning_rate = 0.0003 | ||
self.train_batch_size = 4 | ||
self.train_n_val_patches = None | ||
self.train_tensorboard = True | ||
# the parameter 'min_delta' was called 'epsilon' for keras<=2.1.5 | ||
min_delta_key = ( | ||
"epsilon" | ||
if LooseVersion(keras.__version__) <= LooseVersion("2.1.5") | ||
else "min_delta" | ||
) | ||
self.train_reduce_lr = {"factor": 0.5, "patience": 40, min_delta_key: 0} | ||
|
||
self.use_gpu = False | ||
|
||
# remove derived attributes that shouldn't be overwritten | ||
for k in ("n_dim", "n_channel_out"): | ||
try: | ||
del kwargs[k] | ||
except KeyError: | ||
pass | ||
|
||
self.update_parameters(False, **kwargs) | ||
|
||
# FIXME: put into is_valid() | ||
if not len(self.train_loss_weights) == (2 if self.n_classes is None else 3): | ||
raise ValueError( | ||
f"train_loss_weights {self.train_loss_weights} not compatible " | ||
f"with n_classes ({self.n_classes}): must be 3 weights if " | ||
"n_classes is not None, otherwise 2" | ||
) | ||
|
||
if not len(self.train_class_weights) == ( | ||
2 if self.n_classes is None else self.n_classes + 1 | ||
): | ||
raise ValueError( | ||
f"train_class_weights {self.train_class_weights} not compatible " | ||
f"with n_classes ({self.n_classes}): must be 'n_classes + 1' weights " | ||
"if n_classes is not None, otherwise 2" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
"""Data generator for 2d time stacks based on stardist's data generators.""" | ||
from __future__ import annotations | ||
|
||
from typing import List | ||
from typing import Optional | ||
from typing import Tuple | ||
from typing import TypeVar | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
|
||
from .data_base import AugmenterSignature | ||
from .data_base import StackedTimepointsDataBase | ||
from .sample_patches import sample_patches | ||
from .timeseries_2d import bordering_gaussian_weights_timeseries | ||
from .timeseries_2d import edt_prob_timeseries | ||
from .timeseries_2d import star_dist_timeseries | ||
from .timeseries_2d import touching_pixels_2d_timeseries | ||
|
||
T = TypeVar("T", bound=np.generic) | ||
|
||
|
||
class OptimizedStackedTimepointsData2D(StackedTimepointsDataBase): | ||
"""Uses better weights and stacked timepoints.""" | ||
|
||
def __init__( | ||
self, | ||
xs: List[npt.NDArray[T]], | ||
ys: List[npt.NDArray[T]], | ||
batch_size: int, | ||
n_rays: int, | ||
length: int, | ||
n_classes: Optional[int] = None, | ||
classes: Optional[List[npt.NDArray[T]]] = None, | ||
use_gpu: bool = False, | ||
patch_size: Tuple[int, ...] = (256, 256), | ||
b: int = 32, | ||
grid: Tuple[int, ...] = (1, 1), | ||
shape_completion: bool = False, | ||
augmenter: Optional[AugmenterSignature[T]] = None, | ||
foreground_prob: int = 0, | ||
maxfilter_patch_size: Optional[int] = None, | ||
sample_ind_cache: bool = True, | ||
) -> None: | ||
"""Initialize with arrays of shape (size, T, Y, X, channels).""" | ||
super().__init__( | ||
xs=xs, | ||
ys=ys, | ||
n_rays=n_rays, | ||
grid=grid, | ||
n_classes=n_classes, | ||
classes=classes, | ||
batch_size=batch_size, | ||
patch_size=patch_size, | ||
length=length, | ||
augmenter=augmenter, | ||
foreground_prob=foreground_prob, | ||
maxfilter_patch_size=maxfilter_patch_size, | ||
use_gpu=use_gpu, | ||
sample_ind_cache=sample_ind_cache, | ||
) | ||
|
||
self.shape_completion = bool(shape_completion) | ||
if self.shape_completion and b > 0: | ||
self.b = slice(None), slice(b, -b), slice(b, -b) | ||
else: | ||
self.b = slice(None), slice(None), slice(None) | ||
|
||
self.sd_mode = "opencl" if self.use_gpu else "cpp" | ||
|
||
def __getitem__( | ||
self, i: int | ||
) -> Tuple[List[npt.NDArray[np.double]], List[npt.NDArray[np.double]],]: | ||
"""Return batch i as numpy array.""" | ||
idx = self.batch(i) | ||
arrays = [ | ||
sample_patches( | ||
(self.ys[k],) + self.channels_as_tuple(self.xs[k]), | ||
patch_size=self.patch_size, | ||
n_samples=1, | ||
valid_inds=self.get_valid_inds(k), | ||
) | ||
for k in idx | ||
] | ||
|
||
if self.n_channel is None: | ||
xs, ys = list(zip(*[(x[0][self.b], y[0]) for y, x in arrays])) | ||
else: | ||
xs, ys = list( | ||
zip( | ||
*[ | ||
(np.stack([_x[0] for _x in x], axis=-1)[self.b], y[0]) | ||
for y, *x in arrays | ||
] | ||
) | ||
) | ||
|
||
xs, ys = tuple(zip(*tuple(self.augmenter(_x, _y) for _x, _y in zip(xs, ys)))) | ||
|
||
prob_ = np.stack([edt_prob_timeseries(lbl, self.b, self.ss_grid) for lbl in ys]) | ||
touching = np.stack( | ||
[touching_pixels_2d_timeseries(lbl, self.b, self.ss_grid) for lbl in ys] | ||
) | ||
touching_edt = np.stack( | ||
[ | ||
bordering_gaussian_weights_timeseries( | ||
mask, lbl, sigma=2, b=self.b, ss_grid=self.ss_grid | ||
) | ||
for mask, lbl in zip(touching, ys) | ||
] | ||
) | ||
prob = np.clip(prob_ - touching, 0, 1) | ||
dist_mask: npt.NDArray[np.double] = prob_ + touching_edt | ||
|
||
dists = np.stack( | ||
[ | ||
star_dist_timeseries( | ||
lbl, self.n_rays, mode=self.sd_mode, grid=self.grid | ||
) | ||
for lbl in ys | ||
] | ||
) | ||
|
||
if xs[0].ndim == 3: | ||
xs = [ | ||
np.expand_dims(x, axis=-1) for x in xs # type: ignore [no-untyped-call] | ||
] | ||
xs = np.stack( | ||
[ | ||
np.concatenate( # type: ignore [no-untyped-call] | ||
[x[i] for i in range(self.len_t)], axis=-1 | ||
) | ||
for x in xs | ||
] | ||
) | ||
|
||
# append dist_mask to dist as additional channel | ||
# dist_and_mask = np.concatenate([dist,dist_mask],axis=-1) | ||
# faster than concatenate | ||
dist_and_mask = np.empty( | ||
dists.shape[:-1] + (self.len_t * (self.n_rays + 1),), np.float32 | ||
) | ||
dist_and_mask[..., : -self.len_t] = dists | ||
dist_and_mask[..., -self.len_t :] = dist_mask | ||
|
||
return [xs], [prob, dist_and_mask] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.