Skip to content

Commit

Permalink
add functions for geodesic distances between pairs of nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
schlegelp committed Sep 16, 2024
1 parent 652a8af commit 0dcacd3
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 8 deletions.
123 changes: 122 additions & 1 deletion fastcore/src/dag.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use itertools::Itertools;
use ndarray::{s, Array, Array1, Array2, ArrayView1};
use ndarray::parallel::prelude::*;
use num::Float;
use std::collections::HashMap;
use std::collections::HashSet;
Expand Down Expand Up @@ -358,7 +359,7 @@ where
// We're basically brute forcing the "forward" (child -> parent) distances here
// In theory we could be a bit more efficient by using leaf nodes as seeds,
// tracking the distances as we go along and then filling the missing values.
// This requires a lot more book keeping though and I'm not (yet) sure that's worth it.
// This requires a lot more bookkeeping though and I'm not (yet) sure that's worth it.
for idx1 in 0..parents.len() {
node = idx1; // start with the distance between the node and itself
d = T::zero();
Expand Down Expand Up @@ -761,6 +762,126 @@ where
(source_dists, target_dists)
}


/// Compute geodesic distances between pairs of nodes.
///
///
/// Arguments:
///
/// - `parents`: array of parent indices
/// - `pair_source`: array of source indices
/// - `pair_target`: array of target indices
/// - `weights`: optional array of weights for each child -> parent connection
/// - `directed`: boolean indicating whether to return only the directed (child -> parent) distances
///
/// Returns:
///
/// A 1d array of f32/f64 values indicating the distances between the queried pairs.
///
pub fn geodesic_pairs(
parents: &ArrayView1<i32>,
pairs_source: &ArrayView1<i32>,
pairs_target: &ArrayView1<i32>,
weights: &Option<Array1<f32>>,
directed: bool,
) -> Array1<f32>
{
// Make sure we have even number of sources/targets
if pairs_source.len() != pairs_target.len() {
panic!("Length of sources and targets not matching!");
}

// Convert `pairs_source` to a vector for parallel processing
let pairs_source: Vec<i32> = pairs_source.iter().cloned().collect();
let pairs_target: Vec<i32> = pairs_target.iter().cloned().collect();

let dists: Vec<_> = pairs_source
.par_iter()
.zip(pairs_target.par_iter())
.map(|(idx1, idx2)| {
geodesic_distances_single_pair(parents, *idx1 as usize, *idx2 as usize, weights, directed)
})
.collect();

// Convert the vector to an array and return
Array::from(dists)
}

fn geodesic_distances_single_pair(
parents: &ArrayView1<i32>,
idx1: usize,
idx2: usize,
weights: &Option<Array1<f32>>,
directed: bool,
) -> f32
{
// Walk from idx1 to root node
let mut node: usize = idx1;
let mut d: f32 = 0.0;
let mut seen: Array1<f32> = Array::from_elem(parents.len(), -1.0);

loop {
// If come across the target node, return here
// (also happens if idx1 == idx2)
if node == idx2 {
return d;
};
seen[node] = d;

// Break if we hit the root node
if parents[node] < 0 {
// If we reached the root node without finding idx2 and we want only
// directed distances, then we return -1
if directed {
return -1.0;
};

// Now do the same for the target node
node = idx2;
d = 0.0;

loop {
// If this node has already been visited in when walking from idx1 to the root
// we can just sum up the distances.
// This also covers cases where `idx2` is upstream `idx1`!
if seen[node] > -1.0 {
return d + seen[node];
}

// Track distance travelled
d += if let Some(w) = weights {
w[node]
} else {
1.0
};

// If we hit the root node again, then idx1 and idx2 are on disconnected
// branches
if parents[node] < 0 {
break;
}

node = parents[node] as usize;

}

break;
}
// Track distance travelled
d += if let Some(w) = weights {
w[node]
} else {
1.0
};

node = parents[node] as usize;
}

// If we made it until here, then idx1 and idx2 are disconnected
return 1.0;

}

/// Calculate synapse flow centrality for each node.
///
/// This works by, for each connected component:
Expand Down
2 changes: 0 additions & 2 deletions fastcore/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@

pub mod nblast;

pub mod dag;

4 changes: 3 additions & 1 deletion py/docs/Trees/geodesic.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Geodesic Distances

::: navis_fastcore.geodesic_matrix
::: navis_fastcore.geodesic_matrix

::: navis_fastcore.geodesic_pairs
1 change: 1 addition & 0 deletions py/docs/Trees/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ implement anything related to graph traversal.
The Python bindings for `navis-fastcore` currently cover the following functions:

