Skip to content

Commit

Permalink
Remove cell and particle modes to support arbitrary TrackerConfig con…
Browse files Browse the repository at this point in the history
…figs (#64)

Fixes #63 

- handle arbitrary user configs
- when the plugin starts, the example `cell_config` and
`particle_config` files are loaded, and users can load more of their own

This ended up being a bigger refactor than I planned - I need to
refactor quite a few of the functions because they assume there are two
modes ('cell' and 'particle'). And at the same time I moved toward a
more modular structure as has been suggested
lowe-lab-ucl/napari-btrack#26.

**refactor `napari_btrack.track` into various sub-packages and modules**

- `napari_btrack.track`:
    - renamed to `napari_btrack.main`
- contains only code for launching the plugin, defining callback
functions, and running the analysis

- `napari_btrack.config`:
- a new module with classes to handle converting between scaled and
unscaled matrices in `MotionModel`s
- removed `Matrices` class (which hardcoded values for `cell` or
`particle`modes) and replaced with `UnscaledTrackerConfig` to handle
arbitrary user configs

- `napari_btrack.widgets`:
    - a new sub-package for creating the widgets for the plugin

- `napari_btrack.sync`:
- a new module to handle updating a config from widget values or vice
versa

**other changes**

- added per-file-ignores for some linting rules (allow `assert` in test
files, allow unused imports in `__init__.py` files)
- some classes / functions were being imported only for type checking -
move these imports in a check `if TYPE_CHECKING:` so that they're not
imported at runtime
- renamed the function that launches the plugin from
`napari_btrack.track.track` to `napari_btrack.main.create_btrack_widget`
- explicitly set the expected `widget_type` when using
`magicgui.widgets.create_widget` - knowing the widgets we're using
should make it easier to move to using `qt` directly at some point,
which would allow us to have separate tabs for each section
- add tooltips for every widget - they're based on the `btrack` api docs

---------

Co-authored-by: Patrick Roddy <[email protected]>
  • Loading branch information
p-j-smith and paddyroddy authored Mar 22, 2023
1 parent 173a8c8 commit 11e61ad
Show file tree
Hide file tree
Showing 15 changed files with 1,116 additions and 499 deletions.
9 changes: 9 additions & 0 deletions napari_btrack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,12 @@
from ._version import version as __version__
except ImportError:
__version__ = "unknown"

import logging

from napari_btrack import constants, main

__all__ = [
"constants",
"main",
]
124 changes: 68 additions & 56 deletions napari_btrack/_tests/test_dock_widget.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from magicgui.widgets import Container

import json
from unittest.mock import patch

import btrack
import napari
import numpy as np
import numpy.typing as npt
import pytest
from btrack import datasets
from btrack.config import load_config
from btrack.datasets import cell_config, particle_config
from magicgui.widgets import Container

from napari_btrack.track import (
_update_widgets_from_config,
_widgets_to_tracker_config,
track,
)
import napari_btrack
import napari_btrack.main

OLD_WIDGET_LAYERS = 1
NEW_WIDGET_LAYERS = 2
Expand All @@ -32,79 +33,92 @@ def test_add_widget(make_napari_viewer):
widget_name="Track",
)

assert len(list(viewer.window._dock_widgets)) == num_dw + 1 # noqa: S101
assert len(list(viewer.window._dock_widgets)) == num_dw + 1


@pytest.fixture
def track_widget(make_napari_viewer) -> Container:
"""Provides an instance of the track widget to test"""
make_napari_viewer() # make sure there is a viewer available
return track()
return napari_btrack.main.create_btrack_widget()


@pytest.mark.parametrize("config", [cell_config(), particle_config()])
def test_config_to_widgets_round_trip(track_widget, config):
"""Tests that going back and forth between
config objects and widgets works as expected.
"""
expected_config = load_config(config)
_update_widgets_from_config(track_widget, expected_config)
actual_config = _widgets_to_tracker_config(track_widget)
# use json.loads to avoid failure in string comparison because e.g "100.0" != "100"
assert json.loads(actual_config.json()) == json.loads( # noqa: S101
expected_config.json()
)

expected_config = btrack.config.load_config(config).json()

@pytest.fixture
def user_config_path() -> str:
"""Provides a (dummy) string to represent a user-provided config path."""
return "user_config.json"
unscaled_config = napari_btrack.config.UnscaledTrackerConfig(config)
napari_btrack.sync.update_widgets_from_config(unscaled_config, track_widget)
napari_btrack.sync.update_config_from_widgets(unscaled_config, track_widget)

actual_config = unscaled_config.scale_config().json()

