diff --git a/examples/show_large_tree.py b/examples/show_large_tree.py index 22c0eb7..42351a2 100644 --- a/examples/show_large_tree.py +++ b/examples/show_large_tree.py @@ -24,7 +24,7 @@ def generate_binary_tree(max_depth: int): return nodes -nodes = generate_binary_tree(6) +nodes = generate_binary_tree(8) print(f"{len(nodes)} total nodes") diff --git a/napari_arboretum/graph.py b/napari_arboretum/graph.py index 81178ed..78e9836 100644 --- a/napari_arboretum/graph.py +++ b/napari_arboretum/graph.py @@ -16,7 +16,7 @@ class TreeNode: """TreeNode.""" ID: int - t: Tuple[int, int] + t: np.ndarray generation: int children: List[int] = field(default_factory=list) @@ -154,7 +154,8 @@ def build_subgraph(layer: napari.layers.Tracks, search_node: int) -> List[TreeNo def _node_from_graph(_id): idx = np.where(layer.data[:, 0] == _id)[0] - t = (np.min(layer.data[idx, 1]), np.max(layer.data[idx, 1])) + # t = (np.min(layer.data[idx, 1]), np.max(layer.data[idx, 1])) + t = layer.data[idx, 1] node = TreeNode(ID=_id, t=t, generation=1) if _id in reverse_graph: diff --git a/napari_arboretum/tree.py b/napari_arboretum/tree.py index b3dafe1..2516d65 100644 --- a/napari_arboretum/tree.py +++ b/napari_arboretum/tree.py @@ -30,6 +30,7 @@ class Edge: y: Tuple[float, float] color: np.ndarray = WHITE id: Optional[int] = None + node: Optional[TreeNode] = None def layout_tree(nodes: List[TreeNode]) -> Tuple[List[Edge], List[Annotation]]: @@ -66,7 +67,7 @@ def layout_tree(nodes: List[TreeNode]) -> Tuple[List[Edge], List[Annotation]]: y = y_pos.pop(0) # draw the root of the tree - edges.append(Edge(y=(y, y), x=(node.t[0], node.t[-1]), id=node.ID)) + edges.append(Edge(y=(y, y), x=(node.t[0], node.t[-1]), id=node.ID, node=node)) if node.is_root: annotations.append(Annotation(y=y, x=node.t[0], label=str(node.ID))) @@ -108,14 +109,4 @@ def layout_tree(nodes: List[TreeNode]) -> Tuple[List[Edge], List[Annotation]]: ) ) - # now that we have traversed the tree, calculate the span - tree_span = [] - for edge in edges: - tree_span.append(edge.y[0]) - tree_span.append(edge.y[1]) - - # # work out the span of the tree, we can modify positioning here - # min_x = min(tree_span) - # max_x = max(tree_span) - return edges, annotations diff --git a/napari_arboretum/visualisation/base_plotter.py b/napari_arboretum/visualisation/base_plotter.py index 522f190..28fe216 100644 --- a/napari_arboretum/visualisation/base_plotter.py +++ b/napari_arboretum/visualisation/base_plotter.py @@ -9,6 +9,8 @@ from ..tree import Annotation, Edge, layout_tree from ..util import TrackPropertyMixin +# from ..profiler import profiler + GUI_MAXIMUM_WIDTH = 600 __all__ = ["TreePlotterBase", "TreePlotterQWidgetBase"] @@ -46,6 +48,7 @@ def draw_tree(self) -> None: subgraph_nodes = build_subgraph(self.tracks, self.track_id) self.draw_from_nodes(subgraph_nodes, self.track_id) + # @profiler("draw_from_nodes") def draw_from_nodes( self, tree_nodes: List[TreeNode], track_id: Optional[int] = None ): @@ -61,6 +64,8 @@ def draw_from_nodes( for a in self.annotations: self.add_annotation(a) + self.draw_tree_visual() + def update_edge_colors(self, update_live: bool = True) -> None: """ Update tree edge colours from the track properties. @@ -117,6 +122,13 @@ def draw_current_time_line(self, time: int) -> None: """ raise NotImplementedError + @abc.abstractmethod + def draw_tree_visual(self) -> None: + """ + Function to draw the visual after construction. + """ + raise NotImplementedError + class TreePlotterQWidgetBase(TreePlotterBase): """ diff --git a/napari_arboretum/visualisation/vispy_plotter.py b/napari_arboretum/visualisation/vispy_plotter.py index abca114..20fd496 100644 --- a/napari_arboretum/visualisation/vispy_plotter.py +++ b/napari_arboretum/visualisation/vispy_plotter.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional import numpy as np from qtpy.QtWidgets import QWidget @@ -11,6 +10,10 @@ __all__ = ["VisPyPlotter"] +DEFAULT_TEXT_SIZE = 8 +DEFAULT_BRANCH_WIDTH = 3 + + @dataclass class Bounds: xmin: float @@ -19,6 +22,31 @@ class Bounds: ymax: float +@dataclass +class TrackSubvisualProxy: + pos: np.ndarray + color: np.ndarray = np.array([1.0, 1.0, 1.0, 1.0]) + + @property + def connex(self): + connex = [True] * (self.pos.shape[0] - 1) + [False] + return connex + + @property + def safe_color(self) -> np.ndarray: + if self.color.ndim != 2: + safe_color = np.repeat([self.color], self.pos.shape[0], axis=0) + return safe_color + return self.color + + +@dataclass +class AnnotationSubvisualProxy: + pos: np.ndarray + text: str + color: str = "white" + + class VisPyPlotter(TreePlotterQWidgetBase): """ Tree plotter using pyqtgraph as the plotting backend. @@ -71,6 +99,14 @@ def autoscale_view(self) -> None: width * (1 + 2 * padding), height * (1 + 2 * padding), ) + + # change the aspect ratio of the camera if we have just a single branch + # this will centre the camera on the single branch, otherwise, set the + # aspect ratio to match the data + if width == 0: + self.view.camera.aspect = 1.0 + else: + self.view.camera.aspect = None self.view.camera.rect = rect def update_colors(self) -> None: @@ -85,7 +121,8 @@ def add_branch(self, e: Edge) -> None: """ Add a single branch to the tree. """ - self.tree.add_track(e.id, np.column_stack((e.y, e.x)), e.color) + # self.tree.add_track(e.id, np.column_stack((e.y, e.x)), e.color) + self.tree.add_track(e) self.autoscale_view() def add_annotation(self, a: Annotation) -> None: @@ -104,6 +141,12 @@ def draw_current_time_line(self, time: int) -> None: pos=np.array([[bounds.xmin - padding, time], [bounds.xmax + padding, time]]) ) + def draw_tree_visual(self) -> None: + """ + Draw the whole tree. + """ + self.tree.draw_tree() + class TreeVisual(scene.visuals.Compound): """ @@ -116,7 +159,22 @@ def __init__(self, parent): self.unfreeze() # Keep a reference to tracks we add so their colour can be changed later self.tracks = {} - self.subvisuals = [] + self.edges = [] + self.annotations = [] + + subvisuals = [ + scene.visuals.Line(color="white", width=DEFAULT_BRANCH_WIDTH), + scene.visuals.Text( + anchor_x="left", + anchor_y="top", + rotation=90, + font_size=DEFAULT_TEXT_SIZE, + color="white", + ), + ] + + for visual in subvisuals: + self.add_subvisual(visual) def get_branch_color(self, branch_id: int) -> np.ndarray: return self.tracks[branch_id].color @@ -125,9 +183,12 @@ def set_branch_color(self, branch_id: int, color: np.ndarray) -> None: """ Set the color of an individual branch. """ - self.tracks[branch_id].set_data(color=color) + self.tracks[branch_id].color = color + self._subvisuals[0].set_data( + color=np.row_stack([e.safe_color for e in self.edges]), + ) - def add_track(self, id: Optional[int], pos: np.ndarray, color: np.ndarray) -> None: + def add_track(self, e: Edge) -> None: """ Parameters ---------- @@ -139,36 +200,58 @@ def add_track(self, id: Optional[int], pos: np.ndarray, color: np.ndarray) -> No Array of shape (n, 4) specifying RGBA values in range [0, 1] along the track. """ - if id is None: - visual = scene.visuals.Line(pos=pos, color=color, width=3) + color = e.color + pos = np.column_stack((e.y, e.x)) + + if e.node is None: + subvisual_proxy = TrackSubvisualProxy( + pos=pos, + color=np.array([1.0, 1.0, 1.0, 1.0]), + ) else: # Split up line into individual time steps so color can vary # along the line - ys = np.arange(pos[0, 1], pos[1, 1] + 1) + ys = np.asarray(e.node.t) # np.arange(pos[0, 1], pos[1, 1] + 1) xs = np.ones(ys.size) * pos[0, 0] - visual = scene.visuals.Line( - pos=np.column_stack((xs, ys)), color=color, width=3 + subvisual_proxy = TrackSubvisualProxy( + pos=np.column_stack((xs, ys)), + color=color, ) - self.tracks[id] = visual + # store a reference to this subvisual proxy + self.tracks[e.id] = subvisual_proxy - self.add_subvisual(visual) - self.subvisuals.append(visual) + self.edges.append(subvisual_proxy) def add_annotation(self, x: float, y: float, label: str, color): - visual = scene.visuals.Text( + + subvisual_proxy = AnnotationSubvisualProxy( text=label, - color=color, pos=[y, x, 0], - anchor_x="left", - anchor_y="top", - font_size=6, - rotation=90, ) - self.add_subvisual(visual) - self.subvisuals.append(visual) + + self.annotations.append(subvisual_proxy) def clear(self) -> None: """Remove all tracks.""" - while self.subvisuals: - subvisual = self.subvisuals.pop() - self.remove_subvisual(subvisual) + self.tracks = {} + self.edges = [] + self.annotations = [] + + for visual in self._subvisuals: + visual._pos = None + + if hasattr(visual, "_text"): + visual._text = None + + def draw_tree(self) -> None: + """Once the data is added, draw the tree.""" + + self._subvisuals[0].set_data( + pos=np.row_stack([e.pos for e in self.edges]), + color=np.row_stack([e.safe_color for e in self.edges]), + connect=np.concatenate([e.connex for e in self.edges]), + ) + + # TextVisual does not have a ``set_data`` method + self._subvisuals[1].pos = np.asarray([a.pos for a in self.annotations]) + self._subvisuals[1].text = [a.text for a in self.annotations]