Skip to content

Commit

Permalink
Merge pull request #78 from lowe-lab-ucl/increase-performance
Browse files Browse the repository at this point in the history
Increase rendering performance
  • Loading branch information
quantumjot authored Aug 24, 2022
2 parents 8bc5ed6 + b388223 commit f2eec96
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 38 deletions.
2 changes: 1 addition & 1 deletion examples/show_large_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def generate_binary_tree(max_depth: int):
return nodes


nodes = generate_binary_tree(6)
nodes = generate_binary_tree(8)
print(f"{len(nodes)} total nodes")


Expand Down
5 changes: 3 additions & 2 deletions napari_arboretum/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TreeNode:
"""TreeNode."""

ID: int
t: Tuple[int, int]
t: np.ndarray
generation: int
children: List[int] = field(default_factory=list)

Expand Down Expand Up @@ -154,7 +154,8 @@ def build_subgraph(layer: napari.layers.Tracks, search_node: int) -> List[TreeNo
def _node_from_graph(_id):

idx = np.where(layer.data[:, 0] == _id)[0]
t = (np.min(layer.data[idx, 1]), np.max(layer.data[idx, 1]))
# t = (np.min(layer.data[idx, 1]), np.max(layer.data[idx, 1]))
t = layer.data[idx, 1]
node = TreeNode(ID=_id, t=t, generation=1)

if _id in reverse_graph:
Expand Down
13 changes: 2 additions & 11 deletions napari_arboretum/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Edge:
y: Tuple[float, float]
color: np.ndarray = WHITE
id: Optional[int] = None
node: Optional[TreeNode] = None


def layout_tree(nodes: List[TreeNode]) -> Tuple[List[Edge], List[Annotation]]:
Expand Down Expand Up @@ -66,7 +67,7 @@ def layout_tree(nodes: List[TreeNode]) -> Tuple[List[Edge], List[Annotation]]:
y = y_pos.pop(0)

# draw the root of the tree
edges.append(Edge(y=(y, y), x=(node.t[0], node.t[-1]), id=node.ID))
edges.append(Edge(y=(y, y), x=(node.t[0], node.t[-1]), id=node.ID, node=node))

if node.is_root:
annotations.append(Annotation(y=y, x=node.t[0], label=str(node.ID)))
Expand Down Expand Up @@ -108,14 +109,4 @@ def layout_tree(nodes: List[TreeNode]) -> Tuple[List[Edge], List[Annotation]]:
)
)

# now that we have traversed the tree, calculate the span
tree_span = []
for edge in edges:
tree_span.append(edge.y[0])
tree_span.append(edge.y[1])

# # work out the span of the tree, we can modify positioning here
# min_x = min(tree_span)
# max_x = max(tree_span)

return edges, annotations
12 changes: 12 additions & 0 deletions napari_arboretum/visualisation/base_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from ..tree import Annotation, Edge, layout_tree
from ..util import TrackPropertyMixin

# from ..profiler import profiler

GUI_MAXIMUM_WIDTH = 600

__all__ = ["TreePlotterBase", "TreePlotterQWidgetBase"]
Expand Down Expand Up @@ -46,6 +48,7 @@ def draw_tree(self) -> None:
subgraph_nodes = build_subgraph(self.tracks, self.track_id)
self.draw_from_nodes(subgraph_nodes, self.track_id)

# @profiler("draw_from_nodes")
def draw_from_nodes(
self, tree_nodes: List[TreeNode], track_id: Optional[int] = None
):
Expand All @@ -61,6 +64,8 @@ def draw_from_nodes(
for a in self.annotations:
self.add_annotation(a)

self.draw_tree_visual()

def update_edge_colors(self, update_live: bool = True) -> None:
"""
Update tree edge colours from the track properties.
Expand Down Expand Up @@ -117,6 +122,13 @@ def draw_current_time_line(self, time: int) -> None:
"""
raise NotImplementedError

@abc.abstractmethod
def draw_tree_visual(self) -> None:
"""
Function to draw the visual after construction.
"""
raise NotImplementedError


class TreePlotterQWidgetBase(TreePlotterBase):
"""
Expand Down
131 changes: 107 additions & 24 deletions napari_arboretum/visualisation/vispy_plotter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Optional

import numpy as np
from qtpy.QtWidgets import QWidget
Expand All @@ -11,6 +10,10 @@
__all__ = ["VisPyPlotter"]


DEFAULT_TEXT_SIZE = 8
DEFAULT_BRANCH_WIDTH = 3


@dataclass
class Bounds:
xmin: float
Expand All @@ -19,6 +22,31 @@ class Bounds:
ymax: float


@dataclass
class TrackSubvisualProxy:
pos: np.ndarray
color: np.ndarray = np.array([1.0, 1.0, 1.0, 1.0])

@property
def connex(self):
connex = [True] * (self.pos.shape[0] - 1) + [False]
return connex

@property
def safe_color(self) -> np.ndarray:
if self.color.ndim != 2:
safe_color = np.repeat([self.color], self.pos.shape[0], axis=0)
return safe_color
return self.color


@dataclass
class AnnotationSubvisualProxy:
pos: np.ndarray
text: str
color: str = "white"


class VisPyPlotter(TreePlotterQWidgetBase):
"""
Tree plotter using pyqtgraph as the plotting backend.
Expand Down Expand Up @@ -71,6 +99,14 @@ def autoscale_view(self) -> None:
width * (1 + 2 * padding),
height * (1 + 2 * padding),
)

