Skip to content

Commit

Permalink
improved target edges
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Oct 15, 2023
1 parent d20db2b commit eabae95
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 185 deletions.
100 changes: 100 additions & 0 deletions src/deep_neurographs/densegraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import networkx as nx
import numpy as np
from deep_neurographs import swc_utils, utils
from deep_neurographs.geometry_utils import dist, make_line
from more_itertools import zip_broadcast
from scipy.spatial import KDTree


class DenseGraph:

def __init__(self, swc_dir):
self.xyz_to_node = dict()
self.xyz_to_swc = dict()
self.init_graphs(swc_dir)
self.init_kdtree()

def init_graphs(self, swc_dir):
self.graphs = dict()
for f in utils.listdir(swc_dir, ext=".swc"):
# Extract info
path = os.path.join(swc_dir, f)

# Construct Graph
swc_dict = swc_utils.parse(swc_utils.read_swc(path))
graph, xyz_to_node = swc_utils.file_to_graph(
swc_dict, set_attrs=True, return_dict=True,
)

# Store
xyz_to_id = dict(zip_broadcast(swc_dict["xyz"], f))
self.graphs[f] = graph
self.xyz_to_node[f] = xyz_to_node
self.xyz_to_swc.update(xyz_to_id)

def init_kdtree(self):
self.kdtree = KDTree(list(self.xyz_to_swc.keys()))

def get_projection(self, xyz):
_, idx = self.kdtree.query(xyz, k=1)
proj_xyz = tuple(self.kdtree.data[idx])
proj_dist = dist(proj_xyz, xyz)
return proj_xyz, proj_dist

def connect_nodes(self, graph_id, xyz_i, xyz_j, return_dist=True):
i = self.xyz_to_node[graph_id][xyz_i]
j = self.xyz_to_node[graph_id][xyz_j]
path = nx.shortest_path(self.graphs[graph_id], source=i, target=j)
if return_dist:
dist = self.compute_dist(graph_id, path)
return path, dist
else:
return path

def compute_dist(self, graph_id, path):
d = 0
for i in range(1, len(path)):
xyz_1 = self.graphs[graph_id].nodes[i]["xyz"]
xyz_2 = self.graphs[graph_id].nodes[i-1]["xyz"]
d += dist(xyz_1, xyz_2)
return d

def check_aligned(self, pred_xyz_i, pred_xyz_j):
# Get target graph
xyz_i, _ = self.get_projection(pred_xyz_i)
xyz_j, _ = self.get_projection(pred_xyz_j)
graph_id = self.xyz_to_swc[xyz_i]
if self.xyz_to_swc[xyz_i] != self.xyz_to_swc[xyz_j]:
return False

# Compare pred and target distances
pred_xyz_i = np.array(pred_xyz_i)
pred_xyz_j = np.array(pred_xyz_j)
pred_dist = dist(pred_xyz_i, pred_xyz_j)

target_path, target_dist = self.connect_nodes(graph_id, xyz_i, xyz_j)
target_dist = max(target_dist, 1)

ratio = min(pred_dist, target_dist) / max(pred_dist, target_dist)
if ratio < 0.7 and pred_dist > 25:
return False
elif ratio < 0.25:
return False

# Compare projected predicted path
proj_dists = []
proj_nodes = set()
for xyz in make_line(pred_xyz_i, pred_xyz_j, len(target_path)):
proj_xyz, proj_d = self.get_projection(xyz)
swc = self.xyz_to_swc[tuple(proj_xyz)]
proj_nodes.add(self.xyz_to_node[swc][tuple(proj_xyz)])
proj_dists.append(proj_d)

intersection = proj_nodes.intersection(set(target_path))
overlap = len(intersection) / len(target_path)
if overlap < 0.4 and pred_dist > 25:
return False
elif overlap < 0.2:
return False
return True
7 changes: 4 additions & 3 deletions src/deep_neurographs/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ def generate_mutable_skel_features(neurograph):
for edge in neurograph.mutable_edges:
length = compute_length(neurograph, edge)
radius_i, radius_j = get_radii(neurograph, edge)

dot1, dot2, dot3 = get_directionals(neurograph, edge, 5)
ddot1, ddot2, ddot3 = get_directionals(neurograph, edge, 5)
features[edge] = np.concatenate((length, radius_i, radius_j, dot1, dot2, dot3), axis=None)
ddot1, ddot2, ddot3 = get_directionals(neurograph, edge, 10)
features[edge] = np.concatenate(
(length, radius_i, radius_j, dot1, dot2, dot3, ddot1, ddot2, ddot3), axis=None
)
return features


