Skip to content

Commit

Permalink
prune_twigs: add optional mask parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
schlegelp committed Oct 31, 2024
1 parent 75106df commit 7a4a278
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 32 deletions.
5 changes: 3 additions & 2 deletions R/nat.fastcore/src/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ pub fn connected_components(
pub fn prune_twigs(
parents: Vec<i32>,
threshold: f64,
weights: Option<Vec<f64>>,
weights: Option<Vec<f64>>
) -> Vec<i32> {
let parents = Array1::from_vec(parents);

Expand All @@ -169,7 +169,8 @@ pub fn prune_twigs(
Some(Array1::from_vec(weights.unwrap()))
};

fastcore::dag::prune_twigs(&parents.view(), threshold as f32, &weights)
// Mask is currently not supported - strangely, extendr does not seem to support Vec<bool>
fastcore::dag::prune_twigs(&parents.view(), threshold as f32, &weights, &None)
}

// Macro to generate exports.
Expand Down
50 changes: 26 additions & 24 deletions fastcore/src/dag.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use itertools::Itertools;
use ndarray::{s, Array, Array1, Array2, ArrayView1};
use ndarray::parallel::prelude::*;
use ndarray::{s, Array, Array1, Array2, ArrayView1};
use num::Float;
use std::collections::HashMap;
use std::collections::HashSet;
Expand Down Expand Up @@ -129,11 +129,9 @@ fn find_roots(parents: &ArrayView1<i32>) -> Vec<i32> {
///
/// A vector of vectors where each vector contains the nodes of a segment.
///
pub fn generate_segments<T>(
parents: &ArrayView1<i32>,
weights: Option<Array1<T>>
) -> Vec<Vec<i32>>
where T: Float + AddAssign,
pub fn generate_segments<T>(parents: &ArrayView1<i32>, weights: Option<Array1<T>>) -> Vec<Vec<i32>>
where
T: Float + AddAssign,
{
let mut all_segments: Vec<Vec<i32>> = vec![];
let mut current_segment = Array::from_elem(parents.len(), -1i32);
Expand Down Expand Up @@ -762,7 +760,6 @@ where
(source_dists, target_dists)
}


/// Compute geodesic distances between pairs of nodes.
///
///
Expand All @@ -784,8 +781,7 @@ pub fn geodesic_pairs(
pairs_target: &ArrayView1<i32>,
weights: &Option<Array1<f32>>,
directed: bool,
) -> Array1<f32>
{
) -> Array1<f32> {
// Make sure we have even number of sources/targets
if pairs_source.len() != pairs_target.len() {
panic!("Length of sources and targets not matching!");
Expand All @@ -799,7 +795,13 @@ pub fn geodesic_pairs(
.par_iter()
.zip(pairs_target.par_iter())
.map(|(idx1, idx2)| {
geodesic_distances_single_pair(parents, *idx1 as usize, *idx2 as usize, weights, directed)
geodesic_distances_single_pair(
parents,
*idx1 as usize,
*idx2 as usize,
weights,
directed,
)
})
.collect();

Expand All @@ -813,8 +815,7 @@ fn geodesic_distances_single_pair(
idx2: usize,
weights: &Option<Array1<f32>>,
directed: bool,
) -> f32
{
) -> f32 {
// Walk from idx1 to root node
let mut node: usize = idx1;
let mut d: f32 = 0.0;
Expand Down Expand Up @@ -849,11 +850,7 @@ fn geodesic_distances_single_pair(
}

// Track distance travelled
d += if let Some(w) = weights {
w[node]
} else {
1.0
};
d += if let Some(w) = weights { w[node] } else { 1.0 };

// If we hit the root node again, then idx1 and idx2 are on disconnected
// branches
Expand All @@ -862,24 +859,18 @@ fn geodesic_distances_single_pair(
}

node = parents[node] as usize;

}

break;
}
// Track distance travelled
d += if let Some(w) = weights {
w[node]
} else {
1.0
};
d += if let Some(w) = weights { w[node] } else { 1.0 };

node = parents[node] as usize;
}

// If we made it until here, then idx1 and idx2 are disconnected
return 1.0;

}

/// Calculate synapse flow centrality for each node.
Expand Down Expand Up @@ -1094,6 +1085,7 @@ pub fn prune_twigs<T>(
parents: &ArrayView1<i32>,
threshold: f32,
weights: &Option<Array1<T>>,
mask: &Option<Array1<bool>>,
) -> Vec<i32>
where
T: Float + AddAssign,
Expand All @@ -1114,6 +1106,11 @@ where
continue;
}

// Skip leaf nodes that are not in the mask
if mask.is_some() && !mask.as_ref().unwrap()[node as usize] {
continue;
}

// Reset distance and twig
d = T::zero();
twig.clear();
Expand All @@ -1123,6 +1120,11 @@ where
break;
}

