From 3e19b701141738f5ceedab8826767b4b3bbe7351 Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Mon, 22 Jan 2024 13:47:34 -0800 Subject: [PATCH] Add features (#40) * feat: added new features generated from skeleton * minor upds --------- Co-authored-by: anna-grim --- src/deep_neurographs/feature_extraction.py | 19 +++++- src/deep_neurographs/graph_utils.py | 74 ++++++++++++++++++---- 2 files changed, 78 insertions(+), 15 deletions(-) diff --git a/src/deep_neurographs/feature_extraction.py b/src/deep_neurographs/feature_extraction.py index 53e0800..c7bf420 100644 --- a/src/deep_neurographs/feature_extraction.py +++ b/src/deep_neurographs/feature_extraction.py @@ -18,7 +18,7 @@ CHUNK_SIZE = [64, 64, 64] WINDOW = [5, 5, 5] N_PROFILE_POINTS = 10 -N_SKEL_FEATURES = 11 +N_SKEL_FEATURES = 18 SUPPORTED_MODELS = [ "AdaBoost", "RandomForest", @@ -183,8 +183,7 @@ def generate_mutable_skel_features(neurograph): neurograph.immutable_degree(j), get_radii(neurograph, edge), get_avg_radii(neurograph, edge), - #get_avg_branch_len(neurograph, edge) - get_directionals(neurograph, edge, 4), + get_avg_branch_lens(neurograph, edge), get_directionals(neurograph, edge, 8), get_directionals(neurograph, edge, 16), get_directionals(neurograph, edge, 32), @@ -228,6 +227,20 @@ def get_avg_radius(radii_list): avg += np.mean(radii[0:end]) / len(radii_list) return avg + +def get_avg_branch_lens(neurograph, edge): + i, j = tuple(edge) + branches_i = neurograph.get_branches(i, key="xyz") + branches_j = neurograph.get_branches(j, key="xyz") + return np.array([get_branch_len(branches_i), get_branch_len(branches_j)]) + + +def get_branch_len(branch_list): + branch_len = 0 + for branch in branch_list: + branch_len += len(branch) / len(branch_list) + return branch_len + def get_radii(neurograph, edge): i, j = tuple(edge) diff --git a/src/deep_neurographs/graph_utils.py b/src/deep_neurographs/graph_utils.py index 42c65c3..b9a9fed 100644 --- a/src/deep_neurographs/graph_utils.py +++ b/src/deep_neurographs/graph_utils.py @@ -210,7 +210,7 @@ def __smooth_branch(swc_dict, attrs, edges, nbs, root, j): swc_dict : dict Contents of an swc file. attrs : dict - Attributes (from swc file) of edge being smoothed. + Attributes (from "swc_dict") of edge being smoothed. edges : dict Dictionary where the keys are edges in irreducible graph and values are the corresponding attributes. @@ -238,7 +238,7 @@ def upd_xyz(swc_dict, attrs, edges, nbs, i, endpoint): swc_dict : dict Contents of an swc file. attrs : dict - Attributes (from swc file) of edge being smoothed. + Attributes (from "swc_dict") of edge being smoothed. edges : dict Dictionary where the keys are edges in irreducible graph and values are the corresponding attributes. @@ -297,35 +297,85 @@ def upd_endpoint_xyz(edges, key, old_xyz, new_xyz): # -- attribute utils -- def init_edge_attrs(swc_dict, i): + """ + Initializes edge attribute dictionary with attributes from node "i" which + is an end point of the edge. + + Parameters + ---------- + swc_dict : dict + Contents of an swc file. + i : int + End point of edge and the swc attributes of this node are used to + initialize the edge attriubte dictionary. + + Returns + ------- + dict + Edge attribute dictionary. + + """ return {"radius": [swc_dict["radius"][i]], "xyz": [swc_dict["xyz"][i]]} def upd_edge_attrs(swc_dict, attrs, i): + """ + Updates an edge attribute dictionary with attributes of node i. + + Parameters + ---------- + swc_dict : dict + Contents of an swc file. + attrs : dict + Attributes (from "swc_dict") of edge being updated. + i : int + Node of edge whose attributes will be added to "attrs". + + Returns + ------- + attrs : dict + Edge attribute dictionary. + + """ attrs["radius"].append(swc_dict["radius"][i]) attrs["xyz"].append(swc_dict["xyz"][i]) return attrs def get_edge_attr(graph, edge, attr): - edge_data = graph.get_edge_data(*edge) - return edge_data[attr] + """ + Gets the attribute "attr" of "edge". + + Parameters + ---------- + graph : networkx.Graph + Graph which "edge" belongs to. + edge : tuple + Edge to be queried for its attributes. + attr : str + Attribute to be queried. + Returns + ------- + Attribute "attr" of "edge" + + """ + return graph.edges[edge][attr] + def set_edge_attrs(attrs): attrs["xyz"] = np.array(attrs["xyz"], dtype=np.float32) attrs["radius"] = np.array(attrs["radius"], dtype=np.float16) return attrs - -def init_node_attrs(swc_dict, i): - return {"radius": swc_dict["radius"][i], "xyz": swc_dict["xyz"][i]} - - + def set_node_attrs(swc_dict, nodes): - node_attrs = dict() + attrs = dict() for i in nodes: - node_attrs[i] = init_node_attrs(swc_dict, i) - return node_attrs + attrs[i] = { + "radius": swc_dict["radius"][i], "xyz": swc_dict["xyz"][i] + } + return attrs def upd_node_attrs(swc_dict, leafs, junctions, i):