Skip to content

Commit

Permalink
Fix local (#32)
Browse files Browse the repository at this point in the history
* bug: node attributes

* bug: fix smoothing issue

---------

Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Jan 17, 2024
1 parent 2db5e4f commit 43602c4
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 42 deletions.
2 changes: 1 addition & 1 deletion src/deep_neurographs/geometry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_directional(
elif branch.shape[0] <= d:
xyz = deepcopy(branch)
else:
xyz = deepcopy(branch[d: window + d, :])
xyz = deepcopy(branch[d : window + d, :])
directionals.append(compute_tangent(xyz))

# Determine best
Expand Down
32 changes: 24 additions & 8 deletions src/deep_neurographs/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,10 @@ def get_irreducibles(swc_dict, swc_id=None, prune=True, depth=16, smooth=True):
# Extract irreducibles
leafs, junctions = get_irreducible_nodes(dense_graph, swc_dict)
assert len(leafs) > 0, "Error: swc with no leaf nodes!"
source = sample(leafs.keys(), 1)[0]
root = None
edges = dict()
nbs = dict()
for (i, j) in nx.dfs_edges(dense_graph, source=source):
for (i, j) in nx.dfs_edges(dense_graph, source=sample(leafs, 1)[0]):
# Check if start of path is valid
if root is None:
root = i
Expand All @@ -90,6 +89,8 @@ def get_irreducibles(swc_dict, swc_id=None, prune=True, depth=16, smooth=True):
root = None

# Output
leafs = set_node_attrs(swc_dict, leafs)
junctions = set_node_attrs(swc_dict, junctions)
irreducibles = {"leafs": leafs, "junctions": junctions, "edges": edges}
return swc_id, irreducibles

Expand All @@ -111,13 +112,13 @@ def get_irreducible_nodes(graph, swc_dict):
Nodes with degree > 2.
"""
leafs = dict()
junctions = dict()
leafs = set()
junctions = set()
for i in graph.nodes:
if graph.degree[i] == 1:
leafs[i] = init_node_attrs(swc_dict, i)
leafs.add(i)
elif graph.degree[i] > 2:
junctions[i] = init_node_attrs(swc_dict, i)
junctions.add(i)
return leafs, junctions


Expand Down Expand Up @@ -194,7 +195,7 @@ def get_leafs(graph):


def __smooth_branch(swc_dict, attrs, edges, nbs, root, j):
attrs["xyz"] = geometry_utils.smooth_branch(np.array(attrs["xyz"]))
attrs["xyz"] = geometry_utils.smooth_branch(np.array(attrs["xyz"]), s=10)
swc_dict, edges = upd_xyz(swc_dict, attrs, edges, nbs, root, 0)
swc_dict, edges = upd_xyz(swc_dict, attrs, edges, nbs, j, -1)
edges[(root, j)] = attrs
Expand Down Expand Up @@ -230,7 +231,6 @@ def upd_branch_endpoint(edges, key, old_xyz, new_xyz):

# -- attribute utils --
def init_edge_attrs(swc_dict, i):
#print(len(swc_dict["radius"]), i)
return {"radius": [swc_dict["radius"][i]], "xyz": [swc_dict["xyz"][i]]}


Expand All @@ -247,3 +247,19 @@ def get_edge_attr(graph, edge, attr):

def init_node_attrs(swc_dict, i):
return {"radius": swc_dict["radius"][i], "xyz": swc_dict["xyz"][i]}


def set_node_attrs(swc_dict, nodes):
node_attrs = dict()
for i in nodes:
node_attrs[i] = init_node_attrs(swc_dict, i)
return node_attrs


def upd_node_attrs(swc_dict, leafs, junctions, i):
upd_attrs = {"radius": swc_dict["radius"][i], "xyz": swc_dict["xyz"][i]}
if i in leafs:
leafs[i] = upd_attrs
else:
junctions[i] = upd_attrs
return leafs, junctions
28 changes: 8 additions & 20 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
ProcessPoolExecutor,
ThreadPoolExecutor,
as_completed,
wait,
)
from time import time

Expand All @@ -22,7 +21,7 @@
from deep_neurographs import graph_utils as gutils
from deep_neurographs import utils
from deep_neurographs.neurograph import NeuroGraph
from deep_neurographs.swc_utils import process_local_paths, process_gsc_zip
from deep_neurographs.swc_utils import process_gsc_zip, process_local_paths

N_PROPOSALS_PER_LEAF = 3
OPTIMIZE_PROPOSALS = False
Expand Down Expand Up @@ -57,7 +56,6 @@ def build_neurograph_from_local(
swc_dicts = process_local_paths(paths, min_size, bbox=bbox)

# Build neurograph
t0 = time()
neurograph = build_neurograph(
swc_dicts,
bbox=bbox,
Expand All @@ -67,19 +65,15 @@ def build_neurograph_from_local(
prune_depth=prune_depth,
smooth=smooth,
)
print(f"build_neurograph(): {time() - t0} seconds")

# Generate proposals
t0 = time()
if search_radius > 0:
neurograph.generate_proposals(
search_radius,
n_proposals_per_leaf=n_proposals_per_leaf,
optimize=optimize_proposals,
optimization_depth=optimization_depth,
)
print(f"generate_proposals(): {time() - t0} seconds")

return neurograph


Expand Down Expand Up @@ -204,7 +198,9 @@ def download_gcs_zips(bucket_name, cloud_path, min_size):
for i, path in enumerate(zip_paths):
swc_dicts.update(process_gsc_zip(bucket, path, min_size=min_size))
if i > cnt * chunk_size:
cnt, t1 = report_progress(i, len(zip_paths), chunk_size, cnt, t0, t1)
cnt, t1 = report_progress(
i, len(zip_paths), chunk_size, cnt, t0, t1
)
return swc_dicts


Expand All @@ -231,10 +227,7 @@ def build_neurograph(
print("Extract irreducible nodes and edges...")
print("# connected components:", utils.reformat_number(n_components))
irreducibles, n_nodes, n_edges = get_irreducibles(
swc_dicts,
prune=prune,
prune_depth=prune_depth,
smooth=smooth,
swc_dicts, prune=prune, prune_depth=prune_depth, smooth=smooth
)

# Build neurograph
Expand All @@ -247,9 +240,7 @@ def build_neurograph(
cnt, i = 1, 0
while len(irreducibles):
key, irreducible_set = irreducibles.popitem()
neurograph.add_immutables(
irreducible_set, key
)
neurograph.add_immutables(irreducible_set, key)
if i > cnt * chunk_size:
cnt, t1 = report_progress(i, n_components, chunk_size, cnt, t0, t1)
print(f"add_irreducibles(): {time() - t0} seconds")
Expand All @@ -261,18 +252,15 @@ def build_neurograph(
futures = {
executor.submit(
neurograph.add_immutables, irreducibles[key], swc_dicts[key], key, start_ids[key]): key for key in swc_dicts.keys()
}
}
wait(futures)
print(f" --> asynchronous - add_irreducibles(): {time() - t0} seconds")
"""
return neurograph


def get_irreducibles(
swc_dicts,
prune=PRUNE,
prune_depth=PRUNE_DEPTH,
smooth=SMOOTH,
swc_dicts, prune=PRUNE, prune_depth=PRUNE_DEPTH, smooth=SMOOTH
):
n_components = len(swc_dicts)
chunk_size = max(int(n_components * 0.02), 1)
Expand Down
4 changes: 1 addition & 3 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
import networkx as nx
import numpy as np
import tensorstore as ts
from operator import itemgetter
from scipy.spatial import KDTree
from time import time

from deep_neurographs import geometry_utils
from deep_neurographs import graph_utils as gutils
Expand Down Expand Up @@ -113,7 +111,7 @@ def add_immutables(self, irreducibles, swc_id, start_id=None):
for xyz in collisions:
del xyz_to_edge[xyz]
self.xyz_to_edge.update(xyz_to_edge)

def __add_nodes(self, nodes, key, node_ids, cur_id, swc_id):
for i in nodes[key].keys():
node_ids[i] = cur_id
Expand Down
20 changes: 10 additions & 10 deletions src/deep_neurographs/swc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
"""

from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO
from zipfile import ZipFile

import networkx as nx
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
from time import time
from zipfile import ZipFile

from deep_neurographs import geometry_utils
from deep_neurographs import graph_utils as gutils
Expand Down Expand Up @@ -60,7 +60,9 @@ def parse_local_swc(path, bbox=None, min_size=0):

def parse_gcs_zip(zip_file, path, min_size=0):
contents = read_from_gcs_zip(zip_file, path)
swc_dict = fast_parse(contents) if len(contents) > min_size else {"id": [-1]}
swc_dict = (
fast_parse(contents) if len(contents) > min_size else {"id": [-1]}
)
return utils.get_swc_id(path), swc_dict


Expand All @@ -81,10 +83,9 @@ def parse(contents, bbox=None):
...
"""
contents, offset = get_contents(contents)
min_id = np.inf
offset = [0, 0, 0]
swc_dict = {"id": [], "radius": [], "pid": [], "xyz": []}
contents, swc_id = get_contents(contents)
for line in contents:
parts = line.split()
xyz = read_xyz(parts[2:5], offset=offset)
Expand Down Expand Up @@ -124,14 +125,13 @@ def fast_parse(contents):
"""
contents, offset = get_contents(contents)
min_id = np.inf
swc_dict = {
"id": np.zeros((len(contents)), dtype=int),
"radius": np.zeros((len(contents)), dtype=float),
"pid": np.zeros((len(contents)), dtype=int),
"xyz": []
"xyz": [],
}

min_id = np.inf
for i, line in enumerate(contents):
parts = line.split()
xyz = read_xyz(parts[2:5], offset=offset)
Expand Down Expand Up @@ -368,4 +368,4 @@ def smooth(swc_dict):
def upd_edge(xyz, idxs):
idxs = np.array(idxs)
xyz[idxs] = geometry_utils.smooth_branch(xyz[idxs], s=10)
return xyz
return xyz

0 comments on commit 43602c4

Please sign in to comment.