Skip to content

Commit

Permalink
training updates
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Oct 22, 2023
1 parent eb0500b commit 5b62912
Show file tree
Hide file tree
Showing 11 changed files with 184 additions and 360 deletions.
17 changes: 9 additions & 8 deletions src/deep_neurographs/densegraph.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os

import networkx as nx
import numpy as np
from deep_neurographs import swc_utils, utils
from deep_neurographs.geometry_utils import dist, make_line
from more_itertools import zip_broadcast
from scipy.spatial import KDTree

from deep_neurographs import swc_utils, utils
from deep_neurographs.geometry_utils import dist, make_line


class DenseGraph:

def __init__(self, swc_dir):
self.xyz_to_node = dict()
self.xyz_to_swc = dict()
Expand All @@ -24,7 +25,7 @@ def init_graphs(self, swc_dir):
# Construct Graph
swc_dict = swc_utils.parse(swc_utils.read_swc(path))
graph, xyz_to_node = swc_utils.file_to_graph(
swc_dict, set_attrs=True, return_dict=True,
swc_dict, set_attrs=True, return_dict=True
)

# Store
Expand All @@ -41,7 +42,7 @@ def get_projection(self, xyz):
proj_xyz = tuple(self.kdtree.data[idx])
proj_dist = dist(proj_xyz, xyz)
return proj_xyz, proj_dist

def connect_nodes(self, graph_id, xyz_i, xyz_j, return_dist=True):
i = self.xyz_to_node[graph_id][xyz_i]
j = self.xyz_to_node[graph_id][xyz_j]
Expand All @@ -56,7 +57,7 @@ def compute_dist(self, graph_id, path):
d = 0
for i in range(1, len(path)):
xyz_1 = self.graphs[graph_id].nodes[i]["xyz"]
xyz_2 = self.graphs[graph_id].nodes[i-1]["xyz"]
xyz_2 = self.graphs[graph_id].nodes[i - 1]["xyz"]
d += dist(xyz_1, xyz_2)
return d

Expand All @@ -77,7 +78,7 @@ def check_aligned(self, pred_xyz_i, pred_xyz_j):
target_dist = max(target_dist, 1)

ratio = min(pred_dist, target_dist) / max(pred_dist, target_dist)
if ratio < 0.7 and pred_dist > 25:
if ratio < 0.6 and pred_dist > 25:
return False
elif ratio < 0.25:
return False
Expand All @@ -93,7 +94,7 @@ def check_aligned(self, pred_xyz_i, pred_xyz_j):

intersection = proj_nodes.intersection(set(target_path))
overlap = len(intersection) / len(target_path)
if overlap < 0.4 and pred_dist > 25:
if overlap < 0.5 and pred_dist > 25:
return False
elif overlap < 0.2:
return False
Expand Down
116 changes: 68 additions & 48 deletions src/deep_neurographs/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,74 +12,87 @@
from random import sample

import numpy as np
from scipy.linalg import svd

from deep_neurographs import geometry_utils, utils

NUM_IMG_FEATURES = 0
NUM_SKEL_FEATURES = 9
NUM_PC_FEATURES = 0
NUM_POINTS = 10
WINDOW_SIZE = [6, 6, 6]

NUM_IMG_FEATURES = NUM_POINTS
NUM_SKEL_FEATURES = 11


# -- Wrappers --
def generate_mutable_features(
neurograph, img=True, pointcloud=True, skel=True
neurograph, anisotropy=[1.0, 1.0, 1.0], img_path=None
):
features = dict()
if img:
features["img"] = generate_img_features(neurograph)
if skel:
features["skel"] = generate_mutable_skel_features(neurograph)
features = combine_feature_vecs(features)
return features


# -- Node feature extraction --
def generate_img_features(neurograph):
img_features = np.zeros((neurograph.num_nodes(), NUM_IMG_FEATURES))
for node in neurograph.nodes:
img_features[node] = _generate_node_img_features()
return img_features


def _generate_node_img_features():
pass


def generate_skel_features(neurograph):
skel_features = np.zeros((neurograph.num_nodes(), NUM_SKEL_FEATURES))
for node in neurograph.nodes:
skel_features[node] = _generate_node_skel_features(neurograph, node)
return skel_features


