Skip to content

Commit

Permalink
added feature generation and training
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Oct 10, 2023
1 parent d9a6557 commit 4999172
Show file tree
Hide file tree
Showing 6 changed files with 575 additions and 124 deletions.
195 changes: 134 additions & 61 deletions src/deep_neurographs/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,51 +9,25 @@
"""

import numpy as np
from deep_neurographs import utils
from copy import deepcopy
from deep_neurographs import utils, geometry_utils
from random import sample
from scipy.linalg import svd

NUM_IMG_FEATURES = 0
NUM_SKEL_FEATURES = 4
NUM_SKEL_FEATURES = 9
NUM_PC_FEATURES = 0


# -- Wrappers --
def generate_node_features(neurograph, img=True, pointcloud=True, skel=True):
features = dict()
if img:
features["img"] = generate_img_features(neurograph)

if skel:
features["skel"] = generate_skel_features(neurograph)

if pointcloud:
features["pointcloud"] = generate_pointcloud_features(neurograph)
return extract_feature_vec(features)


def generate_immutable_features(neurograph, img=True, pointcloud=True, skel=True):
features = dict()
if img:
features["img"] = generate_img_features(neurograph)

if skel:
features["skel"] = generate_immutable_skel_features(neurograph)

if pointcloud:
features["pointcloud"] = generate_pointcloud_features(neurograph)
return extract_feature_vec(features)


def generate_mutable_features(neurograph, img=True, pointcloud=True, skel=True):
features = dict()
if img:
features["img"] = generate_img_features(neurograph)

if skel:
features["skel"] = generate_mutable_skel_features(neurograph)

if pointcloud:
features["pointcloud"] = generate_pointcloud_features(neurograph)
return extract_feature_vec(features)
features = combine_feature_vecs(features)
return features


# -- Node feature extraction --
Expand All @@ -71,7 +45,6 @@ def _generate_node_img_features():
def generate_skel_features(neurograph):
skel_features = np.zeros((neurograph.num_nodes(), NUM_SKEL_FEATURES))
for node in neurograph.nodes:
output = _generate_node_skel_features(neurograph, node)
skel_features[node] = _generate_node_skel_features(neurograph, node)
return skel_features

Expand All @@ -94,31 +67,94 @@ def _generate_pointcloud_node_features():


# -- Edge feature extraction --
def generate_immutable_skel_features(neurograph):
def generate_mutable_skel_features(neurograph):
features = dict()
for edge in neurograph.immutable_edges:
features[edge] = _generate_immutable_skel_features(neurograph, edge)
for edge in neurograph.mutable_edges:
length = compute_length(neurograph, edge)
radius_i, radius_j = get_radii(neurograph, edge)

dot1, dot2, dot3 = get_directionals(neurograph, edge, 5)
ddot1, ddot2, ddot3 = get_directionals(neurograph, edge, 10)
features[edge] = np.concatenate((length, dot1, dot2, dot3), axis=None)
return features


def _generate_immutable_skel_features(neurograph, edge):
mean_xyz = np.mean(neurograph.edges[edge]["xyz"], axis=0)
mean_radius = np.mean(neurograph.edges[edge]["radius"], axis=0)
path_length = len(neurograph.edges[edge]["radius"])
return np.concatenate((mean_xyz, mean_radius, path_length), axis=None)
def compute_length(neurograph, edge, metric="l2"):
i, j = tuple(edge)
xyz_1, xyz_2 = neurograph.get_edge_attr("xyz", i, j)
return utils.dist(xyz_1, xyz_2, metric=metric)


def get_directionals(neurograph, edge, window_size):
# Compute tangent vectors
i, j = tuple(edge)
mutable_xyz_i, mutable_xyz_j = neurograph.get_edge_attr("xyz", i, j)
mutable_xyz = np.array([mutable_xyz_i, mutable_xyz_j])
mutable_tangent = geometry_utils.compute_tangent(mutable_xyz)
context_tangent_1 = geometry_utils.compute_context_vec(neurograph, i, mutable_tangent, window_size=window_size)
context_tangent_2 = geometry_utils.compute_context_vec(neurograph, j, mutable_tangent, window_size=window_size)

# Compute features
inner_product_1 = abs(np.dot(mutable_tangent, context_tangent_1))
inner_product_2 = abs(np.dot(mutable_tangent, context_tangent_2))
inner_product_3 = np.dot(context_tangent_1, context_tangent_2)
return inner_product_1, inner_product_2, inner_product_3


def get_radii(neurograph, edge):
i, j = tuple(edge)
radius_i = neurograph.nodes[i]["radius"]
radius_j = neurograph.nodes[j]["radius"]
return radius_i, radius_j


# -- Combine feature vectors
def build_feature_matrix(neurographs, features, blocks):
# Initialize
X = None
block_to_idxs = dict()
idx_to_edge = dict()

# Feature extraction
for block_id in blocks:
# Get features
idx_shift = 0 if X is None else X.shape[0]
X_i, y_i, idx_to_edge_i = build_feature_submatrix(
neurographs[block_id],
features[block_id],
idx_shift,
)

# Concatenate
if X is None:
X = deepcopy(X_i)
y = deepcopy(y_i)
else:
X = np.concatenate((X, X_i), axis=0)
y = np.concatenate((y, y_i), axis=0)

# Update dicts
idxs = set(np.arange(idx_shift, idx_shift + len(idx_to_edge_i)))
block_to_idxs[block_id] = idxs
idx_to_edge.update(idx_to_edge_i)
return X, y, block_to_idxs, idx_to_edge

def generate_mutable_skel_features(neurograph):
features = dict()
for edge in neurograph.mutable_edges:
features[edge] = _generate_mutable_skel_features(neurograph, edge)
return features

def build_feature_submatrix(neurograph, feat_dict, shift):
# Extract info
key = sample(list(feat_dict.keys()), 1)[0]
num_edges = neurograph.num_mutables()
num_features = len(feat_dict[key])

def _generate_mutable_skel_features(neurograph, edge):
mean_xyz = np.mean(neurograph.edges[edge]["xyz"], axis=0)
edge_length = compute_length(neurograph, edge)
return np.concatenate((mean_xyz, edge_length), axis=None)
# Build
idx_to_edge = dict()
X = np.zeros((num_edges, num_features))
y = np.zeros((num_edges))
for i, edge in enumerate(feat_dict.keys()):
idx_to_edge[i + shift] = edge
X[i, :] = feat_dict[edge]
y[i] = 1 if edge in neurograph.target_edges else 0
return X, y, idx_to_edge


# -- Utils --
Expand All @@ -129,17 +165,54 @@ def compute_num_features(features):
return num_features


def extract_feature_vec(features,):
feature_vec = None
def combine_feature_vecs(features):
vec = None
for key in features.keys():
if feature_vec is None:
feature_vec = features[key]
if vec is None:
vec = features[key]
else:
feature_vec = np.concatenate((feature_vec, features[key]), axis=1)
return feature_vec
vec = np.concatenate((vec, features[key]), axis=1)
return vec



"""
def generate_node_features(neurograph, img=True, pointcloud=True, skel=True):
features = dict()
if img:
features["img"] = generate_img_features(neurograph)
if skel:
features["skel"] = generate_skel_features(neurograph)
if pointcloud:
features["pointcloud"] = generate_pointcloud_features(neurograph)
return extract_feature_vec(features)
def compute_length(neurograph, edge):
xyz_1 = neurograph.edges[edge]["xyz"][0]
xyz_2 = neurograph.edges[edge]["xyz"][1]
return utils.dist(xyz_1, xyz_2)
def generate_immutable_features(
neurograph, img=True, pointcloud=True, skel=True
):
features = dict()
if img:
features["img"] = generate_img_features(neurograph)
if skel:
features["skel"] = generate_immutable_skel_features(neurograph)
if pointcloud:
features["pointcloud"] = generate_pointcloud_features(neurograph)
return extract_feature_vec(features)
def generate_immutable_skel_features(neurograph):
features = dict()
for edge in neurograph.immutable_edges:
features[edge] = _generate_immutable_skel_features(neurograph, edge)
return features
def _generate_immutable_skel_features(neurograph, edge):
mean_radius = np.mean(neurograph.edges[edge]["radius"], axis=0)
return np.concatenate((mean_radius), axis=None)
"""
78 changes: 78 additions & 0 deletions src/deep_neurographs/geometry_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
from deep_neurographs import utils
from scipy.linalg import svd


# Context Tangent Vectors
def compute_context_vec(neurograph, i, mutable_tangent, window_size=5, return_pts=False, vec_type="tangent"):
# Compute context vecs
branches = get_branches(neurograph, i)
context_vec_list = []
xyz_list = []
ref_xyz = neurograph.nodes[i]["xyz"]
for branch in branches:
context_vec, xyz = _compute_context_vec(branch, ref_xyz, window_size, vec_type)
context_vec_list.append(context_vec)
xyz_list.append(xyz)

# Determine best
max_dot_prod = 0
arg_max = -1
for k in range(len(context_vec_list)):
dot_prod = abs(np.dot(mutable_tangent, context_vec_list[k]))
if dot_prod >= max_dot_prod:
max_dot_prod = dot_prod
arg_max = k

# Compute normal
if return_pts:
return context_vec_list, branches, xyz_list, arg_max
else:
return context_vec_list[arg_max]


def _compute_context_vec(all_xyz, ref_xyz, window_size, vec_type):
from_start = orient_pts(all_xyz, ref_xyz)
xyz = get_pts(all_xyz, window_size, from_start)
if vec_type == "normal":
vec = compute_normal(xyz)
else:
vec = compute_tangent(xyz)
return vec, np.mean(xyz, axis=0).reshape(1, 3)


def get_branches(neurograph, i):
nbs = []
for j in list(neurograph.neighbors(i)):
if frozenset((i, j)) in neurograph.immutable_edges:
nbs.append(j)
return [neurograph.edges[i, j]["xyz"] for j in nbs]


def orient_pts(xyz, ref_xyz):
return True if all(xyz[0] == ref_xyz) else False


def get_pts(xyz, window_size, from_start):
if len(xyz) > window_size and from_start:
return xyz[0:window_size]
elif len(xyz) > window_size and not from_start:
return xyz[-window_size:]
else:
return xyz


def compute_svd(xyz):
xyz = xyz - np.mean(xyz, axis=0)
return svd(xyz)


def compute_tangent(xyz):
if xyz.shape[0] == 2:
tangent = (xyz[1] - xyz[0]) / utils.dist(xyz[1], xyz[0])
else:
U, S, VT = compute_svd(xyz)
tangent = VT[0]
return tangent / np.linalg.norm(tangent)


6 changes: 3 additions & 3 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def init_immutables_from_local(
anisotropy=[1.0, 1.0, 1.0],
prune=True,
prune_depth=16,
smooth=True,
smooth=False,
):
"""
To do...
Expand All @@ -109,8 +109,8 @@ def init_immutables_from_local(
raw_swc = swc_utils.read_swc(os.path.join(swc_dir, swc_id))
swc_id = swc_id.replace(".0.swc", "")
swc_dict = swc_utils.parse(raw_swc, anisotropy=anisotropy)
#if smooth:
# swc_dict = swc_utils.smooth(swc_dict)
if smooth:
swc_dict = swc_utils.smooth(swc_dict)
neurograph.generate_immutables(
swc_id, swc_dict, prune=prune, prune_depth=prune_depth
)
Expand Down
Loading

0 comments on commit 4999172

Please sign in to comment.