From 5b6291288d3f2ff53919e848bc29d97f2607c2a2 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sun, 22 Oct 2023 00:24:18 +0000 Subject: [PATCH] training updates --- src/deep_neurographs/densegraph.py | 17 +-- src/deep_neurographs/feature_extraction.py | 116 ++++++++++++--------- src/deep_neurographs/geometry_utils.py | 32 +++++- src/deep_neurographs/graph_utils.py | 7 +- src/deep_neurographs/intake.py | 45 +------- src/deep_neurographs/neurograph.py | 86 +++++++-------- src/deep_neurographs/swc_utils.py | 34 +++--- src/deep_neurographs/utils.py | 71 +++---------- tests/test_build_neurograph.py | 11 +- tests/test_example.py | 16 --- tests/test_train_pipeline.py | 109 ------------------- 11 files changed, 184 insertions(+), 360 deletions(-) delete mode 100644 tests/test_example.py delete mode 100644 tests/test_train_pipeline.py diff --git a/src/deep_neurographs/densegraph.py b/src/deep_neurographs/densegraph.py index a80592e..09068e4 100644 --- a/src/deep_neurographs/densegraph.py +++ b/src/deep_neurographs/densegraph.py @@ -1,14 +1,15 @@ import os + import networkx as nx import numpy as np -from deep_neurographs import swc_utils, utils -from deep_neurographs.geometry_utils import dist, make_line from more_itertools import zip_broadcast from scipy.spatial import KDTree +from deep_neurographs import swc_utils, utils +from deep_neurographs.geometry_utils import dist, make_line + class DenseGraph: - def __init__(self, swc_dir): self.xyz_to_node = dict() self.xyz_to_swc = dict() @@ -24,7 +25,7 @@ def init_graphs(self, swc_dir): # Construct Graph swc_dict = swc_utils.parse(swc_utils.read_swc(path)) graph, xyz_to_node = swc_utils.file_to_graph( - swc_dict, set_attrs=True, return_dict=True, + swc_dict, set_attrs=True, return_dict=True ) # Store @@ -41,7 +42,7 @@ def get_projection(self, xyz): proj_xyz = tuple(self.kdtree.data[idx]) proj_dist = dist(proj_xyz, xyz) return proj_xyz, proj_dist - + def connect_nodes(self, graph_id, xyz_i, xyz_j, return_dist=True): i = self.xyz_to_node[graph_id][xyz_i] j = self.xyz_to_node[graph_id][xyz_j] @@ -56,7 +57,7 @@ def compute_dist(self, graph_id, path): d = 0 for i in range(1, len(path)): xyz_1 = self.graphs[graph_id].nodes[i]["xyz"] - xyz_2 = self.graphs[graph_id].nodes[i-1]["xyz"] + xyz_2 = self.graphs[graph_id].nodes[i - 1]["xyz"] d += dist(xyz_1, xyz_2) return d @@ -77,7 +78,7 @@ def check_aligned(self, pred_xyz_i, pred_xyz_j): target_dist = max(target_dist, 1) ratio = min(pred_dist, target_dist) / max(pred_dist, target_dist) - if ratio < 0.7 and pred_dist > 25: + if ratio < 0.6 and pred_dist > 25: return False elif ratio < 0.25: return False @@ -93,7 +94,7 @@ def check_aligned(self, pred_xyz_i, pred_xyz_j): intersection = proj_nodes.intersection(set(target_path)) overlap = len(intersection) / len(target_path) - if overlap < 0.4 and pred_dist > 25: + if overlap < 0.5 and pred_dist > 25: return False elif overlap < 0.2: return False diff --git a/src/deep_neurographs/feature_extraction.py b/src/deep_neurographs/feature_extraction.py index 7d781e9..7572326 100644 --- a/src/deep_neurographs/feature_extraction.py +++ b/src/deep_neurographs/feature_extraction.py @@ -12,74 +12,87 @@ from random import sample import numpy as np -from scipy.linalg import svd from deep_neurographs import geometry_utils, utils -NUM_IMG_FEATURES = 0 -NUM_SKEL_FEATURES = 9 -NUM_PC_FEATURES = 0 +NUM_POINTS = 10 +WINDOW_SIZE = [6, 6, 6] + +NUM_IMG_FEATURES = NUM_POINTS +NUM_SKEL_FEATURES = 11 # -- Wrappers -- def generate_mutable_features( - neurograph, img=True, pointcloud=True, skel=True + neurograph, anisotropy=[1.0, 1.0, 1.0], img_path=None ): - features = dict() - if img: - features["img"] = generate_img_features(neurograph) - if skel: - features["skel"] = generate_mutable_skel_features(neurograph) - features = combine_feature_vecs(features) - return features - - -# -- Node feature extraction -- -def generate_img_features(neurograph): - img_features = np.zeros((neurograph.num_nodes(), NUM_IMG_FEATURES)) - for node in neurograph.nodes: - img_features[node] = _generate_node_img_features() - return img_features - - -def _generate_node_img_features(): - pass - - -def generate_skel_features(neurograph): - skel_features = np.zeros((neurograph.num_nodes(), NUM_SKEL_FEATURES)) - for node in neurograph.nodes: - skel_features[node] = _generate_node_skel_features(neurograph, node) - return skel_features - - -def _generate_node_skel_features(neurograph, node): - radius = neurograph.nodes[node]["radius"] - xyz = neurograph.nodes[node]["xyz"] - return np.append(xyz, radius) - - -def generate_pointcloud_features(neurograph): - pc_features = np.zeros((neurograph.num_nodes(), NUM_PC_FEATURES)) - for node in neurograph.nodes: - pc_features[node] = _generate_pointcloud_node_features() - return pc_features + """ + Generates feature vectors for every mutable edge in a neurograph. + + Parameters + ---------- + neurograph : NeuroGraph + NeuroGraph generated from a directory of swcs generated from a + predicted segmentation. + anisotropy : list[float] + Real-world to image coordinates scaling factor for (x, y, z). + img_path : str, optional + Path to image volume. + + Returns + ------- + Dictionary where each key-value pair corresponds to a type of feature + vector and the numerical vector. + + """ + features = {"skel": generate_mutable_skel_features(neurograph)} + if img_path is not None: + features["img"] = generate_mutable_img_features( + neurograph, img_path, anisotropy=anisotropy + ) + return combine_feature_vecs(features) -def _generate_pointcloud_node_features(): - pass +# -- Edge feature extraction -- +def generate_mutable_img_features( + neurograph, path, anisotropy=[1.0, 1.0, 1.0] +): + img = utils.open_zarr(path) + features = dict() + for edge in neurograph.mutable_edges: + xyz = neurograph.edges[edge]["xyz"] + line = geometry_utils.make_line(xyz[0], xyz[1], NUM_POINTS) + features[edge] = geometry_utils.get_profile( + img, line, anisotropy=anisotropy, window_size=WINDOW_SIZE + ) + return features -# -- Edge feature extraction -- def generate_mutable_skel_features(neurograph): features = dict() for edge in neurograph.mutable_edges: + i, j = tuple(edge) + deg_i = len(list(neurograph.neighbors(i))) + deg_j = len(list(neurograph.neighbors(j))) length = compute_length(neurograph, edge) radius_i, radius_j = get_radii(neurograph, edge) dot1, dot2, dot3 = get_directionals(neurograph, edge, 5) ddot1, ddot2, ddot3 = get_directionals(neurograph, edge, 10) features[edge] = np.concatenate( - (length, radius_i, radius_j, dot1, dot2, dot3, ddot1, ddot2, ddot3), axis=None + ( + length, + deg_i, + deg_j, + radius_i, + radius_j, + dot1, + dot2, + dot3, + ddot1, + ddot2, + ddot3, + ), + axis=None, ) return features @@ -221,4 +234,11 @@ def generate_immutable_skel_features(neurograph): def _generate_immutable_skel_features(neurograph, edge): mean_radius = np.mean(neurograph.edges[edge]["radius"], axis=0) return np.concatenate((mean_radius), axis=None) + +def generate_skel_features(neurograph): + skel_features = np.zeros((neurograph.num_nodes(), NUM_SKEL_FEATURES)) + for node in neurograph.nodes: + skel_features[node] = _generate_node_skel_features(neurograph, node) + return skel_features + """ diff --git a/src/deep_neurographs/geometry_utils.py b/src/deep_neurographs/geometry_utils.py index a47438d..11b2c4a 100644 --- a/src/deep_neurographs/geometry_utils.py +++ b/src/deep_neurographs/geometry_utils.py @@ -1,6 +1,7 @@ import numpy as np -from scipy.interpolate import CubicSpline, UnivariateSpline +from scipy.interpolate import UnivariateSpline from scipy.linalg import svd + from deep_neurographs import utils @@ -87,6 +88,12 @@ def compute_tangent(xyz): return tangent / np.linalg.norm(tangent) +def compute_normal(xyz): + U, S, VT = compute_svd(xyz) + normal = VT[-1] + return normal / np.linalg.norm(normal) + + # Smoothing def smooth_branch(xyz): if xyz.shape[0] > 5: @@ -117,12 +124,30 @@ def smooth_end(branch_xyz, radii, ref_xyz, num_pts=8): return branch_xyz, radii, None +# Image feature extraction +def get_profile( + img, xyz_arr, anisotropy=[1.0, 1.0, 1.0], window_size=[4, 4, 4] +): + xyz_arr = get_coords(xyz_arr, anisotropy=anisotropy) + profile = [] + for xyz in xyz_arr: + img_chunk = utils.read_img_chunk(img, xyz, window_size) + profile.append(np.max(img_chunk)) + return profile + + +def get_coords(xyz_arr, anisotropy=[1.0, 1.0, 1.0]): + for i in range(3): + xyz_arr[:, i] = xyz_arr[:, i] / anisotropy[i] + return xyz_arr.astype(int) + + # Miscellaneous def compare_edges(xyx_i, xyz_j, xyz_k): dist_ij = dist(xyx_i, xyz_j) dist_ik = dist(xyx_i, xyz_k) return dist_ij < dist_ik - + def dist(x, y, metric="l2"): """ @@ -141,6 +166,7 @@ def dist(x, y, metric="l2"): else: return np.linalg.norm(np.subtract(x, y), ord=2) + def make_line(xyz_1, xyz_2, num_steps): t_steps = np.linspace(0, 1, num_steps) - return [(1 - t) * xyz_1 + t * xyz_2 for t in t_steps] + return np.array([(1 - t) * xyz_1 + t * xyz_2 for t in t_steps]) diff --git a/src/deep_neurographs/graph_utils.py b/src/deep_neurographs/graph_utils.py index 44debb9..1f9a236 100644 --- a/src/deep_neurographs/graph_utils.py +++ b/src/deep_neurographs/graph_utils.py @@ -9,11 +9,7 @@ """ -from copy import deepcopy as cp - -import os import networkx as nx -import numpy as np from deep_neurographs import swc_utils, utils @@ -32,7 +28,6 @@ def get_irreducibles(graph): def extract_irreducible_graph(swc_dict, prune=True, prune_depth=16): graph = swc_utils.file_to_graph(swc_dict) leafs, junctions = get_irreducibles(graph) - irreducible_nodes = set(leafs + junctions) irreducible_edges, leafs = extract_irreducible_edges( graph, leafs, junctions, swc_dict, prune=prune, prune_depth=prune_depth ) @@ -148,7 +143,7 @@ def get_edge_attr(graph, edge, attr): edge_data = graph.get_edge_data(*edge) return edge_data[attr] + def is_leaf(graph, i): nbs = [j for j in graph.neighbors(i)] return True if len(nbs) == 1 else False - \ No newline at end of file diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index 9b3edbc..f1381b4 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -10,12 +10,11 @@ import os -import numpy as np import torch from torch_geometric.data import Data from deep_neurographs import neurograph as ng -from deep_neurographs import geometry_utils, s3_utils, swc_utils, utils +from deep_neurographs import s3_utils, swc_utils, utils # --- Build graph --- @@ -59,8 +58,8 @@ def build_neurograph( prune_depth=prune_depth, ) neurograph.generate_mutables( - max_degree=max_mutable_degree, max_dist=max_mutable_dist - ) + max_degree=max_mutable_degree, max_dist=max_mutable_dist + ) return neurograph @@ -132,48 +131,14 @@ def init_data( x = torch.tensor(node_features, dtype=torch.float) edge_index = torch.tensor(list(supergraph.edges()), dtype=torch.long) edge_features = torch.tensor(edge_features, dtype=torch.float) - edge_label_index, mistake_log = get_target_edges( - supergraph, - edge_index.tolist(), - bucket, - file_key, - access_key_id=access_key_id, - secret_access_key=secret_access_key, - ) + edge_label_index = None # target labels data = Data( x=x, edge_index=edge_index.t().contiguous(), edge_label_index=edge_label_index, edge_attr=edge_features, ) - return data, mistake_log - - -def get_target_edges( - supergraph, - edges, - bucket, - file_key, - access_key_id=None, - secret_access_key=None, -): - """ - To do... - """ - s3_client = s3_utils.init_session( - access_key_id=access_key_id, secret_access_key=secret_access_key - ) - hash_table = read_mistake_log(bucket, file_key, s3_client) - target_edges = torch.zeros((len(edges))) - cnt = 0 - for i, e in enumerate(edges): - e1, e2 = get_old_edge(supergraph, e) - if utils.check_key(hash_table, e1) or utils.check_key(hash_table, e2): - target_edges[i] = 1 - cnt += 1 - print("Number of mistakes:", len(hash_table)) - print("Number of hits:", cnt) - return torch.tensor(target_edges), hash_table + return data def read_mistake_log(bucket, file_key, s3_client): diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index 412c8f3..4b732d3 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -8,24 +8,16 @@ """ -import copy -import matplotlib.colors as mcolors -import matplotlib.pyplot as plt import networkx as nx import numpy as np import plotly.graph_objects as go -import plotly.tools as tls import tensorstore as ts -from more_itertools import zip_broadcast -from plotly.subplots import make_subplots from scipy.spatial import KDTree -from tifffile import imwrite +from deep_neurographs import geometry_utils from deep_neurographs import graph_utils as gutils -from deep_neurographs import geometry_utils, swc_utils, utils +from deep_neurographs import utils -COLORS = list(mcolors.TABLEAU_COLORS.keys()) -nCOLORS = len(COLORS) SUPPORTED_LABEL_MASK_TYPES = [dict, np.array, ts.TensorStore] @@ -110,15 +102,9 @@ def generate_immutables( # Add edge self.immutable_edges.add(frozenset(edge)) self.add_edge( - node_id[i], - node_id[j], - xyz=xyz, - radius=radii, - swc_id=swc_id, - ) - xyz_to_edge = dict( - (tuple(xyz), edge) for xyz in xyz + node_id[i], node_id[j], xyz=xyz, radius=radii, swc_id=swc_id ) + xyz_to_edge = dict((tuple(xyz), edge) for xyz in xyz) check_xyz = set(xyz_to_edge.keys()) collisions = check_xyz.intersection(set(self.xyz_to_edge.keys())) if len(collisions) > 0: @@ -132,13 +118,13 @@ def generate_immutables( for j in junctions: self.junctions.add(node_id[j]) - + def init_immutable_graph(self): immutable_graph = nx.Graph() immutable_graph.add_nodes_from(self) immutable_graph.add_edges_from(self.immutable_edges) return immutable_graph - + def generate_mutables(self, max_degree=5, max_dist=50.0): """ Generates edges for the graph. @@ -147,7 +133,7 @@ def generate_mutables(self, max_degree=5, max_dist=50.0): ------- None - """ + """ # Search for mutable connections self.mutable_edges = set() self._init_kdtree() @@ -178,9 +164,6 @@ def generate_mutables(self, max_degree=5, max_dist=50.0): self.add_edge(leaf, node, xyz=np.array([xyz_leaf, xyz])) self.mutable_edges.add(frozenset((leaf, node))) - if len(self.mutable_edges) > 0: - print("# proposed edges:", len(self.mutable_edges)) - def _get_mutables(self, query_id, query_xyz, max_degree, max_dist): """ Parameters @@ -299,16 +282,21 @@ def _query_kdtree(self, query, dist): """ idxs = self.kdtree.query_ball_point(query, dist, return_sorted=True) return self.kdtree.data[idxs] - + def init_targets(self, target_neurograph, target_densegraph): + self.target_edges = set() self.groundtruth_graph = self.init_immutable_graph() + predicted_graph = self.init_immutable_graph() complex_mutables = [] site_to_site = dict() pair_to_edge = dict() - for edge in self.mutable_edges: + + mutable_edges = list(self.mutable_edges) + dists = [self.compute_length(edge) for edge in mutable_edges] + for idx in np.argsort(dists): # Get projection - i, j = tuple(edge) + i, j = tuple(mutable_edges[idx]) xyz_i = self.nodes[i]["xyz"] xyz_j = self.nodes[j]["xyz"] proj_xyz_i, d_i = target_neurograph.get_projection(xyz_i) @@ -319,12 +307,11 @@ def init_targets(self, target_neurograph, target_densegraph): # Get corresponding edges on target edge_i = target_neurograph.xyz_to_edge[proj_xyz_i] edge_j = target_neurograph.xyz_to_edge[proj_xyz_j] - - + # Check whether complex - if edge_i != edge_j: + if edge_i != edge_j: if target_neurograph.is_adjacent(edge_i, edge_j): - complex_mutables.append(edge) + complex_mutables.append(mutable_edges[idx]) else: # Simple criteria inclusion_i = proj_xyz_i in site_to_site.keys() @@ -339,7 +326,7 @@ def init_targets(self, target_neurograph, target_densegraph): pair_to_edge, proj_xyz_i, proj_xyz_j, - edge, + mutable_edges[idx], ) else: # Get projected points @@ -350,23 +337,22 @@ def init_targets(self, target_neurograph, target_densegraph): proj_xyz_k = site_to_site[proj_xyz_i] # Compare edge - if geometry_utils.compare_edges(proj_xyz_i, proj_xyz_j, proj_xyz_k): + if geometry_utils.compare_edges( + proj_xyz_i, proj_xyz_j, proj_xyz_k + ): site_to_site, pair_to_edge = self.remove_site( - site_to_site, - pair_to_edge, - proj_xyz_i, - proj_xyz_k, + site_to_site, pair_to_edge, proj_xyz_i, proj_xyz_k ) site_to_site, pair_to_edge = self.add_site( site_to_site, pair_to_edge, proj_xyz_i, proj_xyz_j, - edge, + mutable_edges[idx], ) # Filter - print("# complex proposed edges", len(complex_mutables)) + # print("# complex proposed edges", len(complex_mutables)) filtered_proposals = [] for edge in complex_mutables: # Check whether edge creates a cycle @@ -385,23 +371,23 @@ def init_targets(self, target_neurograph, target_densegraph): filtered_proposals.append(edge) # Parse filtered proposals - print("# filtered proposed edges", len(filtered_proposals)) + # print("# filtered proposed edges", len(filtered_proposals)) dists = [self.compute_length(edge) for edge in filtered_proposals] for idx in np.argsort(dists): edge = filtered_proposals[idx] if self.check_cycle(tuple(edge)): continue - + add_bool = True if add_bool: site_to_site, pair_to_edge = self.add_site( - site_to_site, pair_to_edge, proj_xyz_i, proj_xyz_j, edge - ) + site_to_site, pair_to_edge, proj_xyz_i, proj_xyz_j, edge + ) # Print results - target_ratio = len(self.target_edges) / len(self.mutable_edges) - print("# target edges:", len(self.target_edges)) - print("% target edges in mutable:", target_ratio) + # target_ratio = len(self.target_edges) / len(self.mutable_edges) + # print("# target edges:", len(self.target_edges)) + # print("% target edges in mutable:", target_ratio) def add_site(self, site_to_site, pair_to_edge, xyz_i, xyz_j, edge): # Check whether cycle is created @@ -609,7 +595,7 @@ def is_adjacent(self, edge_i, edge_j): return True else: return False - + def is_nb(self, i, j): return True if i in self.neighbors(j) else False @@ -621,10 +607,10 @@ def is_contained(self, xyz): if lower_bool or upper_bool: return False return True - + def is_leaf(self, i): return True if len(self.neighbors(i)) == 1 else False - + def check_cycle(self, edge): self.groundtruth_graph.add_edges_from([edge]) try: @@ -657,7 +643,7 @@ def to_line_graph(self): graph.add_nodes_from(self.nodes) graph.add_edges_from(self.edges) return nx.line_graph(graph) - + # Check whether to trim end of branch """ diff --git a/src/deep_neurographs/swc_utils.py b/src/deep_neurographs/swc_utils.py index 04a8dca..59be19c 100644 --- a/src/deep_neurographs/swc_utils.py +++ b/src/deep_neurographs/swc_utils.py @@ -10,14 +10,15 @@ """ import os -import random from copy import deepcopy as cp import networkx as nx import numpy as np from more_itertools import zip_broadcast -from deep_neurographs import geometry_utils, graph_utils as gutils, utils +from deep_neurographs import geometry_utils +from deep_neurographs import graph_utils as gutils +from deep_neurographs import utils # -- io utils -- @@ -100,14 +101,14 @@ def read_xyz(xyz, anisotropy=[1.0, 1.0, 1.0], offset=[0, 0, 0]): def write_swc(path, contents): - if type(content) is list: + if type(contents) is list: write_list(path, contents) - elif type(content) is dict: + elif type(contents) is dict: write_dict(path, contents) - elif type(content) is nx.Graph: + elif type(contents) is nx.Graph: write_graph(path, contents) else: - assert True, "Unable to write {} to swc".format(type(content)) + assert True, "Unable to write {} to swc".format(type(contents)) def write_list(path, entry_list, color=None): @@ -116,9 +117,9 @@ def write_list(path, entry_list, color=None): Parameters ---------- - path_to_swc : str + path : str Path that swc will be written to. - list_of_entries : list[list[int]] + entry_list : list[list[int]] List of entries that will be written to an swc file. color : str, optional Color of nodes. The default is None. @@ -166,10 +167,10 @@ def write_graph(path, graph): Parameters ---------- + path : str + Path that swc will be written to. graph : networkx.Graph - Graph that edges in "edge_list" belong to. - edge_list : list[tuple[int]] - List of edges to be written to an swc file. + Graph to be written to swc file. Returns ------- @@ -180,12 +181,11 @@ def write_graph(path, graph): # loop through connected components reindex = dict() - edges = graph.edges if edge_list is None else edge_list - for i, j in edges: + for i, j in graph.edges: if len(reindex) < 1: - entry, reindex = make_entry(graph, i, -1, reindex, anisotropy) + entry, reindex = make_entry(graph, i, -1, reindex) entry_list = [entry] - entry, reindex = make_entry(graph, j, reindex[i], reindex, anisotropy) + entry, reindex = make_entry(graph, j, reindex[i], reindex) entry_list.append(entry) return entry_list @@ -202,6 +202,10 @@ def make_entry(graph, i, parent, r, reindex): Node that entry corresponds to. parent : int Parent of node "i". + r : float + Radius. + reindex : dict + Converts 'graph node id' to 'swc node id'. """ reindex[i] = len(reindex) + 1 diff --git a/src/deep_neurographs/utils.py b/src/deep_neurographs/utils.py index 0b082fc..d2269e8 100644 --- a/src/deep_neurographs/utils.py +++ b/src/deep_neurographs/utils.py @@ -21,10 +21,8 @@ # --- dictionary utils --- def remove_item(my_set, item): - try: + if item in my_set: my_set.remove(item) - except: - pass return my_set @@ -44,9 +42,9 @@ def check_key(my_dict, key): dict value or bool """ - try: + if key in my_dict.keys(): return my_dict[key] - except: + else: return False @@ -96,21 +94,20 @@ def list_subdirs(path, keyword=None): # --- io utils --- -def read_n5(path): - """ - Reads n5 file at "path". +def open_zarr(path): + n5store = zarr.N5FSStore(path, "r") + if "653980" in path: + return zarr.open(n5store).ch488.s0 + elif "653158" in path: + return zarr.open(n5store).s0 - Parameters - ---------- - path : str - Path to n5. - Returns - ------- - np.array - Image volume. - """ - return zarr.open(zarr.N5FSStore(path), "r").volume +def read_img_chunk(img, xyz, shape): + return img[ + xyz[2] - shape[2] // 2 : xyz[2] + shape[2] // 2, + xyz[1] - shape[1] // 2 : xyz[1] + shape[1] // 2, + xyz[0] - shape[0] // 2 : xyz[0] + shape[0] // 2, + ] def read_json(path): @@ -238,44 +235,6 @@ def subplot(data1, data2, title): # --- miscellaneous --- -def pair_dist(pair_1, pair_2, metric="l2"): - pair_1.reverse() - d1 = _pair_dist(pair_1, pair_2) - - pair_1.reverse() - d2 = _pair_dist(pair_1, pair_2) - return min(d1, d2) - - -def _pair_dist(pair_1, pair_2, metric="l2"): - d1 = dist(pair_1[0], pair_2[0], metric=metric) - d2 = dist(pair_1[1], pair_2[1], metric=metric) - return max(d1, d2) - - -def check_img_path(target_labels, xyz_1, xyz_2): - d = dist(xyz_1, xyz_2) - t_steps = np.arange(0, 1, 1 / d) - num_steps = len(t_steps) - labels = set() - collisions = set() - for t in t_steps: - xyz = tuple([int(line(xyz_1[i], xyz_2[i], t)) for i in range(3)]) - if target_labels[xyz] != 0: - # Check for repeat collisions - if xyz in collisions: - num_steps -= 1 - else: - collisions.add(xyz) - - # Check for collision with multiple labels - labels.add(target_labels[xyz]) - if len(labels) > 1: - return False - ratio = len(collisions) / len(t_steps) - return True if ratio > 1 / 3 else False - - def to_world(xyz, anisotropy, shift=[0, 0, 0]): return tuple([int((xyz[i] - shift[i]) * anisotropy[i]) for i in range(3)]) diff --git a/tests/test_build_neurograph.py b/tests/test_build_neurograph.py index deb663e..7b40106 100644 --- a/tests/test_build_neurograph.py +++ b/tests/test_build_neurograph.py @@ -8,20 +8,13 @@ """ -import networkx as nx -import torch -import torch_geometric.transforms as T - -from deep_neurographs import feature_extraction as extracter -from deep_neurographs import intake, net -from deep_neurographs import neurograph as ng -from deep_neurographs import train +from deep_neurographs import intake if __name__ == "__main__": # Parameters max_mutable_degree = 5 - max_mutable_dist = 100.0 + max_mutable_dist = 50.0 prune = True prune_depth = 16 diff --git a/tests/test_example.py b/tests/test_example.py deleted file mode 100644 index 06e9e0d..0000000 --- a/tests/test_example.py +++ /dev/null @@ -1,16 +0,0 @@ -"""Example test template.""" - -import unittest - - -class ExampleTest(unittest.TestCase): - """Example Test Class""" - - def test_assert_example(self): - """Example of how to test the truth of a statement.""" - - self.assertTrue(1 == 1) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_train_pipeline.py b/tests/test_train_pipeline.py deleted file mode 100644 index 26ade66..0000000 --- a/tests/test_train_pipeline.py +++ /dev/null @@ -1,109 +0,0 @@ -""" -Created on Sat July 15 9:00:00 2023 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Tests routines that build graph and generate features - -""" - -import os - -import networkx as nx -import torch -import torch_geometric.transforms as T - -from deep_neurographs import feature_extraction as extracter -from deep_neurographs import intake, net -from deep_neurographs import neurograph as ng -from deep_neurographs import train, utils - - -def train(): - # Data loader - graph = build_neurographs() - # node_features, edge_features = generate_features() - - # Cross validation - - -def build_neurographs(): - print("Building NeuroGraphs...") - graphs = [] - for block_id in utils.listsubdirs(root_dir, keyword="block"): - print(" " + block_id) - swc_dir = os.path.join(root_dir, block_id) - graphs.append(intake.build_neurograph(swc_dir)) - return graphs - - -def generate_features(): - pass - - -if __name__ == "__main__": - # Paramaters - anisotropy = [1.0, 1.0, 1.0] - dataset = "653158" - pred_id = "20230801_2steps_segmentation_filtered" - root_dir = f"/home/jupyter/workspace/data/{dataset}/pred_swcs/{pred_id}" - whole_brain = False - - # Main - train() - - """ - # Feature extraction - node_features = extracter.generate_node_features( - supergraph, img=False, pointcloud=False - ) - edge_features = extracter.generate_edge_features(supergraph) - print("Generated node and edge features...") - print("Number of node features:", node_features.shape[1]) - print("Number of edge features:", edge_features.shape[1]) - print("") - - - # Initialize training data - data, mistake_log = intake.init_data( - supergraph, node_features, edge_features, bucket, mistake_log_path, - ) - - - # Training parameters - num_feats = node_features.shape[1] - model = net.GCN(num_feats, num_feats // 2, num_feats // 4).to(device) - optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01) - criterion = torch.nn.BCEWithLogitsLoss() - transform = T.Compose([ - T.NormalizeFeatures(), - T.ToDevice(device), - T.RandomLinkSplit( - num_val=0.05, - num_test=0.1, - is_undirected=True, - add_negative_train_samples=False, - ), - ]) - - - # Train - print("Training...") - train_data, val_data, test_data = transform(data) - best_val_auc = final_test_auc = 0 - for epoch in range(1, 101): - loss = train.train(model, optimizer, criterion, train_data) - val_auc = train.test(model, val_data) - test_auc = train.test(model, test_data) - if val_auc > best_val_auc: - best_val_auc = val_auc - final_test_auc = test_auc - print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, ' - f'Test: {test_auc:.4f}') - - print(f'Final Test: {final_test_auc:.4f}') - - z = model.encode(test_data.x, test_data.edge_index) - final_edge_index = model.decode_all(z) - """