- [`fastcore.geodesic_matrix`](geodesic.md): calculate geodesic ("along-the-arbor") distances either between all pairs of nodes or between specific sources and targets
- [`fastcore.geodesic_pairs`](geodesic.md): calculate geodesic ("along-the-arbor") distances between given pairs of nodes
- [`fastcore.connected_components`](cc.md): generate the connected components
- [`fastcore.synapse_flow_centrality`](flow.md): calculate synapse flow centrality ([Schneider-Mizell, eLife, 2016](https://elifesciences.org/articles/12059))
- [`fastcore.generate_segments`](segments.md#generate-segments): break the neuron into linear segments
Expand Down
71 changes: 70 additions & 1 deletion py/navis_fastcore/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

__all__ = [
"geodesic_matrix",
"geodesic_pairs",
"connected_components",
"synapse_flow_centrality",
"generate_segments",
Expand Down Expand Up @@ -214,7 +215,7 @@ def geodesic_matrix(
Returns
-------
matrix : float32 (double) array
matrix : float32 (single) array
Geodesic distances. Unreachable nodes are set to -1.
Examples
Expand Down Expand Up @@ -268,6 +269,74 @@ def geodesic_matrix(
return dists


def geodesic_pairs(
node_ids,
parent_ids,
pairs,
directed=False,
weights=None,
):
"""Calculate geodesic ("along-the-arbor") distances between pairs of nodes.
This uses a pretty simple algorithm that calculates distances using brute
force. It's pretty fast because we parallelize the calculation of each pair
of nodes.
Parameters
----------
node_ids : (N, ) array
Array of node IDs.
parent_ids : (N, ) array
Array of parent IDs for each node. Root nodes' parents
must be -1.
pairs : (N, 2) array
Pairs of node IDs for which to calculate distances.
directed : bool, optional
If ``True`` will only return distances in the direction of
the child -> parent (i.e. towards the root) relationship.
weights : (N, ) float32 array, optional
Array of distances for each child -> parent connection.
If ``None`` all node to node distances are set to 1.
Returns
-------
matrix : float32 (single) array
Geodesic distances. Unreachable nodes are set to -1.
Examples
--------
>>> import navis_fastcore as fastcore
>>> import numpy as np
>>> node_ids = np.arange(7)
>>> parent_ids = np.array([-1, 0, 1, 2, 1, 4, 5])
>>> pairs = np.array([(0, 1), (0, 2)])
>>> fastcore.geodesic_pairs(node_ids, parent_ids, pairs)
"""
# Convert parent IDs into indices
parent_ix = _ids_to_indices(node_ids, parent_ids)

pairs = np.asarray(pairs)
assert pairs.ndim == 2 and pairs.shape[1] == 2, "`pairs` must be of shape (N, 2)"

if weights is not None:
weights = np.asarray(weights, dtype=np.float32, order="C")
assert len(weights) == len(
node_ids
), "`weights` must have the same length as `node_ids`"

# Calculate distances
dists = _fastcore.geodesic_pairs(
parent_ix,
pairs_source=_ids_to_indices(node_ids, pairs[:, 0]),
pairs_target=_ids_to_indices(node_ids, pairs[:, 1]),
weights=weights,
directed=directed,
)

return dists


def connected_components(node_ids, parent_ids):
"""Get the connected components for this neuron.
Expand Down
49 changes: 46 additions & 3 deletions py/src/dag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use pyo3::prelude::*;
use std::collections::HashMap;

use fastcore::dag::{
all_dists_to_root, break_segments, dist_to_root, generate_segments,
geodesic_distances_all_by_all, geodesic_distances_partial, synapse_flow_centrality,
connected_components, classify_nodes, strahler_index, prune_twigs
all_dists_to_root, break_segments, classify_nodes, connected_components, dist_to_root,
generate_segments, geodesic_distances_all_by_all, geodesic_distances_partial, geodesic_pairs,
prune_twigs, strahler_index, synapse_flow_centrality,
};

/// For each node ID in `parents` find its index in `nodes`.
Expand Down Expand Up @@ -240,6 +240,49 @@ pub fn geodesic_distances_py<'py>(
dists.into_pyarray(py)
}

/// Compute geodesic distances along the tree for pairs of nodes.
///
/// This function wrangles the Python arrays into Rust arrays and then calls the
/// appropriate geodesic distance function.
///
/// Arguments:
///
/// - `parents`: array of parent indices
/// - `sources`: array of source indices for pairs
/// - `targets`: array of target indices for pairs
/// - `weights`: optional array of weights for each node
/// - `directed`: boolean indicating whether to return only the directed (child -> parent) distances
///
/// Returns:
///
/// A 1D array of f32 values indicating the distances between the pairs of nodes.
///
#[pyfunction]
#[pyo3(name = "geodesic_pairs", signature = (parents, pairs_source, pairs_target, weights, directed=false))]
pub fn geodesic_pairs_py<'py>(
py: Python<'py>,
parents: PyReadonlyArray1<i32>,
pairs_source: PyReadonlyArray1<i32>,
pairs_target: PyReadonlyArray1<i32>,
weights: Option<PyReadonlyArray1<f32>>,
directed: bool,
) -> &'py PyArray1<f32> {
let weights: Option<Array1<f32>> = if weights.is_some() {
Some(weights.unwrap().as_array().to_owned())
} else {
None
};

let dists = geodesic_pairs(
&parents.as_array(),
&pairs_source.as_array(),
&pairs_target.as_array(),
&weights,
directed,
);
dists.into_pyarray(py)
}

/// Compute synapse flow centrality for each node.
///
/// Arguments:
Expand Down
1 change: 1 addition & 0 deletions py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ fn fastcore(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(dist_to_root_py, m)?)?;
m.add_function(wrap_pyfunction!(top_nn_py, m)?)?;
m.add_function(wrap_pyfunction!(geodesic_distances_py, m)?)?;
m.add_function(wrap_pyfunction!(geodesic_pairs_py, m)?)?;
m.add_function(wrap_pyfunction!(nblast_single_py, m)?)?;
m.add_function(wrap_pyfunction!(nblast_allbyall_py, m)?)?;
m.add_function(wrap_pyfunction!(synapse_flow_centrality_py, m)?)?;
Expand Down

0 comments on commit 0dcacd3

Please sign in to comment.