From 8764cd9920739ea81ccf7ced1cd74d6d8d793e05 Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:50:42 -0800 Subject: [PATCH] Update graph_util.py --- src/deep_neurographs/utils/graph_util.py | 34 +++++++++++------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index da4e68d..a375fc8 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -126,15 +126,10 @@ def run( irreducibles = self.schedule_processes(swc_dicts) # Build FragmentsGraph - if self.progress_bar: - pbar = tqdm(total=len(irreducibles), desc="Combine Graphs") - fragments_graph = FragmentsGraph(node_spacing=self.node_spacing) while len(irreducibles): irreducible_set = irreducibles.pop() fragments_graph.add_component(irreducible_set) - if self.progress_bar: - pbar.update(1) return fragments_graph # --- Graph structure extraction --- @@ -204,7 +199,8 @@ def get_irreducibles(self, swc_dict): # Extract irreducibles irreducibles = list() - if graph.number_of_nodes() > 1: + path_length = compute_path_length(graph) + if path_length > min_size and graph.number_of_nodes() > 1: for nodes in nx.connected_components(graph): if len(nodes) > 1: result = self.get_component_irreducibles( @@ -303,7 +299,6 @@ def get_component_irreducibles(self, graph, swc_dict): edges = dict() nbs = defaultdict(list) root = None - total_length = 0 branch_length = 0 for (i, j) in nx.dfs_edges(graph, source=util.sample_once(leafs)): # Check if starting new or continuing current path @@ -333,19 +328,15 @@ def get_component_irreducibles(self, graph, swc_dict): nbs[root].append(j) nbs[j].append(root) root = None - total_length += branch_length # Output - if total_length > self.min_size: - irreducibles = { - "leaf": set_node_attrs(swc_dict, leafs), - "branching": set_node_attrs(swc_dict, branchings), - "edge": edges, - "swc_id": swc_dict["swc_id"], - } - return irreducibles - else: - return False + irreducibles = { + "leaf": set_node_attrs(swc_dict, leafs), + "branching": set_node_attrs(swc_dict, branchings), + "edge": edges, + "swc_id": swc_dict["swc_id"], + } + return irreducibles # --- Utils --- @@ -630,6 +621,13 @@ def upd_node_attrs(swc_dict, leafs, branchings, i): return leafs, branchings +def compute_path_length(self, graph): + path_length = 0 + for i, j in nx.dfs_edges(graph): + path_length += compute_dist(graph, i, j) + return path_length + + def compute_dist(graph, i, j): """ Computes Euclidean distance between nodes i and j.