From 7a4a278257348bb8317a850f4f8314942a02c0e4 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Thu, 31 Oct 2024 12:38:34 +0000 Subject: [PATCH] prune_twigs: add optional `mask` parameter --- R/nat.fastcore/src/rust/src/lib.rs | 5 +-- fastcore/src/dag.rs | 50 ++++++++++++++++-------------- py/navis_fastcore/dag.py | 19 ++++++++++-- py/src/dag.rs | 9 +++++- py/tests/test_fastcore.py | 7 +++-- 5 files changed, 58 insertions(+), 32 deletions(-) diff --git a/R/nat.fastcore/src/rust/src/lib.rs b/R/nat.fastcore/src/rust/src/lib.rs index dae047f..8eb6467 100644 --- a/R/nat.fastcore/src/rust/src/lib.rs +++ b/R/nat.fastcore/src/rust/src/lib.rs @@ -159,7 +159,7 @@ pub fn connected_components( pub fn prune_twigs( parents: Vec, threshold: f64, - weights: Option>, + weights: Option> ) -> Vec { let parents = Array1::from_vec(parents); @@ -169,7 +169,8 @@ pub fn prune_twigs( Some(Array1::from_vec(weights.unwrap())) }; - fastcore::dag::prune_twigs(&parents.view(), threshold as f32, &weights) + // Mask is currently not supported - strangely, extendr does not seem to support Vec + fastcore::dag::prune_twigs(&parents.view(), threshold as f32, &weights, &None) } // Macro to generate exports. diff --git a/fastcore/src/dag.rs b/fastcore/src/dag.rs index 62f88ba..07d9914 100644 --- a/fastcore/src/dag.rs +++ b/fastcore/src/dag.rs @@ -1,6 +1,6 @@ use itertools::Itertools; -use ndarray::{s, Array, Array1, Array2, ArrayView1}; use ndarray::parallel::prelude::*; +use ndarray::{s, Array, Array1, Array2, ArrayView1}; use num::Float; use std::collections::HashMap; use std::collections::HashSet; @@ -129,11 +129,9 @@ fn find_roots(parents: &ArrayView1) -> Vec { /// /// A vector of vectors where each vector contains the nodes of a segment. /// -pub fn generate_segments( - parents: &ArrayView1, - weights: Option> -) -> Vec> -where T: Float + AddAssign, +pub fn generate_segments(parents: &ArrayView1, weights: Option>) -> Vec> +where + T: Float + AddAssign, { let mut all_segments: Vec> = vec![]; let mut current_segment = Array::from_elem(parents.len(), -1i32); @@ -762,7 +760,6 @@ where (source_dists, target_dists) } - /// Compute geodesic distances between pairs of nodes. /// /// @@ -784,8 +781,7 @@ pub fn geodesic_pairs( pairs_target: &ArrayView1, weights: &Option>, directed: bool, -) -> Array1 -{ +) -> Array1 { // Make sure we have even number of sources/targets if pairs_source.len() != pairs_target.len() { panic!("Length of sources and targets not matching!"); @@ -799,7 +795,13 @@ pub fn geodesic_pairs( .par_iter() .zip(pairs_target.par_iter()) .map(|(idx1, idx2)| { - geodesic_distances_single_pair(parents, *idx1 as usize, *idx2 as usize, weights, directed) + geodesic_distances_single_pair( + parents, + *idx1 as usize, + *idx2 as usize, + weights, + directed, + ) }) .collect(); @@ -813,8 +815,7 @@ fn geodesic_distances_single_pair( idx2: usize, weights: &Option>, directed: bool, -) -> f32 -{ +) -> f32 { // Walk from idx1 to root node let mut node: usize = idx1; let mut d: f32 = 0.0; @@ -849,11 +850,7 @@ fn geodesic_distances_single_pair( } // Track distance travelled - d += if let Some(w) = weights { - w[node] - } else { - 1.0 - }; + d += if let Some(w) = weights { w[node] } else { 1.0 }; // If we hit the root node again, then idx1 and idx2 are on disconnected // branches @@ -862,24 +859,18 @@ fn geodesic_distances_single_pair( } node = parents[node] as usize; - } break; } // Track distance travelled - d += if let Some(w) = weights { - w[node] - } else { - 1.0 - }; + d += if let Some(w) = weights { w[node] } else { 1.0 }; node = parents[node] as usize; } // If we made it until here, then idx1 and idx2 are disconnected return 1.0; - } /// Calculate synapse flow centrality for each node. @@ -1094,6 +1085,7 @@ pub fn prune_twigs( parents: &ArrayView1, threshold: f32, weights: &Option>, + mask: &Option>, ) -> Vec where T: Float + AddAssign, @@ -1114,6 +1106,11 @@ where continue; } + // Skip leaf nodes that are not in the mask + if mask.is_some() && !mask.as_ref().unwrap()[node as usize] { + continue; + } + // Reset distance and twig d = T::zero(); twig.clear(); @@ -1123,6 +1120,11 @@ where break; } + // Stop if this nodes is masked out + if mask.is_some() && !mask.as_ref().unwrap()[node as usize] { + break; + } + // Stop if this node has more than one child (i.e. it's a branch point) if n_children[node as usize] > 1 { break; diff --git a/py/navis_fastcore/dag.py b/py/navis_fastcore/dag.py index 1f4a27c..44c84b2 100644 --- a/py/navis_fastcore/dag.py +++ b/py/navis_fastcore/dag.py @@ -525,7 +525,7 @@ def _ids_to_indices(node_ids, to_map): raise ValueError("IDs must be int32 or int64") -def prune_twigs(node_ids, parent_ids, threshold, weights=None): +def prune_twigs(node_ids, parent_ids, threshold, weights=None, mask=None): """Prune twigs shorter than a given threshold. Parameters @@ -540,6 +540,10 @@ def prune_twigs(node_ids, parent_ids, threshold, weights=None): weights : (N, ) float32 array, optional Array of distances for each child -> parent connection. If ``None`` all node-to-node distances are set to 1. + mask : (N, ) bool array, optional + Array of booleans to mask nodes that should not be pruned. + Importantly, twigs with _any_ masked node will not be pruned. + Returns ------- @@ -554,6 +558,8 @@ def prune_twigs(node_ids, parent_ids, threshold, weights=None): >>> parent_ids = np.array([-1, 0, 1, 2, 1, 4, 5]) >>> fastcore.prune_twigs(node_ids, parent_ids, 2) array([0, 1, 4, 5, 6]) + >>> mask = np.array([True, True, True, False, True, True, True]) + >>> fastcore.prune_twigs(node_ids, parent_ids, 2, mask=mask) """ # Convert parent IDs into indices @@ -566,8 +572,15 @@ def prune_twigs(node_ids, parent_ids, threshold, weights=None): node_ids ), "`weights` must have the same length as `node_ids`" - # Get the segments (this will be a list of arrays of node indices) - keep_idx = _fastcore.prune_twigs(parent_ix, threshold, weights=weights) + # Make sure mask is boolean + if mask is not None: + mask = np.asarray(mask, dtype=bool, order="C") + assert len(mask) == len( + node_ids + ), "`mask` must have the same length as `node_ids" + + # Get the nodes to keep + keep_idx = _fastcore.prune_twigs(parent_ix, threshold, weights=weights, mask=mask) # Map node indices back to IDs return node_ids[keep_idx] diff --git a/py/src/dag.rs b/py/src/dag.rs index e0ec959..ec54239 100644 --- a/py/src/dag.rs +++ b/py/src/dag.rs @@ -345,14 +345,21 @@ pub fn prune_twigs_py( parents: PyReadonlyArray1, threshold: f32, weights: Option>, + mask: Option>, ) -> Vec { let weights: Option> = if weights.is_some() { Some(weights.unwrap().as_array().to_owned()) } else { None }; + let mask: Option> = if mask.is_some() { + Some(mask.unwrap().as_array().to_owned()) + } else { + None + }; + - prune_twigs(&parents.as_array(), threshold, &weights) + prune_twigs(&parents.as_array(), threshold, &weights, &mask) } /// Calculate Strahler Index. diff --git a/py/tests/test_fastcore.py b/py/tests/test_fastcore.py index e485312..f217b35 100644 --- a/py/tests/test_fastcore.py +++ b/py/tests/test_fastcore.py @@ -214,10 +214,13 @@ def test_connected_components(swc): @pytest.mark.parametrize("swc", [swc32(), swc64()]) @pytest.mark.parametrize("threshold", [5, 10]) @pytest.mark.parametrize("weights", [None, np.random.rand(N_NODES)]) -def test_prune_twigs(swc, threshold, weights): +@pytest.mark.parametrize("mask", [None, (np.random.rand(N_NODES) > 0.5).astype(bool)]) +def test_prune_twigs(swc, threshold, weights, mask): nodes, parents, _ = swc start = time.time() - pruned = fastcore.prune_twigs(nodes, parents, threshold=threshold, weights=weights) + pruned = fastcore.prune_twigs( + nodes, parents, threshold=threshold, weights=weights, mask=mask + ) dur = time.time() - start print("Pruned nodes:", pruned)