diff --git a/model.py b/model.py index b5b5a5f..08f1505 100644 --- a/model.py +++ b/model.py @@ -436,7 +436,7 @@ def edges_df(self): ) @cached_property - @disk_cache("v1") + @disk_cache("v2") def nodes_df(self): ts = self.ts child_left, child_right = self.child_bounds( @@ -449,6 +449,10 @@ def nodes_df(self): "time": ts.nodes_time, "num_mutations": self.nodes_num_mutations, "ancestors_span": child_right - child_left, + "child_left": child_left, # FIXME add test for this + "child_right": child_right, # FIXME add test for this + "child_left": child_left, # FIXME add test for this + "child_right": child_right, # FIXME add test for this "is_sample": is_sample, } ) @@ -458,6 +462,8 @@ def nodes_df(self): "time": "float64", "num_mutations": "int", "ancestors_span": "float64", + "child_left": "float64", + "child_right": "float64", "is_sample": "bool", } ) @@ -584,3 +590,62 @@ def calc_mutations_per_tree(self): mutations_per_tree = np.zeros(self.ts.num_trees, dtype=np.int64) mutations_per_tree[unique_values] = counts return mutations_per_tree + + def compute_ancestor_spans_heatmap_data(self, num_x_bins, num_y_bins): + """ + Calculates the average ancestor span in a genomic-time window + """ + if self.ts.time_units == tskit.TIME_UNITS_UNCALIBRATED: + logger.warning( + "Cannot compute ancestor spans for uncalibrated tree sequence" + ) + return pd.DataFrame( + { + "position": [], + "time": [], + "overlapping_node_count_log10": [], + "overlapping_node_count": [], + } + ) + else: + nodes_df = self.nodes_df[self.nodes_df.ancestors_span != -np.inf] + nodes_df = nodes_df.reset_index(drop=True) + nodes_left = nodes_df.child_left + nodes_right = nodes_df.child_right + nodes_time = nodes_df.time + + x_bins = np.linspace(nodes_left.min(), nodes_right.max(), num_x_bins + 1) + y_bins = np.linspace(0, nodes_time.max(), num_y_bins + 1) + heatmap_counts = np.zeros((num_x_bins, num_y_bins)) + + x_starts = np.digitize(nodes_left, x_bins, right=True) + x_ends = np.digitize(nodes_right, x_bins, right=True) + y_starts = np.digitize(nodes_time, y_bins, right=True) + + for u in range(len(nodes_left)): + x_start = max(0, x_starts[u] - 1) + x_end = max(0, x_ends[u] - 1) + y_bin = max(0, y_starts[u] - 1) + heatmap_counts[x_start : x_end + 1, y_bin] += 1 + + x_coords = np.repeat(x_bins[:-1], num_y_bins) + y_coords = np.tile(y_bins[:-1], num_x_bins) + overlapping_node_count = heatmap_counts.flatten() + overlapping_node_count[overlapping_node_count == 0] = 1 + # FIXME - better way to avoid log 0 above? + df = pd.DataFrame( + { + "position": x_coords.flatten(), + "time": y_coords.flatten(), + "overlapping_node_count_log10": np.log10(overlapping_node_count), + "overlapping_node_count": overlapping_node_count, + } + ) + return df.astype( + { + "position": "int", + "time": "int", + "overlapping_node_count_log10": "int", + "overlapping_node_count": "int", + } + ) diff --git a/pages/nodes.py b/pages/nodes.py index 81954cd..8e37e63 100644 --- a/pages/nodes.py +++ b/pages/nodes.py @@ -3,6 +3,7 @@ import hvplot.pandas # noqa import numpy as np import panel as pn +from bokeh.models import HoverTool import config from plot_helpers import filter_points @@ -40,8 +41,15 @@ def make_node_hist_panel(tsm, log_y): points = df_nodes.hvplot.scatter( x="ancestors_span", y="time", - hover_cols=["ancestors_span", "time"], - ).opts(width=config.PLOT_WIDTH, height=config.PLOT_HEIGHT) + hover_cols=["ancestors_span", "time"], # add node ID + ).opts( + width=config.PLOT_WIDTH, + height=config.PLOT_HEIGHT, + title="Node span by time", + xlabel="width of genome spanned by node ancestors", + ylabel="node time", + axiswise=True, + ) range_stream = hv.streams.RangeXY(source=points) streams = [range_stream] @@ -54,7 +62,48 @@ def make_node_hist_panel(tsm, log_y): ) plot_options = pn.Column( - pn.pane.Markdown("# Plot Options"), log_y_checkbox, ) - return pn.Column(main, hist_panel, plot_options) + + def make_heatmap(num_x_bins, num_y_bins): + anc_span_data = tsm.compute_ancestor_spans_heatmap_data(num_x_bins, num_y_bins) + tooltips = [ + ("position", "@position"), + ("time", "@time"), + ("overlapping_nodes", "@overlapping_node_count"), + ] + hover = HoverTool(tooltips=tooltips) + heatmap = hv.HeatMap(anc_span_data).opts( + width=config.PLOT_WIDTH, + height=config.PLOT_HEIGHT, + tools=[hover], + colorbar=True, + title="Average ancestor length in time and genome bins", + axiswise=True, + ) + return heatmap + + max_x_bins = int(np.sqrt(df_nodes.child_right.max())) + x_bin_input = pn.widgets.IntInput( + name="genome bins", + value=min(50, max_x_bins), + start=1, + end=max_x_bins, + ) + max_y_bins = int(np.sqrt(df_nodes.time.max())) + y_bin_input = pn.widgets.IntInput( + name="time bins", value=min(50, int(max_y_bins)), start=1, end=max_y_bins + ) + hm_options = pn.Column(x_bin_input, y_bin_input) + + hm_panel = pn.bind( + make_heatmap, + num_x_bins=x_bin_input, + num_y_bins=y_bin_input, + ) + + return pn.Column( + pn.Column(main), + pn.Column(hist_panel, plot_options), + pn.Column(hm_panel, hm_options), + ) diff --git a/tests/test_data_model.py b/tests/test_data_model.py index a88a211..61d6fce 100644 --- a/tests/test_data_model.py +++ b/tests/test_data_model.py @@ -168,6 +168,8 @@ def test_single_tree_example(self): nt.assert_array_equal(df.time, [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 2.0]) nt.assert_array_equal(df.num_mutations, [1, 1, 1, 1, 1, 1, 1]) nt.assert_array_equal(df.ancestors_span, [10, 10, 10, 10, 10, 10, -np.inf]) + nt.assert_array_equal(df.child_left, [0, 0, 0, 0, 0, 0, np.inf]) + nt.assert_array_equal(df.child_right, [10, 10, 10, 10, 10, 10, 0]) nt.assert_array_equal(df.is_sample, [1, 1, 1, 1, 0, 0, 0]) def test_multiple_tree_example(self): @@ -178,6 +180,8 @@ def test_multiple_tree_example(self): nt.assert_array_equal(df.time, [0.0, 0.0, 0.0, 1.0, 2.0]) nt.assert_array_equal(df.num_mutations, [0, 0, 0, 0, 0]) nt.assert_array_equal(df.ancestors_span, [10, 10, 10, 10, -np.inf]) + nt.assert_array_equal(df.child_left, [0, 0, 0, 0, np.inf]) + nt.assert_array_equal(df.child_right, [10, 10, 10, 10, 0]) nt.assert_array_equal(df.is_sample, [1, 1, 1, 0, 0])