From 0dcacd34f486cf53fd95f67d4b8ddf7dabb56075 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Mon, 16 Sep 2024 16:20:47 +0100 Subject: [PATCH] add functions for geodesic distances between pairs of nodes --- fastcore/src/dag.rs | 123 +++++++++++++++++++++++++++++++++++++- fastcore/src/lib.rs | 2 - py/docs/Trees/geodesic.md | 4 +- py/docs/Trees/index.md | 1 + py/navis_fastcore/dag.py | 71 +++++++++++++++++++++- py/src/dag.rs | 49 ++++++++++++++- py/src/lib.rs | 1 + 7 files changed, 243 insertions(+), 8 deletions(-) diff --git a/fastcore/src/dag.rs b/fastcore/src/dag.rs index 09df152..62f88ba 100644 --- a/fastcore/src/dag.rs +++ b/fastcore/src/dag.rs @@ -1,5 +1,6 @@ use itertools::Itertools; use ndarray::{s, Array, Array1, Array2, ArrayView1}; +use ndarray::parallel::prelude::*; use num::Float; use std::collections::HashMap; use std::collections::HashSet; @@ -358,7 +359,7 @@ where // We're basically brute forcing the "forward" (child -> parent) distances here // In theory we could be a bit more efficient by using leaf nodes as seeds, // tracking the distances as we go along and then filling the missing values. - // This requires a lot more book keeping though and I'm not (yet) sure that's worth it. + // This requires a lot more bookkeeping though and I'm not (yet) sure that's worth it. for idx1 in 0..parents.len() { node = idx1; // start with the distance between the node and itself d = T::zero(); @@ -761,6 +762,126 @@ where (source_dists, target_dists) } + +/// Compute geodesic distances between pairs of nodes. +/// +/// +/// Arguments: +/// +/// - `parents`: array of parent indices +/// - `pair_source`: array of source indices +/// - `pair_target`: array of target indices +/// - `weights`: optional array of weights for each child -> parent connection +/// - `directed`: boolean indicating whether to return only the directed (child -> parent) distances +/// +/// Returns: +/// +/// A 1d array of f32/f64 values indicating the distances between the queried pairs. +/// +pub fn geodesic_pairs( + parents: &ArrayView1, + pairs_source: &ArrayView1, + pairs_target: &ArrayView1, + weights: &Option>, + directed: bool, +) -> Array1 +{ + // Make sure we have even number of sources/targets + if pairs_source.len() != pairs_target.len() { + panic!("Length of sources and targets not matching!"); + } + + // Convert `pairs_source` to a vector for parallel processing + let pairs_source: Vec = pairs_source.iter().cloned().collect(); + let pairs_target: Vec = pairs_target.iter().cloned().collect(); + + let dists: Vec<_> = pairs_source + .par_iter() + .zip(pairs_target.par_iter()) + .map(|(idx1, idx2)| { + geodesic_distances_single_pair(parents, *idx1 as usize, *idx2 as usize, weights, directed) + }) + .collect(); + + // Convert the vector to an array and return + Array::from(dists) +} + +fn geodesic_distances_single_pair( + parents: &ArrayView1, + idx1: usize, + idx2: usize, + weights: &Option>, + directed: bool, +) -> f32 +{ + // Walk from idx1 to root node + let mut node: usize = idx1; + let mut d: f32 = 0.0; + let mut seen: Array1 = Array::from_elem(parents.len(), -1.0); + + loop { + // If come across the target node, return here + // (also happens if idx1 == idx2) + if node == idx2 { + return d; + }; + seen[node] = d; + + // Break if we hit the root node + if parents[node] < 0 { + // If we reached the root node without finding idx2 and we want only + // directed distances, then we return -1 + if directed { + return -1.0; + }; + + // Now do the same for the target node + node = idx2; + d = 0.0; + + loop { + // If this node has already been visited in when walking from idx1 to the root + // we can just sum up the distances. + // This also covers cases where `idx2` is upstream `idx1`! + if seen[node] > -1.0 { + return d + seen[node]; + } + + // Track distance travelled + d += if let Some(w) = weights { + w[node] + } else { + 1.0 + }; + + // If we hit the root node again, then idx1 and idx2 are on disconnected + // branches + if parents[node] < 0 { + break; + } + + node = parents[node] as usize; + + } + + break; + } + // Track distance travelled + d += if let Some(w) = weights { + w[node] + } else { + 1.0 + }; + + node = parents[node] as usize; + } + + // If we made it until here, then idx1 and idx2 are disconnected + return 1.0; + +} + /// Calculate synapse flow centrality for each node. /// /// This works by, for each connected component: diff --git a/fastcore/src/lib.rs b/fastcore/src/lib.rs index 089b737..a354af6 100644 --- a/fastcore/src/lib.rs +++ b/fastcore/src/lib.rs @@ -1,5 +1,3 @@ - pub mod nblast; pub mod dag; - diff --git a/py/docs/Trees/geodesic.md b/py/docs/Trees/geodesic.md index fff1d8e..4961771 100644 --- a/py/docs/Trees/geodesic.md +++ b/py/docs/Trees/geodesic.md @@ -1,3 +1,5 @@ # Geodesic Distances -::: navis_fastcore.geodesic_matrix \ No newline at end of file +::: navis_fastcore.geodesic_matrix + +::: navis_fastcore.geodesic_pairs diff --git a/py/docs/Trees/index.md b/py/docs/Trees/index.md index 6e2980e..547cdbc 100644 --- a/py/docs/Trees/index.md +++ b/py/docs/Trees/index.md @@ -29,6 +29,7 @@ implement anything related to graph traversal. The Python bindings for `navis-fastcore` currently cover the following functions: - [`fastcore.geodesic_matrix`](geodesic.md): calculate geodesic ("along-the-arbor") distances either between all pairs of nodes or between specific sources and targets +- [`fastcore.geodesic_pairs`](geodesic.md): calculate geodesic ("along-the-arbor") distances between given pairs of nodes - [`fastcore.connected_components`](cc.md): generate the connected components - [`fastcore.synapse_flow_centrality`](flow.md): calculate synapse flow centrality ([Schneider-Mizell, eLife, 2016](https://elifesciences.org/articles/12059)) - [`fastcore.generate_segments`](segments.md#generate-segments): break the neuron into linear segments diff --git a/py/navis_fastcore/dag.py b/py/navis_fastcore/dag.py index 2b41142..347e895 100644 --- a/py/navis_fastcore/dag.py +++ b/py/navis_fastcore/dag.py @@ -4,6 +4,7 @@ __all__ = [ "geodesic_matrix", + "geodesic_pairs", "connected_components", "synapse_flow_centrality", "generate_segments", @@ -214,7 +215,7 @@ def geodesic_matrix( Returns ------- - matrix : float32 (double) array + matrix : float32 (single) array Geodesic distances. Unreachable nodes are set to -1. Examples @@ -268,6 +269,74 @@ def geodesic_matrix( return dists +def geodesic_pairs( + node_ids, + parent_ids, + pairs, + directed=False, + weights=None, +): + """Calculate geodesic ("along-the-arbor") distances between pairs of nodes. + + This uses a pretty simple algorithm that calculates distances using brute + force. It's pretty fast because we parallelize the calculation of each pair + of nodes. + + Parameters + ---------- + node_ids : (N, ) array + Array of node IDs. + parent_ids : (N, ) array + Array of parent IDs for each node. Root nodes' parents + must be -1. + pairs : (N, 2) array + Pairs of node IDs for which to calculate distances. + directed : bool, optional + If ``True`` will only return distances in the direction of + the child -> parent (i.e. towards the root) relationship. + weights : (N, ) float32 array, optional + Array of distances for each child -> parent connection. + If ``None`` all node to node distances are set to 1. + + Returns + ------- + matrix : float32 (single) array + Geodesic distances. Unreachable nodes are set to -1. + + Examples + -------- + >>> import navis_fastcore as fastcore + >>> import numpy as np + >>> node_ids = np.arange(7) + >>> parent_ids = np.array([-1, 0, 1, 2, 1, 4, 5]) + >>> pairs = np.array([(0, 1), (0, 2)]) + >>> fastcore.geodesic_pairs(node_ids, parent_ids, pairs) + + """ + # Convert parent IDs into indices + parent_ix = _ids_to_indices(node_ids, parent_ids) + + pairs = np.asarray(pairs) + assert pairs.ndim == 2 and pairs.shape[1] == 2, "`pairs` must be of shape (N, 2)" + + if weights is not None: + weights = np.asarray(weights, dtype=np.float32, order="C") + assert len(weights) == len( + node_ids + ), "`weights` must have the same length as `node_ids`" + + # Calculate distances + dists = _fastcore.geodesic_pairs( + parent_ix, + pairs_source=_ids_to_indices(node_ids, pairs[:, 0]), + pairs_target=_ids_to_indices(node_ids, pairs[:, 1]), + weights=weights, + directed=directed, + ) + + return dists + + def connected_components(node_ids, parent_ids): """Get the connected components for this neuron. diff --git a/py/src/dag.rs b/py/src/dag.rs index 8e4ce3a..e0ec959 100644 --- a/py/src/dag.rs +++ b/py/src/dag.rs @@ -5,9 +5,9 @@ use pyo3::prelude::*; use std::collections::HashMap; use fastcore::dag::{ - all_dists_to_root, break_segments, dist_to_root, generate_segments, - geodesic_distances_all_by_all, geodesic_distances_partial, synapse_flow_centrality, - connected_components, classify_nodes, strahler_index, prune_twigs + all_dists_to_root, break_segments, classify_nodes, connected_components, dist_to_root, + generate_segments, geodesic_distances_all_by_all, geodesic_distances_partial, geodesic_pairs, + prune_twigs, strahler_index, synapse_flow_centrality, }; /// For each node ID in `parents` find its index in `nodes`. @@ -240,6 +240,49 @@ pub fn geodesic_distances_py<'py>( dists.into_pyarray(py) } +/// Compute geodesic distances along the tree for pairs of nodes. +/// +/// This function wrangles the Python arrays into Rust arrays and then calls the +/// appropriate geodesic distance function. +/// +/// Arguments: +/// +/// - `parents`: array of parent indices +/// - `sources`: array of source indices for pairs +/// - `targets`: array of target indices for pairs +/// - `weights`: optional array of weights for each node +/// - `directed`: boolean indicating whether to return only the directed (child -> parent) distances +/// +/// Returns: +/// +/// A 1D array of f32 values indicating the distances between the pairs of nodes. +/// +#[pyfunction] +#[pyo3(name = "geodesic_pairs", signature = (parents, pairs_source, pairs_target, weights, directed=false))] +pub fn geodesic_pairs_py<'py>( + py: Python<'py>, + parents: PyReadonlyArray1, + pairs_source: PyReadonlyArray1, + pairs_target: PyReadonlyArray1, + weights: Option>, + directed: bool, +) -> &'py PyArray1 { + let weights: Option> = if weights.is_some() { + Some(weights.unwrap().as_array().to_owned()) + } else { + None + }; + + let dists = geodesic_pairs( + &parents.as_array(), + &pairs_source.as_array(), + &pairs_target.as_array(), + &weights, + directed, + ); + dists.into_pyarray(py) +} + /// Compute synapse flow centrality for each node. /// /// Arguments: diff --git a/py/src/lib.rs b/py/src/lib.rs index 40fcb32..de55b1b 100644 --- a/py/src/lib.rs +++ b/py/src/lib.rs @@ -18,6 +18,7 @@ fn fastcore(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(dist_to_root_py, m)?)?; m.add_function(wrap_pyfunction!(top_nn_py, m)?)?; m.add_function(wrap_pyfunction!(geodesic_distances_py, m)?)?; + m.add_function(wrap_pyfunction!(geodesic_pairs_py, m)?)?; m.add_function(wrap_pyfunction!(nblast_single_py, m)?)?; m.add_function(wrap_pyfunction!(nblast_allbyall_py, m)?)?; m.add_function(wrap_pyfunction!(synapse_flow_centrality_py, m)?)?;