# use json.loads to avoid failure in string comparison because e.g "100.0" != "100"
assert json.loads(actual_config) == json.loads(expected_config)

def test_save_button(user_config_path, track_widget):

def test_save_button(track_widget):
"""Tests that clicking the save configuration button
triggers a call to btrack.config.save_config with expected arguments.
"""
with patch("napari_btrack.track.save_config") as save_config, patch(
"napari_btrack.track.get_save_path"
) as get_save_path:
get_save_path.return_value = user_config_path

unscaled_config = napari_btrack.config.UnscaledTrackerConfig(cell_config())
unscaled_config.tracker_config.name = "cell" # this is done in in the gui too
expected_config = unscaled_config.scale_config().json()

with patch(
"napari_btrack.widgets.save_path_dialogue_box"
) as save_path_dialogue_box:
save_path_dialogue_box.return_value = "user_config.json"
track_widget.save_config_button.clicked()
assert save_config.call_args[0][0] == user_config_path # noqa: S101

actual_config = btrack.config.load_config("user_config.json").json()

# use json.loads to avoid failure in string comparison because e.g "100.0" != "100"
assert json.loads(save_config.call_args[0][1].json()) == json.loads( # noqa: S101
load_config(cell_config()).json()
)
assert json.loads(expected_config) == json.loads(actual_config)


def test_load_button(user_config_path, track_widget):
"""Tests that clicking the load configuration button
triggers a call to btrack.config.load_config with the expected argument
"""
with patch("napari_btrack.track.load_config") as load_config, patch(
"napari_btrack.track.get_load_path"
) as get_load_path:
get_load_path.return_value = user_config_path
def test_load_config(track_widget):
"""Tests that another TrackerConfig can be loaded and made the current config."""

# this is set to be 'cell' rather than 'Default'
original_config_name = track_widget.config.current_choice

with patch(
"napari_btrack.widgets.load_path_dialogue_box"
) as load_path_dialogue_box:
load_path_dialogue_box.return_value = cell_config()
track_widget.load_config_button.clicked()
assert load_config.call_args[0][0] == user_config_path # noqa: S101

# We didn't override the name, so it should be 'Default'
new_config_name = track_widget.config.current_choice

assert track_widget.config.value == "Default"
assert new_config_name != original_config_name


def test_reset_button(track_widget):
"""Tests that clicking the reset button with
particle-config-populated widgets resets to the default (i.e. cell-config)
"""
# change config to particle
_update_widgets_from_config(track_widget, load_config(particle_config()))
"""Tests that clicking the reset button restores the default config values"""

# click reset button (default is cell_config)
original_max_search_radius = track_widget.max_search_radius.value
original_relax = track_widget.relax.value

# change some widget values
track_widget.max_search_radius.value += 10
track_widget.relax.value = not track_widget.relax

# click reset button - restores defaults of the currently-selected base config
track_widget.reset_button.clicked()
config_after_reset = _widgets_to_tracker_config(track_widget)

# use json.loads to avoid failure in string comparison because e.g "100.0" != "100"
assert json.loads(config_after_reset.json()) == json.loads( # noqa: S101
load_config(cell_config()).json()
)
new_max_search_radius = track_widget.max_search_radius.value
new_relax = track_widget.relax.value

assert new_max_search_radius == original_max_search_radius
assert new_relax == original_relax


@pytest.fixture
Expand All @@ -127,14 +141,12 @@ def test_run_button(track_widget, simplistic_tracker_outputs):
"""Tests that clicking the run button calls run_tracker,
and that the napari viewer has an additional tracks layer after running.
"""
with patch("napari_btrack.track.run_tracker") as run_tracker:
with patch("napari_btrack.main._run_tracker") as run_tracker:
run_tracker.return_value = simplistic_tracker_outputs
segmentation = datasets.example_segmentation()
track_widget.viewer.add_labels(segmentation)
assert len(track_widget.viewer.layers) == OLD_WIDGET_LAYERS # noqa: S101
assert len(track_widget.viewer.layers) == OLD_WIDGET_LAYERS
track_widget.call_button.clicked()
assert run_tracker.called # noqa: S101
assert len(track_widget.viewer.layers) == NEW_WIDGET_LAYERS # noqa: S101
assert isinstance( # noqa: S101
track_widget.viewer.layers[-1], napari.layers.Tracks
)
assert run_tracker.called
assert len(track_widget.viewer.layers) == NEW_WIDGET_LAYERS
assert isinstance(track_widget.viewer.layers[-1], napari.layers.Tracks)
181 changes: 181 additions & 0 deletions napari_btrack/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from btrack.config import TrackerConfig

import copy
import os
from dataclasses import dataclass, field

import btrack
import numpy as np
from btrack import datasets