def _generate_node_skel_features(neurograph, node):
radius = neurograph.nodes[node]["radius"]
xyz = neurograph.nodes[node]["xyz"]
return np.append(xyz, radius)


def generate_pointcloud_features(neurograph):
pc_features = np.zeros((neurograph.num_nodes(), NUM_PC_FEATURES))
for node in neurograph.nodes:
pc_features[node] = _generate_pointcloud_node_features()
return pc_features
"""
Generates feature vectors for every mutable edge in a neurograph.
Parameters
----------
neurograph : NeuroGraph
NeuroGraph generated from a directory of swcs generated from a
predicted segmentation.
anisotropy : list[float]
Real-world to image coordinates scaling factor for (x, y, z).
img_path : str, optional
Path to image volume.
Returns
-------
Dictionary where each key-value pair corresponds to a type of feature
vector and the numerical vector.
"""
features = {"skel": generate_mutable_skel_features(neurograph)}
if img_path is not None:
features["img"] = generate_mutable_img_features(
neurograph, img_path, anisotropy=anisotropy
)
return combine_feature_vecs(features)


def _generate_pointcloud_node_features():
pass
# -- Edge feature extraction --
def generate_mutable_img_features(
neurograph, path, anisotropy=[1.0, 1.0, 1.0]
):
img = utils.open_zarr(path)
features = dict()
for edge in neurograph.mutable_edges:
xyz = neurograph.edges[edge]["xyz"]
line = geometry_utils.make_line(xyz[0], xyz[1], NUM_POINTS)
features[edge] = geometry_utils.get_profile(
img, line, anisotropy=anisotropy, window_size=WINDOW_SIZE
)
return features


# -- Edge feature extraction --
def generate_mutable_skel_features(neurograph):
features = dict()
for edge in neurograph.mutable_edges:
i, j = tuple(edge)
deg_i = len(list(neurograph.neighbors(i)))
deg_j = len(list(neurograph.neighbors(j)))
length = compute_length(neurograph, edge)
radius_i, radius_j = get_radii(neurograph, edge)
dot1, dot2, dot3 = get_directionals(neurograph, edge, 5)
ddot1, ddot2, ddot3 = get_directionals(neurograph, edge, 10)
features[edge] = np.concatenate(
(length, radius_i, radius_j, dot1, dot2, dot3, ddot1, ddot2, ddot3), axis=None
(
length,
deg_i,
deg_j,
radius_i,
radius_j,
dot1,
dot2,
dot3,
ddot1,
ddot2,
ddot3,
),
axis=None,
)
return features

Expand Down Expand Up @@ -221,4 +234,11 @@ def generate_immutable_skel_features(neurograph):
def _generate_immutable_skel_features(neurograph, edge):
mean_radius = np.mean(neurograph.edges[edge]["radius"], axis=0)
return np.concatenate((mean_radius), axis=None)
def generate_skel_features(neurograph):
skel_features = np.zeros((neurograph.num_nodes(), NUM_SKEL_FEATURES))
for node in neurograph.nodes:
skel_features[node] = _generate_node_skel_features(neurograph, node)
return skel_features
"""
32 changes: 29 additions & 3 deletions src/deep_neurographs/geometry_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from scipy.interpolate import CubicSpline, UnivariateSpline
from scipy.interpolate import UnivariateSpline
from scipy.linalg import svd

from deep_neurographs import utils


Expand Down Expand Up @@ -87,6 +88,12 @@ def compute_tangent(xyz):
return tangent / np.linalg.norm(tangent)


def compute_normal(xyz):
U, S, VT = compute_svd(xyz)
normal = VT[-1]
return normal / np.linalg.norm(normal)


# Smoothing
def smooth_branch(xyz):
if xyz.shape[0] > 5:
Expand Down Expand Up @@ -117,12 +124,30 @@ def smooth_end(branch_xyz, radii, ref_xyz, num_pts=8):
return branch_xyz, radii, None


# Image feature extraction
def get_profile(
img, xyz_arr, anisotropy=[1.0, 1.0, 1.0], window_size=[4, 4, 4]
):
xyz_arr = get_coords(xyz_arr, anisotropy=anisotropy)
profile = []
for xyz in xyz_arr:
img_chunk = utils.read_img_chunk(img, xyz, window_size)
profile.append(np.max(img_chunk))
return profile


def get_coords(xyz_arr, anisotropy=[1.0, 1.0, 1.0]):
for i in range(3):
xyz_arr[:, i] = xyz_arr[:, i] / anisotropy[i]
return xyz_arr.astype(int)


# Miscellaneous
def compare_edges(xyx_i, xyz_j, xyz_k):
dist_ij = dist(xyx_i, xyz_j)
dist_ik = dist(xyx_i, xyz_k)
return dist_ij < dist_ik


def dist(x, y, metric="l2"):
"""
Expand All @@ -141,6 +166,7 @@ def dist(x, y, metric="l2"):
else:
return np.linalg.norm(np.subtract(x, y), ord=2)


