From 07de06f7833355a73458cef7d49a043e0e77a277 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sat, 12 Oct 2024 02:02:28 +0000 Subject: [PATCH] working version --- .../machine_learning/datasets.py | 11 +- .../machine_learning/feature_generation.py | 926 +++++++++++------- .../feature_generation_graphs.py | 471 --------- .../machine_learning/heterograph_datasets.py | 68 +- .../machine_learning/heterograph_models.py | 257 ++--- src/deep_neurographs/train.py | 31 +- 6 files changed, 673 insertions(+), 1091 deletions(-) delete mode 100644 src/deep_neurographs/machine_learning/feature_generation_graphs.py diff --git a/src/deep_neurographs/machine_learning/datasets.py b/src/deep_neurographs/machine_learning/datasets.py index 5af5349..d1a7e78 100644 --- a/src/deep_neurographs/machine_learning/datasets.py +++ b/src/deep_neurographs/machine_learning/datasets.py @@ -293,7 +293,7 @@ def reformat(arr): return np.expand_dims(arr, axis=1).astype(np.float32) -def init_idxs(idxs): +def init_idx_mapping(idx_to_id): """ Adds dictionary item called "edge_to_index" which maps a branch/proposal in a neurograph to an idx that represents it's position in the feature @@ -310,7 +310,8 @@ def init_idxs(idxs): Updated dictionary. """ - idxs["edge_to_idx"] = dict() - for idx, edge in idxs["idx_to_edge"].items(): - idxs["edge_to_idx"][edge] = idx - return idxs + idx_mapping = { + "idx_to_id": idx_to_id, + "id_to_idx": {v: k for k, v in idx_to_id.items()} + } + return idx_mapping \ No newline at end of file diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index 6e03c75..ca06efb 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -6,353 +6,602 @@ Generates features for training a model and performing inference. -Conventions: (1) "xyz" refers to a real world coordinate such as those from - an swc file. +Conventions: + (1) "xyz" refers to a real world coordinate such as those from an swc file - (2) "voxel" refers to an voxel coordinate in a whole exaspim - image. + (2) "voxel" refers to an voxel coordinate in a whole exaspim image. """ -from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from copy import deepcopy from random import sample import numpy as np -import tensorstore as ts +from scipy.ndimage import zoom from deep_neurographs import geometry -from deep_neurographs.machine_learning.feature_generation_graphs import ( - generate_gnn_features, -) from deep_neurographs.utils import img_util, util -CHUNK_SIZE = [48, 48, 48] -N_BRANCH_PTS = 50 -N_PROFILE_PTS = 16 # 10 -N_SKEL_FEATURES = 22 - - -def run( - neurograph, - img, - model_type, - proposals_dict, - radius, - downsample_factor=1, - labels=None, -): - """ - Generates feature vectors that are used by a machine learning model to - classify proposals. - - Parameters - ---------- - neurograph : NeuroGraph - Graph that "proposals" belong to. - img : tensorstore.Tensorstore - Image stored in a GCS bucket. - model_type : str - Type of machine learning model used to classify proposals. - proposals_dict : dict - Dictionary that contains the items (1) "proposals" which are the - proposals from "neurograph" that features will be generated and - (2) "graph" which is the computation graph used by the gnn. - radius : float - Search radius used to generate proposals. - downsample_factor : int, optional - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. The default is 0. - labels : tensorstore.TensorStore, optional - Segmentation mask stored in a GCS bucket. The default is None. - - Returns - ------- - dict - Feature vectors. +class FeatureGenerator: """ - # Init leaf kd-tree (if applicable) - if neurograph.leaf_kdtree is None: - neurograph.init_kdtree(node_type="leaf") - - # Feature generation by type of machine learning model - if model_type == "GraphNeuralNet": - return generate_gnn_features( - neurograph, img, proposals_dict, radius, downsample_factor - ) - else: - return generate_features( - neurograph, img, proposals_dict, radius, downsample_factor - ) - + Class that generates features vectors that are used by a graph neural + network to classify proposals. -def generate_features( - neurograph, img, proposals_dict, radius, downsample_factor -): """ - Generates feature vectors that are used by a general machine learning model - (e.g. random forest or feed forward neural network). - - Parameters - ---------- - neurograph : NeuroGraph - Graph that "proposals" belong to. - img : tensorstore.Tensorstore - Image stored in a GCS bucket. - proposals_dict : dict - Dictionary containing the computation graph used by gnn and proposals - to be classified. - radius : float - Search radius used to generate proposals. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. - - Returns - ------- - dict - Feature vectors. - - """ - features = defaultdict(bool) - features["proposals"] = { - "skel": proposal_skeletal( - neurograph, proposals_dict["proposals"], radius - ), - "profiles": proposal_profiles( - neurograph, img, proposals_dict["proposals"], downsample_factor - ), - } - return features - - -def proposal_profiles(neurograph, img, proposals, downsample_factor): - """ - Generates an image intensity profile along each proposal by reading from - an image on the cloud. - - Parameters - ---------- - neurograph : NeuroGraph - Graph that "proposals" belong to. - img : tensorstore.TensorStore - Image stored in a GCS bucket. - proposals : list[frozenset] - List of proposals for which features will be generated. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. - - Returns - ------- - dict - Dictonary such that each pair is the proposal id and image intensity - profile. - - """ - with ThreadPoolExecutor() as executor: - threads = [] + # Class attributes + patch_shape = [96, 96, 96] + n_profile_points = 16 + + def __init__( + self, + img_path, + downsample_factor, + label_path=None, + use_img_embedding=False, + ): + """ + Initializes object that generates features for a graph. + + Parameters + ---------- + img_path : str + Path to the raw image assumed to be stored in a GCS bucket. + downsample_factor : int + Downsampling factor that accounts for which level in the image + pyramid the voxel coordinates must index into. + label_path : str, optional + Path to the segmentation assumed to be stored on a GCS bucket. The + default is None. + use_img_embedding : bool, optional + ... + + Returns + ------- + None + + """ + # Initialize instance attributes + self.downsample_factor = downsample_factor + self.use_img_embedding = use_img_embedding + + # Initialize image-based attributes + driver = "n5" if ".n5" in img_path else "zarr" + self.img = img_util.open_tensorstore(img_path, driver=driver) + if label_path: + self.labels = img_util.open_tensorstore(label_path) + else: + self.labels = None + + # Set chunk shapes + self.img_patch_shape = self.set_patch_shape(downsample_factor) + self.label_patch_shape = self.set_patch_shape(0) + + # Validate embedding requirements + if self.use_img_embedding and not label_path: + raise("Must provide labels to generate image embeddings") + + @classmethod + def set_patch_shape(cls, downsample_factor): + """ + Adjusts the chunk shape by downsampling each dimension by a specified + factor. + + Parameters + ---------- + downsample_factor : int + The factor by which to downsample each dimension of the current + chunk shape. + + Returns + ------- + list + Adjusted chunk shape with each dimension reduced by the downsample + factor. + + """ + return [s // 2 ** downsample_factor for s in cls.patch_shape] + + @classmethod + def get_n_profile_points(cls): + return cls.n_profile_points + + def run(self, neurograph, proposals_dict, radius): + """ + Generates feature vectors for nodes, edges, and + proposals in a graph. + + Parameters + ---------- + neurograph : NeuroGraph + Graph that "proposals" belong to. + proposals_dict : dict + Dictionary that contains the items (1) "proposals" which are the + proposals from "neurograph" that features will be generated and + (2) "graph" which is the computation graph used by the gnn. + radius : float + Search radius used to generate proposals. + + Returns + ------- + dict + Dictionary that contains different types of feature vectors for + nodes, edges, and proposals. + + """ + # Initializations + computation_graph = proposals_dict["graph"] + proposals = proposals_dict["proposals"] + if neurograph.leaf_kdtree is None: + neurograph.init_kdtree(node_type="leaf") + + # Main + features = { + "nodes": self.run_on_nodes(neurograph, computation_graph), + "branches": self.run_on_branches(neurograph, computation_graph), + "proposals": self.run_on_proposals(neurograph, proposals, radius) + } + + # Generate image patches (if applicable) + if self.use_img_embedding: + features["patches"] = self.proposal_patches(neurograph, proposals) + return features + + def run_on_nodes(self, neurograph, computation_graph): + """ + Generates feature vectors for every node in "computation_graph". + + Parameters + ---------- + neurograph : NeuroGraph + NeuroGraph generated from a predicted segmentation. + computation_graph : networkx.Graph + Graph used by gnn to classify proposals. + + Returns + ------- + dict + Dictionary that maps a node id to a feature vector. + + """ + return self.node_skeletal(neurograph, computation_graph) + + def run_on_branches(self, neurograph, computation_graph): + """ + Generates feature vectors for every edge in "computation_graph". + + Parameters + ---------- + neurograph : NeuroGraph + NeuroGraph generated from a predicted segmentation. + computation_graph : networkx.Graph + Graph used by gnn to classify proposals. + + Returns + ------- + dict + Dictionary that maps an edge id to a feature vector. + + """ + return self.branch_skeletal(neurograph, computation_graph) + + def run_on_proposals(self, neurograph, proposals, radius): + """ + Generates feature vectors for every proposal in "neurograph". + + Parameters + ---------- + neurograph : NeuroGraph + NeuroGraph generated from a predicted segmentation. + proposals : list[frozenset] + List of proposals for which features will be generated. + radius : float + Search radius used to generate proposals. + + Returns + ------- + dict + Dictionary that maps a proposal id to a feature vector. + + """ + features = self.proposal_skeletal(neurograph, proposals, radius) + if not self.use_img_embedding: + profiles = self.proposal_profiles(neurograph, proposals) + for p in proposals: + features[p] = np.concatenate((features[p], profiles[p])) + return features + + # -- Skeletal Features -- + def node_skeletal(self, neurograph, computation_graph): + """ + Generates skeleton-based features for nodes in "computation_graph". + + Parameters + ---------- + neurograph : NeuroGraph + NeuroGraph generated from a predicted segmentation. + computation_graph : networkx.Graph + Graph used by gnn to classify proposals. + + Returns + ------- + dict + Dictionary that maps a node id to a feature vector. + + """ + node_skeletal_features = dict() + for i in computation_graph.nodes: + node_skeletal_features[i] = np.concatenate( + ( + neurograph.degree[i], + neurograph.nodes[i]["radius"], + len(neurograph.nodes[i]["proposals"]), + ), + axis=None, + ) + return node_skeletal_features + + def branch_skeletal(self, neurograph, computation_graph): + """ + Generates skeleton-based features for edges in "computation_graph". + + Parameters + ---------- + neurograph : NeuroGraph + NeuroGraph generated from a predicted segmentation. + computation_graph : networkx.Graph + Graph used by gnn to classify proposals. + + Returns + ------- + dict + Dictionary that maps an edge id to a feature vector. + + """ + branch_skeletal_features = dict() + for edge in neurograph.edges: + branch_skeletal_features[frozenset(edge)] = np.array( + [ + np.mean(neurograph.edges[edge]["radius"]), + min(neurograph.edges[edge]["length"], 500) / 500, + ], + ) + return branch_skeletal_features + + def proposal_skeletal(self, neurograph, proposals, radius): + """ + Generates skeleton-based features for "proposals". + + Parameters + ---------- + neurograph : NeuroGraph + NeuroGraph generated from a predicted segmentation. + proposals : list[frozenset] + List of proposals for which features will be generated. + radius : float + Search radius used to generate proposals. + + Returns + ------- + dict + Dictionary that maps a node id to a feature vector. + + """ + proposal_skeletal_features = dict() for proposal in proposals: - xyz_1, xyz_2 = neurograph.proposal_xyz(proposal) - specs = get_profile_specs(xyz_1, xyz_2, downsample_factor) - threads.append(executor.submit(get_profile, img, specs, proposal)) - - profiles = dict() - for thread in as_completed(threads): - profiles.update(thread.result()) - return profiles + proposal_skeletal_features[proposal] = np.concatenate( + ( + neurograph.proposal_length(proposal) / radius, + neurograph.n_nearby_leafs(proposal, radius), + neurograph.proposal_radii(proposal), + neurograph.proposal_directionals(proposal, 16), + neurograph.proposal_directionals(proposal, 32), + neurograph.proposal_directionals(proposal, 64), + neurograph.proposal_directionals(proposal, 128), + ), + axis=None, + ) + return proposal_skeletal_features + + # --- Image features --- + def node_profiles(self, neurograph, computation_graph): + """ + Generates image profiles for nodes in "computation_graph". + + Parameters + ---------- + neurograph : NeuroGraph + NeuroGraph generated from a predicted segmentation. + computation_graph : networkx.Graph + Graph used by gnn to classify proposals. + + Returns + ------- + dict + Dictionary that maps a node id to an image profile. + + """ + with ThreadPoolExecutor() as executor: + # Assign threads + threads = computation_graph.number_of_nodes() * [None] + for idx, i in enumerate(computation_graph.nodes): + # Get profile path + if neurograph.is_leaf(i): + xyz_path = self.get_leaf_path(neurograph, i) + else: + xyz_path = self.get_branching_path(neurograph, i) + + # Assign + threads[idx] = executor.submit( + img_util.get_profile, self.img, self.get_spec(xyz_path), i + ) + # Store results + node_profile_features = dict() + for thread in as_completed(threads): + node_profile_features.update(thread.result()) + return node_profile_features + + def proposal_profiles(self, neurograph, proposals): + """ + Generates an image intensity profile along the proposal. + + Parameters + ---------- + neurograph : NeuroGraph + Graph that "proposals" belong to. + proposals : list[frozenset] + List of proposals for which features will be generated. + + Returns + ------- + dict + Dictonary such that each pair is the proposal id and image + intensity profile. + + """ + with ThreadPoolExecutor() as executor: + # Assign threads + threads = list() + for p in proposals: + n_points = self.get_n_profile_points() + xyz_1, xyz_2 = neurograph.proposal_xyz(p) + xyz_path = geometry.make_line(xyz_1, xyz_2, n_points) + threads.append(executor.submit(self.get_profile, xyz_path, p)) + + # Store results + profiles = dict() + for thread in as_completed(threads): + profiles.update(thread.result()) + return profiles + + def proposal_patches(self, neurograph, proposals): + """ + Generates an image intensity profile along the proposal. + + Parameters + ---------- + neurograph : NeuroGraph + Graph that "proposals" belong to. + proposals : list[frozenset] + List of proposals for which features will be generated. + + Returns + ------- + dict + Dictonary such that each pair is the proposal id and image + intensity profile. + + """ + with ThreadPoolExecutor() as executor: + # Assign threads + threads = list() + for p in proposals: + labels = neurograph.proposal_labels(p) + xyz_path = np.vstack(neurograph.proposal_xyz(p)) + threads.append( + executor.submit(self.get_patch, labels, xyz_path, p) + ) -def get_profile_specs(xyz_1, xyz_2, downsample_factor): + # Store results + chunks = dict() + for thread in as_completed(threads): + chunks.update(thread.result()) + return chunks + + def get_profile(self, xyz_path, profile_id): + """ + Gets the image intensity profile given xyz coordinates that form a + path. + + Parameters + ---------- + xyz_path : numpy.ndarray + xyz coordinates of a profile path. + profile_id : hashable + Identifier of profile. + + Returns + ------- + dict + Dictionary that maps an id (e.g. node, edge, or proposal) to its + profile. + + """ + profile = img_util.read_profile(self.img, self.get_spec(xyz_path)) + profile.extend(list(util.get_avg_std(profile))) + return {profile_id: profile} + + def get_spec(self, xyz_path): + """ + Gets image bounding box and voxel coordinates needed to compute an + image intensity profile or extract image chunk for cnn embedding. + + Parameters + ---------- + xyz_path : numpy.ndarray + xyz coordinates of a profile path. + + Returns + ------- + dict + Specifications needed to compute a profile. + + """ + voxels = self.transform_path(xyz_path) + bbox = self.get_bbox(voxels) + profile_path = geometry.shift_path(voxels, bbox["min"]) + return {"bbox": bbox, "profile_path": profile_path} + + def transform_path(self, xyz_path): + """ + Converts "xyz_path" from world to voxel coordinates. + + Parameters + ---------- + xyz_path : numpy.ndarray + xyz coordinates of a profile path. + + Returns + ------- + numpy.ndarray + Voxel coordinates of given path. + + """ + voxels = np.zeros((len(xyz_path), 3), dtype=int) + for i, xyz in enumerate(xyz_path): + voxels[i] = img_util.to_voxels(xyz, self.downsample_factor) + return voxels + + def get_bbox(self, voxels, is_img=True): + center = np.round(np.mean(voxels, axis=0)).astype(int) + shape = self.img_patch_shape if is_img else self.label_patch_shape + bbox = { + "min": [c - s // 2 for c, s in zip(center, shape)], + "max": [c + s // 2 for c, s in zip(center, shape)], + } + return bbox + + def get_patch(self, labels, xyz_path, proposal): + # Initializations + center = np.mean(xyz_path, axis=0) + voxels = [img_util.to_voxels(xyz) for xyz in xyz_path] + + # Read patches + img_patch = self.read_img_patch(center) + label_patch = self.read_label_patch(voxels, labels) + return {proposal: np.stack([img_patch, label_patch], axis=0)} + + def read_img_patch(self, xyz_centroid): + center = img_util.to_voxels(xyz_centroid, self.downsample_factor) + img_patch = img_util.read_tensorstore( + self.img, center, self.img_patch_shape + ) + return img_util.normalize(img_patch) + + def read_label_patch(self, voxels, labels): + bbox = self.get_bbox(voxels, is_img=False) + label_patch = img_util.read_tensorstore_with_bbox(self.labels, bbox) + voxels = geometry.shift_path(voxels, bbox["min"]) + return self.relabel(label_patch, voxels, labels) + + def relabel(self, label_patch, voxels, labels): + # Initializations + n_points = self.get_n_profile_points() + scaling_factor = 2 ** self.downsample_factor + label_patch = zoom(label_patch, 1.0 / scaling_factor, order=0) + for i, voxel in enumerate(voxels): + voxels[i] = [v // scaling_factor for v in voxel] + + # Main + relabel_patch = np.zeros(label_patch.shape) + relabel_patch[label_patch == labels[0]] = 1 + relabel_patch[label_patch == labels[1]] = 2 + line = geometry.make_line(voxels[0], voxels[-1], n_points) + return geometry.fill_path(relabel_patch, line, val=-1) + + +# --- Profile utils --- +def get_leaf_path(neurograph, i): """ - Gets image bounding box and voxel coordinates needed to compute an image - profile. + Gets path that profile will be computed over for the leaf node "i". Parameters ---------- - xyz_1 : numpy.ndarray - xyz coordinate of starting point of profile. - xyz_2 : numpy.ndarray - xyz coordinate of ending point of profile. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. + neurograph : NeuroGraph + NeuroGraph generated from a predicted segmentation. + i : int + Leaf node in "neurograph". Returns ------- - dict - Specifications needed to compute an image profile for a given - proposal. + list + Voxel coordinates that profile is generated from. """ - # Compute voxel coordinates - voxel_1 = img_util.to_voxels(xyz_1, downsample_factor=downsample_factor) - voxel_2 = img_util.to_voxels(xyz_2, downsample_factor=downsample_factor) - - # Store local coordinates - bbox = img_util.get_fixed_bbox(np.vstack([voxel_1, voxel_2]), CHUNK_SIZE) - start = [voxel_1[i] - bbox["min"][i] for i in range(3)] - end = [voxel_2[i] - bbox["min"][i] for i in range(3)] - specs = { - "bbox": bbox, - "profile_path": geometry.make_line(start, end, N_PROFILE_PTS), - } - return specs + j = neurograph.leaf_neighbor(i) + xyz_path = neurograph.oriented_edge((i, j), i) + return geometry.truncate_path(xyz_path) -def get_profile(img, specs, profile_id): +def get_branching_path(neurograph, i): """ - Gets the image profile for a given proposal. + Gets path that profile will be computed over for the branching node "i". Parameters ---------- - img : tensorstore.TensorStore - Image that profiles are generated from. - specs : dict - Dictionary that contains the image bounding box and coordinates of the - image profile path. - profile_id : frozenset - ... + neurograph : NeuroGraph + NeuroGraph generated from a predicted segmentation. + i : int + branching node in "neurograph". Returns ------- - dict - Dictionary that maps an id (e.g. node, edge, or proposal) to its image - profile. + list + Voxel coordinates that profile is generated from. """ - profile = img_util.read_profile(img, specs) - avg, std = util.get_avg_std(profile) - profile.extend([avg, std]) - return {profile_id: profile} + j_1, j_2 = tuple(neurograph.neighbors(i)) + voxels_1 = geometry.truncate_path(neurograph.oriented_edge((i, j_1), i)) + voxles_2 = geometry.truncate_path(neurograph.oriented_edge((i, j_2), i)) + return np.vstack([np.flip(voxels_1, axis=0), voxles_2]) -def proposal_skeletal(neurograph, proposals, radius): - """ - Generates features from skeleton (i.e. graph) which are graph or - geometry type features. +# --- Build feature matrix --- +def get_matrix(features, gt_accepts=set()): + # Initialize matrices + key = sample(list(features.keys()), 1)[0] + X = np.zeros((len(features.keys()), len(features[key]))) + y = np.zeros((len(features.keys()))) - Parameters - ---------- - neurograph : NeuroGraph - Graph that "proposals" belong to. - proposals : list - Proposals for which features will be generated - radius : float - Search radius used to generate proposals. + # Populate + idx_to_id = dict() + for i, id_i in enumerate(features): + idx_to_id[i] = id_i + X[i, :] = features[id_i] + y[i] = 1 if id_i in gt_accepts else 0 + return X, y, idx_to_id - Returns - ------- - dict - Features generated from skeleton. - """ - features = dict() - for proposal in proposals: - i, j = tuple(proposal) - features[proposal] = np.concatenate( - ( - neurograph.proposal_length(proposal), - neurograph.degree[i], - neurograph.degree[j], - len(neurograph.nodes[i]["proposals"]), - len(neurograph.nodes[j]["proposals"]), - neurograph.n_nearby_leafs(proposal, radius), - neurograph.proposal_radii(proposal), - neurograph.proposal_avg_radii(proposal), - neurograph.proposal_directionals(proposal, 8), - neurograph.proposal_directionals(proposal, 16), - neurograph.proposal_directionals(proposal, 32), - neurograph.proposal_directionals(proposal, 64), - ), - axis=None, - ) - return features - - -# --- part 2: edge feature generation -- -def compute_curvature(neurograph, edge): - kappa = curvature(neurograph.edges[edge]["xyz"]) - n_pts = len(kappa) - if n_pts <= N_BRANCH_PTS: - sampled_kappa = np.zeros((N_BRANCH_PTS)) - sampled_kappa[0:n_pts] = kappa - else: - idxs = np.linspace(0, n_pts - 1, N_BRANCH_PTS).astype(int) - sampled_kappa = kappa[idxs] - return np.array(sampled_kappa) - - -def curvature(xyz_list): - a = np.linalg.norm(xyz_list[1:-1] - xyz_list[:-2], axis=1) - b = np.linalg.norm(xyz_list[2:] - xyz_list[1:-1], axis=1) - c = np.linalg.norm(xyz_list[2:] - xyz_list[:-2], axis=1) - s = 0.5 * (a + b + c) - delta = np.sqrt(s * (s - a) * (s - b) * (s - c)) - return 4 * delta / (a * b * c) - - -# -- Build Feature Matrix -- -def get_matrix(neurographs, features, sample_ids=None): - if sample_ids: - return stack_feature_matrices(neurographs, features, sample_ids) - else: - return get_feature_matrix(neurographs, features) - - -def stack_feature_matrices(neurographs, features, blocks): - # Initialize - X = None - y = None - idx_transforms = {"block_to_idxs": dict(), "idx_to_edge": dict()} - - # Feature extraction +def stack_matrices(neurographs, features, blocks): + idx_to_id = dict() + X, y = None, None for block_id in blocks: - # Extract feature matrix - idx_shift = 0 if X is None else X.shape[0] - X_i, y_i, idx_transforms_i = get_feature_matrix( - neurographs[block_id], features[block_id], shift=idx_shift - ) - - # Concatenate + X_i, y_i, _ = get_matrix(features[block_id]) if X is None: X = deepcopy(X_i) y = deepcopy(y_i) else: X = np.concatenate((X, X_i), axis=0) y = np.concatenate((y, y_i), axis=0) + return X, y - # Update dicts - idx_transforms["block_to_idxs"][block_id] = idx_transforms_i[ - "block_to_idxs" - ] - idx_transforms["idx_to_edge"].update(idx_transforms_i["idx_to_edge"]) - return X, y, idx_transforms - - -def get_feature_matrix(neurograph, features, shift=0): - # Initialize - features = combine_features(features) - key = sample(list(features.keys()), 1)[0] - X = np.zeros((len(features.keys()), len(features[key]))) - y = np.zeros((len(features.keys()))) - idx_transforms = {"block_to_idxs": set(), "idx_to_edge": dict()} - - # Build - for i, edge in enumerate(features.keys()): - X[i, :] = features[edge] - y[i] = 1 if edge in neurograph.gt_accepts else 0 - idx_transforms["block_to_idxs"].add(i + shift) - idx_transforms["idx_to_edge"][i + shift] = edge - return X, y, idx_transforms - -# -- util -- -def count_features(): +# --- Utils --- +def get_node_dict(use_img_embedding=False): """ - Counts number of features based on the "model_type". + Returns the number of features for different node types. Parameters ---------- @@ -360,95 +609,30 @@ def count_features(): Returns ------- - int - Number of features. + dict + A dictionary containing the number of features for each node type + """ - return N_SKEL_FEATURES + N_PROFILE_PTS + 2 - - -def combine_features(features): - combined = dict() - for edge in features["skel"].keys(): - combined[edge] = None - for key in features.keys(): - if combined[edge] is None: - combined[edge] = deepcopy(features[key][edge]) - else: - combined[edge] = np.concatenate( - (combined[edge], features[key][edge]) - ) - return combined + return {"branch": 2, "proposal": 34} -def generate_chunks(neurograph, proposals, img, labels): +def get_edge_dict(): """ - Generates an image chunk for each proposal such that the centroid of the - image chunk is the midpoint of the proposal. Image chunks contain two - channels: raw image and predicted segmentation. + Returns the number of features for different edge types. Parameters ---------- - neurograph : NeuroGraph - Graph that "proposals" belong to. - img : tensorstore.TensorStore - Image stored in a GCS bucket. - labels : tensorstore.TensorStore - Predicted segmentation mask stored in a GCS bucket. - proposals : list[frozenset], optional - List of proposals for which features will be generated. The - default is None. + None Returns ------- dict - Dictonary such that each pair is the proposal id and image chunk. + A dictionary containing the number of features for each edge type """ - with ThreadPoolExecutor() as executor: - # Assign Threads - threads = [None] * len(proposals) - for t, proposal in enumerate(proposals): - xyz_0, xyz_1 = neurograph.proposal_xyz(proposal) - voxel_1 = util.to_voxels(xyz_0) - voxel_2 = util.to_voxels(xyz_1) - threads[t] = executor.submit( - get_chunk, img, labels, voxel_1, voxel_2, proposal - ) - - # Save result - chunks = dict() - profiles = dict() - for thread in as_completed(threads): - proposal, chunk, profile = thread.result() - chunks[proposal] = chunk - profiles[proposal] = profile - return chunks, profiles - - -def get_chunk(img, labels, voxel_1, voxel_2, thread_id=None): - # Extract chunks - midpoint = geometry.get_midpoint(voxel_1, voxel_2).astype(int) - if type(img) is ts.TensorStore: - chunk = util.read_tensorstore(img, midpoint, CHUNK_SIZE) - labels_chunk = util.read_tensorstore(labels, midpoint, CHUNK_SIZE) - else: - chunk = img_util.read_chunk(img, midpoint, CHUNK_SIZE) - labels_chunk = img_util.read_chunk(labels, midpoint, CHUNK_SIZE) - - # Coordinate transform - chunk = util.normalize(chunk) - patch_voxel_1 = util.voxels_to_patch(voxel_1, midpoint, CHUNK_SIZE) - patch_voxel_2 = util.voxels_to_patch(voxel_2, midpoint, CHUNK_SIZE) - - # Generate features - path = geometry.make_line(patch_voxel_1, patch_voxel_2, N_PROFILE_PTS) - profile = geometry.get_profile(chunk, path) - labels_chunk[labels_chunk > 0] = 1 - labels_chunk = geometry.fill_path(labels_chunk, path, val=2) - chunk = np.stack([chunk, labels_chunk], axis=0) - - # Output - if thread_id: - return thread_id, chunk, profile - else: - return chunk, profile + edge_dict = { + ("proposal", "edge", "proposal"): 3, + ("branch", "edge", "branch"): 3, + ("branch", "edge", "proposal"): 3 + } + return edge_dict \ No newline at end of file diff --git a/src/deep_neurographs/machine_learning/feature_generation_graphs.py b/src/deep_neurographs/machine_learning/feature_generation_graphs.py deleted file mode 100644 index e750825..0000000 --- a/src/deep_neurographs/machine_learning/feature_generation_graphs.py +++ /dev/null @@ -1,471 +0,0 @@ -""" -Created on Sat May 9 11:00:00 2024 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Generates features for training and performing inference with a heterogenous -graph neural network. - -""" -from concurrent.futures import ThreadPoolExecutor, as_completed - -import numpy as np - -from deep_neurographs import geometry -from deep_neurographs.machine_learning import feature_generation as feats -from deep_neurographs.utils import img_util - -N_PROFILE_PTS = 16 -NODE_PROFILE_DEPTH = 16 -WINDOW = [5, 5, 5] - - -def generate_gnn_features( - neurograph, img, proposals_dict, radius, downsample_factor -): - """ - Generates node and edge features for graph neural network. - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - img : str - Image stored on a GCS bucket. - proposals_dict : dict - Dictionary containing the computation graph used by gnn and proposals - to be classified. - radius : float - Search radius used to generate proposals. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. - - Returns - ------- - dict - Dictionary that contains different types of feature vectors for - nodes, edges, and proposals. - - """ - computation_graph = proposals_dict["graph"] - proposals = proposals_dict["proposals"] - features = { - "nodes": run_on_nodes( - neurograph, computation_graph, img, downsample_factor - ), - "edge": run_on_edges(neurograph, computation_graph), - "proposals": run_on_proposals( - neurograph, img, proposals, radius, downsample_factor - ), - } - return features - - -def run_on_nodes(neurograph, computation_graph, img, downsample_factor): - """ - Generates feature vectors for every node in "computation_graph". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - computation_graph : networkx.Graph - Graph used by gnn to classify proposals. - img : str - Image stored in a GCS bucket. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. - - Returns - ------- - dict - Dictionary whose keys are feature types (i.e. skeletal) and values are - a dictionary that maps a node id to the corresponding feature vector. - - """ - return {"skel": node_skeletal(neurograph, computation_graph)} - - -def run_on_edges(neurograph, computation_graph): - """ - Generates feature vectors for every edge in "computation_graph". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - computation_graph : networkx.Graph - Graph used by gnn to classify proposals. - - Returns - ------- - dict - Dictionary whose keys are feature types (i.e. skeletal) and values are - a dictionary that maps an edge id to the corresponding feature vector. - - """ - return {"skel": edge_skeletal(neurograph, computation_graph)} - - -def run_on_proposals(neurograph, img, proposals, radius, downsample_factor): - """ - Generates feature vectors for every proposal in "neurograph". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - img : str - Image stored in a GCS bucket. - proposals : list[frozenset] - List of proposals for which features will be generated. - radius : float - Search radius used to generate proposals. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. - - Returns - ------- - dict - Dictionary whose keys are feature types (i.e. skeletal and profiles) - and values are a dictionary that maps a proposal id to the - corresponding feature vector. - - """ - proposal_features = { - "skel": proposal_skeletal(neurograph, proposals, radius), - "profiles": feats.proposal_profiles( - neurograph, img, proposals, downsample_factor - ), - } - return proposal_features - - -# -- Skeletal Features -- -def node_skeletal(neurograph, computation_graph): - """ - Generates skeleton-based features for nodes in "computation_graph". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - computation_graph : networkx.Graph - Graph used by gnn to classify proposals. - - Returns - ------- - dict - Dictionary that maps a node id to the corresponding feature vector. - - """ - node_skeletal_features = dict() - for i in computation_graph.nodes: - node_skeletal_features[i] = np.concatenate( - ( - neurograph.degree[i], - neurograph.nodes[i]["radius"], - len(neurograph.nodes[i]["proposals"]), - ), - axis=None, - ) - return node_skeletal_features - - -def edge_skeletal(neurograph, computation_graph): - """ - Generates skeleton-based features for edges in "computation_graph". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - computation_graph : networkx.Graph - Graph used by gnn to classify proposals. - - Returns - ------- - dict - Dictionary that maps an edge id to the corresponding feature vector. - - """ - edge_skeletal_features = dict() - for edge in neurograph.edges: - edge_skeletal_features[frozenset(edge)] = np.array( - [ - np.mean(neurograph.edges[edge]["radius"]), - min(neurograph.edges[edge]["length"], 500) / 500, - ], - ) - return edge_skeletal_features - - -def proposal_skeletal(neurograph, proposals, radius): - """ - Generates skeleton-based features for "proposals". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - proposals : list[frozenset] - List of proposals for which features will be generated. - radius : float - Search radius used to generate proposals. - - Returns - ------- - dict - Dictionary that maps a node id to the corresponding feature vector. - - """ - proposal_skeletal_features = dict() - for proposal in proposals: - proposal_skeletal_features[proposal] = np.concatenate( - ( - neurograph.proposal_length(proposal) / radius, - neurograph.n_nearby_leafs(proposal, radius), - neurograph.proposal_radii(proposal), - neurograph.proposal_directionals(proposal, 16), - neurograph.proposal_directionals(proposal, 32), - neurograph.proposal_directionals(proposal, 64), - neurograph.proposal_directionals(proposal, 128), - ), - axis=None, - ) - return proposal_skeletal_features - - -# -- image features -- -def node_profiles(neurograph, computation_graph, img, downsample_factor): - """ - Generates image profiles for nodes in "computation_graph". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - computation_graph : networkx.Graph - Graph used by gnn to classify proposals. - img : str - Image to be read from. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. - - Returns - ------- - dict - Dictionary that maps a node id to the corresponding image profile. - - """ - # Get specifications to compute profiles - specs = dict() - for i in computation_graph.nodes: - if neurograph.degree[i] == 1: - profile_path = get_leaf_profile_path(neurograph, i) - else: - profile_path = get_branching_profile_path(neurograph, i) - specs[i] = get_node_profile_specs(profile_path, downsample_factor) - - # Generate profiles - with ThreadPoolExecutor() as executor: - threads = [] - for i, specs_i in specs.items(): - threads.append(executor.submit(feats.get_profile, img, specs_i, i)) - - node_profile_features = dict() - for thread in as_completed(threads): - node_profile_features.update(thread.result()) - return node_profile_features - - -def get_leaf_profile_path(neurograph, i): - """ - Gets path that profile will be computed over for the leaf node "i". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - i : int - Leaf node in "neurograph". - - Returns - ------- - list - Voxel coordinates that profile is generated from. - - """ - j = neurograph.leaf_neighbor(i) - return get_profile_path(neurograph.oriented_edge((i, j), i, key="xyz")) - - -def get_branching_profile_path(neurograph, i): - """ - Gets path that profile will be computed over for the branching node "i". - - Parameters - ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. - i : int - branching node in "neurograph". - - Returns - ------- - list - Voxel coordinates that profile is generated from. - - """ - nbs = list(neurograph.neighbors(i)) - voxels_1 = get_profile_path(neurograph.oriented_edge((i, nbs[0]), i)) - voxles_2 = get_profile_path(neurograph.oriented_edge((i, nbs[1]), i)) - return np.vstack([np.flip(voxels_1, axis=0), voxles_2]) - - -def get_profile_path(xyz_path): - """ - Gets a sub-path from "xyz_path" that has a path length of at most - "NODE_PROFILE_DEPTH" microns. - - Parameters - ---------- - xyz_path : numpy.ndarray - xyz coordinates that correspond to some edge in a neurograph from - which the profile path is extracted from. - - Returns - ------- - numpy.ndarray - xyz coordinates that an image profile will be generated from. - - """ - # Check for degeneracy - if xyz_path.shape[0] == 1: - xyz_path = np.vstack([xyz_path, xyz_path - 0.01]) - - # Truncate path - length = 0 - for i in range(1, xyz_path.shape[0]): - length += geometry.dist(xyz_path[i - 1], xyz_path[i]) - if length >= NODE_PROFILE_DEPTH: - break - return xyz_path[0:i, :] - - -def get_node_profile_specs(xyz_path, downsample_factor): - """ - Gets image bounding box and voxel coordinates needed to compute an image - profile. - - Parameters - ---------- - xyz_path : numpy.ndarray - xyz coordinates that represent an image profile path. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. - - Returns - ------- - dict - Specifications needed to compute image profile for a given proposal. - - """ - voxels = transform_path(xyz_path, downsample_factor) - bbox = img_util.get_minimal_bbox(voxels, buffer=1) - return {"bbox": bbox, "profile_path": shift_path(voxels, bbox)} - - -def transform_path(xyz_path, downsample_factor): - """ - Transforms "xyz_path" by converting the xyz coordinates to voxels and - resampling "N_PROFILE_PTS" from voxel coordinates. - - Parameters - ---------- - xyz_path : numpy.ndarray - xyz coordinates that represent an image profile path. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. - - Returns - ------- - numpy.ndarray - Voxel coordinates that represent an image profile path. - - """ - # Main - voxels = list() - for xyz in xyz_path: - voxels.append( - img_util.to_voxels(xyz, downsample_factor=downsample_factor) - ) - - # Finish - voxels = np.array(voxels) - if voxels.shape[0] < 5: - voxels = check_degenerate(voxels) - return geometry.sample_curve(voxels, N_PROFILE_PTS) - - -def shift_path(voxels, bbox): - """ - Shifts "voxels" by subtracting the min coordinate in "bbox". - - Parameters - ---------- - voxels : numpy.ndarray - Voxel coordinates to be shifted. - bbox : dict - Coordinates of a bounding box that contains "voxels". - - Returns - ------- - numpy.ndarray - Voxels shifted by min coordinate in "bbox". - - """ - return [voxel - bbox["min"] for voxel in voxels] - - -def check_degenerate(voxels): - """ - Checks whether "voxels" contains at least two unique points. If False, the - unique voxel coordinate is perturbed and added to "voxels". - - Parameters - ---------- - voxels : numpy.ndarray - Voxel coordinates to be checked. - - Returns - ------- - numpy.ndarray - Voxel coordinates that form a non-degenerate path. - - """ - if np.unique(voxels, axis=0).shape[0] == 1: - voxels = np.vstack( - [voxels, voxels[0, :] + np.array([1, 1, 1], dtype=int)] - ) - return voxels - - -def n_node_features(): - return {"branch": 2, "proposal": 34} - - -def n_edge_features(): - n_edge_features_dict = { - ("proposal", "edge", "proposal"): 3, - ("branch", "edge", "branch"): 3, - ("branch", "edge", "proposal"): 3 - } - return n_edge_features_dict diff --git a/src/deep_neurographs/machine_learning/heterograph_datasets.py b/src/deep_neurographs/machine_learning/heterograph_datasets.py index b41afae..a45d913 100644 --- a/src/deep_neurographs/machine_learning/heterograph_datasets.py +++ b/src/deep_neurographs/machine_learning/heterograph_datasets.py @@ -17,7 +17,8 @@ import torch from torch_geometric.data import HeteroData as HeteroGraphData -from deep_neurographs.machine_learning import datasets, feature_generation +from deep_neurographs.machine_learning import datasets +from deep_neurographs.machine_learning.feature_generation import get_matrix from deep_neurographs.utils import gnn_util DTYPE = torch.float32 @@ -44,17 +45,21 @@ def init(neurograph, features, computation_graph): Custom dataset. """ + # Check for groundtruth + if neurograph.gt_accepts is not None: + gt_accepts = neurograph.gt_accepts + else: + gt_accepts = set() + # Extract features - x_branches, _, idxs_branches = feature_generation.get_matrix( - neurograph, features["edge"] - ) - x_proposals, y_proposals, idxs_proposals = feature_generation.get_matrix( - neurograph, features["proposals"] + x_branches, _, idxs_branches = get_matrix(features["branches"]) + x_proposals, y_proposals, idxs_proposals = get_matrix( + features["proposals"], gt_accepts ) - x_nodes = feature_generation.combine_features(features["nodes"]) + x_nodes = features["nodes"] # Initialize dataset - proposals = list(features["proposals"]["skel"].keys()) + proposals = list(features["proposals"].keys()) heterograph_dataset = HeteroGraphDataset( computation_graph, proposals, @@ -63,7 +68,7 @@ def init(neurograph, features, computation_graph): x_proposals, y_proposals, idxs_branches, - idxs_proposals, + idxs_proposals ) return heterograph_dataset @@ -103,12 +108,9 @@ def __init__( Feature matrix generated from "proposals" in "computation_graph". y_proposals : numpy.ndarray Ground truth of proposals. - idxs_branches : dict - Dictionary that maps edges in "computation_graph" to an index that - represents the edge's position in "x_branches". - idxs_proposals : dict - Dictionary that maps "proposals" to an index that represents the - edge's position in "x_proposals". + idx_to_id : dict + Dictionary that maps an edge id in "computation_graph" to its + index in either x_branches or x_proposals. Returns ------- @@ -116,8 +118,8 @@ def __init__( """ # Conversion idxs - self.idxs_branches = datasets.init_idxs(idxs_branches) - self.idxs_proposals = datasets.init_idxs(idxs_proposals) + self.idxs_branches = datasets.init_idx_mapping(idxs_branches) + self.idxs_proposals = datasets.init_idx_mapping(idxs_proposals) self.computation_graph = computation_graph self.proposals = proposals @@ -214,11 +216,11 @@ def check_missing_edge_types(self): edges = [[n - 1, n - 2], [n - 2, n - 1]] self.data[edge_type].edge_index = gnn_util.toTensor(edges) if node_type == "branch": - self.idxs_branches["idx_to_edge"][n - 1] = e_1 - self.idxs_branches["idx_to_edge"][n - 2] = e_2 + self.idxs_branches["idx_to_id"][n - 1] = e_1 + self.idxs_branches["idx_to_id"][n - 2] = e_2 else: - self.idxs_proposals["idx_to_edge"][n - 1] = e_1 - self.idxs_proposals["idx_to_edge"][n - 2] = e_2 + self.idxs_proposals["idx_to_id"][n - 1] = e_1 + self.idxs_proposals["idx_to_id"][n - 2] = e_2 # -- Getters -- def n_branch_features(self): @@ -289,8 +291,8 @@ def proposal_to_proposal(self): edge_index = [] line_graph = gnn_util.init_line_graph(self.proposals) for e1, e2 in line_graph.edges: - v1 = self.idxs_proposals["edge_to_idx"][frozenset(e1)] - v2 = self.idxs_proposals["edge_to_idx"][frozenset(e2)] + v1 = self.idxs_proposals["id_to_idx"][frozenset(e1)] + v2 = self.idxs_proposals["id_to_idx"][frozenset(e2)] edge_index.extend([[v1, v2], [v2, v1]]) return gnn_util.toTensor(edge_index) @@ -315,8 +317,8 @@ def branch_to_branch(self): e1_edge_bool = frozenset(e1) not in self.proposals e2_edge_bool = frozenset(e2) not in self.proposals if e1_edge_bool and e2_edge_bool: - v1 = self.idxs_branches["edge_to_idx"][frozenset(e1)] - v2 = self.idxs_branches["edge_to_idx"][frozenset(e2)] + v1 = self.idxs_branches["id_to_idx"][frozenset(e1)] + v2 = self.idxs_branches["id_to_idx"][frozenset(e2)] edge_index.extend([[v1, v2], [v2, v1]]) return gnn_util.toTensor(edge_index) @@ -339,14 +341,14 @@ def branch_to_proposal(self): edge_index = [] for p in self.proposals: i, j = tuple(p) - v1 = self.idxs_proposals["edge_to_idx"][frozenset(p)] + v1 = self.idxs_proposals["id_to_idx"][frozenset(p)] for k in self.computation_graph.neighbors(i): if frozenset((i, k)) not in self.proposals: - v2 = self.idxs_branches["edge_to_idx"][frozenset((i, k))] + v2 = self.idxs_branches["id_to_idx"][frozenset((i, k))] edge_index.extend([[v2, v1]]) for k in self.computation_graph.neighbors(j): if frozenset((j, k)) not in self.proposals: - v2 = self.idxs_branches["edge_to_idx"][frozenset((j, k))] + v2 = self.idxs_branches["id_to_idx"][frozenset((j, k))] edge_index.extend([[v2, v1]]) return gnn_util.toTensor(edge_index) @@ -419,8 +421,8 @@ def node_intersection(idx_map, e1, e2): Common node between "e1" and "e2". """ - hat_e1 = idx_map["idx_to_edge"][int(e1)] - hat_e2 = idx_map["idx_to_edge"][int(e2)] + hat_e1 = idx_map["idx_to_id"][int(e1)] + hat_e2 = idx_map["idx_to_id"][int(e2)] node = list(hat_e1.intersection(hat_e2)) assert len(node) == 1, "Node intersection is not unique!" return node[0] @@ -444,8 +446,8 @@ def hetero_node_intersection(idx_map_1, idx_map_2, e1, e2): Common node between "e1" and "e2". """ - hat_e1 = idx_map_1["idx_to_edge"][int(e1)] - hat_e2 = idx_map_2["idx_to_edge"][int(e2)] + hat_e1 = idx_map_1["idx_to_id"][int(e1)] + hat_e2 = idx_map_2["idx_to_id"][int(e2)] node = list(hat_e1.intersection(hat_e2)) assert len(node) == 1, "Node intersection is empty or not unique!" return node[0] @@ -467,4 +469,4 @@ def n_edge_features(x): """ key = sample(list(x.keys()), 1)[0] - return x[key].shape[0] + return x[key].shape[0] \ No newline at end of file diff --git a/src/deep_neurographs/machine_learning/heterograph_models.py b/src/deep_neurographs/machine_learning/heterograph_models.py index 8a35c21..9124acf 100644 --- a/src/deep_neurographs/machine_learning/heterograph_models.py +++ b/src/deep_neurographs/machine_learning/heterograph_models.py @@ -26,9 +26,15 @@ class HeteroGNN(torch.nn.Module): """ - Heterogeneous graph neural network that utilizes edge features. + Heterogeneous graph attention network that classifies proposals. """ + # Class attributes + relation_types = [ + ("proposal", "edge", "proposal"), + ("branch", "edge", "proposal"), + ("branch", "edge", "branch"), + ] def __init__( self, @@ -44,85 +50,65 @@ def __init__( """ super().__init__() # Feature vector sizes - node_dict = ml.feature_generation_graphs.n_node_features() - edge_dict = ml.feature_generation_graphs.n_edge_features() - hidden = scale_hidden * np.max(list(node_dict.values())) + node_dict = ml.feature_generation.get_node_dict() + edge_dict = ml.feature_generation.get_edge_dict() + hidden_dim = scale_hidden * np.max(list(node_dict.values())) + output_dim = heads_1 * heads_2 * hidden_dim - # Linear layers - output_dim = heads_1 * heads_2 * hidden + # Nonlinear activation + self.dropout = dropout + self.dropout_layer = Dropout(dropout) + self.leaky_relu = LeakyReLU() + + # Linear layers self.input_nodes = nn.ModuleDict() self.input_edges = dict() for key, d in node_dict.items(): - self.input_nodes[key] = nn.Linear(d, hidden, device=device) + self.input_nodes[key] = nn.Linear(d, hidden_dim, device=device) for key, d in edge_dict.items(): - self.input_edges[key] = nn.Linear(d, hidden, device=device) + self.input_edges[key] = nn.Linear(d, hidden_dim, device=device) self.output = Linear(output_dim, 1).to(device) - # Convolutional layers - self.conv1 = HeteroConv( - { - ("proposal", "edge", "proposal"): GATConv( - -1, - hidden, - dropout=dropout, - edge_dim=hidden, - heads=heads_1, - ), - ("branch", "edge", "branch"): GATConv( - -1, - hidden, - dropout=dropout, - edge_dim=hidden, - heads=heads_1, - ), - ("branch", "edge", "proposal"): GATConv( - (hidden, hidden), - hidden, - add_self_loops=False, - edge_dim=hidden, - heads=heads_1, - ), - }, - aggr="sum", - ) - edge_dim = hidden - hidden = heads_1 * hidden - - self.conv2 = HeteroConv( - { - ("proposal", "edge", "proposal"): GATConv( - -1, - hidden, - dropout=dropout, - edge_dim=edge_dim, - heads=heads_2, - ), - ("branch", "edge", "branch"): GATConv( - -1, - hidden, - dropout=dropout, - edge_dim=edge_dim, - heads=heads_2, - ), - ("branch", "edge", "proposal"): GATConv( - (hidden, hidden), - hidden, - add_self_loops=False, - edge_dim=edge_dim, - heads=heads_2, - ), - }, - aggr="sum", - ) - hidden = heads_2 * hidden - - # Nonlinear activation - self.dropout = Dropout(dropout) - self.leaky_relu = LeakyReLU() + # Message passing layers + self.gat1 = self.init_gat_layer(hidden_dim, hidden_dim, heads_1) + self.gat2 = self.init_gat_layer(hidden_dim * heads_2, hidden_dim, heads_2) # Initialize weights self.init_weights() + # --- Class methods --- + @classmethod + def get_relation_types(cls): + return cls.relation_types + + # --- Architecture --- + def init_gat_layer(self, hidden_dim, edge_dim, heads): + gat_dict = dict() + for r in self.get_relation_types(): + is_same = True if r[0] == r[2] else False + init_gat = self.init_gat_same if is_same else self.init_gat_mixed + gat_dict[r] = init_gat(hidden_dim, edge_dim, heads) + return HeteroConv(gat_dict, aggr="sum") + + def init_gat_same(self, hidden_dim, edge_dim, heads): + gat_layer = GATConv( + -1, + hidden_dim, + dropout=self.dropout, + edge_dim=edge_dim, + heads=heads, + ) + return gat_layer + + def init_gat_mixed(self, hidden_dim, edge_dim, heads): + gat_layer = GATConv( + (hidden_dim, hidden_dim), + hidden_dim, + add_self_loops=False, + edge_dim=edge_dim, + heads=heads, + ) + return gat_layer def init_weights(self): """ Initializes linear layers. @@ -159,7 +145,7 @@ def activation(self, x_dict): """ x_dict = {key: self.leaky_relu(x) for key, x in x_dict.items()} - x_dict = {key: self.dropout(x) for key, x in x_dict.items()} + x_dict = {key: self.dropout_layer(x) for key, x in x_dict.items()} return x_dict def forward(self, x_dict, edge_index_dict, edge_attr_dict): @@ -173,140 +159,13 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict): edge_attr_dict = self.activation(edge_attr_dict) # Convolutional layers - x_dict = self.conv1( + x_dict = self.gat1( x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict ) - x_dict = self.conv2( + x_dict = self.gat2( x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict ) # Output x_dict = self.output(x_dict["proposal"]) return x_dict - - -class HEATGNN(torch.nn.Module): - """ - Heterogeneous graph neural network. - - """ - - def __init__( - self, - hidden, - metadata, - node_dict, - edge_dict, - dropout=DROPOUT, - heads_1=HEADS_1, - heads_2=HEADS_2, - ): - """ - Constructs a heterogeneous graph neural network. - - """ - super().__init__() - # Linear layers - self.input_nodes = nn.ModuleDict( - {key: nn.Linear(d, hidden) for key, d in node_dict.items()} - ) - self.input_edges = { - key: nn.Linear(d, hidden) for key, d in edge_dict.items() - } - self.output = Linear(heads_1 * heads_2 * hidden) - - # Convolutional layers - self.conv1 = HEATConv( - hidden, - hidden, - heads=heads_1, - dropout=dropout, - metadata=metadata, - ) - """ - x in_channels (int) – Size of each input sample, or -1 to - derive the size from the first input(s) to the forward method. - x out_channels (int) – Size of each output sample. - x num_node_types (int) – The number of node types. - x num_edge_types (int) – The number of edge types. - edge_type_emb_dim (int) – The embedding size of edge types. - edge_dim (int) – Edge feature dimensionality. - edge_attr_emb_dim (int) – The embedding size of edge features. - heads (int, optional) – Number of multi-head-attentions. (default: 1) - """ - hidden = heads_1 * hidden - - self.conv2 = HEATConv( - hidden, - hidden, - heads=heads_2, - dropout=dropout, - metadata=metadata, - ) - hidden = heads_2 * hidden - - # Nonlinear activation - self.dropout = Dropout(dropout) - self.leaky_relu = LeakyReLU() - - # Initialize weights - self.init_weights() - - def init_weights(self): - """ - Initializes linear and convolutional layers. - - Parameters - ---------- - None - - Returns - ------- - None - - """ - layers = [self.input_nodes, self.conv1, self.conv2, self.output] - for layer in layers: - for param in layer.parameters(): - if len(param.shape) > 1: - init.kaiming_normal_(param) - else: - init.zeros_(param) - - def activation(self, x_dict): - """ - Applies nonlinear activation - - Parameters - ---------- - x_dict : dict - Dictionary that maps node/edge types to feature matrices. - - Returns - ------- - dict - Feature matrices with activation applied. - - """ - x_dict = {key: self.leaky_relu(x) for key, x in x_dict.items()} - x_dict = {key: self.dropout(x) for key, x in x_dict.items()} - return x_dict - - def forward(self, x_dict, edge_index_dict, edge_attr_dict, metadata): - # Input - Nodes - x_dict = {key: f(x_dict[key]) for key, f in self.input_nodes.items()} - x_dict = self.activation(x_dict) - - # Input - Edges - edge_attr_dict = { - key: f(edge_attr_dict[key]) for key, f in self.input_edges.items() - } - edge_attr_dict = self.activation(edge_attr_dict) - - # Convolutional layers - x_dict = self.conv1(x_dict, edge_index_dict, metadata) - x_dict = self.conv2(x_dict, edge_index_dict, metadata) - - # Output - x_dict = self.output(x_dict["proposal"]) - return x_dict diff --git a/src/deep_neurographs/train.py b/src/deep_neurographs/train.py index 530be90..24a223b 100644 --- a/src/deep_neurographs/train.py +++ b/src/deep_neurographs/train.py @@ -25,15 +25,15 @@ from torch.optim.lr_scheduler import StepLR from torch.utils.tensorboard import SummaryWriter -from deep_neurographs.machine_learning import feature_generation +from deep_neurographs.machine_learning.feature_generation import FeatureGenerator from deep_neurographs.utils import gnn_util, img_util, ml_util, util from deep_neurographs.utils.gnn_util import toCPU from deep_neurographs.utils.graph_util import GraphLoader LR = 1e-3 -N_EPOCHS = 200 -SCHEDULER_GAMMA = 0.5 -SCHEDULER_STEP_SIZE = 1000 +N_EPOCHS = 500 +SCHEDULER_GAMMA = 0.7 +SCHEDULER_STEP_SIZE = 100 WEIGHT_DECAY = 1e-3 @@ -50,6 +50,7 @@ def __init__( model_type, criterion=None, output_dir=None, + use_img_embedding=False, validation_ids=None, save_model_bool=True, ): @@ -58,17 +59,18 @@ def __init__( raise ValueError("Must provide output_dir to save model.") # Set class attributes + self.feature_generators = dict() self.idx_to_ids = list() self.model = model self.model_type = model_type self.output_dir = output_dir self.save_model_bool = save_model_bool + self.use_img_embedding = use_img_embedding self.validation_ids = validation_ids # Set data structures for training examples self.gt_graphs = list() self.pred_graphs = list() - self.imgs = dict() self.train_dataset_list = list() self.validation_dataset_list = list() @@ -142,9 +144,16 @@ def load_example( } ) - def load_img(self, path, sample_id): - if sample_id not in self.imgs: - self.imgs[sample_id] = img_util.open_tensorstore(path, "zarr") + def load_img( + self, sample_id, img_path, downsample_factor, label_path=None + ): + if sample_id not in self.feature_generators: + self.feature_generators[sample_id] = FeatureGenerator( + img_path, + downsample_factor, + label_path=label_path, + use_img_embedding=self.use_img_embedding, + ) # --- main pipeline --- def run(self): @@ -200,10 +209,8 @@ def generate_features(self): # Generate features sample_id = self.idx_to_ids[i]["sample_id"] - features = feature_generation.run( + features = self.feature_generators[sample_id].run( self.pred_graphs[i], - self.imgs[sample_id], - self.model_type, proposals_dict, self.graph_config.search_radius, ) @@ -463,4 +470,4 @@ def get_predictions(hat_y, threshold=0.5): Binary predictions based on the given threshold. """ - return (ml_util.sigmoid(np.array(hat_y)) > threshold).tolist() + return (ml_util.sigmoid(np.array(hat_y)) > threshold).tolist() \ No newline at end of file