Skip to content

Commit

Permalink
basic-feature-generation (#21)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Sep 7, 2023
1 parent 7a9602a commit d9a6557
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 32 deletions.
98 changes: 74 additions & 24 deletions src/deep_neurographs/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,57 @@
"""

import numpy as np
from deep_neurographs import utils

NUM_EDGE_FEATURES = 1
NUM_IMG_FEATURES = 0
NUM_SKEL_FEATURES = 4
NUM_PC_FEATURES = 0


# -- Wrappers --
def generate_node_features(supergraph, img=True, pointcloud=True, skel=True):
def generate_node_features(neurograph, img=True, pointcloud=True, skel=True):
features = dict()
if img:
features["img"] = generate_img_features(supergraph)
features["img"] = generate_img_features(neurograph)

if skel:
features["skel"] = generate_skel_features(supergraph)
features["skel"] = generate_skel_features(neurograph)

if pointcloud:
features["pointcloud"] = generate_pointcloud_features(supergraph)
features["pointcloud"] = generate_pointcloud_features(neurograph)
return extract_feature_vec(features)


def generate_immutable_features(neurograph, img=True, pointcloud=True, skel=True):
features = dict()
if img:
features["img"] = generate_img_features(neurograph)

if skel:
features["skel"] = generate_immutable_skel_features(neurograph)

if pointcloud:
features["pointcloud"] = generate_pointcloud_features(neurograph)
return extract_feature_vec(features)


def generate_edge_features(supergraph):
features = np.zeros((supergraph.num_edges(), NUM_EDGE_FEATURES))
for i, edge in enumerate(supergraph.edges()):
features[i] = supergraph.edges[edge]["distance"]
return features
def generate_mutable_features(neurograph, img=True, pointcloud=True, skel=True):
features = dict()
if img:
features["img"] = generate_img_features(neurograph)

if skel:
features["skel"] = generate_mutable_skel_features(neurograph)

if pointcloud:
features["pointcloud"] = generate_pointcloud_features(neurograph)
return extract_feature_vec(features)


# -- Node feature extraction --
def generate_img_features(supergraph):
img_features = np.zeros((supergraph.num_nodes(), NUM_IMG_FEATURES))
for node in supergraph.nodes:
def generate_img_features(neurograph):
img_features = np.zeros((neurograph.num_nodes(), NUM_IMG_FEATURES))
for node in neurograph.nodes:
img_features[node] = _generate_node_img_features()
return img_features

Expand All @@ -50,22 +68,23 @@ def _generate_node_img_features():
pass


def generate_skel_features(supergraph):
skel_features = np.zeros((supergraph.num_nodes(), NUM_SKEL_FEATURES))
for node in supergraph.nodes:
skel_features[node] = _generate_node_skel_features(supergraph, node)
def generate_skel_features(neurograph):
skel_features = np.zeros((neurograph.num_nodes(), NUM_SKEL_FEATURES))
for node in neurograph.nodes:
output = _generate_node_skel_features(neurograph, node)
skel_features[node] = _generate_node_skel_features(neurograph, node)
return skel_features


def _generate_node_skel_features(supergraph, node):
mean_radius = np.mean(supergraph.nodes[node]["radius"])
mean_xyz = np.mean(supergraph.nodes[node]["xyz"], axis=0)
return np.concatenate((mean_radius, mean_xyz), axis=None)
def _generate_node_skel_features(neurograph, node):
radius = neurograph.nodes[node]["radius"]
xyz = neurograph.nodes[node]["xyz"]
return np.append(xyz, radius)


def generate_pointcloud_features(supergraph):
pc_features = np.zeros((supergraph.num_nodes(), NUM_PC_FEATURES))
for node in supergraph.nodes:
def generate_pointcloud_features(neurograph):
pc_features = np.zeros((neurograph.num_nodes(), NUM_PC_FEATURES))
for node in neurograph.nodes:
pc_features[node] = _generate_pointcloud_node_features()
return pc_features

Expand All @@ -75,6 +94,31 @@ def _generate_pointcloud_node_features():


# -- Edge feature extraction --
def generate_immutable_skel_features(neurograph):
features = dict()
for edge in neurograph.immutable_edges:
features[edge] = _generate_immutable_skel_features(neurograph, edge)
return features


def _generate_immutable_skel_features(neurograph, edge):
mean_xyz = np.mean(neurograph.edges[edge]["xyz"], axis=0)
mean_radius = np.mean(neurograph.edges[edge]["radius"], axis=0)
path_length = len(neurograph.edges[edge]["radius"])
return np.concatenate((mean_xyz, mean_radius, path_length), axis=None)


def generate_mutable_skel_features(neurograph):
features = dict()
for edge in neurograph.mutable_edges:
features[edge] = _generate_mutable_skel_features(neurograph, edge)
return features


def _generate_mutable_skel_features(neurograph, edge):
mean_xyz = np.mean(neurograph.edges[edge]["xyz"], axis=0)
edge_length = compute_length(neurograph, edge)
return np.concatenate((mean_xyz, edge_length), axis=None)


# -- Utils --
Expand All @@ -93,3 +137,9 @@ def extract_feature_vec(features,):
else:
feature_vec = np.concatenate((feature_vec, features[key]), axis=1)
return feature_vec


def compute_length(neurograph, edge):
xyz_1 = neurograph.edges[edge]["xyz"][0]
xyz_2 = neurograph.edges[edge]["xyz"][1]
return utils.dist(xyz_1, xyz_2)
4 changes: 3 additions & 1 deletion src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def build_neurograph(
bucket=None,
access_key_id=None,
secret_access_key=None,
generate_mutables=True,
max_mutable_degree=5,
max_mutable_dist=50.0,
prune=True,
Expand Down Expand Up @@ -82,7 +83,7 @@ def init_immutables_from_s3(
access_key_id=access_key_id, secret_access_key=secret_access_key
)
for file_key in s3_utils.listdir(bucket, swc_dir, s3_client, ext=".swc"):
swc_id = file_key.split("/")[-1]
swc_id = file_key.split("/")[-1].replace(".swc", "")
raw_swc = s3_utils.read_from_s3(bucket, file_key, s3_client)
swc_dict = swc_utils.parse(raw_swc, anisotropy=anisotropy)
if smooth:
Expand All @@ -106,6 +107,7 @@ def init_immutables_from_local(
"""
for swc_id in utils.listdir(swc_dir, ext=".swc"):
raw_swc = swc_utils.read_swc(os.path.join(swc_dir, swc_id))
swc_id = swc_id.replace(".0.swc", "")
swc_dict = swc_utils.parse(raw_swc, anisotropy=anisotropy)
#if smooth:
# swc_dict = swc_utils.smooth(swc_dict)
Expand Down
38 changes: 31 additions & 7 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def __init__(self):
super(NeuroGraph, self).__init__()
self.leafs = set()
self.junctions = set()
self.mutable_edges = set()
self.immutable_edges = set()
self.mutable_edges = set()
self.target_edges = set()
self.xyz_to_edge = dict()

# --- Add nodes or edges ---
Expand Down Expand Up @@ -80,7 +81,7 @@ def generate_immutables(
node_id[i] = len(self.nodes)
self.add_node(
node_id[i],
xyz=swc_dict["xyz"][i],
xyz=np.array(swc_dict["xyz"][i]),
radius=swc_dict["radius"][i],
swc_id=swc_id,
)
Expand Down Expand Up @@ -134,10 +135,10 @@ def generate_mutables(self, max_degree=5, max_dist=100.0):
attrs = self.get_edge_data(i, j)

# Get connecting node
if utils.dist(xyz, attrs["xyz"][0]) < 16:
if utils.dist(xyz, attrs["xyz"][0]) < 8:
node = i
xyz = self.nodes[node]["xyz"]
elif utils.dist(xyz, attrs["xyz"][-1]) < 16:
elif utils.dist(xyz, attrs["xyz"][-1]) < 8:
node = j
xyz = self.nodes[node]["xyz"]
if node == leaf:
Expand All @@ -148,9 +149,6 @@ def generate_mutables(self, max_degree=5, max_dist=100.0):

# Add edge
self.add_edge(leaf, node, xyz=np.array([xyz_leaf, xyz]))
if frozenset((leaf, node)) == frozenset({309}):
print((leaf, node))
stop
self.mutable_edges.add(frozenset((leaf, node)))

def _get_mutables(self, query_id, query_xyz, max_degree, max_dist):
Expand Down Expand Up @@ -271,6 +269,32 @@ def _query_kdtree(self, query, max_dist):
idxs = self.kdtree.query_ball_point(query, max_dist)
return self.kdtree.data[idxs]

def init_targets(self, log_path, dist_threshold):
# Initializations
splits_log = utils.read_mistake_log(log_path)
split_edges = set(splits_log.keys())
target_edges = set()

# Parse mutables
for edge in self.mutable_edges:
i, j = tuple(edge)
key = frozenset(self.get_edge_attr("swc_id", i, j))
if key in split_edges:
k, l = list(edge)
mutable_xyz_1 = self.nodes[k]["xyz"]
mutable_xyz_2 = self.nodes[l]["xyz"]
log_xyz_1 = splits_log[key]["xyz"][0]
log_xyz_2 = splits_log[key]["xyz"][1]

pair_1 = [mutable_xyz_1, mutable_xyz_2]
pair_2 = [log_xyz_1, log_xyz_2]
d = utils.pair_dist(pair_1, pair_2)
if d < dist_threshold:
target_edges.add(frozenset((i, j)))
self.target_edges = target_edges
print("% target edges in mistake log:", len(target_edges) / len(split_edges))
print("% target edges in mutable:", len(target_edges) / len(self.mutable_edges))

# --- Visualization ---
def visualize_immutables(self, return_data=False, title="Immutable Graph"):
"""
Expand Down
37 changes: 37 additions & 0 deletions src/deep_neurographs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,29 @@ def read_txt(path):
return f.read()


def read_mistake_log(path):
splits_log = dict()
with open(path, "r") as file:
for line in file:
if not line.startswith("#") and len(line) > 0:
parts = line.split()
xyz_1 = extract_coords(parts[0:3])
xyz_2 = extract_coords(parts[3:6])
swc_1 = parts[6].replace(",", "")
swc_2 = parts[7].replace(",", "")
key = frozenset([swc_1, swc_2])
splits_log[key] = {"swc": [swc_1, swc_2], "xyz": [xyz_1, xyz_2]}
return splits_log


def extract_coords(parts):
coords = []
for p in parts:
p = p.replace("[", "").replace("]", "").replace(",", "")
coords.append(float(p))
return np.array(coords, dtype=int)


def write_json(path, contents):
"""
Writes "contents" to a .json file at "path".
Expand Down Expand Up @@ -196,7 +219,21 @@ def dist(x, y, metric="l2"):
return np.linalg.norm(np.subtract(x, y), ord=1)
else:
return np.linalg.norm(np.subtract(x, y), ord=2)


def pair_dist(pair_1, pair_2, metric="l2"):
pair_1.reverse()
d1 = _pair_dist(pair_1, pair_2)

pair_1.reverse()
d2 = _pair_dist(pair_1, pair_2)
return min(d1, d2)


def _pair_dist(pair_1, pair_2, metric="l2"):
d1 = dist(pair_1[0], pair_2[0], metric=metric)
d2 = dist(pair_1[1], pair_2[1], metric=metric)
return max(d1, d2)


def smooth_branch(xyz, k=3):
Expand Down

0 comments on commit d9a6557

Please sign in to comment.