__all__ = [
"create_default_configs",
]


@dataclass
class Sigmas:
"""Values to scale TrackerConfig MotionModel matrices by.
Args:
P: Scaling factor for the matrix ``P`` - the error in initial estimates.
G: Scaling factor for the matrix ``G`` - the error in the MotionModel process.
R: Scaling factor for the matrix ``R`` - the error in the measurements.
"""

P: float
G: float
R: float

def __getitem__(self, matrix_name):
return self.__dict__[matrix_name]

def __setitem__(self, matrix_name, sigma):
if matrix_name not in self.__dict__.keys():
_msg = f"Unknown matrix name '{matrix_name}'"
raise ValueError(_msg)
self.__dict__[matrix_name] = sigma

def __iter__(self):
yield from self.__dict__.keys()


@dataclass
class UnscaledTrackerConfig:
"""Convert TrackerConfig MotionModel matrices from scaled to unscaled.
This is needed because TrackerConfig stores "scaled" matrices, i.e. it
doesn't store sigma and the "unscaled" MotionModel matrices separately.
Args:
filename: name of the json file containing the TrackerConfig to load.
Attributes:
tracker_config: unscaled configuration based on the config in ``filename``.
sigmas: scaling factors to apply to the unscaled MotionModel matrices of
``tracker_config``.
"""

filename: os.PathLike
tracker_config: TrackerConfig = field(init=False)
sigmas: Sigmas = field(init=False)

def __post_init__(self):
"""Create the TrackerConfig and un-scale the MotionModel indices"""

config = btrack.config.load_config(self.filename)
self.tracker_config, self.sigmas = self._unscale_config(config)

def _unscale_config(self, config: TrackerConfig) -> tuple[TrackerConfig, Sigmas]:
"""Convert the matrices of a scaled TrackerConfig MotionModel to unscaled."""

P_sigma = np.max(config.motion_model.P)
config.motion_model.P /= P_sigma

R_sigma = np.max(config.motion_model.R)
config.motion_model.R /= R_sigma

# Use only G, not Q. If we use both G and Q, then Q_sigma must be updated
# when G_sigma is, and vice-versa
# Instead, use G if it exists. If not, determine G from Q, which we can
# do because Q = G.T @ G
if config.motion_model.G is None:
config.motion_model.G = config.motion_model.Q.diagonal() ** 0.5
G_sigma = np.max(config.motion_model.G)
config.motion_model.G /= G_sigma

sigmas = Sigmas(
P=P_sigma,
G=G_sigma,
R=R_sigma,
)

return config, sigmas

def scale_config(self) -> TrackerConfig:
"""Create a new TrackerConfig with scaled MotionModel matrices"""

# Create a copy so that config values stay in sync with widget values
scaled_config = copy.deepcopy(self.tracker_config)
scaled_config.motion_model.P *= self.sigmas.P
scaled_config.motion_model.R *= self.sigmas.R
scaled_config.motion_model.G *= self.sigmas.G
scaled_config.motion_model.Q = (
scaled_config.motion_model.G.T @ scaled_config.motion_model.G
)

return scaled_config


@dataclass
class TrackerConfigs:
"""Store all loaded TrackerConfig configurations.
Will load ``btrack``'s default 'cell' and 'particle' configurations on
initialisation.
Attributes:
configs: dictionary of loaded configurations. The name of the config (
TrackerConfig.name) is used as the key.
current_config: the currently-selected configuration.
"""

configs: dict[str, UnscaledTrackerConfig] = field(default_factory=dict)
current_config: str = field(init=False)

def __post_init__(self):
"""Add the default cell and particle configs."""

self.add_config(
filename=datasets.cell_config(),
name="cell",
overwrite=False,
)
self.add_config(
filename=datasets.particle_config(),
name="particle",
overwrite=False,
)

self.current_config = "cell"

def __getitem__(self, config_name):
return self.configs[config_name]

def add_config(
self,
filename,
overwrite,
name=None,
) -> str:
"""Load a TrackerConfig and add it to the store."""

config = UnscaledTrackerConfig(filename)
config_name = config.tracker_config.name if name is None else name
config.tracker_config.name = config_name

# TODO: Make the combobox editable so config names can be changed within the GUI
if config_name in self.configs and not overwrite:
_msg = (
f"Config '{config_name}' already exists - config names must be unique."
)
raise ValueError(_msg)

self.configs[config_name] = config

return config_name


def create_default_configs() -> TrackerConfigs:
"""Create a set of default configurations."""

# TrackerConfigs automatically loads default cell and particle configs
return TrackerConfigs()
Loading

0 comments on commit 11e61ad

Please sign in to comment.