diff --git a/src/deep_neurographs/densegraph.py b/src/deep_neurographs/densegraph.py index 0231cbe..fc92672 100644 --- a/src/deep_neurographs/densegraph.py +++ b/src/deep_neurographs/densegraph.py @@ -67,9 +67,7 @@ def init_graphs(self, swc_dir): # Construct Graph path = os.path.join(swc_dir, f) swc_dict = swc_utils.parse_local_swc(path) - graph, xyz_to_node = swc_utils.file_to_graph( - swc_dict, set_attrs=True, return_dict=True - ) + graph, xyz_to_node = swc_utils.to_graph(swc_dict, set_attrs=True) # Store xyz_to_id = dict(zip_broadcast(swc_dict["xyz"], f)) diff --git a/src/deep_neurographs/graph_utils.py b/src/deep_neurographs/graph_utils.py index 3aa79cb..b9f3247 100644 --- a/src/deep_neurographs/graph_utils.py +++ b/src/deep_neurographs/graph_utils.py @@ -5,143 +5,175 @@ @email: anna.grim@alleninstitute.org -Routines for working with graphs. +Routines that extract the irreducible components of a graph. """ +from copy import deepcopy +from random import sample + import networkx as nx +import numpy as np + +from deep_neurographs import geometry_utils, swc_utils, utils + + +def get_irreducibles(swc_dict, prune=True, depth=16, smooth=True): + """ + Gets irreducible components of the graph stored in "swc_dict". The + irreducible components consist of the leaf and junction nodes along with + the edges among this set of nodes. + + Parameters + ---------- + swc_dict : dict + Contents of an swc file. + prune : True + Indication of whether to prune short branches. + depth : int + Path length that determines whether a branch is short. + smooth : bool + Indication of whether to smooth each branch. + + Returns + ------- + leafs : set + Nodes with degreee 1. + junctions : set + Nodes with degree > 2. + edges : dict + Set of edges connecting nodes in leafs and junctions. The keys are + pairs of nodes connected by an edge and values are a dictionary of + attributes. + + """ + # Initializations + dense_graph = swc_utils.to_graph(swc_dict) + if prune: + dense_graph = prune_short_branches(dense_graph, depth) + + # Extract irreducibles + leafs, junctions = get_irreducible_nodes(dense_graph) + source = sample(leafs, 1)[0] + root = None + edges = dict() + nbs = dict() + for (i, j) in nx.dfs_edges(dense_graph, source=source): + # Check if start of path is valid + if root is None: + root = i + attrs = __init_edge_attrs(swc_dict, i) + + # Visit j + attrs = __upd_edge_attrs(swc_dict, attrs, j) + if j in leafs or j in junctions: + if smooth: + swc_dict, edges = __smooth_branch( + swc_dict, attrs, edges, nbs, root, j + ) + else: + edges[(root, j)] = attrs + nbs = append_value(nbs, root, j) + nbs = append_value(nbs, j, root) + root = None + return leafs, junctions, edges -from deep_neurographs import swc_utils, utils +def get_irreducible_nodes(graph): + """ + Gets irreducible nodes (i.e. leafs and junctions) of a graph. -def extract_irreducible_graph(swc_dict, prune=True, prune_depth=16): - graph = swc_utils.to_graph(swc_dict) - leafs, junctions = get_irreducibles(graph) - irreducible_edges, leafs = extract_irreducible_edges( - graph, leafs, junctions, swc_dict, prune=prune, prune_depth=prune_depth - ) - if prune: - irreducible_edges, junctions = check_irreducibility( - junctions, irreducible_edges - ) - return leafs, junctions, irreducible_edges + Parameters + ---------- + graph : networkx.Graph + Graph to be searched. + Returns + ------- + leafs : set + Nodes with degreee 1. + junctions : set + Nodes with degree > 2. -def get_irreducibles(graph): - leafs = [] - junctions = [] + """ + leafs = set() + junctions = set() for i in graph.nodes: if graph.degree[i] == 1: - leafs.append(i) + leafs.add(i) elif graph.degree[i] > 2: - junctions.append(i) + junctions.add(i) return leafs, junctions -def extract_irreducible_edges( - graph, leafs, junctions, swc_dict, prune=True, prune_depth=16 -): - root = None - irreducible_edges = dict() - for (i, j) in nx.dfs_edges(graph, source=leafs[0]): - # Check start of path is valid - if root is None: - root = i - edge = _init_edge(swc_dict=swc_dict, node=i) - path_length = 0 +def prune_short_branches(graph, depth): + remove_nodes = [] + for leaf in get_leafs(graph): + remove_nodes.extend(inspect_branch(graph, leaf, depth)) + graph.remove_nodes_from(remove_nodes) + return graph - # Add to path - edge["radius"].append(swc_dict["radius"][j]) - edge["xyz"].append(swc_dict["xyz"][j]) - path_length += 1 - # Check whether to end path - if j in leafs or j in junctions: - if prune and path_length <= prune_depth: - condition1 = j in leafs and root in junctions - condition2 = root in leafs and j in junctions - if condition1 or condition2: - leafs.remove(j if condition1 else root) - else: - irreducible_edges[(root, j)] = edge - else: - irreducible_edges[(root, j)] = edge - root = None - return irreducible_edges, leafs - - -def check_irreducibility(junctions, irreducible_edges): - graph = nx.Graph() - graph.add_edges_from(irreducible_edges.keys()) - nx.set_edge_attributes(graph, irreducible_edges) - for j in junctions: - if j not in graph.nodes: - junctions.remove(j) - elif graph.degree[j] == 2: - # Get join edges - nbs = list(graph.neighbors(j)) - edge1 = graph.get_edge_data(j, nbs[0]) - edge2 = graph.get_edge_data(j, nbs[1]) - edge = join_edges(edge1, edge2) - - # Update irreducible edges - junctions.remove(j) - irreducible_edges = utils.remove_key( - irreducible_edges, (j, nbs[0]) - ) - irreducible_edges = utils.remove_key( - irreducible_edges, (j, nbs[1]) - ) - irreducible_edges[tuple(nbs)] = edge +def inspect_branch(graph, leaf, depth): + path = [leaf] + for (i, j) in nx.dfs_edges(graph, source=leaf, depth_limit=depth): + if graph.degree(j) > 2: + return path + elif graph.degree(j) == 2: + path.append(j) + return [] + - graph.remove_edge(j, nbs[0]) - graph.remove_edge(j, nbs[1]) - graph.remove_node(j) - graph.add_edge(*tuple(nbs), xyz=edge["xyz"], radius=edge["radius"]) - if graph.degree[nbs[0]] > 2: - junctions.append(nbs[0]) +def get_leafs(graph): + return [i for i in graph.nodes if graph.degree[i] == 1] - if graph.degree[nbs[1]] > 2: - junctions.append(nbs[1]) - return irreducible_edges, junctions +def __smooth_branch(swc_dict, attrs, edges, nbs, root, j): + attrs["xyz"] = geometry_utils.smooth_branch(np.array(attrs["xyz"])) + swc_dict, edges = upd_xyz(swc_dict, attrs, edges, nbs, root, 0) + swc_dict, edges = upd_xyz(swc_dict, attrs, edges, nbs, j, -1) + edges[(root, j)] = attrs + return swc_dict, edges -def join_edges(edge1, edge2): - # Last point in edge1 must connect to first point in edge2 - if edge1["xyz"][0] == edge2["xyz"][0]: - edge1 = reverse_edge(edge1) - elif edge1["xyz"][-1] == edge2["xyz"][-1]: - edge2 = reverse_edge(edge2) - elif edge1["xyz"][0] == edge2["xyz"][-1]: - edge1 = reverse_edge(edge1) - edge2 = reverse_edge(edge2) - edge = { - "xyz": edge1["xyz"] + edge2["xyz"][1:], - "radius": edge1["radius"] + edge2["radius"], - } - return edge +def upd_xyz(swc_dict, attrs, edges, nbs, i, start_or_end): + if i in nbs.keys(): + for j in nbs[i]: + key = (i, j) if (i, j) in edges.keys() else (j, i) + edges = upd_branch_endpoint( + edges, key, swc_dict["xyz"][i], attrs["xyz"][start_or_end] + ) + swc_dict["xyz"][i] = attrs["xyz"][start_or_end] + return swc_dict, edges -def reverse_edge(edge): - edge["xyz"].reverse() - edge["radius"].reverse() - return edge +def append_value(my_dict, key, value): + if key in my_dict.keys(): + my_dict[key].append(value) + else: + my_dict[key] = [value] + return my_dict -def _init_edge(swc_dict=None, node=None): - edge = {"radius": [], "xyz": []} - if node is not None: - edge["radius"].append(swc_dict["radius"][node]) - edge["xyz"].append(swc_dict["xyz"][node]) - return edge +def upd_branch_endpoint(edges, key, old_xyz, new_xyz): + if all(edges[key]["xyz"][0] == old_xyz): + edges[key]["xyz"][0] = new_xyz + else: + edges[key]["xyz"][-1] = new_xyz + return edges + + +# -- attribute utils -- +def __init_edge_attrs(swc_dict, i): + return {"radius": [swc_dict["radius"][i]], "xyz": [swc_dict["xyz"][i]]} + + +def __upd_edge_attrs(swc_dict, attrs, i): + attrs["radius"].append(swc_dict["radius"][i]) + attrs["xyz"].append(swc_dict["xyz"][i]) + return attrs def get_edge_attr(graph, edge, attr): edge_data = graph.get_edge_data(*edge) return edge_data[attr] - - -def is_leaf(graph, i): - nbs = [j for j in graph.neighbors(i)] - return True if len(nbs) == 1 else False diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index 032b1cc..5f98381 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -17,7 +17,8 @@ from google.cloud import storage -from deep_neurographs import graph_utils as gutils, swc_utils, utils +from deep_neurographs import graph_utils as gutils +from deep_neurographs import swc_utils, utils from deep_neurographs.neurograph import NeuroGraph from deep_neurographs.swc_utils import parse_gcs_zip @@ -56,6 +57,7 @@ def build_neurograph_from_local( origin=img_patch_origin, shape=img_patch_shape, ) + print("Build Graph...") neurograph = init_immutables_from_local( neurograph, swc_dir=swc_dir, @@ -65,11 +67,13 @@ def build_neurograph_from_local( min_size=min_size, smooth=smooth, ) + t0 = time() if search_radius > 0: neurograph.generate_proposals( n_proposals_per_leaf=n_proposals_per_leaf, search_radius=search_radius, ) + print(f" generate_proposals(): {time() - t0} seconds") return neurograph @@ -113,7 +117,7 @@ def build_neurograph_from_gcs_zips( signal. optimize_depth : int Distance from each edge proposal end point that is search during - alignment optimization. + alignment optimization. smooth : bool Indication of whether to smooth branches from swc files. @@ -150,11 +154,17 @@ def init_immutables_from_local( min_size=MIN_SIZE, smooth=SMOOTH, ): + neurograph.extraction = 0 + neurograph.add_edges_timer = 0 swc_paths = get_paths(swc_dir) if swc_dir else swc_paths for path in swc_paths: neurograph.ingest_swc_from_local( path, prune=True, prune_depth=16, smooth=smooth ) + print( + f" extract_irreducible_graph(): {neurograph.extraction} seconds" + ) + print(f" add_edges(): {neurograph.add_edges_timer} seconds") return neurograph @@ -165,11 +175,7 @@ def get_paths(swc_dir): return paths -def download_gcs_zips( - bucket_name, - cloud_path, - min_size=0, -): +def download_gcs_zips(bucket_name, cloud_path, min_size=0): """ Downloads swc files from zips stored in a GCS bucket. @@ -192,7 +198,7 @@ def download_gcs_zips( bucket = storage_client.bucket(bucket_name) zip_paths = list_gcs_filenames(bucket, cloud_path, ".zip") chunk_size = int(len(zip_paths) * 0.1) - print(f"# zip files: {len(zip_paths)} \n\n", ) + print(f"# zip files: {len(zip_paths)} \n\n") # Parse cnt = 1 @@ -257,7 +263,7 @@ def list_gcs_filenames(bucket, cloud_path, extension): def report_runtimes( - n_files, n_files_completed, chunk_size, start, start_chunk, + n_files, n_files_completed, chunk_size, start, start_chunk ): runtime = time() - start chunk_runtime = time() - start_chunk @@ -281,7 +287,7 @@ def build_neurograph( optimize_depth=OPTIMIZE_DEPTH, prune=PRUNE, prune_depth=PRUNE_DEPTH, - smooth=SMOOTH + smooth=SMOOTH, ): graph_list = build_graphs(swc_dicts, prune, prune_depth, smooth) start_ids = get_start_ids(swc_dicts) @@ -291,7 +297,8 @@ def build_neurograph( img_path=img_path, optimize_alignment=optimize_alignment, optimize_depth=optimize_depth, - ) + ) + def build_graphs(swc_dicts, prune, prune_depth, smooth): t0 = time() @@ -307,7 +314,7 @@ def build_subgraph(swc_dict): graph = nx.Graph() graph.add_edges_from(zip(swc_dict["id"][1:], swc_dict["pid"][1:])) return graph - + def get_start_ids(swc_dicts): # runtime: ~ 1 minute diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index 0010a7d..0cb9b3a 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -9,6 +9,7 @@ """ from copy import deepcopy +from time import time import networkx as nx import numpy as np @@ -104,17 +105,21 @@ def init_densegraph(self): def ingest_swc_from_local( self, path, prune=True, prune_depth=16, smooth=True ): + # Parse swc swc_id = utils.get_id(path) - swc_dict = swc_utils.parse_local_swc( - path, bbox=self.bbox, img_shape=self.shape - ) + swc_dict = swc_utils.parse_local_swc(path, bbox=self.bbox) if len(swc_dict["xyz"]) > self.size_threshold: - swc_dict = swc_utils.smooth(swc_dict) if smooth else swc_dict - self.add_immutables( - swc_id, swc_dict, prune=prune, prune_depth=prune_depth - ) + return None + + # Build neurograph + t0 = time() + leafs, junctions, edges = gutils.get_irreducibles( + swc_dict, prune=prune, depth=prune_depth, smooth=smooth + ) + self.extraction += time() - t0 + self.add_immutables(swc_id, swc_dict, leafs, junctions, edges) - def add_immutables(self, swc_id, swc_dict, prune=True, prune_depth=16): + def add_immutables(self, swc_id, swc_dict, leafs, junctions, edges): """ Adds nodes to graph from a dictionary generated from an swc files. @@ -132,11 +137,9 @@ def add_immutables(self, swc_id, swc_dict, prune=True, prune_depth=16): """ # Add nodes - leafs, junctions, edges = gutils.extract_irreducible_graph( - swc_dict, prune=prune, prune_depth=prune_depth - ) + node_id = dict() - for i in leafs + junctions: + for i in list(leafs) + list(junctions): node_id[i] = len(self.nodes) self.add_node( node_id[i], @@ -146,6 +149,7 @@ def add_immutables(self, swc_id, swc_dict, prune=True, prune_depth=16): ) # Add edges + t0 = time() for i, j in edges.keys(): # Get edge edge = (node_id[i], node_id[j]) @@ -164,6 +168,7 @@ def add_immutables(self, swc_id, swc_dict, prune=True, prune_depth=16): for xyz in collisions: del xyz_to_edge[xyz] self.xyz_to_edge.update(xyz_to_edge) + self.add_edges_timer += time() - t0 # Update leafs and junctions for l in leafs: diff --git a/src/deep_neurographs/swc_utils.py b/src/deep_neurographs/swc_utils.py index 20c39de..9dcf7c1 100644 --- a/src/deep_neurographs/swc_utils.py +++ b/src/deep_neurographs/swc_utils.py @@ -11,10 +11,10 @@ import os from copy import deepcopy as cp +from itertools import repeat import networkx as nx import numpy as np -from more_itertools import zip_broadcast from deep_neurographs import geometry_utils from deep_neurographs import graph_utils as gutils @@ -22,23 +22,18 @@ # -- io utils -- -def parse_local_swc(path, bbox=None, img_shape=None, min_size=0): +def parse_local_swc(path, bbox=None, min_size=0): swc_contents = read_from_local(path) - if len(swc_contents) > min_size: - return parse(swc_contents, bbox=bbox, img_shape=img_shape) - else: - return [] + parse_bool = len(swc_contents) > min_size + return parse(swc_contents, bbox=bbox) if parse_bool else [] -def parse_gcs_zip(zip_file, path, bbox=None, img_shape=None, min_size=0): +def parse_gcs_zip(zip_file, path, min_size=0): swc_contents = read_from_gcs_zip(zip_file, path) - if len(swc_contents) > min_size: - return parse(swc_contents, bbox=bbox, img_shape=img_shape) - else: - return [] + return parse(swc_contents) if len(swc_contents) > min_size else [] -def parse(swc_contents, bbox=None, img_shape=None): +def parse(swc_contents, bbox=None): """ Parses an swc file to extract the contents which is stored in a dict. Note that node_ids from swc are refactored to index from 0 to n-1 where n is @@ -65,8 +60,8 @@ def parse(swc_contents, bbox=None, img_shape=None): if not line.startswith("#") and len(line) > 0: parts = line.split() xyz = read_xyz(parts[2:5], offset=offset) - if bbox and img_shape: - if not utils.is_contained(bbox, img_shape, xyz): + if bbox: + if not utils.is_contained(bbox, xyz): break swc_dict["id"].append(int(parts[0])) @@ -131,7 +126,7 @@ def read_xyz(xyz, offset=[0, 0, 0]): return tuple([float(xyz[i]) + offset[i] for i in range(3)]) -def write_swc(path, contents): +def write(path, contents): if type(contents) is list: write_list(path, contents) elif type(contents) is dict: @@ -246,19 +241,22 @@ def make_entry(graph, i, parent, r, reindex): # -- Conversions -- -def to_graph(swc_dict, graph_id=None, set_attrs=False, return_dict=False): +def to_graph(swc_dict, graph_id=None, set_attrs=False): graph = nx.Graph(graph_id=graph_id) graph.add_edges_from(zip(swc_dict["id"][1:], swc_dict["pid"][1:])) - xyz_to_node = dict() if set_attrs: - for i in graph.nodes: - graph.nodes[i]["xyz"] = swc_dict["xyz"][i] - graph.nodes[i]["radius"] = swc_dict["radius"][i] - xyz_to_node[swc_dict["xyz"][i]] = i - if return_dict: + graph = __add_attributes(swc_dict, graph) + xyz_to_node = dict(zip(swc_dict["xyz"], swc_dict["id"])) return graph, xyz_to_node - else: - return graph + return graph + + +def __add_attributes(swc_dict, graph): + xyz = swc_dict["xyz"] + radii = swc_dict["radius"] + attrs = [{"xyz": xyz[i], "radius": radii[i]} for i in graph.nodes] + nx.set_node_attributes(graph, dict(zip(swc_dict["id"], attrs))) + return graph # -- miscellaneous -- @@ -266,7 +264,7 @@ def smooth(swc_dict): if len(swc_dict["xyz"]) > 10: xyz = np.array(swc_dict["xyz"]) graph = to_graph(swc_dict) - leafs, junctions = gutils.get_irreducibles(graph) + leafs, junctions = gutils.get_irreducible_nodes(graph) if len(junctions) == 0: xyz = geometry_utils.smooth_branch(xyz) else: @@ -292,28 +290,3 @@ def upd_edge(xyz, idxs): idxs = np.array(idxs) xyz[idxs] = geometry_utils.smooth_branch(xyz[idxs], s=10) return xyz - - -def upd_dict(upd, path, radius, pid): - for k in range(len(path)): - next_id = len(upd["id"]) - upd["id"].append(next_id) - upd["xyz"].append(path[k]) - upd["radius"].append(radius[k]) - if len(upd["pid"]) == 0: - upd["pid"].extend([1, 0]) - elif k == 0: - upd["pid"].append(pid) - else: - upd["pid"].append(next_id - 1) - return upd, next_id - - -def generate_coords(center, r): - xyz = [] - for x in range(-r, r + 1): - for y in range(-r, r + 1): - for z in range(-r, r + 1): - if abs(x) + abs(y) + abs(z) <= r: - xyz.append((center[0] + x, center[1] + y, center[2] + z)) - return xyz diff --git a/src/deep_neurographs/utils.py b/src/deep_neurographs/utils.py index 8ca1eea..e10be55 100644 --- a/src/deep_neurographs/utils.py +++ b/src/deep_neurographs/utils.py @@ -411,11 +411,12 @@ def get_avg_std(data, weights=None): return avg, math.sqrt(var) -def is_contained(bbox, img_shape, xyz): +def is_contained(bbox, xyz): xyz = apply_anisotropy(xyz - bbox["min"]) + dims = bbox["max"] - bbox["min"] for i in range(3): lower_bool = xyz[i] < 0 - upper_bool = xyz[i] >= img_shape[i] + upper_bool = xyz[i] >= dims[i] if lower_bool or upper_bool: return False return True