diff --git a/napari_btrack/__init__.py b/napari_btrack/__init__.py index 7dd10698..3652513f 100644 --- a/napari_btrack/__init__.py +++ b/napari_btrack/__init__.py @@ -2,3 +2,12 @@ from ._version import version as __version__ except ImportError: __version__ = "unknown" + +import logging + +from napari_btrack import constants, main + +__all__ = [ + "constants", + "main", +] diff --git a/napari_btrack/_tests/test_dock_widget.py b/napari_btrack/_tests/test_dock_widget.py index ae82ece0..3d75559d 100644 --- a/napari_btrack/_tests/test_dock_widget.py +++ b/napari_btrack/_tests/test_dock_widget.py @@ -1,22 +1,23 @@ from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from magicgui.widgets import Container + import json from unittest.mock import patch +import btrack import napari import numpy as np import numpy.typing as npt import pytest from btrack import datasets -from btrack.config import load_config from btrack.datasets import cell_config, particle_config -from magicgui.widgets import Container -from napari_btrack.track import ( - _update_widgets_from_config, - _widgets_to_tracker_config, - track, -) +import napari_btrack +import napari_btrack.main OLD_WIDGET_LAYERS = 1 NEW_WIDGET_LAYERS = 2 @@ -32,14 +33,14 @@ def test_add_widget(make_napari_viewer): widget_name="Track", ) - assert len(list(viewer.window._dock_widgets)) == num_dw + 1 # noqa: S101 + assert len(list(viewer.window._dock_widgets)) == num_dw + 1 @pytest.fixture def track_widget(make_napari_viewer) -> Container: """Provides an instance of the track widget to test""" make_napari_viewer() # make sure there is a viewer available - return track() + return napari_btrack.main.create_btrack_widget() @pytest.mark.parametrize("config", [cell_config(), particle_config()]) @@ -47,64 +48,77 @@ def test_config_to_widgets_round_trip(track_widget, config): """Tests that going back and forth between config objects and widgets works as expected. """ - expected_config = load_config(config) - _update_widgets_from_config(track_widget, expected_config) - actual_config = _widgets_to_tracker_config(track_widget) - # use json.loads to avoid failure in string comparison because e.g "100.0" != "100" - assert json.loads(actual_config.json()) == json.loads( # noqa: S101 - expected_config.json() - ) + expected_config = btrack.config.load_config(config).json() -@pytest.fixture -def user_config_path() -> str: - """Provides a (dummy) string to represent a user-provided config path.""" - return "user_config.json" + unscaled_config = napari_btrack.config.UnscaledTrackerConfig(config) + napari_btrack.sync.update_widgets_from_config(unscaled_config, track_widget) + napari_btrack.sync.update_config_from_widgets(unscaled_config, track_widget) + + actual_config = unscaled_config.scale_config().json() + # use json.loads to avoid failure in string comparison because e.g "100.0" != "100" + assert json.loads(actual_config) == json.loads(expected_config) -def test_save_button(user_config_path, track_widget): + +def test_save_button(track_widget): """Tests that clicking the save configuration button triggers a call to btrack.config.save_config with expected arguments. """ - with patch("napari_btrack.track.save_config") as save_config, patch( - "napari_btrack.track.get_save_path" - ) as get_save_path: - get_save_path.return_value = user_config_path + + unscaled_config = napari_btrack.config.UnscaledTrackerConfig(cell_config()) + unscaled_config.tracker_config.name = "cell" # this is done in in the gui too + expected_config = unscaled_config.scale_config().json() + + with patch( + "napari_btrack.widgets.save_path_dialogue_box" + ) as save_path_dialogue_box: + save_path_dialogue_box.return_value = "user_config.json" track_widget.save_config_button.clicked() - assert save_config.call_args[0][0] == user_config_path # noqa: S101 + + actual_config = btrack.config.load_config("user_config.json").json() + # use json.loads to avoid failure in string comparison because e.g "100.0" != "100" - assert json.loads(save_config.call_args[0][1].json()) == json.loads( # noqa: S101 - load_config(cell_config()).json() - ) + assert json.loads(expected_config) == json.loads(actual_config) -def test_load_button(user_config_path, track_widget): - """Tests that clicking the load configuration button - triggers a call to btrack.config.load_config with the expected argument - """ - with patch("napari_btrack.track.load_config") as load_config, patch( - "napari_btrack.track.get_load_path" - ) as get_load_path: - get_load_path.return_value = user_config_path +def test_load_config(track_widget): + """Tests that another TrackerConfig can be loaded and made the current config.""" + + # this is set to be 'cell' rather than 'Default' + original_config_name = track_widget.config.current_choice + + with patch( + "napari_btrack.widgets.load_path_dialogue_box" + ) as load_path_dialogue_box: + load_path_dialogue_box.return_value = cell_config() track_widget.load_config_button.clicked() - assert load_config.call_args[0][0] == user_config_path # noqa: S101 + + # We didn't override the name, so it should be 'Default' + new_config_name = track_widget.config.current_choice + + assert track_widget.config.value == "Default" + assert new_config_name != original_config_name def test_reset_button(track_widget): - """Tests that clicking the reset button with - particle-config-populated widgets resets to the default (i.e. cell-config) - """ - # change config to particle - _update_widgets_from_config(track_widget, load_config(particle_config())) + """Tests that clicking the reset button restores the default config values""" - # click reset button (default is cell_config) + original_max_search_radius = track_widget.max_search_radius.value + original_relax = track_widget.relax.value + + # change some widget values + track_widget.max_search_radius.value += 10 + track_widget.relax.value = not track_widget.relax + + # click reset button - restores defaults of the currently-selected base config track_widget.reset_button.clicked() - config_after_reset = _widgets_to_tracker_config(track_widget) - # use json.loads to avoid failure in string comparison because e.g "100.0" != "100" - assert json.loads(config_after_reset.json()) == json.loads( # noqa: S101 - load_config(cell_config()).json() - ) + new_max_search_radius = track_widget.max_search_radius.value + new_relax = track_widget.relax.value + + assert new_max_search_radius == original_max_search_radius + assert new_relax == original_relax @pytest.fixture @@ -127,14 +141,12 @@ def test_run_button(track_widget, simplistic_tracker_outputs): """Tests that clicking the run button calls run_tracker, and that the napari viewer has an additional tracks layer after running. """ - with patch("napari_btrack.track.run_tracker") as run_tracker: + with patch("napari_btrack.main._run_tracker") as run_tracker: run_tracker.return_value = simplistic_tracker_outputs segmentation = datasets.example_segmentation() track_widget.viewer.add_labels(segmentation) - assert len(track_widget.viewer.layers) == OLD_WIDGET_LAYERS # noqa: S101 + assert len(track_widget.viewer.layers) == OLD_WIDGET_LAYERS track_widget.call_button.clicked() - assert run_tracker.called # noqa: S101 - assert len(track_widget.viewer.layers) == NEW_WIDGET_LAYERS # noqa: S101 - assert isinstance( # noqa: S101 - track_widget.viewer.layers[-1], napari.layers.Tracks - ) + assert run_tracker.called + assert len(track_widget.viewer.layers) == NEW_WIDGET_LAYERS + assert isinstance(track_widget.viewer.layers[-1], napari.layers.Tracks) diff --git a/napari_btrack/config.py b/napari_btrack/config.py new file mode 100644 index 00000000..c74c8d7b --- /dev/null +++ b/napari_btrack/config.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from btrack.config import TrackerConfig + +import copy +import os +from dataclasses import dataclass, field + +import btrack +import numpy as np +from btrack import datasets + +__all__ = [ + "create_default_configs", +] + + +@dataclass +class Sigmas: + """Values to scale TrackerConfig MotionModel matrices by. + + Args: + P: Scaling factor for the matrix ``P`` - the error in initial estimates. + G: Scaling factor for the matrix ``G`` - the error in the MotionModel process. + R: Scaling factor for the matrix ``R`` - the error in the measurements. + + """ + + P: float + G: float + R: float + + def __getitem__(self, matrix_name): + return self.__dict__[matrix_name] + + def __setitem__(self, matrix_name, sigma): + if matrix_name not in self.__dict__.keys(): + _msg = f"Unknown matrix name '{matrix_name}'" + raise ValueError(_msg) + self.__dict__[matrix_name] = sigma + + def __iter__(self): + yield from self.__dict__.keys() + + +@dataclass +class UnscaledTrackerConfig: + """Convert TrackerConfig MotionModel matrices from scaled to unscaled. + + This is needed because TrackerConfig stores "scaled" matrices, i.e. it + doesn't store sigma and the "unscaled" MotionModel matrices separately. + + Args: + filename: name of the json file containing the TrackerConfig to load. + + Attributes: + tracker_config: unscaled configuration based on the config in ``filename``. + sigmas: scaling factors to apply to the unscaled MotionModel matrices of + ``tracker_config``. + + """ + + filename: os.PathLike + tracker_config: TrackerConfig = field(init=False) + sigmas: Sigmas = field(init=False) + + def __post_init__(self): + """Create the TrackerConfig and un-scale the MotionModel indices""" + + config = btrack.config.load_config(self.filename) + self.tracker_config, self.sigmas = self._unscale_config(config) + + def _unscale_config(self, config: TrackerConfig) -> tuple[TrackerConfig, Sigmas]: + """Convert the matrices of a scaled TrackerConfig MotionModel to unscaled.""" + + P_sigma = np.max(config.motion_model.P) + config.motion_model.P /= P_sigma + + R_sigma = np.max(config.motion_model.R) + config.motion_model.R /= R_sigma + + # Use only G, not Q. If we use both G and Q, then Q_sigma must be updated + # when G_sigma is, and vice-versa + # Instead, use G if it exists. If not, determine G from Q, which we can + # do because Q = G.T @ G + if config.motion_model.G is None: + config.motion_model.G = config.motion_model.Q.diagonal() ** 0.5 + G_sigma = np.max(config.motion_model.G) + config.motion_model.G /= G_sigma + + sigmas = Sigmas( + P=P_sigma, + G=G_sigma, + R=R_sigma, + ) + + return config, sigmas + + def scale_config(self) -> TrackerConfig: + """Create a new TrackerConfig with scaled MotionModel matrices""" + + # Create a copy so that config values stay in sync with widget values + scaled_config = copy.deepcopy(self.tracker_config) + scaled_config.motion_model.P *= self.sigmas.P + scaled_config.motion_model.R *= self.sigmas.R + scaled_config.motion_model.G *= self.sigmas.G + scaled_config.motion_model.Q = ( + scaled_config.motion_model.G.T @ scaled_config.motion_model.G + ) + + return scaled_config + + +@dataclass +class TrackerConfigs: + """Store all loaded TrackerConfig configurations. + + Will load ``btrack``'s default 'cell' and 'particle' configurations on + initialisation. + + Attributes: + configs: dictionary of loaded configurations. The name of the config ( + TrackerConfig.name) is used as the key. + current_config: the currently-selected configuration. + + """ + + configs: dict[str, UnscaledTrackerConfig] = field(default_factory=dict) + current_config: str = field(init=False) + + def __post_init__(self): + """Add the default cell and particle configs.""" + + self.add_config( + filename=datasets.cell_config(), + name="cell", + overwrite=False, + ) + self.add_config( + filename=datasets.particle_config(), + name="particle", + overwrite=False, + ) + + self.current_config = "cell" + + def __getitem__(self, config_name): + return self.configs[config_name] + + def add_config( + self, + filename, + overwrite, + name=None, + ) -> str: + """Load a TrackerConfig and add it to the store.""" + + config = UnscaledTrackerConfig(filename) + config_name = config.tracker_config.name if name is None else name + config.tracker_config.name = config_name + + # TODO: Make the combobox editable so config names can be changed within the GUI + if config_name in self.configs and not overwrite: + _msg = ( + f"Config '{config_name}' already exists - config names must be unique." + ) + raise ValueError(_msg) + + self.configs[config_name] = config + + return config_name + + +def create_default_configs() -> TrackerConfigs: + """Create a set of default configurations.""" + + # TrackerConfigs automatically loads default cell and particle configs + return TrackerConfigs() diff --git a/napari_btrack/constants.py b/napari_btrack/constants.py new file mode 100644 index 00000000..df691731 --- /dev/null +++ b/napari_btrack/constants.py @@ -0,0 +1,30 @@ +""" +This module contains variables that are used throughout the +napari_btrack package. +""" + +HYPOTHESES = [ + "P_FP", + "P_init", + "P_term", + "P_link", + "P_branch", + "P_dead", + "P_merge", +] + +HYPOTHESIS_SCALING_FACTORS = [ + "lambda_time", + "lambda_dist", + "lambda_link", + "lambda_branch", +] + +HYPOTHESIS_THRESHOLDS = [ + "theta_dist", + "theta_time", + "dist_thresh", + "time_thresh", + "apop_thresh", + "relax", +] diff --git a/napari_btrack/main.py b/napari_btrack/main.py new file mode 100644 index 00000000..5fc4ab08 --- /dev/null +++ b/napari_btrack/main.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import numpy.typing as npt + from btrack.config import TrackerConfig + from magicgui.widgets import Container + + from napari_btrack.config import TrackerConfigs + +import logging + +import btrack +import magicgui.widgets +import napari +import qtpy.QtWidgets +from btrack.utils import segmentation_to_objects + +import napari_btrack.config +import napari_btrack.sync +import napari_btrack.widgets + +__all__ = [ + "create_btrack_widget", +] + +# get the logger instance +logger = logging.getLogger(__name__) + +# if we don't have any handlers, set one up +if not logger.handlers: + # configure stream handler + log_fmt = logging.Formatter( + "[%(levelname)s][%(asctime)s] %(message)s", + datefmt="%Y/%m/%d %I:%M:%S %p", + ) + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_fmt) + + logger.addHandler(console_handler) + logger.setLevel(logging.DEBUG) + + +def create_btrack_widget() -> Container: + """Create widgets for the btrack plugin.""" + + # First create our UI along with some default configs for the widgets + all_configs = napari_btrack.config.create_default_configs() + widgets = napari_btrack.widgets.create_widgets() + btrack_widget = magicgui.widgets.Container(widgets=widgets, scrollable=True) + btrack_widget.viewer = napari.current_viewer() + + # Set the cell_config defaults in the gui + napari_btrack.sync.update_widgets_from_config( + unscaled_config=all_configs["cell"], + container=btrack_widget, + ) + + # Now set the callbacks + btrack_widget.config.changed.connect( + lambda selected: select_config(btrack_widget, all_configs, selected), + ) + + btrack_widget.call_button.changed.connect( + lambda: run(btrack_widget, all_configs), + ) + + btrack_widget.reset_button.changed.connect( + lambda: restore_defaults(btrack_widget, all_configs), + ) + + btrack_widget.save_config_button.changed.connect( + lambda: save_config_to_json(btrack_widget, all_configs) + ) + + btrack_widget.load_config_button.changed.connect( + lambda: load_config_from_json(btrack_widget, all_configs) + ) + + # there are lots of widgets so make the container scrollable + scroll = qtpy.QtWidgets.QScrollArea() + scroll.setWidget(btrack_widget._widget._qwidget) + btrack_widget._widget._qwidget = scroll + + return btrack_widget + + +def select_config( + btrack_widget: Container, + configs: TrackerConfigs, + new_config_name: str, +) -> None: + """Set widget values from a newly-selected base config""" + + # first update the previous config with the current widget values + previous_config_name = configs.current_config + previous_config = configs[previous_config_name] + previous_config = napari_btrack.sync.update_config_from_widgets( + unscaled_config=previous_config, + container=btrack_widget, + ) + + # now load the newly-selected config and set widget values + configs.current_config = new_config_name + new_config = configs[new_config_name] + new_config = napari_btrack.sync.update_widgets_from_config( + unscaled_config=new_config, + container=btrack_widget, + ) + + +def run(btrack_widget: Container, configs: TrackerConfigs) -> None: + """ + Update the TrackerConfig from widget values, run tracking, + and add tracks to the viewer. + """ + + unscaled_config = configs[btrack_widget.config.current_choice] + unscaled_config = napari_btrack.sync.update_config_from_widgets( + unscaled_config=unscaled_config, + container=btrack_widget, + ) + + config = unscaled_config.scale_config() + segmentation = btrack_widget.segmentation.value + data, properties, graph = _run_tracker(segmentation, config) + + btrack_widget.viewer.add_tracks( + data=data, + properties=properties, + graph=graph, + name=f"{segmentation}_btrack", + ) + + +def _run_tracker( + segmentation: napari.layers.Image | napari.layers.Labels, + tracker_config: TrackerConfig, +) -> tuple[npt.NDArray, dict, dict]: + """ + Runs BayesianTracker with given segmentation and configuration. + """ + with btrack.BayesianTracker() as tracker: + tracker.configure(tracker_config) + + # append the objects to be tracked + segmented_objects = segmentation_to_objects(segmentation.data) + tracker.append(segmented_objects) + + # set the volume + # btrack order of dimensions is XY(Z) + # napari order of dimensions is T(Z)XY + # so we ignore the first dimension (time) and reverse the others + dimensions = segmentation.level_shapes[0, 1:] + tracker.volume = tuple((0, dimension) for dimension in reversed(dimensions)) + + # track them (in interactive mode) + tracker.track_interactive(step_size=100) + + # generate hypotheses and run the global optimizer + tracker.optimize() + + # get the tracks in a format for napari visualization + data, properties, graph = tracker.to_napari(ndim=2) + return data, properties, graph + + +def restore_defaults(btrack_widget: Container, configs: TrackerConfigs) -> None: + """Reload the config file then set widgets to the config's default values.""" + + config_name = configs.current_config + filename = configs[config_name].filename + configs.add_config( + filename=filename, + overwrite=True, + name=config_name, + ) + + config = configs[config_name] + config = napari_btrack.sync.update_widgets_from_config( + unscaled_config=config, + container=btrack_widget, + ) + + +def save_config_to_json(btrack_widget: Container, configs: TrackerConfigs) -> None: + """Save widget values to file""" + + save_path = napari_btrack.widgets.save_path_dialogue_box() + if save_path is None: + _msg = ( + "napari-btrack: Configuration not saved - operation cancelled by the user." + ) + logger.info(_msg) + return + + unscaled_config = configs[btrack_widget.config.current_choice] + napari_btrack.sync.update_config_from_widgets( + unscaled_config=unscaled_config, + container=btrack_widget, + ) + config = unscaled_config.scale_config() + + btrack.config.save_config(save_path, config) + + +def load_config_from_json(btrack_widget: Container, configs: TrackerConfigs) -> None: + """Load a config from file and set it as the selected base config""" + + load_path = napari_btrack.widgets.load_path_dialogue_box() + if load_path is None: + _msg = "napari-btrack: No file loaded - operation cancelled by the user." + logger.info(_msg) + return + + config_name = configs.add_config(filename=load_path, overwrite=False) + btrack_widget.config.options["choices"].append(config_name) + btrack_widget.config.reset_choices() + btrack_widget.config.value = config_name diff --git a/napari_btrack/napari.yaml b/napari_btrack/napari.yaml index d680efa9..44aac49e 100644 --- a/napari_btrack/napari.yaml +++ b/napari_btrack/napari.yaml @@ -4,7 +4,7 @@ contributions: commands: - id: napari-btrack.track title: Create Track - python_name: napari_btrack.track:track + python_name: napari_btrack.main:create_btrack_widget widgets: - command: napari-btrack.track display_name: Track diff --git a/napari_btrack/sync.py b/napari_btrack/sync.py new file mode 100644 index 00000000..a8a3dd91 --- /dev/null +++ b/napari_btrack/sync.py @@ -0,0 +1,102 @@ +""" +This module contains functions for syncing widget values with TrackerConfig +values. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from btrack.config import TrackerConfig + from magicgui.widgets import Container + + from napari_btrack.config import Sigmas, UnscaledTrackerConfig + +import napari_btrack.constants + + +def update_config_from_widgets( + unscaled_config: UnscaledTrackerConfig, + container: Container, +) -> TrackerConfig: + """Update an UnscaledTrackerConfig with the current widget values.""" + + # Update MotionModel matrix scaling factors + sigmas: Sigmas = unscaled_config.sigmas + for matrix_name in sigmas: + sigmas[matrix_name] = container[f"{matrix_name}_sigma"].value + + # Update TrackerConfig values + config = unscaled_config.tracker_config + update_method_name = container.update_method.current_choice + update_method_index = container.update_method.choices.index(update_method_name) + config.update_method = update_method_index + config.max_search_radius = container.max_search_radius.value + + # Update MotionModel values + motion_model = config.motion_model + motion_model.accuracy = container.accuracy.value + motion_model.max_lost = container.max_lost.value + + # Update HypothesisModel.hypotheses values + hypothesis_model = config.hypothesis_model + hypothesis_model.hypotheses = [ + hypothesis + for hypothesis in napari_btrack.constants.HYPOTHESES + if container[hypothesis].value + ] + + # Update HypothesisModel scaling factors + for scaling_factor in napari_btrack.constants.HYPOTHESIS_SCALING_FACTORS: + setattr(hypothesis_model, scaling_factor, container[scaling_factor].value) + + # Update HypothesisModel thresholds + for threshold in napari_btrack.constants.HYPOTHESIS_THRESHOLDS: + setattr(hypothesis_model, threshold, container[threshold].value) + + hypothesis_model.segmentation_miss_rate = container.segmentation_miss_rate.value + + return unscaled_config + + +def update_widgets_from_config( + unscaled_config: UnscaledTrackerConfig, + container: Container, +) -> Container: + """ + Update the widgets in a container with the values in an + UnscaledTrackerConfig. + """ + + # Update widgets from MotionModel matrix scaling factors + sigmas: Sigmas = unscaled_config.sigmas + for matrix_name in sigmas: + container[f"{matrix_name}_sigma"].value = sigmas[matrix_name] + + # Update widgets from TrackerConfig values + config = unscaled_config.tracker_config + container.update_method.value = config.update_method.name + container.max_search_radius.value = config.max_search_radius + + # Update widgets from MotionModel values + motion_model = config.motion_model + container.accuracy.value = motion_model.accuracy + container.max_lost.value = motion_model.max_lost + + # Update widgets from HypothesisModel.hypotheses values + hypothesis_model = config.hypothesis_model + for hypothesis in napari_btrack.constants.HYPOTHESES: + is_checked = hypothesis in hypothesis_model.hypotheses + container[hypothesis].value = is_checked + + # Update widgets from HypothesisModel scaling factors + for scaling_factor in napari_btrack.constants.HYPOTHESIS_SCALING_FACTORS: + container[scaling_factor].value = getattr(hypothesis_model, scaling_factor) + + # Update widgets from HypothesisModel thresholds + for threshold in napari_btrack.constants.HYPOTHESIS_THRESHOLDS: + container[threshold].value = getattr(hypothesis_model, threshold) + + container.segmentation_miss_rate.value = hypothesis_model.segmentation_miss_rate + + return container diff --git a/napari_btrack/track.py b/napari_btrack/track.py deleted file mode 100644 index a6f8bd43..00000000 --- a/napari_btrack/track.py +++ /dev/null @@ -1,442 +0,0 @@ -from __future__ import annotations - -import logging -from dataclasses import dataclass, field -from typing import Any - -import btrack -import napari -import numpy as np -import numpy.typing as npt -from btrack import datasets -from btrack.config import ( - HypothesisModel, - MotionModel, - TrackerConfig, - load_config, - save_config, -) -from btrack.utils import segmentation_to_objects -from magicgui.application import use_app -from magicgui.types import FileDialogMode -from magicgui.widgets import Container, PushButton, Widget, create_widget -from pydantic import BaseModel -from qtpy.QtWidgets import QScrollArea - -default_cell_config = load_config(datasets.cell_config()) - -# widgets for which the default widget type is incorrect -HIDDEN_VARIABLE_NAMES = [ - "name", - "measurements", - "states", - "dt", - "apoptosis_rate", - "prob_not_assign", - "eta", -] -ALL_HYPOTHESES = ["P_FP", "P_init", "P_term", "P_link", "P_branch", "P_dead"] - - -@dataclass -class Matrices: - """A helper dataclass to adapt matrix representation to and from pydantic. - This is needed because TrackerConfig stores "scaled" matrices, i.e. - doesn't store sigma and the "unscaled" matrix separately. - """ - - names: list[str] = field(default_factory=lambda: ["A", "H", "P", "G", "R", "Q"]) - widget_labels: list[str] = field( - default_factory=lambda: [ - "A_sigma", - "H_sigma", - "P_sigma", - "G_sigma", - "R_sigma", - "Q_sigma", - ] - ) - default_sigmas: list[float] = field( - default_factory=lambda: [1.0, 1.0, 150.0, 15.0, 5.0] - ) - unscaled_matrices: dict[str, npt.NDArray[np.float64]] = field( - default_factory=lambda: { - "A_cell": np.array( - [ - [1, 0, 0, 1, 0, 0], - [0, 1, 0, 0, 1, 0], - [0, 0, 1, 0, 0, 1], - [0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 1], - ] - ), - "A_particle": np.array( - [ - [1, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 1], - ] - ), - "H": np.array([[1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0]]), - "P": np.array( - [ - [0.1, 0, 0, 0, 0, 0], - [0, 0.1, 0, 0, 0, 0], - [0, 0, 0.1, 0, 0, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 1], - ] - ), - "G": np.array([[0.5, 0.5, 0.5, 1, 1, 1]]), - "R": np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - "Q": np.array( - [ - [56.25, 56.25, 56.25, 112.5, 112.5, 112.5], - [56.25, 56.25, 56.25, 112.5, 112.5, 112.5], - [56.25, 56.25, 56.25, 112.5, 112.5, 112.5], - [112.5, 112.5, 112.5, 225.0, 225.0, 225.0], - [112.5, 112.5, 112.5, 225.0, 225.0, 225.0], - [112.5, 112.5, 112.5, 225.0, 225.0, 225.0], - ] - ), - } - ) - - @classmethod - def get_scaled_matrix( - cls, name: str, *, sigma: float, use_cell_config: bool = True - ) -> list[float]: - """Returns the scaled version (i.e. the unscaled matrix multiplied by sigma) - of the matrix. - - Keyword arguments: - name -- the matrix name (can be one of A, H, P, G, R) - sigma -- the factor to scale the matrix entries with - cell -- whether to use cell config matrices or not (default true) - """ - if name == "A": - name = "A_cell" if use_cell_config else "A_particle" - return (np.asarray(cls().unscaled_matrices[name]) * sigma).tolist() - - @classmethod - def get_sigma(cls, name: str, scaled_matrix: npt.NDArray[np.float64]) -> float: - """Returns the factor sigma which is the multiplier between the given scaled - matrix and the unscaled matrix of the given name. - - Note: The calculation is done with the top-left entry of the matrix, - and all other entries are ignored. - - Keyword arguments: - name -- the matrix name (can be one of A, H, P, G, R) - scaled_matrix -- the scaled matrix to find sigma from. - """ - if name == "A": - name = "A_cell" # doesn't matter which A we use here, as [0][0] is the same - return scaled_matrix[0][0] / cls().unscaled_matrices[name][0][0] - - -def run_tracker( - segmentation: napari.layers.Image | napari.layers.Labels, - tracker_config: TrackerConfig, -) -> tuple[npt.NDArray, dict, dict]: - """ - Runs BayesianTracker with given segmentation and configuration. - """ - with btrack.BayesianTracker() as tracker: - tracker.configure(tracker_config) - - # append the objects to be tracked - segmented_objects = segmentation_to_objects(segmentation.data) - tracker.append(segmented_objects) - - # set the volume - segmentation_size = segmentation.level_shapes[0] - # btrack order of dimensions is XY(Z) - # napari order of dimensions is T(Z)XY - # so we ignore the first entry and then iterate backwards - tracker.volume = tuple((0, s) for s in segmentation_size[1:][::-1]) - - # track them (in interactive mode) - tracker.track_interactive(step_size=100) - - # generate hypotheses and run the global optimizer - tracker.optimize() - - # get the tracks in a format for napari visualization - data, properties, graph = tracker.to_napari(ndim=2) - return data, properties, graph - - -def get_save_path(): - """Helper function to open a save configuration file dialog.""" - show_file_dialog = use_app().get_obj("show_file_dialog") - return show_file_dialog( - mode=FileDialogMode.OPTIONAL_FILE, - caption="Specify file to save btrack configuration", - start_path=None, - filter="*.json", - ) - - -def get_load_path(): - """Helper function to open a load configuration file dialog.""" - show_file_dialog = use_app().get_obj("show_file_dialog") - return show_file_dialog( - mode=FileDialogMode.EXISTING_FILE, - caption="Choose JSON file containing btrack configuration", - start_path=None, - filter="*.json", - ) - - -def html_label_widget(label: str, tag: str = "b") -> dict: - """ - Create a HMTL label widget. - """ - return { - "widget_type": "Label", - "label": f"<{tag}>{label}", - } - - -def _create_per_model_widgets(model: BaseModel) -> list[Widget]: - """ - For a given model create the required list of widgets. - The items "hypotheses" and the various matrices need customisation, - otherwise we can use the napari default. - """ - widgets: list[Widget] = [] - widgets.append(create_widget(**html_label_widget(type(model).__name__))) - for parameter, default_value in model: - if parameter in HIDDEN_VARIABLE_NAMES: - continue - if parameter in Matrices().names: - # only expose the scalar sigma to user - sigma = Matrices.get_sigma(parameter, default_value) - widgets.append( - create_widget(value=sigma, name=f"{parameter}_sigma", annotation=float) - ) - elif parameter == "hypotheses": - # the hypothesis list should be represented as a series of checkboxes - widgets.extend( - [ - create_widget( - value=(choice in default_value), name=choice, annotation=bool - ) - for choice in ALL_HYPOTHESES - ] - ) - else: # use napari default - widgets.append( - create_widget( - value=default_value, name=parameter, annotation=type(default_value) - ) - ) - return widgets - - -def _create_napari_specific_widgets(widgets: list[Widget]) -> None: - """ - Add the widgets which interact with napari itself - """ - widgets.append(create_widget(**html_label_widget("Segmentation"))) - segmentation_widget = create_widget( - name="segmentation", - annotation=napari.layers.Labels, - options={ - "tooltip": ( - "Should be a Labels layer. Convert an Image to Labels by right-clicking" - "on it in the layers list, and clicking on 'Convert to Labels'" - ), - }, - ) - widgets.append(segmentation_widget) - - -def _create_pydantic_default_widgets( - widgets: list[Widget], config: TrackerConfig -) -> None: - """ - Create the widgets which have a tracker config equivalent. - """ - widgets.append( - create_widget(name="max_search_radius", value=config.max_search_radius) - ) - model_configs = [config.motion_model, config.hypothesis_model] - model_widgets = [_create_per_model_widgets(model) for model in model_configs] - widgets.extend([item for sublist in model_widgets for item in sublist]) - - -def _create_cell_or_particle_widget(widgets: list[Widget]) -> None: - """Create a dropdown menu to choose between cell or particle mode.""" - widgets.append(create_widget(**html_label_widget("Mode"))) - widgets.append( - create_widget( - name="mode", value="cell", options={"choices": ["cell", "particle"]} - ) - ) - - -def _widgets_to_tracker_config(container: Container) -> TrackerConfig: - """Helper function to convert from the widgets to a tracker configuration.""" - motion_model_dict: dict[str, Any] = {} - hypothesis_model_dict = {} - - motion_model_keys = default_cell_config.motion_model.dict().keys() - hypothesis_model_keys = default_cell_config.hypothesis_model.dict().keys() - hypotheses = [] - for widget in container: - # setup motion model - # matrices need special treatment - if widget.name in Matrices().widget_labels: - sigma = getattr(container, widget.name).value - matrix_name = widget.name.rstrip("_sigma") - matrix = Matrices.get_scaled_matrix( - matrix_name, - sigma=sigma, - use_cell_config=(container.mode.value == "cell"), - ) - motion_model_dict[matrix_name] = matrix - elif widget.name in motion_model_keys: - motion_model_dict[widget.name] = widget.value - # setup hypothesis model - if widget.name in hypothesis_model_keys: - hypothesis_model_dict[widget.name] = widget.value - # hypotheses need special treatment - if widget.name in ALL_HYPOTHESES and getattr(container, widget.name).value: - hypotheses.append(widget.name) - - # add some non-exposed default values to the motion model - mode = container.mode.value - for default_name, default_value in zip( - ["measurements", "states", "dt", "prob_not_assign", "name"], - [3, 6, 1.0, 0.001, f"{mode}_motion"], - ): - motion_model_dict[default_name] = default_value - - # add some non-exposed default value to the hypothesis model - for default_name, default_value in zip( - ["apoptosis_rate", "eta", "name"], - [0.001, 1.0e-10, f"{mode}_hypothesis"], - ): - hypothesis_model_dict[default_name] = default_value - - # add hypotheses to hypothesis model - hypothesis_model_dict["hypotheses"] = hypotheses - motion_model = MotionModel(**motion_model_dict) - hypothesis_model = HypothesisModel(**hypothesis_model_dict) - - # add parameters outside the internal models - max_search_radius = container.max_search_radius.value - return TrackerConfig( - max_search_radius=max_search_radius, - motion_model=motion_model, - hypothesis_model=hypothesis_model, - ) - - -def _update_widgets_from_config(container: Container, config: TrackerConfig) -> None: - """Helper function to update a container's widgets - with the values in a given tracker config. - """ - container.max_search_radius.value = config.max_search_radius - for model in ["motion_model", "hypothesis_model", "object_model"]: - if model_config := getattr(config, model): - for parameter, value in model_config: - if parameter in HIDDEN_VARIABLE_NAMES: - continue - if parameter in Matrices().names: - sigma = Matrices.get_sigma(parameter, value) - getattr(container, f"{parameter}_sigma").value = sigma - elif parameter == "hypotheses": - for hypothesis in ALL_HYPOTHESES: - getattr(container, hypothesis).value = hypothesis in value - else: - getattr(container, parameter).value = value - # we can determine whether we are in particle or cell mode - # by checking whether the 4th entry of the first row of the - # A matrix is 1 or 0 (1 for cell mode) - mode_is_cell = config.motion_model.A[0, 3] == 1 - logging.info(f"mode is cell: {mode_is_cell}") - container.mode.value = "cell" if mode_is_cell else "particle" - - -def _create_button_widgets(widgets: list[Widget]) -> None: - """Create the set of button widgets needed: - run, save/load configuration and reset.""" - widget_names = [ - "load_config_button", - "save_config_button", - "reset_button", - "call_button", - ] - widget_labels = [ - "Load configuration", - "Save configuration", - "Reset defaults", - "Run", - ] - widgets.append(create_widget(**html_label_widget("Control buttons"))) - widgets.extend( - [ - create_widget(name=widget_name, label=widget_label, widget_type=PushButton) - for widget_name, widget_label in zip( - widget_names, - widget_labels, - ) - ] - ) - - -def track() -> Container: - """ - Create a series of widgets programatically - """ - # initialise a list for all widgets - widgets: list = [] - - # create all the widgets - _create_napari_specific_widgets(widgets) - _create_cell_or_particle_widget(widgets) - _create_pydantic_default_widgets(widgets, default_cell_config) - _create_button_widgets(widgets) - - btrack_widget = Container(widgets=widgets, scrollable=True) - btrack_widget.viewer = napari.current_viewer() - - @btrack_widget.reset_button.changed.connect - def restore_defaults() -> None: - _update_widgets_from_config(btrack_widget, default_cell_config) - - @btrack_widget.call_button.changed.connect - def run() -> None: - config = _widgets_to_tracker_config(btrack_widget) - segmentation = btrack_widget.segmentation.value - data, properties, graph = run_tracker(segmentation, config) - btrack_widget.viewer.add_tracks( - data=data, properties=properties, graph=graph, name=f"{segmentation}_btrack" - ) - - @btrack_widget.save_config_button.changed.connect - def save_config_to_json() -> None: - save_path = get_save_path() - if save_path: # save path is None if user cancels - save_config(save_path, _widgets_to_tracker_config(btrack_widget)) - - @btrack_widget.load_config_button.changed.connect - def load_config_from_json() -> None: - load_path = get_load_path() - if load_path: # load path is None if user cancels - config = load_config(load_path) - _update_widgets_from_config(btrack_widget, config) - - scroll = QScrollArea() - scroll.setWidget(btrack_widget._widget._qwidget) - btrack_widget._widget._qwidget = scroll - - return btrack_widget diff --git a/napari_btrack/widgets/__init__.py b/napari_btrack/widgets/__init__.py new file mode 100644 index 00000000..552b249d --- /dev/null +++ b/napari_btrack/widgets/__init__.py @@ -0,0 +1,5 @@ +from napari_btrack.widgets.create_ui import create_widgets +from napari_btrack.widgets.io import ( + load_path_dialogue_box, + save_path_dialogue_box, +) diff --git a/napari_btrack/widgets/_general.py b/napari_btrack/widgets/_general.py new file mode 100644 index 00000000..6c857121 --- /dev/null +++ b/napari_btrack/widgets/_general.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from magicgui.widgets import Widget + +import magicgui +import napari + + +def create_input_widgets() -> list[Widget]: + """Create widgets for selecting labels layer and TrackerConfig""" + + segmentation_tooltip = ( + "Select a 'Labels' layer to use for tracking.\n" + "To use an 'Image' layer, first convert 'Labels' by right-clicking " + "on it in the layers list, and clicking on 'Convert to Labels'" + ) + segmentation = magicgui.widgets.create_widget( + annotation=napari.layers.Labels, + name="segmentation", + label="segmentation", + options={"tooltip": segmentation_tooltip}, + ) + + config_tooltip = ( + "Select a loaded configuration.\nNote, this will update values set below." + ) + config = magicgui.widgets.create_widget( + value="cell", + name="config", + label="config name", + widget_type="ComboBox", + options={ + "choices": ["cell", "particle"], + "tooltip": config_tooltip, + }, + ) + + return [segmentation, config] + + +def create_update_method_widgets() -> list[Widget]: + """Create widgets for selecting the update method""" + + update_method_tooltip = ( + "Select the update method.\n" + "EXACT: exact calculation of Bayesian belief matrix.\n" + "APPROXIMATE: approximate the Bayesian belief matrix. Useful for datasets with " + "more than 1000 particles per frame." + ) + update_method = magicgui.widgets.create_widget( + value="EXACT", + name="update_method", + label="update method", + widget_type="ComboBox", + options={ + "choices": ["EXACT", "APPROXIMATE"], + "tooltip": update_method_tooltip, + }, + ) + + # TODO: this widget should be hidden when the update method is set to EXACT + max_search_radius_tooltip = ( + "The local spatial search radius (isotropic, pixels) used when the update " + "method is 'APPROXIMATE'" + ) + max_search_radius = magicgui.widgets.create_widget( + value=100, + name="max_search_radius", + label="search radius", + widget_type="SpinBox", + options={"tooltip": max_search_radius_tooltip}, + ) + + return [update_method, max_search_radius] + + +def create_control_widgets() -> list[Widget]: + """Create widgets for running the analysis or handling I/O. + + This includes widgets for running the tracking, saving and loading + configuration files, and resetting the widget values to those in + the selected config.""" + + names = [ + "load_config_button", + "save_config_button", + "reset_button", + "call_button", + ] + labels = [ + "Load configuration", + "Save configuration", + "Reset defaults", + "Run", + ] + tooltips = [ + "Load a TrackerConfig json file.", + "Export the current configuration to a TrackerConfig json file.", + "Reset the current configuration to the defaults stored in the corresponding json file.", # noqa: E501 + "Run the tracking analysis with the current configuration.", + ] + + control_buttons = [] + for name, label, tooltip in zip(names, labels, tooltips): + widget = magicgui.widgets.create_widget( + name=name, + label=label, + widget_type="PushButton", + options={"tooltip": tooltip}, + ) + control_buttons.append(widget) + + return control_buttons diff --git a/napari_btrack/widgets/_hypothesis.py b/napari_btrack/widgets/_hypothesis.py new file mode 100644 index 00000000..c690b754 --- /dev/null +++ b/napari_btrack/widgets/_hypothesis.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from magicgui.widgets import Widget + +import magicgui + +import napari_btrack.constants + + +def _create_hypotheses_widgets() -> list[Widget]: + """Create widgets for selecting which hypotheses to generate.""" + + hypotheses = napari_btrack.constants.HYPOTHESES + tooltips = [ + "Hypothesis that a tracklet is a false positive detection. Always required.", + "Hypothesis that a tracklet starts at the beginning of the movie or edge of the field of view.", # noqa: E501 + "Hypothesis that a tracklet ends at the end of the movie or edge of the field of view.", # noqa: E501 + "Hypothesis that two tracklets should be linked together.", + "Hypothesis that a tracklet can split into two daughter tracklets.", + "Hypothesis that a tracklet terminates without leaving the field of view.", + "Hypothesis that two tracklets merge into one tracklet.", + ] + + hypotheses_widgets = [] + for hypothesis, tooltip in zip(hypotheses, tooltips): + widget = magicgui.widgets.create_widget( + value=True, + name=hypothesis, + label=hypothesis, + widget_type="CheckBox", + options={"tooltip": tooltip}, + ) + hypotheses_widgets.append(widget) + + # P_FP is always required + P_FP_hypothesis = hypotheses_widgets[0] + P_FP_hypothesis.enabled = False + + # P_merge should be disabled by default + P_merge_hypothesis = hypotheses_widgets[-1] + P_merge_hypothesis.value = False + + return hypotheses_widgets + + +def _create_scaling_factor_widgets() -> list[Widget]: + """Create widgets for setting the scaling factors of the HypothesisModel""" + + widget_values = [5.0, 3.0, 10.0, 50.0] + names = [ + "lambda_time", + "lambda_dist", + "lambda_link", + "lambda_branch", + ] + labels = [ + "λ time", + "λ distance", + "λ linking", + "λ branching", + ] + tooltips = [ + "Scaling factor for the influence of time when determining initialization or termination hypotheses.", # noqa: E501 + "Scaling factor for the influence of distance at the border when determining initialization or termination hypotheses.", # noqa: E501 + "Scaling factor for the influence of track-to-track distance on linking probability.", # noqa: E501 + "Scaling factor for the influence of cell state and position on division (mitosis/branching) probability.", # noqa: E501 + ] + + scaling_factor_widgets = [] + for value, name, label, tooltip in zip(widget_values, names, labels, tooltips): + widget = magicgui.widgets.create_widget( + value=value, + name=name, + label=label, + widget_type="FloatSpinBox", + options={"tooltip": tooltip}, + ) + scaling_factor_widgets.append(widget) + + return scaling_factor_widgets + + +def _create_threshold_widgets() -> list[Widget]: + """Create widgets for setting thresholds for the HypothesisModel""" + + distance_threshold_tooltip = ( + "A threshold distance from the edge of the field of view to add an " + "initialization or termination hypothesis." + ) + distance_threshold = magicgui.widgets.create_widget( + value=20.0, + name="theta_dist", + label="distance threshold", + widget_type="FloatSpinBox", + options={"tooltip": distance_threshold_tooltip}, + ) + + time_threshold_tooltip = ( + "A threshold time from the beginning or end of movie to add " + "an initialization or termination hypothesis." + ) + time_threshold = magicgui.widgets.create_widget( + value=5.0, + name="theta_time", + label="time threshold", + widget_type="FloatSpinBox", + options={"tooltip": time_threshold_tooltip}, + ) + + apoptosis_threshold_tooltip = ( + "Number of apoptotic detections to be considered a genuine event.\n" + "Detections are counted consecutively from the back of the track" + ) + apoptosis_threshold = magicgui.widgets.create_widget( + value=5, + name="apop_thresh", + label="apoptosis threshold", + widget_type="SpinBox", + options={"tooltip": apoptosis_threshold_tooltip}, + ) + + return [ + distance_threshold, + time_threshold, + apoptosis_threshold, + ] + + +def _create_bin_size_widgets() -> list[Widget]: + """Create widget for setting bin sizes for the HypothesisModel""" + + distance_bin_size_tooltip = ( + "Isotropic spatial bin size for considering hypotheses.\n" + "Larger bin sizes generate more hypothesese for each tracklet." + ) + distance_bin_size = magicgui.widgets.create_widget( + value=40.0, + name="dist_thresh", + label="distance bin size", + widget_type="FloatSpinBox", + options={"tooltip": distance_bin_size_tooltip}, + ) + + time_bin_size_tooltip = ( + "Temporal bin size for considering hypotheses.\n" + "Larger bin sizes generate more hypothesese for each tracklet." + ) + time_bin_size = magicgui.widgets.create_widget( + value=2.0, + name="time_thresh", + label="time bin size", + widget_type="FloatSpinBox", + options={"tooltip": time_bin_size_tooltip}, + ) + + return [ + distance_bin_size, + time_bin_size, + ] + + +def create_hypothesis_model_widgets() -> list[Widget]: + """Create widgets for setting parameters of the MotionModel""" + + hypothesis_model_label = magicgui.widgets.create_widget( + label="Hypothesis model", # bold label + widget_type="Label", + gui_only=True, + ) + + hypotheses_widgets = _create_hypotheses_widgets() + scaling_factor_widgets = _create_scaling_factor_widgets() + threshold_widgets = _create_threshold_widgets() + bin_size_widgets = _create_bin_size_widgets() + + segmentation_miss_rate_tooltip = ( + "Miss rate for the segmentation.\n" + "e.g. 1/100 segmentations incorrect gives a segmentation miss rate of 0.01." + ) + segmentation_miss_rate = magicgui.widgets.create_widget( + value=0.1, + name="segmentation_miss_rate", + label="miss rate", + widget_type="FloatSpinBox", + options={"tooltip": segmentation_miss_rate_tooltip}, + ) + + relax_tooltip = ( + "Disable the time and distance thresholds.\n" + "This means that tracks can initialize or terminate anywhere and" + "at any time in the dataset." + ) + relax = magicgui.widgets.create_widget( + value=True, + name="relax", + label="relax thresholds", + widget_type="CheckBox", + options={"tooltip": relax_tooltip}, + ) + + return [ + hypothesis_model_label, + *hypotheses_widgets, + *scaling_factor_widgets, + *threshold_widgets, + *bin_size_widgets, + segmentation_miss_rate, + relax, + ] diff --git a/napari_btrack/widgets/_motion.py b/napari_btrack/widgets/_motion.py new file mode 100644 index 00000000..db0a3848 --- /dev/null +++ b/napari_btrack/widgets/_motion.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from magicgui.widgets import Widget + +import magicgui + + +def _make_label_bold(label: str) -> str: + """Generate html for a bold label""" + + return f"{label}" + + +def _create_sigma_widgets() -> list[Widget]: + """Create widgets for setting the magnitudes of the MotionModel matrices""" + + P_sigma_tooltip = ( + "Magnitude of error in initial estimates.\n Used to scale the matrix P." + ) + P_sigma = magicgui.widgets.create_widget( + value=150.0, + name="P_sigma", + label=f"max({_make_label_bold('P')})", + widget_type="FloatSpinBox", + options={"tooltip": P_sigma_tooltip}, + ) + + G_sigma_tooltip = "Magnitude of error in process.\n Used to scale the matrix G." + G_sigma = magicgui.widgets.create_widget( + value=15.0, + name="G_sigma", + label=f"max({_make_label_bold('G')})", + widget_type="FloatSpinBox", + options={"tooltip": G_sigma_tooltip}, + ) + + R_sigma_tooltip = ( + "Magnitude of error in measurements.\n Used to scale the matrix R." + ) + R_sigma = magicgui.widgets.create_widget( + value=5.0, + name="R_sigma", + label=f"max({_make_label_bold('R')})", + widget_type="FloatSpinBox", + options={"tooltip": R_sigma_tooltip}, + ) + + return [ + P_sigma, + G_sigma, + R_sigma, + ] + + +def create_motion_model_widgets() -> list[Widget]: + """Create widgets for setting parameters of the MotionModel""" + + motion_model_label = magicgui.widgets.create_widget( + label=_make_label_bold("Motion model"), + widget_type="Label", + gui_only=True, + ) + + sigma_widgets = _create_sigma_widgets() + + accuracy_tooltip = "Integration limits for calculating probabilities" + accuracy = magicgui.widgets.create_widget( + value=7.5, + name="accuracy", + label="accuracy", + widget_type="FloatSpinBox", + options={"tooltip": accuracy_tooltip}, + ) + + max_lost_frames_tooltip = ( + "Number of frames without observation before marking as lost" + ) + max_lost_frames = magicgui.widgets.create_widget( + value=5, + name="max_lost", + label="max lost", + widget_type="SpinBox", + options={"tooltip": max_lost_frames_tooltip}, + ) + + return [ + motion_model_label, + *sigma_widgets, + accuracy, + max_lost_frames, + ] diff --git a/napari_btrack/widgets/create_ui.py b/napari_btrack/widgets/create_ui.py new file mode 100644 index 00000000..11f4966b --- /dev/null +++ b/napari_btrack/widgets/create_ui.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from magicgui.widgets import Widget + +from napari_btrack.widgets._general import ( + create_control_widgets, + create_input_widgets, + create_update_method_widgets, +) +from napari_btrack.widgets._hypothesis import create_hypothesis_model_widgets +from napari_btrack.widgets._motion import create_motion_model_widgets + + +def create_widgets() -> list[Widget]: + """Create all the widgets for the plugin""" + + input_widgets = create_input_widgets() + update_method_widgets = create_update_method_widgets() + motion_model_widgets = create_motion_model_widgets() + hypothesis_model_widgets = create_hypothesis_model_widgets() + control_buttons = create_control_widgets() + + return [ + *input_widgets, + *update_method_widgets, + *motion_model_widgets, + *hypothesis_model_widgets, + *control_buttons, + ] diff --git a/napari_btrack/widgets/io.py b/napari_btrack/widgets/io.py new file mode 100644 index 00000000..6be57e65 --- /dev/null +++ b/napari_btrack/widgets/io.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import os + +import magicgui +from magicgui.types import FileDialogMode + + +def save_path_dialogue_box() -> os.PathLike: + """Helper function to open a save configuration file dialog.""" + + app = magicgui.application.use_app() + show_file_dialog = app.get_obj("show_file_dialog") + + return show_file_dialog( + mode=FileDialogMode.OPTIONAL_FILE, + caption="Specify file to save btrack configuration", + start_path=None, + filter="*.json", + ) + + +def load_path_dialogue_box() -> os.PathLike: + """Helper function to open a load configuration file dialog.""" + + app = magicgui.application.use_app() + show_file_dialog = app.get_obj("show_file_dialog") + + return show_file_dialog( + mode=FileDialogMode.EXISTING_FILE, + caption="Choose JSON file containing btrack configuration", + start_path=None, + filter="*.json", + ) diff --git a/pyproject.toml b/pyproject.toml index 23b3f052..532449a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ fix = true force-exclude = true ignore = [ "B905", # python>=3.10 + "N806", ] select = [ "A", @@ -107,6 +108,14 @@ select = [ isort.known-first-party = ["napari_btrack"] mccabe.max-complexity = 18 +[tool.ruff.per-file-ignores] +"__init__.py" = [ + "F401", # unused-import +] +"test_*.py" = [ + "S101", # use of 'assert' +] + [tool.setuptools_scm] write_to = "napari_btrack/_version.py"