diff --git a/src/deep_neurographs/deep_learning/datasets.py b/src/deep_neurographs/deep_learning/datasets.py index f3bbdf0..dd818b8 100644 --- a/src/deep_neurographs/deep_learning/datasets.py +++ b/src/deep_neurographs/deep_learning/datasets.py @@ -12,8 +12,6 @@ import torchio as tio from torch.utils.data import Dataset -from deep_neurographs import utils - # Custom datasets class ProposalDataset(Dataset): @@ -250,7 +248,7 @@ def __init__(self): tio.RandomFlip(axes=(0, 1, 2)), tio.RandomAffine( degrees=20, scales=(0.8, 1), image_interpolation="nearest" - ) + ), ] ) diff --git a/src/deep_neurographs/deep_learning/train.py b/src/deep_neurographs/deep_learning/train.py index fba2b8d..4ec9725 100644 --- a/src/deep_neurographs/deep_learning/train.py +++ b/src/deep_neurographs/deep_learning/train.py @@ -19,10 +19,9 @@ from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.profilers import PyTorchProfiler from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier -from torch.utils.data import DataLoader from torch.nn.functional import sigmoid +from torch.utils.data import DataLoader from torcheval.metrics.functional import ( - binary_accuracy, binary_f1_score, binary_precision, binary_recall, @@ -113,9 +112,7 @@ def train_network( # Configure trainer model = LitNeuralNet(net=net, lr=lr) - ckpt_callback = ModelCheckpoint( - save_top_k=1, monitor="val_f1", mode="max" - ) + ckpt_callback = ModelCheckpoint(save_top_k=1, monitor="val_f1", mode="max") profiler = PyTorchProfiler() if profile else None # Fit model @@ -131,7 +128,7 @@ def train_network( profiler=profiler, ) trainer.fit(model, train_loader, valid_loader) - + # Return best model ckpt = torch.load(ckpt_callback.best_model_path) model.net.load_state_dict(ckpt["state_dict"]) @@ -158,7 +155,7 @@ def __init__(self, net=None, lr=10e-3): super().__init__() self.net = net self.lr = lr - + def forward(self, batch): x = self.get_example(batch, "inputs") return self.net(x) @@ -192,5 +189,5 @@ def compute_stats(self, y_hat, y, prefix=""): def get_example(self, batch, key): return batch[key] - def state_dict(self, destination=None, prefix='', keep_vars=False): - return self.net.state_dict(destination, prefix + '', keep_vars) + def state_dict(self, destination=None, prefix="", keep_vars=False): + return self.net.state_dict(destination, prefix + "", keep_vars) diff --git a/src/deep_neurographs/densegraph.py b/src/deep_neurographs/densegraph.py index c064cbe..eaa4f8f 100644 --- a/src/deep_neurographs/densegraph.py +++ b/src/deep_neurographs/densegraph.py @@ -4,35 +4,68 @@ @author: Anna Grim @email: anna.grim@alleninstitute.org -Class of graphs that are built from swc files. +Class of graphs built from swc files. Each swc file is stored as a distinct +graph and each node in this graph. """ import os import networkx as nx -import numpy as np from more_itertools import zip_broadcast from scipy.spatial import KDTree from deep_neurographs import swc_utils, utils -from deep_neurographs.geometry_utils import dist +from deep_neurographs.geometry_utils import dist as get_dist class DenseGraph: + """ + Class of graphs built from swc files. Each swc file is stored as a + distinct graph and each node in this graph. + + """ + def __init__(self, swc_dir): + """ + Constructs a DenseGraph object from a directory of swc files. + + Parameters + ---------- + swc_dir : path + Path to directory of swc files which are used to construct a hash + table in which the entries are filename-graph pairs. + + Returns + ------- + None + + """ self.xyz_to_node = dict() self.xyz_to_swc = dict() self.init_graphs(swc_dir) self.init_kdtree() def init_graphs(self, swc_dir): + """ + Initializes graphs by reading swc files in "swc_dir". Graphs are + stored in a hash table where the entries are filename-graph pairs. + + Parameters + ---------- + swc_dir : path + Path to directory of swc files which are used to construct a hash + table in which the entries are filename-graph pairs. + + Returns + ------- + None + + """ self.graphs = dict() for f in utils.listdir(swc_dir, ext=".swc"): - # Extract info - path = os.path.join(swc_dir, f) - # Construct Graph + path = os.path.join(swc_dir, f) swc_dict = swc_utils.parse(swc_utils.read_swc(path)) graph, xyz_to_node = swc_utils.file_to_graph( swc_dict, set_attrs=True, return_dict=True @@ -45,49 +78,147 @@ def init_graphs(self, swc_dir): self.xyz_to_swc.update(xyz_to_id) def init_kdtree(self): + """ + Initializes KDTree from all xyz coordinates contained in all + graphs in "self.graphs". + + Parameters + ---------- + None + + Returns + ------- + None + + """ self.kdtree = KDTree(list(self.xyz_to_swc.keys())) - def get_projection(self, xyz): - _, idx = self.kdtree.query(xyz, k=1) - proj_xyz = tuple(self.kdtree.data[idx]) - proj_dist = dist(proj_xyz, xyz) - return proj_xyz, proj_dist + def query_kdtree(self, xyz): + """ + Queries "self.kdtree" for the nearest neighbor of "xyz". + + Parameters + ---------- + xyz : tuple[float] + Coordinate to be queried. - def connect_nodes(self, graph_id, xyz_i, xyz_j, return_dist=True): - i = self.xyz_to_node[graph_id][xyz_i] - j = self.xyz_to_node[graph_id][xyz_j] + Returns + ------- + tuple[float] + Result of query. + + """ + _, idx = self.kdtree.query(xyz, k=1) + return tuple(self.kdtree.data[idx]) + + def is_connected(self, xyz_1, xyz_2): + """ + Determines whether the points "xyz_1" and "xyz_2" belong to the same + swc file (i.e. graph). + + Parameters + ---------- + xyz_1 : tuple[float] + Coordinate contained in some graph in "self.graph". + xyz_2 : tuple[float] + Coordinate contained in some graph in "self.graph". + + Returns + ------- + bool + Indication of whether "xyz_1" and "xyz_2" belong to the same swc + file (i.e. graph). + + """ + swc_identical = self.xyz_to_swc[xyz_1] == self.xyz_to_swc[xyz_2] + return True if swc_identical else False + + def connect_nodes(self, xyz_1, xyz_2): + """ + Finds path connecting two points that belong to some graph in + "self.graph". + + Parameters + ---------- + xyz_1 : tuple[float] + Source of path. + xyz_2 : tuple[float] + Target of path. + + Returns + ------- + list[int] + Path of nodes connecting source and target. + float + Length of path with respect to l2-metric. + + """ + graph_id = self.xyz_to_swc[xyz_1] + i = self.xyz_to_node[graph_id][xyz_1] + j = self.xyz_to_node[graph_id][xyz_2] path = nx.shortest_path(self.graphs[graph_id], source=i, target=j) - if return_dist: - dist = self.compute_dist(graph_id, path) - return path, dist - else: - return path - - def compute_dist(self, graph_id, path): - d = 0 + return path, self.path_length(graph_id, path) + + def path_length(self, graph_id, path): + """ + Computes length of path with respect to the l2-metric. + + Parameters + ---------- + graph_id : str + ID of graph that path belongs to. + path : list[int] + List of nodes that form a path. + Returns + ------- + float + Length of path with respect to l2-metrics. + + """ + path_length = 0 for i in range(1, len(path)): xyz_1 = self.graphs[graph_id].nodes[i]["xyz"] xyz_2 = self.graphs[graph_id].nodes[i - 1]["xyz"] - d += dist(xyz_1, xyz_2) - return d - - def check_aligned(self, pred_xyz_i, pred_xyz_j): - # Get target graph - xyz_i, _ = self.get_projection(pred_xyz_i) - xyz_j, _ = self.get_projection(pred_xyz_j) - graph_id = self.xyz_to_swc[xyz_i] - if self.xyz_to_swc[xyz_i] != self.xyz_to_swc[xyz_j]: - return False - - # Compute distances - pred_xyz_i = np.array(pred_xyz_i) - pred_xyz_j = np.array(pred_xyz_j) - pred_dist = dist(pred_xyz_i, pred_xyz_j) - - # Check criteria - target_path, target_dist = self.connect_nodes(graph_id, xyz_i, xyz_j) - ratio = min(pred_dist, target_dist) / max(pred_dist, target_dist) - if ratio < 0.5 and pred_dist > 10: - return False - else: - return True + path_length += get_dist(xyz_1, xyz_2) + return path_length + + def is_aligned(self, xyz_1, xyz_2, ratio_threshold=0.5, exclude=10.0): + """ + Determines whether the edge proposal corresponding to "xyz_1" and + "xyz_2" is aligned to the ground truth. This is determined by checking + two conditions: (1) connectedness and (2) distance ratio. For (1), we + project "xyz_1" and "xyz_2" onto "self.graph", then verify that they + project to the same graph. For (2), we compute the ratio between the + Euclidean distance "dist" from "xyz_1" to "xyz" and the path length + between the corresponding projections. This ratio can be skewed if + "dist" is small, so we skip this criteria if "dist" < "exclude". + + Parameters + ---------- + xyz_1 : numpy.array + Endpoint of edge proposal. + xyz_2 : numpy.array + Endpoint of edge proposal. + ratio_threshold : float + Lower bound on threshold used to compare similarity between "dist" + and "path length". + exclude : float + Upper bound on threshold to ignore criteria 1. + + Returns + ------- + bool + Indication of whether edge proposal is aligned to ground truth. + + """ + hat_xyz_1 = self.query_kdtree(xyz_1) + hat_xyz_2 = self.query_kdtree(xyz_2) + if self.is_connected(hat_xyz_1, hat_xyz_2): + dist = get_dist(xyz_1, xyz_2) + _, path_length = self.connect_nodes(hat_xyz_1, hat_xyz_2) + ratio = min(dist, path_length) / max(dist, path_length) + if ratio > ratio_threshold and dist > exclude: + return True + elif dist <= exclude: + return True + return False diff --git a/src/deep_neurographs/feature_extraction.py b/src/deep_neurographs/feature_extraction.py index 8040d0a..2108abe 100644 --- a/src/deep_neurographs/feature_extraction.py +++ b/src/deep_neurographs/feature_extraction.py @@ -72,8 +72,7 @@ def generate_mutable_img_chunk_features( img, labels = utils.get_superchunks( img_path, labels_path, origin, neurograph.shape, from_center=False ) - - #img = utils.normalize_img(img) + img = utils.normalize_img(img) for edge in neurograph.mutable_edges: # Compute image coordinates i, j = tuple(edge) @@ -88,7 +87,7 @@ def generate_mutable_img_chunk_features( # Mark path d = int(geometry_utils.dist(xyz_i, xyz_j) + 5) img_coords_i = np.round(xyz_i - midpoint + HALF_CHUNK_SIZE).astype(int) - img_coords_j = np.round(xyz_j - midpoint + HALF_CHUNK_SIZE).astype(int) + img_coords_j = np.round(xyz_j - midpoint + HALF_CHUNK_SIZE).astype(int) path = geometry_utils.make_line(img_coords_i, img_coords_j, d) img_chunk = utils.normalize_img(img_chunk) diff --git a/src/deep_neurographs/geometry_utils.py b/src/deep_neurographs/geometry_utils.py index c8fda30..95ffa71 100644 --- a/src/deep_neurographs/geometry_utils.py +++ b/src/deep_neurographs/geometry_utils.py @@ -1,10 +1,10 @@ import heapq -import networkx as nx + import numpy as np from scipy.interpolate import UnivariateSpline from scipy.linalg import svd -from deep_neurographs import utils, feature_extraction as extracter +from deep_neurographs import utils # Context Tangent Vectors @@ -118,19 +118,6 @@ def fit_spline(xyz): return cs_x, cs_y, cs_z -""" -def smooth_end(branch_xyz, radii, ref_xyz, num_pts=8): - smooth_bool = branch_xyz.shape[0] > 10 - if all(branch_xyz[0] == ref_xyz) and smooth_bool: - return branch_xyz[num_pts:-1, :], radii[num_pts:-1], 0 - elif all(branch_xyz[-1] == ref_xyz) and smooth_bool: - branch_xyz = branch_xyz[:-num_pts] - radii = radii[:-num_pts] - return branch_xyz, radii, -1 - else: - return branch_xyz, radii, None -""" - # Image feature extraction def get_profile(img, xyz_arr, window_size=[5, 5, 5]): return [np.max(utils.get_chunk(img, xyz, window_size)) for xyz in xyz_arr] @@ -139,19 +126,35 @@ def get_profile(img, xyz_arr, window_size=[5, 5, 5]): def fill_path(img, path, val=-1): for xyz in path: x, y, z = tuple(np.floor(xyz).astype(int)) - #img[x - 1 : x + 2, y - 1 : y + 2, z - 1 : z + 2] = val - img[x,y,z] = val + # img[x - 1 : x + 2, y - 1 : y + 2, z - 1 : z + 2] = val + img[x, y, z] = val return img # Miscellaneous def shortest_path(img, start, end): def is_valid_move(x, y, z): - return 0 <= x < shape[0] and 0 <= y < shape[1] and 0 <= z < shape[2] and not visited[x, y, z] + return ( + 0 <= x < shape[0] + and 0 <= y < shape[1] + and 0 <= z < shape[2] + and not visited[x, y, z] + ) def get_nbs(x, y, z): - moves = [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)] - return [(x + dx, y + dy, z + dz) for dx, dy, dz in moves if is_valid_move(x + dx, y + dy, z + dz)] + moves = [ + (1, 0, 0), + (-1, 0, 0), + (0, 1, 0), + (0, -1, 0), + (0, 0, 1), + (0, 0, -1), + ] + return [ + (x + dx, y + dy, z + dz) + for dx, dy, dz in moves + if is_valid_move(x + dx, y + dy, z + dz) + ] img = img - np.min(img) + 1 start = tuple(start) diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index 9185aae..72474b4 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -10,7 +10,7 @@ import os -from deep_neurographs import s3_utils, swc_utils, utils +from deep_neurographs import swc_utils, utils from deep_neurographs.neurograph import NeuroGraph @@ -46,7 +46,7 @@ def build_neurograph( anisotropy=anisotropy, prune=prune, prune_depth=prune_depth, - smooth=smooth + smooth=smooth, ) if search_radius > 0: neurograph.generate_proposals( diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index 828ffa1..0dd358b 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -8,13 +8,14 @@ """ +from copy import deepcopy +from time import time + import networkx as nx import numpy as np import plotly.graph_objects as go import tensorstore as ts -from copy import deepcopy from scipy.spatial import KDTree -from time import time from deep_neurographs import geometry_utils from deep_neurographs import graph_utils as gutils @@ -22,7 +23,7 @@ from deep_neurographs.densegraph import DenseGraph from deep_neurographs.geometry_utils import dist as get_dist -BUFFER = 5 +BUFFER = 8 SUPPORTED_LABEL_MASK_TYPES = [dict, np.array, ts.TensorStore] @@ -34,7 +35,16 @@ class NeuroGraph(nx.Graph): """ - def __init__(self, swc_path, img_path=None, label_mask=None, optimize_proposals=False, origin=None, shape=None): + def __init__( + self, + swc_path, + img_path=None, + label_mask=None, + optimize_depth=5, + optimize_proposals=False, + origin=None, + shape=None, + ): """ Parameters ---------- @@ -57,7 +67,10 @@ def __init__(self, swc_path, img_path=None, label_mask=None, optimize_proposals= self.xyz_to_edge = dict() self.img_path = img_path + self.optimize_depth = optimize_depth self.optimize_proposals = optimize_proposals + self.simple_proposals = set() + self.complex_proposals = set() if origin and shape: self.bbox = { @@ -139,6 +152,7 @@ def init_immutable_graph(self): immutable_graph.add_edges_from(self.immutable_edges) return immutable_graph + # --- Proposal Generation --- def generate_proposals(self, num_proposals=3, search_radius=25.0): """ Generates edges for the graph. @@ -163,13 +177,12 @@ def generate_proposals(self, num_proposals=3, search_radius=25.0): # Get connecting node contained_j = self.is_contained(j) - if get_dist(xyz, attrs["xyz"][0]) < 10 and self.is_contained(i): + if get_dist(xyz, attrs["xyz"][0]) < 10 and self.is_contained( + i + ): node = i xyz = self.nodes[node]["xyz"] - elif ( - get_dist(xyz, attrs["xyz"][-1]) < 10 - and contained_j - ): + elif get_dist(xyz, attrs["xyz"][-1]) < 10 and contained_j: node = j xyz = self.nodes[node]["xyz"] else: @@ -180,11 +193,14 @@ def generate_proposals(self, num_proposals=3, search_radius=25.0): self.add_edge(leaf, node, xyz=np.array([xyz_leaf, xyz])) self.mutable_edges.add(frozenset((leaf, node))) + self.simple_proposals = self.get_simple_proposals() + self.complex_proposals = self.get_complex_proposals() if self.optimize_proposals: self.run_optimization() - - def _get_proposals(self, query_id, query_xyz, num_proposals, search_radius): + def _get_proposals( + self, query_id, query_xyz, num_proposals, search_radius + ): """ Parameters ---------- @@ -293,7 +309,8 @@ def _query_kdtree(self, query, d): """ idxs = self.kdtree.query_ball_point(query, d, return_sorted=True) return self.kdtree.data[idxs] - + + # --- Optimize Proposals --- def run_optimization(self): t0 = time() origin = utils.apply_anisotropy(self.origin, return_int=True) @@ -302,28 +319,32 @@ def run_optimization(self): ) img = utils.normalize_img(img) simple_edges = self.get_simple_proposals() - complex_edges = self.get_complex_proposals() for edge in self.mutable_edges: if edge in simple_edges: self.optimize_simple_edge(img, edge) else: self.optimize_complex_edge(img, edge) print("") - print("edge_optimization(): {} seconds / edge".format((time() - t0) / len(self.get_simple_proposals()))) + print( + "edge_optimization(): {} seconds / edge".format( + (time() - t0) / len(self.get_simple_proposals()) + ) + ) def optimize_simple_edge(self, img, edge): # Extract Branches i, j = tuple(edge) - xyz_i = self.nodes[i]["xyz"] - xyz_j = self.nodes[j]["xyz"] - branch_i = self.get_branch(xyz_i) - branch_j = self.get_branch(xyz_j) - + branch_i = self.get_branch(self.nodes[i]["xyz"]) + branch_j = self.get_branch(self.nodes[j]["xyz"]) + depth = self.optimize_depth + # Get image patch - hat_xyz_i = self.to_img(branch_i[8]) - hat_xyz_j = self.to_img(branch_j[8]) + hat_xyz_i = self.to_img(branch_i[depth]) + hat_xyz_j = self.to_img(branch_j[depth]) patch_dims = geometry_utils.get_optimal_patch(hat_xyz_i, hat_xyz_j) - center = geometry_utils.compute_midpoint(hat_xyz_i, hat_xyz_j).astype(int) + center = geometry_utils.compute_midpoint(hat_xyz_i, hat_xyz_j).astype( + int + ) img_chunk = utils.get_chunk(img, center, patch_dims) # Optimize @@ -334,7 +355,9 @@ def optimize_simple_edge(self, img, edge): ) origin = utils.apply_anisotropy(self.origin, return_int=True) path = geometry_utils.transform_path(path, origin, center, patch_dims) - self.edges[edge]["xyz"] = np.vstack([branch_i[8], path, branch_j[8]]) + self.edges[edge]["xyz"] = np.vstack( + [branch_i[depth], path, branch_j[depth]] + ) def get_branch(self, xyz): edge = self.xyz_to_edge[tuple(xyz)] @@ -347,114 +370,100 @@ def get_branch(self, xyz): def optimize_complex_edge(self, img_superchunk, edge): pass + # --- Ground Truth Generation --- def init_targets(self, target_neurograph): + # Initializations self.target_edges = set() self.groundtruth_graph = self.init_immutable_graph() - target_densegraph = DenseGraph(target_neurograph.path) + target_neurograph.densegraph = DenseGraph(target_neurograph.path) - predicted_graph = self.init_immutable_graph() - site_to_site = dict() - pair_to_edge = dict() - - proposals = list(self.mutable_edges) + # Add best simple edges + remaining_proposals = [] + proposals = self.filter_infeasible(target_neurograph) dists = [self.compute_length(edge) for edge in proposals] for idx in np.argsort(dists): - # Check for cycle edge = proposals[idx] - i, j = tuple(edge) - if self.check_cycle((i, j)): - continue - - # Check projection - xyz_i = self.edges[edge]["xyz"][1] - xyz_j = self.edges[edge]["xyz"][-1] - proj_xyz_i, d_i = target_neurograph.get_projection(xyz_i) - proj_xyz_j, d_j = target_neurograph.get_projection(xyz_j) - if d_i > 5 or d_j > 5: - continue - - # Check cases - edge_i = target_neurograph.xyz_to_edge[proj_xyz_i] - edge_j = target_neurograph.xyz_to_edge[proj_xyz_j] - if edge_i != edge_j: - # Complex criteria - if not target_neurograph.is_adjacent(edge_i, edge_j): - continue - if not target_densegraph.check_aligned(xyz_i, xyz_j): + if self.is_simple(edge): + add_bool = self.is_target( + target_neurograph, edge, dist=2.5, ratio=0.75, exclude=6 + ) + if add_bool: + self.target_edges.add(edge) continue - else: - # Simple criteria - inclusion_i = proj_xyz_i in site_to_site.keys() - inclusion_j = proj_xyz_j in site_to_site.keys() - leaf_i = gutils.is_leaf(predicted_graph, i) - leaf_j = gutils.is_leaf(predicted_graph, j) - if not leaf_i or not leaf_j: - None - # continue - elif inclusion_i or inclusion_j: - if inclusion_j: - proj_xyz_i = proj_xyz_j - proj_xyz_k = site_to_site[proj_xyz_j] - else: - proj_xyz_k = site_to_site[proj_xyz_i] - - # Compare edge - exists = geometry_utils.compare_edges( - proj_xyz_i, proj_xyz_j, proj_xyz_k - ) - if exists: - site_to_site, pair_to_edge = self.remove_site( - site_to_site, pair_to_edge, proj_xyz_i, proj_xyz_k - ) - - # Add site - site_to_site, pair_to_edge = self.add_site( - site_to_site, - pair_to_edge, - proj_xyz_i, - proj_xyz_j, - proposals[idx], + remaining_proposals.append(edge) + + # Check remaining proposals + dists = [self.compute_length(edge) for edge in remaining_proposals] + for idx in np.argsort(dists): + edge = proposals[idx] + add_bool = self.is_target( + target_neurograph, edge, dist=5, ratio=0.5, exclude=10 ) + if add_bool: + self.target_edges.add(edge) # Print results - # target_ratio = len(self.target_edges) / len(self.mutable_edges) - # print("# target edges:", len(self.target_edges)) - # print("% target edges in mutable:", target_ratio) + target_ratio = len(self.target_edges) / len(self.mutable_edges) + print("") + print("# target edges:", len(self.target_edges)) + print("% target edges in mutable:", target_ratio) - def check_simple_criteria(self): - pass - - def check_complex_criteria(self): - pass + def filter_infeasible(self, target_neurograph): + proposals = list() + for edge in self.mutable_edges: + i, j = tuple(edge) + xyz_i = self.nodes[i]["xyz"] + xyz_j = self.nodes[j]["xyz"] + if target_neurograph.is_feasible(xyz_i, xyz_j): + proposals.append(edge) + return proposals + + def is_feasible(self, xyz_1, xyz_2): + # Check if edges are identical + edge_1 = self.xyz_to_edge[self.get_projection(xyz_1)[0]] + edge_2 = self.xyz_to_edge[self.get_projection(xyz_2)[0]] + if edge_1 == edge_2: + return True + + # Check if edges are adjacent + i, j = tuple(edge_1) + k, l = tuple(edge_2) + nb_bool_i = self.is_nb(i, k) or self.is_nb(i, l) + nb_bool_j = self.is_nb(j, k) or self.is_nb(j, l) + if nb_bool_i or nb_bool_j: + return True - def add_site(self, site_to_site, pair_to_edge, xyz_i, xyz_j, edge): - self.target_edges.add(edge) - site_to_site[xyz_i] = xyz_j - site_to_site[xyz_j] = xyz_i - pair_to_edge = self._add_pair_edge(pair_to_edge, xyz_i, xyz_j, edge) - return site_to_site, pair_to_edge - - def remove_site(self, site_to_site, pair_to_edge, xyz_i, xyz_j): - del site_to_site[xyz_i] - del site_to_site[xyz_j] - pair_to_edge = self._remove_pair_edge(pair_to_edge, xyz_i, xyz_j) - return site_to_site, pair_to_edge - - def _add_pair_edge(self, pair_to_edge, xyz_i, xyz_j, edge): - key = frozenset([xyz_i, xyz_j]) - if key not in pair_to_edge.keys(): - pair_to_edge[key] = set([edge]) + # Not feasible + return False + + def is_target( + self, target_graph, edge, dist=5, ratio=0.5, strict=True, exclude=10 + ): + # Check for cycle + i, j = tuple(edge) + if self.check_cycle((i, j)): + return False + + # Get branch + if self.optimize_proposals and self.is_simple(edge): + xyz_i = self.edges[edge]["xyz"][self.optimize_depth] + xyz_j = self.edges[edge]["xyz"][-self.optimize_depth] else: - pair_to_edge[key].add(edge) - return pair_to_edge + xyz_i = self.edges[edge]["xyz"][0] + xyz_j = self.edges[edge]["xyz"][-1] - def _remove_pair_edge(self, pair_to_edge, xyz_i, xyz_j): - key = frozenset([xyz_i, xyz_j]) - edges = list(pair_to_edge[key]) - if len(edges) == 1: - self.target_edges.remove(edges[0]) - del pair_to_edge[key] - return pair_to_edge + # Check projection distance + proj_i, d_i = target_graph.get_projection(xyz_i) + proj_j, d_j = target_graph.get_projection(xyz_j) + if d_i > dist or d_j > dist: + return False + + # Check alignment + aligned = target_graph.densegraph.is_aligned( + tuple(xyz_i), xyz_j, ratio_threshold=ratio, exclude=exclude + ) + if aligned: + return True # --- Visualization --- def visualize_immutables(self, title="Immutable Graph", return_data=False): @@ -504,10 +513,16 @@ def visualize_targets( else: utils.plot(data, title) - def visualize_subset(self, edges, title=""): + def visualize_subset(self, edges, target_graph=None, title=""): data = [self._plot_nodes()] data.extend(self._plot_edges(self.immutable_edges, color="black")) data.extend(self._plot_edges(edges)) + if target_graph is not None: + data.extend( + target_graph._plot_edges( + target_graph.immutable_edges, color="blue" + ) + ) utils.plot(data, title) def _plot_nodes(self): @@ -627,16 +642,6 @@ def get_projection(self, xyz): proj_dist = get_dist(proj_xyz, xyz) return proj_xyz, proj_dist - def is_adjacent(self, edge_i, edge_j): - i, j = tuple(edge_i) - k, l = tuple(edge_j) - nb_bool_i = self.is_nb(i, k) or self.is_nb(i, l) - nb_bool_j = self.is_nb(j, k) or self.is_nb(j, l) - if nb_bool_i or nb_bool_j: - return True - else: - return False - def is_nb(self, i, j): return True if i in self.neighbors(j) else False @@ -654,7 +659,7 @@ def is_contained(self, node_or_xyz): return True def is_leaf(self, i): - return True if len(self.neighbors(i)) == 1 else False + return True if self.immutable_degree(i) == 1 else False def check_cycle(self, edge): self.groundtruth_graph.add_edges_from([edge]) diff --git a/src/deep_neurographs/s3_utils.py b/src/deep_neurographs/s3_utils.py deleted file mode 100644 index 9a08e80..0000000 --- a/src/deep_neurographs/s3_utils.py +++ /dev/null @@ -1,58 +0,0 @@ -""" -Created on Sun July 16 14:00:00 2023 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Routines for reading and writing to an s3 bucket. - -""" - -import io -import json - -import boto3 - - -def init_session(access_key_id=None, secret_access_key=None): - if access_key_id is None or access_key_id is None: - session = boto3.Session() - else: - session = boto3.Session( - aws_access_key_id=access_key_id, - aws_secret_access_key=secret_access_key, - ) - return session.client("s3") - - -def listdir(bucket, dir_prefix, s3_client, ext=None): - response = s3_client.list_objects_v2(Bucket=bucket, Prefix=dir_prefix) - filenames = [] - for file in response["Contents"]: - file_key = file["Key"] - if ext is not None: - if ext in file_key: - filenames.append(file_key) - else: - filenames.append(file_key) - return filenames - - -def read_from_s3(bucket, file_key, s3_client): - if ".txt" in file_key or ".swc" in file_key: - return read_txt(bucket, file_key, s3_client) - elif ".json" in file_key: - return read_json(bucket, file_key, s3_client) - else: - assert False, "File type of {} is not supported".format(file_key) - - -def read_json(bucket, file_key, s3_client): - response = s3_client.get_object(Bucket=bucket, Key=file_key) - json_data = response["Body"].read().decode("utf-8") - return json.loads(json_data) - - -def read_txt(bucket, file_key, s3_client): - s3_object = s3_client.get_object(Bucket=bucket, Key=file_key) - return io.TextIOWrapper(s3_object["Body"]) diff --git a/src/deep_neurographs/utils.py b/src/deep_neurographs/utils.py index dd1b75e..668d031 100644 --- a/src/deep_neurographs/utils.py +++ b/src/deep_neurographs/utils.py @@ -12,6 +12,7 @@ import json import os import shutil +from copy import deepcopy import numpy as np import plotly.graph_objects as go @@ -154,7 +155,7 @@ def read_img_chunk(img, xyz, shape): def get_chunk(arr, xyz, shape): xyz_1 = [max(xyz[i] - shape[i] // 2, 0) for i in range(3)] xyz_2 = [min(xyz[i] + shape[i] // 2, arr.shape[i] - 1) for i in range(3)] - return arr[xyz_1[0] : xyz_2[0], xyz_1[1] : xyz_2[1], xyz_1[2] : xyz_2[2]] + return arr[xyz_1[0]: xyz_2[0], xyz_1[1]: xyz_2[1], xyz_1[2]: xyz_2[2]] def read_tensorstore(ts_arr, xyz, shape):