def make_line(xyz_1, xyz_2, num_steps):
t_steps = np.linspace(0, 1, num_steps)
return [(1 - t) * xyz_1 + t * xyz_2 for t in t_steps]
return np.array([(1 - t) * xyz_1 + t * xyz_2 for t in t_steps])
7 changes: 1 addition & 6 deletions src/deep_neurographs/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@
"""

from copy import deepcopy as cp

import os
import networkx as nx
import numpy as np

from deep_neurographs import swc_utils, utils

Expand All @@ -32,7 +28,6 @@ def get_irreducibles(graph):
def extract_irreducible_graph(swc_dict, prune=True, prune_depth=16):
graph = swc_utils.file_to_graph(swc_dict)
leafs, junctions = get_irreducibles(graph)
irreducible_nodes = set(leafs + junctions)
irreducible_edges, leafs = extract_irreducible_edges(
graph, leafs, junctions, swc_dict, prune=prune, prune_depth=prune_depth
)
Expand Down Expand Up @@ -148,7 +143,7 @@ def get_edge_attr(graph, edge, attr):
edge_data = graph.get_edge_data(*edge)
return edge_data[attr]


def is_leaf(graph, i):
nbs = [j for j in graph.neighbors(i)]
return True if len(nbs) == 1 else False

45 changes: 5 additions & 40 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@

import os

import numpy as np
import torch
from torch_geometric.data import Data

from deep_neurographs import neurograph as ng
from deep_neurographs import geometry_utils, s3_utils, swc_utils, utils
from deep_neurographs import s3_utils, swc_utils, utils


# --- Build graph ---
Expand Down Expand Up @@ -59,8 +58,8 @@ def build_neurograph(
prune_depth=prune_depth,
)
neurograph.generate_mutables(
max_degree=max_mutable_degree, max_dist=max_mutable_dist
)
max_degree=max_mutable_degree, max_dist=max_mutable_dist
)
return neurograph


Expand Down Expand Up @@ -132,48 +131,14 @@ def init_data(
x = torch.tensor(node_features, dtype=torch.float)
edge_index = torch.tensor(list(supergraph.edges()), dtype=torch.long)
edge_features = torch.tensor(edge_features, dtype=torch.float)
edge_label_index, mistake_log = get_target_edges(
supergraph,
edge_index.tolist(),
bucket,
file_key,
access_key_id=access_key_id,
secret_access_key=secret_access_key,
)
edge_label_index = None # target labels
data = Data(
x=x,
edge_index=edge_index.t().contiguous(),
edge_label_index=edge_label_index,
edge_attr=edge_features,
)
return data, mistake_log


def get_target_edges(
supergraph,
edges,
bucket,
file_key,
access_key_id=None,
secret_access_key=None,
):
"""
To do...
"""
s3_client = s3_utils.init_session(
access_key_id=access_key_id, secret_access_key=secret_access_key
)
hash_table = read_mistake_log(bucket, file_key, s3_client)
target_edges = torch.zeros((len(edges)))
cnt = 0
for i, e in enumerate(edges):
e1, e2 = get_old_edge(supergraph, e)
if utils.check_key(hash_table, e1) or utils.check_key(hash_table, e2):
target_edges[i] = 1
cnt += 1
print("Number of mistakes:", len(hash_table))
print("Number of hits:", cnt)
return torch.tensor(target_edges), hash_table
return data


def read_mistake_log(bucket, file_key, s3_client):
Expand Down
Loading

0 comments on commit 5b62912

Please sign in to comment.