Skip to content

Commit

Permalink
Add features (#40)
Browse files Browse the repository at this point in the history
* feat: added new features generated from skeleton

* minor upds

---------

Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Jan 22, 2024
1 parent ac5a621 commit 3e19b70
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 15 deletions.
19 changes: 16 additions & 3 deletions src/deep_neurographs/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
CHUNK_SIZE = [64, 64, 64]
WINDOW = [5, 5, 5]
N_PROFILE_POINTS = 10
N_SKEL_FEATURES = 11
N_SKEL_FEATURES = 18
SUPPORTED_MODELS = [
"AdaBoost",
"RandomForest",
Expand Down Expand Up @@ -183,8 +183,7 @@ def generate_mutable_skel_features(neurograph):
neurograph.immutable_degree(j),
get_radii(neurograph, edge),
get_avg_radii(neurograph, edge),
#get_avg_branch_len(neurograph, edge)
get_directionals(neurograph, edge, 4),
get_avg_branch_lens(neurograph, edge),
get_directionals(neurograph, edge, 8),
get_directionals(neurograph, edge, 16),
get_directionals(neurograph, edge, 32),
Expand Down Expand Up @@ -228,6 +227,20 @@ def get_avg_radius(radii_list):
avg += np.mean(radii[0:end]) / len(radii_list)
return avg


def get_avg_branch_lens(neurograph, edge):
i, j = tuple(edge)
branches_i = neurograph.get_branches(i, key="xyz")
branches_j = neurograph.get_branches(j, key="xyz")
return np.array([get_branch_len(branches_i), get_branch_len(branches_j)])


def get_branch_len(branch_list):
branch_len = 0
for branch in branch_list:
branch_len += len(branch) / len(branch_list)
return branch_len


def get_radii(neurograph, edge):
i, j = tuple(edge)
Expand Down
74 changes: 62 additions & 12 deletions src/deep_neurographs/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def __smooth_branch(swc_dict, attrs, edges, nbs, root, j):
swc_dict : dict
Contents of an swc file.
attrs : dict
Attributes (from swc file) of edge being smoothed.
Attributes (from "swc_dict") of edge being smoothed.
edges : dict
Dictionary where the keys are edges in irreducible graph and values
are the corresponding attributes.
Expand Down Expand Up @@ -238,7 +238,7 @@ def upd_xyz(swc_dict, attrs, edges, nbs, i, endpoint):
swc_dict : dict
Contents of an swc file.
attrs : dict
Attributes (from swc file) of edge being smoothed.
Attributes (from "swc_dict") of edge being smoothed.
edges : dict
Dictionary where the keys are edges in irreducible graph and values
are the corresponding attributes.
Expand Down Expand Up @@ -297,35 +297,85 @@ def upd_endpoint_xyz(edges, key, old_xyz, new_xyz):

# -- attribute utils --
def init_edge_attrs(swc_dict, i):
"""
Initializes edge attribute dictionary with attributes from node "i" which
is an end point of the edge.
Parameters
----------
swc_dict : dict
Contents of an swc file.
i : int
End point of edge and the swc attributes of this node are used to
initialize the edge attriubte dictionary.
Returns
-------
dict
Edge attribute dictionary.
"""
return {"radius": [swc_dict["radius"][i]], "xyz": [swc_dict["xyz"][i]]}


def upd_edge_attrs(swc_dict, attrs, i):
"""
Updates an edge attribute dictionary with attributes of node i.
Parameters
----------
swc_dict : dict
Contents of an swc file.
attrs : dict
Attributes (from "swc_dict") of edge being updated.
i : int
Node of edge whose attributes will be added to "attrs".
Returns
-------
attrs : dict
Edge attribute dictionary.
"""
attrs["radius"].append(swc_dict["radius"][i])
attrs["xyz"].append(swc_dict["xyz"][i])
return attrs


def get_edge_attr(graph, edge, attr):
edge_data = graph.get_edge_data(*edge)
return edge_data[attr]
"""
Gets the attribute "attr" of "edge".
Parameters
----------
graph : networkx.Graph
Graph which "edge" belongs to.
edge : tuple
Edge to be queried for its attributes.
attr : str
Attribute to be queried.
Returns
-------
Attribute "attr" of "edge"
"""
return graph.edges[edge][attr]


def set_edge_attrs(attrs):
attrs["xyz"] = np.array(attrs["xyz"], dtype=np.float32)
attrs["radius"] = np.array(attrs["radius"], dtype=np.float16)
return attrs


def init_node_attrs(swc_dict, i):
return {"radius": swc_dict["radius"][i], "xyz": swc_dict["xyz"][i]}



def set_node_attrs(swc_dict, nodes):
node_attrs = dict()
attrs = dict()
for i in nodes:
node_attrs[i] = init_node_attrs(swc_dict, i)
return node_attrs
attrs[i] = {
"radius": swc_dict["radius"][i], "xyz": swc_dict["xyz"][i]
}
return attrs


def upd_node_attrs(swc_dict, leafs, junctions, i):
Expand Down

0 comments on commit 3e19b70

Please sign in to comment.