Expand Down
57 changes: 31 additions & 26 deletions src/deep_neurographs/geometry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,31 +89,32 @@ def compute_tangent(xyz):

# Smoothing
def smooth_branch(xyz):
t = np.arange(len(xyz[:, 0]) + 12)
s = len(t) / 10
cs_x = UnivariateSpline(t, extend_boundary(xyz[:, 0]), s=s, k=3)
cs_y = UnivariateSpline(t, extend_boundary(xyz[:, 1]), s=s, k=3)
cs_z = UnivariateSpline(t, extend_boundary(xyz[:, 2]), s=s, k=3)
smoothed_x = trim_boundary(cs_x(t))
smoothed_y = trim_boundary(cs_y(t))
smoothed_z = trim_boundary(cs_z(t))
smoothed = np.column_stack((smoothed_x, smoothed_y, smoothed_z))
return smoothed


def extend_boundary(x, num_boundary_points=6):
extended_x = np.concatenate(
(
np.linspace(x[0], x[1], num_boundary_points, endpoint=False),
x,
np.linspace(x[-2], x[-1], num_boundary_points, endpoint=False),
)
)
return extended_x


def trim_boundary(x, num_boundary_points=6):
return x[num_boundary_points:-num_boundary_points]
if xyz.shape[0] > 5:
spl_x, spl_y, spl_z = fit_spline(xyz)
t = np.arange(xyz.shape[0])
xyz = np.column_stack((spl_x(t), spl_y(t), spl_z(t)))
return xyz


def fit_spline(xyz):
s = xyz.shape[0] / 10
t = np.arange(xyz.shape[0])
cs_x = UnivariateSpline(t, xyz[:, 0], s=s, k=3)
cs_y = UnivariateSpline(t, xyz[:, 1], s=s, k=3)
cs_z = UnivariateSpline(t, xyz[:, 2], s=s, k=3)
return cs_x, cs_y, cs_z


def smooth_end(branch_xyz, radii, ref_xyz, num_pts=8):
smooth_bool = branch_xyz.shape[0] > 10
if all(branch_xyz[0] == ref_xyz) and smooth_bool:
return branch_xyz[num_pts:-1, :], radii[num_pts:-1], 0
elif all(branch_xyz[-1] == ref_xyz) and smooth_bool:
branch_xyz = branch_xyz[:-num_pts]
radii = radii[:-num_pts]
return branch_xyz, radii, -1
else:
return branch_xyz, radii, None


# Miscellaneous
Expand All @@ -138,4 +139,8 @@ def dist(x, y, metric="l2"):
if metric == "l1":
return np.linalg.norm(np.subtract(x, y), ord=1)
else:
return np.linalg.norm(np.subtract(x, y), ord=2)
return np.linalg.norm(np.subtract(x, y), ord=2)

def make_line(xyz_1, xyz_2, num_steps):
t_steps = np.linspace(0, 1, num_steps)
return [(1 - t) * xyz_1 + t * xyz_2 for t in t_steps]
15 changes: 5 additions & 10 deletions src/deep_neurographs/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,6 @@
from deep_neurographs import swc_utils, utils


def init_dense_graphs(swc_dir):
dense_graphs = dict()
for f in utils.listdir(swc_dir, ext=".swc"):
raw_txt = swc_utils.read_swc(os.path.join(swc_dir, f))
swc_dict = swc_utils.parse(raw_txt)
graph_id = f.replace(".0.swc", "")
graph = swc_utils.file_to_graph(swc_dict, graph_id=graph_id, set_attrs=True)
dense_graphs[graph_id] = graph
return dense_graphs

def get_irreducibles(graph):
leafs = []
junctions = []
Expand Down Expand Up @@ -157,3 +147,8 @@ def _init_edge(swc_dict=None, node=None):
def get_edge_attr(graph, edge, attr):
edge_data = graph.get_edge_data(*edge)
return edge_data[attr]

def is_leaf(graph, i):
nbs = [j for j in graph.neighbors(i)]
return True if len(nbs) == 1 else False

2 changes: 1 addition & 1 deletion src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def init_immutables_from_local(
anisotropy=[1.0, 1.0, 1.0],
prune=True,
prune_depth=16,
smooth=False,
smooth=True,
):
"""
To do...
Expand Down
Loading

0 comments on commit eabae95

Please sign in to comment.