From 71d5d7c75baaca69814b5c687913992b35661fcc Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Thu, 26 Sep 2024 21:46:25 -0700 Subject: [PATCH] Refactor gnn training (#250) * minor upds * refactor: training pipeline * feat: find gcs image path * feat: feature generation in trainer * feat: validation sets in training * bug: hgraph forward passes with missing edge types * refactor: hgnn trainer * feat: functional training pipeline * bug: set validation data --------- Co-authored-by: anna-grim --- .../graph_artifact_removal.py | 6 +- .../machine_learning/gnn_trainer.py | 11 +- .../groundtruth_generation.py | 10 +- src/deep_neurographs/neurograph.py | 17 +- src/deep_neurographs/train_pipeline.py | 5 +- src/deep_neurographs/utils/graph_util.py | 7 +- src/deep_neurographs/utils/util.py | 209 ++++-------------- 7 files changed, 72 insertions(+), 193 deletions(-) diff --git a/src/deep_neurographs/graph_artifact_removal.py b/src/deep_neurographs/graph_artifact_removal.py index 421adba..627b65d 100644 --- a/src/deep_neurographs/graph_artifact_removal.py +++ b/src/deep_neurographs/graph_artifact_removal.py @@ -8,6 +8,8 @@ other from a NeuroGraph. """ +from collections import defaultdict + import numpy as np from networkx import connected_components from tqdm import tqdm @@ -93,7 +95,7 @@ def compute_projections(neurograph, kdtree, edge): projection distances. """ - hits = dict() + hits = defaultdict(list) query_id = neurograph.edges[edge]["swc_id"] for i, xyz in enumerate(neurograph.edges[edge]["xyz"]): # Compute projections @@ -108,7 +110,7 @@ def compute_projections(neurograph, kdtree, edge): # Store best if best_id: - hits = util.append_dict_value(hits, best_id, best_dist) + hits[best_id].append(best_dist) elif i == 15 and len(hits) == 0: return hits return hits diff --git a/src/deep_neurographs/machine_learning/gnn_trainer.py b/src/deep_neurographs/machine_learning/gnn_trainer.py index bd32650..2c17c2a 100644 --- a/src/deep_neurographs/machine_learning/gnn_trainer.py +++ b/src/deep_neurographs/machine_learning/gnn_trainer.py @@ -118,7 +118,7 @@ def run(self, train_dataset_list, validation_dataset_list): loss.backward() self.optimizer.step() - # Store predictions + # Store prediction y.extend(toCPU(y_i)) hat_y.extend(toCPU(hat_y_i)) @@ -159,13 +159,8 @@ def predict(self, data): Prediction. """ - # Run model - x_dict, edge_index_dict, edge_attr_dict = gnn_util.get_inputs( - data, "HeteroGNN" - ) - hat_y = self.model(x_dict, edge_index_dict, edge_attr_dict) - - # Output + x, edge_index, edge_attr = gnn_util.get_inputs(data, "HeteroGNN") + hat_y = self.model(x, edge_index, edge_attr) y = data["proposal"]["y"] return truncate(hat_y, y), y diff --git a/src/deep_neurographs/machine_learning/groundtruth_generation.py b/src/deep_neurographs/machine_learning/groundtruth_generation.py index c272143..0095630 100644 --- a/src/deep_neurographs/machine_learning/groundtruth_generation.py +++ b/src/deep_neurographs/machine_learning/groundtruth_generation.py @@ -9,6 +9,8 @@ """ +from collections import defaultdict + import networkx as nx import numpy as np @@ -145,13 +147,13 @@ def is_component_aligned(target_graph, pred_graph, component, kdtree): """ # Compute distances - dists = dict() + dists = defaultdict(list) for edge in pred_graph.subgraph(component).edges: for xyz in pred_graph.edges[edge]["xyz"]: hat_xyz = geometry.kdtree_query(kdtree, xyz) hat_swc_id = target_graph.xyz_to_swc(hat_xyz) d = get_dist(hat_xyz, xyz) - dists = util.append_dict_value(dists, hat_swc_id, d) + dists[hat_swc_id].append(d) # Deterine whether aligned hat_swc_id = util.find_best(dists) @@ -212,14 +214,14 @@ def is_valid(target_graph, pred_graph, kdtree, target_id, edge): def proj_branch(target_graph, pred_graph, kdtree, target_id, i): # Compute projections - hits = dict() + hits = defaultdict(list) for branch in pred_graph.get_branches(i): for xyz in branch: hat_xyz = geometry.kdtree_query(kdtree, xyz) swc_id = target_graph.xyz_to_swc(hat_xyz) if swc_id == target_id: hat_edge = target_graph.xyz_to_edge[hat_xyz] - hits = util.append_dict_value(hits, hat_edge, hat_xyz) + hits[hat_edge].append(hat_xyz) # Determine closest edge min_dist = np.inf diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index d9cad37..e131fe6 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -892,12 +892,6 @@ def leaf_neighbor(self, i): assert self.is_leaf(i) return list(self.neighbors(i))[0] - """ - def get_edge_attr(self, edge, key): - xyz_arr = gutil.get_edge_attr(self, edge, key) - return xyz_arr[0], xyz_arr[-1] - """ - def to_patch_coords(self, edge, midpoint, chunk_size): patch_coords = list() for xyz in self.edges[edge]["xyz"]: @@ -917,6 +911,7 @@ def xyz_to_swc(self, xyz, return_node=False): else: return None + """ def component_cardinality(self, root): cardinality = 0 queue = [(-1, root)] @@ -933,6 +928,7 @@ def component_cardinality(self, root): if frozenset((j, k)) not in visited: queue.append((j, k)) return cardinality + """ # --- write graph to swcs --- def to_zipped_swcs(self, zip_path, color=None): @@ -956,8 +952,7 @@ def to_zipped_swc(self, zip_writer, nodes, color): swc_id = self.nodes[i]["swc_id"] x, y, z = tuple(self.nodes[i]["xyz"]) r = self.nodes[i]["radius"] - if color != "1.0 0.0 0.0": - r += 1.5 + text_buffer.write("\n" + f"1 2 {x} {y} {z} {r} -1") node_to_idx[i] = 1 n_entries += 1 @@ -1056,11 +1051,11 @@ def branch_to_zip(self, text_buffer, n_entries, i, j, parent, color): branch_radius = np.flip(branch_radius, axis=0) # Make entries - for k in range(1, len(branch_xyz)): + idxs = np.arange(1, len(branch_xyz)) + for k in util.spaced_idxs(idxs, 4): x, y, z = tuple(branch_xyz[k]) r = branch_radius[k] - if color != "1.0 0.0 0.0": - r += 1 + node_id = n_entries + 1 parent = n_entries if k > 1 else parent text_buffer.write("\n" + f"{node_id} 2 {x} {y} {z} {r} {parent}") diff --git a/src/deep_neurographs/train_pipeline.py b/src/deep_neurographs/train_pipeline.py index f400840..c25bb77 100644 --- a/src/deep_neurographs/train_pipeline.py +++ b/src/deep_neurographs/train_pipeline.py @@ -135,6 +135,8 @@ def run(self): # Initialize training data self.generate_proposals() self.generate_features() + self.set_validation_idxs() + assert len(self.validation_dataset_list) > 0, "No validation data!" # Train model trainer = Trainer( @@ -174,7 +176,6 @@ def generate_proposals(self): print(f"{sample_id} {example_id} {n_proposals} {p_accepts}") def generate_features(self): - self.set_validation_idxs() for i in range(self.n_examples()): # Get proposals proposals_dict = { @@ -199,7 +200,7 @@ def generate_features(self): self.model_type, computation_graph=proposals_dict["graph"] ) - if i in self.validation_ids: + if i in self.validation_idxs: self.validation_dataset_list.append(dataset) else: self.train_dataset_list.append(dataset) diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index b01d9ce..fa2b8ce 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -22,6 +22,7 @@ """ +from collections import defaultdict from concurrent.futures import ProcessPoolExecutor, as_completed from random import sample @@ -372,7 +373,7 @@ def get_subcomponent_irreducibles(graph, swc_dict, smooth_bool): # Extract edges edges = dict() - nbs = dict() + nbs = defaultdict(list) root = None for (i, j) in nx.dfs_edges(graph, source=source): # Check if start of path is valid @@ -390,8 +391,8 @@ def get_subcomponent_irreducibles(graph, swc_dict, smooth_bool): ) else: edges[(root, j)] = attrs - nbs = util.append_dict_value(nbs, root, j) - nbs = util.append_dict_value(nbs, j, root) + nbs[root].append(j) # = util.append_dict_value(nbs, root, j) + nbs[j].append(root) # = util.append_dict_value(nbs, j, root) root = None # Output diff --git a/src/deep_neurographs/utils/util.py b/src/deep_neurographs/utils/util.py index 9c4809a..e8c9a68 100644 --- a/src/deep_neurographs/utils/util.py +++ b/src/deep_neurographs/utils/util.py @@ -14,7 +14,6 @@ import shutil from io import BytesIO from random import sample -from time import time from zipfile import ZipFile import boto3 @@ -23,158 +22,6 @@ from google.cloud import storage -# --- dictionary utils --- -def remove_item(my_set, item): - """ - Removes item from a set. - - Parameters - ---------- - my_set : set - Set to be queried. - item : - Value to query. - - Returns - ------- - set - Set "my_set" with "item" removed if it existed. - - """ - if item in my_set: - my_set.remove(item) - return my_set - - -def check_key(my_dict, key): - """ - Checks whether "key" is contained in "my_dict". If so, returns the - corresponding value. - - Parameters - ---------- - my_dict : dict - Dictionary to be checked - key : hashable data type - - Returns - ------- - dict value or bool - If "key" is a key in "my_dict", then the associated value is returned. - Otherwise, the bool "False" is returned. - - """ - if key in my_dict.keys(): - return my_dict[key] - else: - return False - - -def remove_key(my_dict, key): - """ - Removes "key" from "my_dict" in the case when key may need to be reversed. - - Parameters - ---------- - my_dict : dict - Dictionary to be queried - key : hashable data type - Key to query. - - Returns - ------- - dict - Updated dictionary. - - """ - if check_key(my_dict, key): - my_dict.pop(key) - elif check_key(my_dict, (key[1], key[0])): - my_dict.pop((key[1], key[0])) - return my_dict - - -def remove_items(my_dict, keys): - """ - Removes dictionary items corresponding to "keys". - - Parameters - ---------- - my_dict : dict - Dictionary to be edited. - keys : list - List of keys to be deleted from "my_dict". - - Returns - ------- - dict - Updated dictionary. - - """ - for key in keys: - if key in my_dict.keys(): - del my_dict[key] - return my_dict - - -def append_dict_value(my_dict, key, value): - """ - Appends "value" to the list stored at "key". - - Parameters - ---------- - my_dict : dict - Dictionary to be queried. - key : hashable data type - Key to be query. - value : list item type - Value to append to list stored at "key". - - Returns - ------- - dict - Updated dictionary. - - """ - if key in my_dict.keys(): - my_dict[key].append(value) - else: - my_dict[key] = [value] - return my_dict - - -def find_best(my_dict, maximize=True): - """ - Given a dictionary where each value is either a list or int (i.e. cnt), - finds the key associated with the longest list or largest integer. - - Parameters - ---------- - my_dict : dict - Dictionary to be searched. - maximize : bool, optional - Indication of whether to find the largest value or highest vote cnt. - - Returns - ------- - hashable data type - Key associated with the longest list or largest integer in "my_dict". - - """ - best_key = None - best_vote_cnt = 0 if maximize else np.inf - for key in my_dict.keys(): - val_type = type(my_dict[key]) - vote_cnt = my_dict[key] if val_type == float else len(my_dict[key]) - if vote_cnt > best_vote_cnt and maximize: - best_key = key - best_vote_cnt = vote_cnt - elif vote_cnt < best_vote_cnt and not maximize: - best_key = key - best_vote_cnt = vote_cnt - return best_key - - # --- os utils --- def mkdir(path, delete=False): """ @@ -633,26 +480,63 @@ def sample_once(my_container): return sample(my_container, 1)[0] -# --- runtime --- -def init_timers(): +# --- dictionary utils --- +def remove_items(my_dict, keys): """ - Initializes two timers. + Removes dictionary items corresponding to "keys". Parameters ---------- - None + my_dict : dict + Dictionary to be edited. + keys : list + List of keys to be deleted from "my_dict". Returns ------- - time.time - Timer. - time.time - Timer. + dict + Updated dictionary. + + """ + for key in keys: + if key in my_dict: + del my_dict[key] + return my_dict + +def find_best(my_dict, maximize=True): """ - return time(), time() + Given a dictionary where each value is either a list or int (i.e. cnt), + finds the key associated with the longest list or largest integer. + Parameters + ---------- + my_dict : dict + Dictionary to be searched. + maximize : bool, optional + Indication of whether to find the largest value or highest vote cnt. + Returns + ------- + hashable data type + Key associated with the longest list or largest integer in "my_dict". + + """ + best_key = None + best_vote_cnt = 0 if maximize else np.inf + for key in my_dict.keys(): + val_type = type(my_dict[key]) + vote_cnt = my_dict[key] if val_type == float else len(my_dict[key]) + if vote_cnt > best_vote_cnt and maximize: + best_key = key + best_vote_cnt = vote_cnt + elif vote_cnt < best_vote_cnt and not maximize: + best_key = key + best_vote_cnt = vote_cnt + return best_key + + +# --- miscellaneous --- def time_writer(t, unit="seconds"): """ Converts a runtime "t" to a larger unit of time if applicable. @@ -683,7 +567,6 @@ def time_writer(t, unit="seconds"): return t, unit -# --- miscellaneous --- def get_swc_id(path): """ Gets segment id of the swc file at "path".