// Stop if this nodes is masked out
if mask.is_some() && !mask.as_ref().unwrap()[node as usize] {
break;
}

// Stop if this node has more than one child (i.e. it's a branch point)
if n_children[node as usize] > 1 {
break;
Expand Down
19 changes: 16 additions & 3 deletions py/navis_fastcore/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def _ids_to_indices(node_ids, to_map):
raise ValueError("IDs must be int32 or int64")


def prune_twigs(node_ids, parent_ids, threshold, weights=None):
def prune_twigs(node_ids, parent_ids, threshold, weights=None, mask=None):
"""Prune twigs shorter than a given threshold.
Parameters
Expand All @@ -540,6 +540,10 @@ def prune_twigs(node_ids, parent_ids, threshold, weights=None):
weights : (N, ) float32 array, optional
Array of distances for each child -> parent connection.
If ``None`` all node-to-node distances are set to 1.
mask : (N, ) bool array, optional
Array of booleans to mask nodes that should not be pruned.
Importantly, twigs with _any_ masked node will not be pruned.
Returns
-------
Expand All @@ -554,6 +558,8 @@ def prune_twigs(node_ids, parent_ids, threshold, weights=None):
>>> parent_ids = np.array([-1, 0, 1, 2, 1, 4, 5])
>>> fastcore.prune_twigs(node_ids, parent_ids, 2)
array([0, 1, 4, 5, 6])
>>> mask = np.array([True, True, True, False, True, True, True])
>>> fastcore.prune_twigs(node_ids, parent_ids, 2, mask=mask)
"""
# Convert parent IDs into indices
Expand All @@ -566,8 +572,15 @@ def prune_twigs(node_ids, parent_ids, threshold, weights=None):
node_ids
), "`weights` must have the same length as `node_ids`"

# Get the segments (this will be a list of arrays of node indices)
keep_idx = _fastcore.prune_twigs(parent_ix, threshold, weights=weights)
# Make sure mask is boolean
if mask is not None:
mask = np.asarray(mask, dtype=bool, order="C")
assert len(mask) == len(
node_ids
), "`mask` must have the same length as `node_ids"

# Get the nodes to keep
keep_idx = _fastcore.prune_twigs(parent_ix, threshold, weights=weights, mask=mask)

# Map node indices back to IDs
return node_ids[keep_idx]
Expand Down
9 changes: 8 additions & 1 deletion py/src/dag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,14 +345,21 @@ pub fn prune_twigs_py(
parents: PyReadonlyArray1<i32>,
threshold: f32,
weights: Option<PyReadonlyArray1<f32>>,
mask: Option<PyReadonlyArray1<bool>>,
) -> Vec<i32> {
let weights: Option<Array1<f32>> = if weights.is_some() {
Some(weights.unwrap().as_array().to_owned())
} else {
None
};
let mask: Option<Array1<bool>> = if mask.is_some() {
Some(mask.unwrap().as_array().to_owned())
} else {
None
};


prune_twigs(&parents.as_array(), threshold, &weights)
prune_twigs(&parents.as_array(), threshold, &weights, &mask)
}

/// Calculate Strahler Index.
Expand Down
7 changes: 5 additions & 2 deletions py/tests/test_fastcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,13 @@ def test_connected_components(swc):
@pytest.mark.parametrize("swc", [swc32(), swc64()])
@pytest.mark.parametrize("threshold", [5, 10])
@pytest.mark.parametrize("weights", [None, np.random.rand(N_NODES)])
def test_prune_twigs(swc, threshold, weights):
@pytest.mark.parametrize("mask", [None, (np.random.rand(N_NODES) > 0.5).astype(bool)])
def test_prune_twigs(swc, threshold, weights, mask):
nodes, parents, _ = swc
start = time.time()
pruned = fastcore.prune_twigs(nodes, parents, threshold=threshold, weights=weights)
pruned = fastcore.prune_twigs(
nodes, parents, threshold=threshold, weights=weights, mask=mask
)
dur = time.time() - start

print("Pruned nodes:", pruned)
Expand Down

0 comments on commit 7a4a278

Please sign in to comment.