From d547554cfb4783535555f99516245d69fda00d86 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Wed, 27 Nov 2024 01:14:45 +0000 Subject: [PATCH] refactor: batch formation --- src/deep_neurographs/fragments_graph.py | 33 ++++++++ src/deep_neurographs/inference.py | 82 ++++++++++++-------- src/deep_neurographs/utils/gnn_util.py | 96 ++++++++++++++++++++---- src/deep_neurographs/utils/graph_util.py | 20 ++++- src/deep_neurographs/utils/img_util.py | 9 --- src/deep_neurographs/utils/swc_util.py | 2 +- 6 files changed, 184 insertions(+), 58 deletions(-) diff --git a/src/deep_neurographs/fragments_graph.py b/src/deep_neurographs/fragments_graph.py index 993f0a0..28e58e5 100644 --- a/src/deep_neurographs/fragments_graph.py +++ b/src/deep_neurographs/fragments_graph.py @@ -492,6 +492,39 @@ def list_proposals(self): """ return list(self.proposals) + def proposal_connected_component(self, proposal): + """ + Extracts the connected component that "proposal" belongs to in the + proposal induced subgraph. + + Parameters + ---------- + proposal : frozenset + Proposal used to as the root to extract its connected component + in the proposal induced subgraph. + + Returns + ------- + List[frozenset] + List of proposals in the connected component that "proposal" + belongs to in the proposal induced subgraph. + + """ + queue = [proposal] + visited = set() + while len(queue) > 0: + # Visit proposal + p = queue.pop() + visited.add(p) + + # Update queue + for i in p: + for j in self.nodes[i]["proposals"]: + p_ij = frozenset({i, j}) + if p_ij not in visited: + queue.append(p_ij) + return visited + # -- KDTree -- def init_kdtree(self, node_type): """ diff --git a/src/deep_neurographs/inference.py b/src/deep_neurographs/inference.py index e909427..280c9c2 100644 --- a/src/deep_neurographs/inference.py +++ b/src/deep_neurographs/inference.py @@ -132,6 +132,7 @@ def __init__( self.model_path, self.ml_config.model_type, self.graph_config.search_radius, + batch_size=self.ml_config.batch_size, confidence_threshold=self.ml_config.threshold, device=device, downsample_factor=self.ml_config.downsample_factor, @@ -535,62 +536,66 @@ def __init__( if self.is_gnn and "cuda" in device: self.model = self.model.to(self.device) - def run(self, neurograph, proposals): + def run(self, fragments_graph, proposals): """ Runs inference by forming batches of proposals, then performing the following steps for each batch: (1) generate features, (2) classify proposals by running model, and (3) adding each accepted proposal as - an edge to "neurograph" if it does not create a cycle. + an edge to "fragments_graph" if it does not create a cycle. Parameters ---------- - neurograph : NeuroGraph + fragments_graph : FragmentsGraph Graph that inference will be performed on. proposals : list Proposals to be classified as accept or reject. Returns ------- - NeuroGraph + FragmentsGraph Updated graph with accepted proposals added as edges. list Accepted proposals. """ # Initializations - assert not gutil.cycle_exists(neurograph), "Graph contains cycle!" + assert not gutil.cycle_exists(fragments_graph), "Graph has cycle!" if self.is_gnn: proposals = set(proposals) else: - proposals = sort_proposals(neurograph, proposals) + proposals = sort_proposals(fragments_graph, proposals) # Main + flagged = get_large_proposal_components(fragments_graph, 4) with tqdm(total=len(proposals), desc="Inference") as pbar: accepts = list() while len(proposals) > 0: # Predict - batch = self.get_batch(neurograph, proposals) - dataset = self.get_batch_dataset(neurograph, batch) + batch = self.get_batch(fragments_graph, proposals, flagged) + dataset = self.get_batch_dataset(fragments_graph, batch) preds = self.predict(dataset) # Update graph - for p in get_accepts(neurograph, preds, self.threshold): - neurograph.merge_proposal(p) + for p in get_accepts(fragments_graph, preds, self.threshold): + fragments_graph.merge_proposal(p) accepts.append(p) pbar.update(len(batch["proposals"])) - neurograph.absorb_reducibles() - return neurograph, accepts + fragments_graph.absorb_reducibles() + return fragments_graph, accepts - def get_batch(self, neurograph, proposals): + def get_batch(self, fragments_graph, proposals, flagged_proposals): """ Generates a batch of proposals. Parameters ---------- - neurograph : NeuroGraph + fragments_graph : FragmentsGraph Graph that proposals were generated from. - proposals : list + proposals : List[frozenset] Proposals for which batch is to be generated from. + flagged_proposals : List[frozenset] + List of proposals that are part of a "large" connected component + in the proposal induced subgraph of "fragments_graph". Returns ------- @@ -600,20 +605,22 @@ def get_batch(self, neurograph, proposals): """ if self.is_gnn: - return gnn_util.get_batch(neurograph, proposals, self.batch_size) + return gnn_util.get_batch( + fragments_graph, proposals, self.batch_size, flagged_proposals + ) else: batch = {"proposals": proposals[0:self.batch_size], "graph": None} del proposals[0:self.batch_size] return batch - def get_batch_dataset(self, neurograph, batch): + def get_batch_dataset(self, fragments_graph, batch): """ Generates features and initializes dataset that can be input to a machine learning model. Parameters ---------- - neurograph : NeuroGraph + fragments_graph : FragmentsGraph Graph that inference will be performed on. batch : list Proposals to be classified. @@ -623,10 +630,12 @@ def get_batch_dataset(self, neurograph, batch): ... """ - features = self.feature_generator.run(neurograph, batch, self.radius) + features = self.feature_generator.run( + fragments_graph, batch, self.radius + ) computation_graph = batch["graph"] if type(batch) is dict else None dataset = ml_util.init_dataset( - neurograph, + fragments_graph, features, self.is_gnn, computation_graph=computation_graph, @@ -694,14 +703,14 @@ def predict_with_gnn(model, data, device=None): return toCPU(preds[0:len(data["proposal"]["y"]), 0]) -def get_accepts(neurograph, preds, threshold, high_threshold=0.9): +def get_accepts(fragments_graph, preds, threshold, high_threshold=0.9): """ Determines which proposals to accept based on prediction scores and the specified threshold. Parameters ---------- - neurograph : NeuroGraph + fragments_graph : FragmentsGraph Graph that proposals belong to. preds : dict Dictionary that maps proposal ids to probability generated from @@ -713,20 +722,20 @@ def get_accepts(neurograph, preds, threshold, high_threshold=0.9): Returns ------- list - Proposals to be added as edges to "neurograph". + Proposals to be added as edges to "fragments_graph". """ # Partition proposals into best and the rest preds = {k: v for k, v in preds.items() if v > threshold} best_proposals, proposals = separate_best( - preds, neurograph.simple_proposals(), high_threshold + preds, fragments_graph.simple_proposals(), high_threshold ) # Determine which proposals to accept accepts = list() - accepts.extend(filter_proposals(neurograph, best_proposals)) - accepts.extend(filter_proposals(neurograph, proposals)) - neurograph.remove_edges_from(map(tuple, accepts)) + accepts.extend(filter_proposals(fragments_graph, best_proposals)) + accepts.extend(filter_proposals(fragments_graph, proposals)) + fragments_graph.remove_edges_from(map(tuple, accepts)) return accepts @@ -795,13 +804,13 @@ def filter_proposals(graph, proposals): return accepts -def sort_proposals(neurograph, proposals): +def sort_proposals(fragments_graph, proposals): """ Sorts proposals by length. Parameters ---------- - neurograph : NeuroGraph + fragments_graph : FragmentsGraph Graph that proposals were generated from. proposals : list[frozenset] List of proposals. @@ -812,5 +821,18 @@ def sort_proposals(neurograph, proposals): Sorted proposals. """ - idxs = np.argsort([neurograph.proposal_length(p) for p in proposals]) + idxs = np.argsort([fragments_graph.proposal_length(p) for p in proposals]) return [proposals[idx] for idx in idxs] + + +# --- Batch Formation --- +def get_large_proposal_components(fragments_graph, k): + flagged_proposals = set() + visited = set() + for p in fragments_graph.list_proposals(): + if p not in visited: + component = fragments_graph.proposal_connected_component(p) + if len(component) > k: + flagged_proposals = flagged_proposals.union(component) + visited = visited.union(component) + return flagged_proposals diff --git a/src/deep_neurographs/utils/gnn_util.py b/src/deep_neurographs/utils/gnn_util.py index 1a23f82..9946b6e 100644 --- a/src/deep_neurographs/utils/gnn_util.py +++ b/src/deep_neurographs/utils/gnn_util.py @@ -4,7 +4,8 @@ @author: Anna Grim @email: anna.grim@alleninstitute.org -Helper routines for training graph neural networks. +Helper routines for both training graph neural networks and running inference +with them. """ @@ -21,7 +22,32 @@ # --- Tensor Operations --- -def get_inputs(data, device=None, is_multimodal=False): +def get_inputs(data, device="cpu", is_multimodal=False): + """ + Extracts input data for a graph-based model and optionally moves it to a + GPU. + + Parameters + ---------- + data : torch_geometric.data.HeteroData + A data object with the following attributes: + - x_dict: Dictionary of node features for different node types. + - edge_index_dict: Dictionary of edge indices for edge types. + - edge_attr_dict: Dictionary of edge attributes for edge types. + device : str, optional + Target device for the data, 'cuda' for GPU and 'cpu' for CPU. The + default is "cpu". + is_multimodal : bool, optional + Flag for handling multimodal data. The default is False. + + Returns + -------- + tuple: + - x (dict): Node features dictionary. + - edge_index (dict): Edge indices dictionary. + - edge_attr (dict): Edge attributes dictionary. + + """ x = data.x_dict edge_index = data.edge_index_dict edge_attr = data.edge_attr_dict @@ -59,7 +85,8 @@ def toCPU(tensor): Returns ------- - None + list + Tensor moved to CPU and converted into a list. """ return tensor.detach().cpu().tolist() @@ -85,7 +112,9 @@ def toTensor(my_list): # --- Batch Generation --- -def get_batch(graph, proposals, batch_size): +def get_batch( + fragments_graph, proposals, batch_size, flagged_proposals=set() +): """ Gets a batch for training or inference that consist of a computation graph and list of proposals. Note: queue contains tuples that consist of a node @@ -93,12 +122,15 @@ def get_batch(graph, proposals, batch_size): Parameters ---------- - graph : NeuroGraph - Graph that contains proposals + fragments_graph : FragmentsGraph + Graph that contains proposals to be classified. proposals : list Proposals to be classified as accept or reject. batch_size : int Maximum number of proposals in the computation graph. + flagged_proposals : List[frozenset], optional + List of proposals that are part of a large connected component in the + proposal induced subgraph of "fragments_graph". The default is None Returns ------- @@ -107,6 +139,14 @@ def get_batch(graph, proposals, batch_size): graph if the model type is a gnn. """ + # Helpers + def visit_proposal(p): + batch["graph"].add_edge(i, j) + batch["proposals"].add(p) + proposals.remove(p) + queue.append((j, 0)) + + # Main batch = reset_batch() visited = set() while len(proposals) > 0 and len(batch["proposals"]) < batch_size: @@ -115,22 +155,30 @@ def get_batch(graph, proposals, batch_size): while len(queue) > 0: # Visit node i, d = queue.pop() - for j in graph.neighbors(i): + visited.add(i) + for j in fragments_graph.neighbors(i): if (i, j) not in batch["graph"].edges: batch["graph"].add_edge(i, j) - for p in graph.nodes[i]["proposals"]: - if frozenset({i, p}) in proposals: - batch["graph"].add_edge(i, p) - batch["proposals"].add(frozenset({i, p})) - proposals.remove(frozenset({i, p})) - queue.append((p, 0)) - visited.add(i) + for j in fragments_graph.nodes[i]["proposals"]: + p = frozenset({i, j}) + if p in proposals and p in flagged_proposals: + for q in fragments_graph.proposal_connected_component(p): + visit_proposal(q) + q_0, q_1 = tuple(q) + if q_0 not in visited: + queue.append((q_0, 0)) + if q_1 not in visited: + queue.append((q_1, 0)) + elif p in proposals: + visit_proposal(p) # Update queue if len(batch["proposals"]) < batch_size: - for j in [j for j in graph.neighbors(i) if j not in visited]: - d_j = min(d + 1, -len(graph.nodes[j]["proposals"])) + nbhd_i = fragments_graph.neighbors(i) + for j in [j for j in nbhd_i if j not in visited]: + n_proposals = len(fragments_graph.nodes[j]["proposals"]) + d_j = min(d + 1, -n_proposals) if d_j <= GNN_DEPTH: queue.append((j, d + 1)) return batch @@ -273,6 +321,22 @@ def proposals_in_graph(graph, proposals): def get_node_proposal_cnt(proposals): + """ + Computes the number of proposals associated with each node. + + Parameters + ---------- + proposals : List[frozenset] + A list of pairs of nodes that represent a proposal in a fragments + graph. + + Returns + ------- + defaultdict + Dictionary where keys are node identifiers and values are the count of + proposals each node appears in. + + """ node_proposal_cnt = defaultdict(int) for i, j in proposals: node_proposal_cnt[i] += 1 diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index 66e23af..5ea7057 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -36,7 +36,7 @@ MIN_SIZE = 30 NODE_SPACING = 1 SMOOTH_BOOL = True -PRUNE_DEPTH = 16 +PRUNE_DEPTH = 20 class GraphLoader: @@ -271,7 +271,8 @@ def prune_branches(self, graph): # Check whether to stop if length > self.prune_depth: if n_passes == 1: - graph.remove_nodes_from(branch[0:min(3, len(branch))]) + k = min(3, len(branch)) + graph.remove_nodes_from(branch[0:k]) break def get_component_irreducibles(self, graph, swc_dict): @@ -623,6 +624,21 @@ def upd_node_attrs(swc_dict, leafs, branchings, i): def compute_path_length(graph): + """ + Computes the path length of the given graph. + + Parameters + ---------- + graph : networkx.Graph + Graph whose nodes have an attribute called "xyz" which represents + a 3d coordinate. + + Returns + ------- + float + Path length of graph. + + """ path_length = 0 for i, j in nx.dfs_edges(graph): path_length += compute_dist(graph, i, j) diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index f70815d..1d8ed8c 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -388,15 +388,6 @@ def get_minimal_bbox(voxels, buffer=0): return bbox -def get_fixed_bbox(voxels, shape): - centroid = np.round(np.mean(voxels, axis=0)).astype(int) - bbox = { - "min": [centroid[i] - shape[i] // 2 for i in range(3)], - "max": [centroid[i] + shape[i] // 2 for i in range(3)], - } - return bbox - - def find_img_path(bucket_name, img_root, dataset_name): """ Find the path of a specific dataset in a GCS bucket. diff --git a/src/deep_neurographs/utils/swc_util.py b/src/deep_neurographs/utils/swc_util.py index ce1eef9..8003d21 100644 --- a/src/deep_neurographs/utils/swc_util.py +++ b/src/deep_neurographs/utils/swc_util.py @@ -176,7 +176,7 @@ def load_from_local_zip(self, zip_path): if result: swc_dicts.append(result) return swc_dicts - + def load_from_gcs(self, gcs_dict): """ Reads swc files from zips on a GCS bucket.