# change the aspect ratio of the camera if we have just a single branch
# this will centre the camera on the single branch, otherwise, set the
# aspect ratio to match the data
if width == 0:
self.view.camera.aspect = 1.0
else:
self.view.camera.aspect = None
self.view.camera.rect = rect

def update_colors(self) -> None:
Expand All @@ -85,7 +121,8 @@ def add_branch(self, e: Edge) -> None:
"""
Add a single branch to the tree.
"""
self.tree.add_track(e.id, np.column_stack((e.y, e.x)), e.color)
# self.tree.add_track(e.id, np.column_stack((e.y, e.x)), e.color)
self.tree.add_track(e)
self.autoscale_view()

def add_annotation(self, a: Annotation) -> None:
Expand All @@ -104,6 +141,12 @@ def draw_current_time_line(self, time: int) -> None:
pos=np.array([[bounds.xmin - padding, time], [bounds.xmax + padding, time]])
)

def draw_tree_visual(self) -> None:
"""
Draw the whole tree.
"""
self.tree.draw_tree()


class TreeVisual(scene.visuals.Compound):
"""
Expand All @@ -116,7 +159,22 @@ def __init__(self, parent):
self.unfreeze()
# Keep a reference to tracks we add so their colour can be changed later
self.tracks = {}
self.subvisuals = []
self.edges = []
self.annotations = []

subvisuals = [
scene.visuals.Line(color="white", width=DEFAULT_BRANCH_WIDTH),
scene.visuals.Text(
anchor_x="left",
anchor_y="top",
rotation=90,
font_size=DEFAULT_TEXT_SIZE,
color="white",
),
]

for visual in subvisuals:
self.add_subvisual(visual)

def get_branch_color(self, branch_id: int) -> np.ndarray:
return self.tracks[branch_id].color
Expand All @@ -125,9 +183,12 @@ def set_branch_color(self, branch_id: int, color: np.ndarray) -> None:
"""
Set the color of an individual branch.
"""
self.tracks[branch_id].set_data(color=color)
self.tracks[branch_id].color = color
self._subvisuals[0].set_data(
color=np.row_stack([e.safe_color for e in self.edges]),
)

def add_track(self, id: Optional[int], pos: np.ndarray, color: np.ndarray) -> None:
def add_track(self, e: Edge) -> None:
"""
Parameters
----------
Expand All @@ -139,36 +200,58 @@ def add_track(self, id: Optional[int], pos: np.ndarray, color: np.ndarray) -> No
Array of shape (n, 4) specifying RGBA values in range [0, 1] along
the track.
"""
if id is None:
visual = scene.visuals.Line(pos=pos, color=color, width=3)
color = e.color
pos = np.column_stack((e.y, e.x))

if e.node is None:
subvisual_proxy = TrackSubvisualProxy(
pos=pos,
color=np.array([1.0, 1.0, 1.0, 1.0]),
)
else:
# Split up line into individual time steps so color can vary
# along the line
ys = np.arange(pos[0, 1], pos[1, 1] + 1)
ys = np.asarray(e.node.t) # np.arange(pos[0, 1], pos[1, 1] + 1)
xs = np.ones(ys.size) * pos[0, 0]
visual = scene.visuals.Line(
pos=np.column_stack((xs, ys)), color=color, width=3
subvisual_proxy = TrackSubvisualProxy(
pos=np.column_stack((xs, ys)),
color=color,
)
self.tracks[id] = visual
# store a reference to this subvisual proxy
self.tracks[e.id] = subvisual_proxy

self.add_subvisual(visual)
self.subvisuals.append(visual)
self.edges.append(subvisual_proxy)

def add_annotation(self, x: float, y: float, label: str, color):
visual = scene.visuals.Text(

subvisual_proxy = AnnotationSubvisualProxy(
text=label,
color=color,
pos=[y, x, 0],
anchor_x="left",
anchor_y="top",
font_size=6,
rotation=90,
)
self.add_subvisual(visual)
self.subvisuals.append(visual)

self.annotations.append(subvisual_proxy)

def clear(self) -> None:
"""Remove all tracks."""
while self.subvisuals:
subvisual = self.subvisuals.pop()
self.remove_subvisual(subvisual)
self.tracks = {}
self.edges = []
self.annotations = []

for visual in self._subvisuals:
visual._pos = None

if hasattr(visual, "_text"):
visual._text = None

def draw_tree(self) -> None:
"""Once the data is added, draw the tree."""

self._subvisuals[0].set_data(
pos=np.row_stack([e.pos for e in self.edges]),
color=np.row_stack([e.safe_color for e in self.edges]),
connect=np.concatenate([e.connex for e in self.edges]),
)

# TextVisual does not have a ``set_data`` method
self._subvisuals[1].pos = np.asarray([a.pos for a in self.annotations])
self._subvisuals[1].text = [a.text for a in self.annotations]

0 comments on commit f2eec96

Please sign in to comment.