From 5a116ad7aa1de4b266790a6746b29808b9fef388 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Mon, 13 Nov 2023 19:58:03 +0000 Subject: [PATCH] features extraction speedup --- pyproject.toml | 12 +++ .../deep_learning/datasets.py | 2 +- src/deep_neurographs/deep_learning/train.py | 1 + src/deep_neurographs/evaluation.py | 89 +++++++++++++++++++ src/deep_neurographs/feature_extraction.py | 84 ++++++++++++----- src/deep_neurographs/geometry_utils.py | 26 +++++- src/deep_neurographs/intake.py | 7 +- src/deep_neurographs/neurograph.py | 63 ++++++++----- src/deep_neurographs/utils.py | 48 ++++++++-- 9 files changed, 271 insertions(+), 61 deletions(-) create mode 100644 src/deep_neurographs/evaluation.py diff --git a/pyproject.toml b/pyproject.toml index 92a7c99..9167aac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,18 @@ readme = "README.md" dynamic = ["version"] dependencies = [ + 'boto3', + 'lightning', + 'more_itertools', + 'networkx', + 'plotly', + 'scikit-learn', + 'scipy', + 'tensorstore', + 'torch', + 'torcheval', + 'torchio', + 'zarr', ] [project.optional-dependencies] diff --git a/src/deep_neurographs/deep_learning/datasets.py b/src/deep_neurographs/deep_learning/datasets.py index 63d44c3..f53df06 100644 --- a/src/deep_neurographs/deep_learning/datasets.py +++ b/src/deep_neurographs/deep_learning/datasets.py @@ -248,7 +248,7 @@ def __init__(self): self.transform = tio.Compose( [ tio.RandomBlur(std=(0, 0.4)), - tio.RandomNoise(std=(0, 0.03)), + tio.RandomNoise(std=(0, 0.0125)), tio.RandomFlip(axes=(0, 1, 2)), # tio.RandomAffine( # degrees=20, scales=(0.8, 1), image_interpolation="nearest" diff --git a/src/deep_neurographs/deep_learning/train.py b/src/deep_neurographs/deep_learning/train.py index de06130..ff9e819 100644 --- a/src/deep_neurographs/deep_learning/train.py +++ b/src/deep_neurographs/deep_learning/train.py @@ -99,6 +99,7 @@ def train_network( train_set, num_workers=NUM_WORKERS, batch_size=BATCH_SIZE, + pin_memory=True, shuffle=SHUFFLE, ) valid_loader = DataLoader( diff --git a/src/deep_neurographs/evaluation.py b/src/deep_neurographs/evaluation.py new file mode 100644 index 0000000..439e284 --- /dev/null +++ b/src/deep_neurographs/evaluation.py @@ -0,0 +1,89 @@ +""" +Created on Sat July 15 9:00:00 2023 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Evaluates performance of edge classifier. + +""" +from copy import deepcopy + +import numpy as np + + +def run_evaluation( + target_graphs, pred_graphs, y_pred, block_to_idxs, idx_to_edge, blocks +): + stats = init_counters() + stats_by_type = {"simple": init_counters(), "complex": init_counters()} + for block_id in blocks: + # Get predicted edges + pred_edges = get_predictions( + block_to_idxs[block_id], idx_to_edge, y_pred + ) + + # Overall performance + num_fixes, num_mistakes = __reconstruction_stats( + target_graphs[block_id], pred_graphs[block_id], pred_edges + ) + stats["# splits fixed"].append(num_fixes) + stats["# merges created"].append(num_mistakes) + + # In-depth performance + simple_stats, complex_stats = __reconstruction_type_stats( + target_graphs[block_id], pred_graphs[block_id], pred_edges + ) + for key in stats.keys(): + stats_by_type["simple"][key].append(simple_stats[key]) + stats_by_type["complex"][key].append(complex_stats[key]) + return stats, stats_by_type + + +def init_counters(val=[]): + return {"# splits fixed": deepcopy(val), "# merges created": deepcopy(val)} + + +def get_predictions(idxs, idx_to_edge, y_pred): + edge_idxs = set(np.where(y_pred > 0)[0]).intersection(idxs) + return set([idx_to_edge[idx] for idx in edge_idxs]) + + +def __reconstruction_stats(target_graph, pred_graph, pred_edges): + true_positives = 0 + false_positives = 0 + for edge in pred_edges: + if edge in pred_graph.target_edges: + true_positives += 1 + else: + false_positives += 1 + return true_positives, false_positives + + +def __reconstruction_type_stats(target_graph, pred_graph, pred_edges): + simple_stats = init_counters(val=0) + complex_stats = init_counters(val=0) + for edge in pred_edges: + i, j = tuple(edge) + deg_i = pred_graph.immutable_degree(i) + deg_j = pred_graph.immutable_degree(j) + if edge in pred_graph.target_edges: + if deg_i == 1 and deg_j == 1: + simple_stats["# splits fixed"] += 1 + else: + complex_stats["# splits fixed"] += 1 + else: + if deg_i == 1 and deg_j == 1: + simple_stats["# merges created"] += 1 + else: + complex_stats["# merges created"] += 1 + return simple_stats, complex_stats + + +def compute_accuracy(stats, type_key, num_edges): + tp = deepcopy(stats[type_key]["# splits fixed"]) + fp = deepcopy(stats[type_key]["# merges created"]) + + recall = tp / num_edges + precision = tp / (tp + fp) + f1 = (2 * recall * precision) / (recall + precision) diff --git a/src/deep_neurographs/feature_extraction.py b/src/deep_neurographs/feature_extraction.py index 59df17a..288fbd0 100644 --- a/src/deep_neurographs/feature_extraction.py +++ b/src/deep_neurographs/feature_extraction.py @@ -4,7 +4,7 @@ @author: Anna Grim @email: anna.grim@alleninstitute.org -Builds graph for postprocessing with GNN. +Generates features. """ @@ -16,6 +16,7 @@ from deep_neurographs import geometry_utils, utils CHUNK_SIZE = [64, 64, 64] +BUFFER = 256 HALF_CHUNK_SIZE = [CHUNK_SIZE[i] // 2 for i in range(3)] WINDOW_SIZE = [5, 5, 5] @@ -67,51 +68,86 @@ def generate_mutable_features( def generate_mutable_img_chunk_features( neurograph, img_path, labels_path, anisotropy=[1.0, 1.0, 1.0] ): - img = utils.open_zarr(img_path) - pred_labels = utils.open_tensorstore(labels_path) features = dict() + shape = neurograph.shape + origin = neurograph.bbox["min"] # world coordinates + origin = utils.apply_anisotropy( + origin, anisotropy, return_int=True + ) # global image coordinates + img, labels = utils.get_superchunks( + img_path, labels_path, origin, shape, from_center=False + ) for edge in neurograph.mutable_edges: - # Extract coordinates - edge_xyz = neurograph.edges[edge]["xyz"] - edge_xyz[0] = utils.apply_anisotropy( - edge_xyz[0], anisotropy=anisotropy - ) - edge_xyz[1] = utils.apply_anisotropy( - edge_xyz[1], anisotropy=anisotropy - ) + # Compute image coordinates + edge_xyz = deepcopy(neurograph.edges[edge]["xyz"]) + edge_xyz = [ + utils.apply_anisotropy( + edge_xyz[0] - origin, anisotropy=anisotropy + ), + utils.apply_anisotropy( + edge_xyz[1] - origin, anisotropy=anisotropy + ), + ] - # Read chunks - midpoint = geometry_utils.compute_midpoint(edge_xyz[0], edge_xyz[1]) - origin = tuple(np.round(midpoint).astype(int)) - img_chunk = utils.read_img_chunk(img, origin, CHUNK_SIZE) - labels_chunk = utils.read_tensorstore(pred_labels, origin, CHUNK_SIZE) + # Extract chunks + midpoint = geometry_utils.compute_midpoint( + edge_xyz[0], edge_xyz[1] + ).astype(int) + img_chunk = utils.get_chunk(img, midpoint, CHUNK_SIZE) + labels_chunk = utils.get_chunk(labels, midpoint, CHUNK_SIZE) - # Add path - d = geometry_utils.dist(edge_xyz[0], edge_xyz[1]) + # Compute path + d = int(geometry_utils.dist(edge_xyz[0], edge_xyz[1]) + 5) img_coords_1 = np.round( - (edge_xyz[0] - midpoint + HALF_CHUNK_SIZE) + edge_xyz[0] - midpoint + HALF_CHUNK_SIZE ).astype(int) img_coords_2 = np.round( edge_xyz[1] - midpoint + HALF_CHUNK_SIZE ).astype(int) - path = geometry_utils.make_line(img_coords_1, img_coords_2, int(d + 5)) + path = geometry_utils.make_line(img_coords_1, img_coords_2, d) + # Fill path labels_chunk[labels_chunk > 0] = 1 labels_chunk = geometry_utils.fill_path(labels_chunk, path, val=-1) features[edge] = np.stack([img_chunk, labels_chunk], axis=0) + return features +def get_chunk(superchunk, xyz): + return deepcopy( + superchunk[ + (xyz[0] - CHUNK_SIZE[0] // 2) : xyz[0] + CHUNK_SIZE[0] // 2, + (xyz[1] - CHUNK_SIZE[1] // 2) : xyz[1] + CHUNK_SIZE[1] // 2, + (xyz[2] - CHUNK_SIZE[2] // 2) : xyz[2] + CHUNK_SIZE[2] // 2, + ] + ) + + def generate_mutable_img_profile_features( neurograph, path, anisotropy=[1.0, 1.0, 1.0] ): - img = utils.open_zarr(path) features = dict() + origin = utils.apply_anisotropy( + neurograph.bbox["min"], anisotropy, return_int=True + ) + shape = [neurograph.shape[i] + BUFFER for i in range(3)] + superchunk = utils.get_superchunk( + path, "zarr", origin, shape, from_center=False + ) for edge in neurograph.mutable_edges: - xyz = neurograph.edges[edge]["xyz"] - line = geometry_utils.make_line(xyz[0], xyz[1], NUM_POINTS) + edge_xyz = deepcopy(neurograph.edges[edge]["xyz"]) + edge_xyz = [ + utils.apply_anisotropy( + edge_xyz[0] - neurograph.origin, anisotropy=anisotropy + ), + utils.apply_anisotropy( + edge_xyz[1] - neurograph.origin, anisotropy=anisotropy + ), + ] + line = geometry_utils.make_line(edge_xyz[0], edge_xyz[1], NUM_POINTS) features[edge] = geometry_utils.get_profile( - img, line, anisotropy=anisotropy, window_size=WINDOW_SIZE + superchunk, line, window_size=WINDOW_SIZE ) return features diff --git a/src/deep_neurographs/geometry_utils.py b/src/deep_neurographs/geometry_utils.py index 47b22fd..12c6e4b 100644 --- a/src/deep_neurographs/geometry_utils.py +++ b/src/deep_neurographs/geometry_utils.py @@ -129,21 +129,37 @@ def smooth_end(branch_xyz, radii, ref_xyz, num_pts=8): # Image feature extraction -def get_profile( - img, xyz_arr, anisotropy=[1.0, 1.0, 1.0], window_size=[4, 4, 4] +def get_profile(img, xyz_arr, window_size=[5, 5, 5]): + return [np.max(utils.get_chunk(img, xyz, window_size)) for xyz in xyz_arr] + + +""" +def get_profile_old( + img, xyz_arr, anisotropy=[1.0, 1.0, 1.0], window_size=[5, 5, 5] ): + #xyz_arr = get_coords(xyz_arr, anisotropy=anisotropy) + profile = [] + for xyz in xyz_arr: + xyz = xyz.astype(int) + img_chunk = utils.get_chunk(img, xyz, window_size) + profile.append(np.max(img_chunk)) + return np.array(profile) + + xyz_arr = get_coords(xyz_arr, anisotropy=anisotropy) profile = [] for xyz in xyz_arr: img_chunk = utils.read_img_chunk(img, xyz, window_size) profile.append(np.max(img_chunk)) return np.array(profile) +""" def fill_path(img, path, val=-1): for xyz in path: x, y, z = tuple(np.round(xyz).astype(int)) - img[(x - 1) : x + 1, (y - 1) : y + 1, (z - 1) : z + 1] = val + img[x, y, z] = val + # img[(x - 1) : x + 1, (y - 1) : y + 1, (z - 1) : z + 1] = val return img @@ -185,8 +201,10 @@ def dist(x, y, metric="l2"): def make_line(xyz_1, xyz_2, num_steps): + xyz_1 = np.array(xyz_1) + xyz_2 = np.array(xyz_2) t_steps = np.linspace(0, 1, num_steps) - return np.array([(1 - t) * xyz_1 + t * xyz_2 for t in t_steps]) + return np.array([(1 - t) * xyz_1 + t * xyz_2 for t in t_steps], dtype=int) def normalize(x, norm="l2"): diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index 86a6fd4..2d79eac 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -52,9 +52,10 @@ def build_neurograph( prune=prune, prune_depth=prune_depth, ) - neurograph.generate_mutables( - max_degree=max_mutable_degree, search_radius=search_radius - ) + if search_radius > 0: + neurograph.generate_mutables( + max_degree=max_mutable_degree, search_radius=search_radius + ) return neurograph diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index d496ae9..4e29921 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -51,17 +51,18 @@ def __init__(self, swc_path, label_mask=None, origin=None, shape=None): self.mutable_edges = set() self.target_edges = set() self.xyz_to_edge = dict() - if origin is not None and shape is not None: - self.init_bbox(origin, shape) + + if origin and shape: + self.bbox = { + "min": list(origin), + "max": [origin[i] + shape[i] for i in range(3)], + } + self.origin = origin + self.shape = shape else: self.bbox = None # --- Add nodes or edges --- - def init_bbox(self, origin, shape): - self.bbox = dict() - self.bbox["min"] = list(origin) - self.bbox["max"] = [self.bbox["min"][i] + shape[i] for i in range(3)] - def generate_immutables( self, swc_id, swc_dict, prune=True, prune_depth=16 ): @@ -122,6 +123,9 @@ def generate_immutables( for j in junctions: self.junctions.add(node_id[j]) + # Build kdtree + self._init_kdtree() + def init_immutable_graph(self): immutable_graph = nx.Graph() immutable_graph.add_nodes_from(self) @@ -137,28 +141,32 @@ def generate_mutables(self, max_degree=3, search_radius=25.0): None """ - # Search for mutable connections self.mutable_edges = set() - self._init_kdtree() for leaf in self.leafs: - xyz_leaf = self.nodes[leaf]["xyz"] - if not self.is_contained(xyz_leaf): + if not self.is_contained(self.nodes[leaf]["xyz"]): continue + xyz_leaf = self.nodes[leaf]["xyz"] mutables = self._get_mutables( leaf, xyz_leaf, max_degree, search_radius ) for xyz in mutables: - if not self.is_contained(xyz): - continue # Extract info on mutable connection (i, j) = self.xyz_to_edge[xyz] attrs = self.get_edge_data(i, j) # Get connecting node - if geometry_utils.dist(xyz, attrs["xyz"][0]) < 10: + contained_i = self.is_contained(self.nodes[i]["xyz"]) + contained_j = self.is_contained(self.nodes[j]["xyz"]) + if ( + geometry_utils.dist(xyz, attrs["xyz"][0]) < 10 + and contained_i + ): node = i xyz = self.nodes[node]["xyz"] - elif geometry_utils.dist(xyz, attrs["xyz"][-1]) < 10: + elif ( + geometry_utils.dist(xyz, attrs["xyz"][-1]) < 10 + and contained_j + ): node = j xyz = self.nodes[node]["xyz"] else: @@ -187,6 +195,8 @@ def _get_mutables(self, query_id, query_xyz, max_degree, search_radius): best_dist = dict() query_swc_id = self.nodes[query_id]["swc_id"] for xyz in self._query_kdtree(query_xyz, search_radius): + if not self.is_contained(xyz): + continue xyz = tuple(xyz) edge = self.xyz_to_edge[xyz] swc_id = gutils.get_edge_attr(self, edge, "swc_id") @@ -289,6 +299,8 @@ def _query_kdtree(self, query, dist): return self.kdtree.data[idxs] def init_targets(self, target_neurograph): + self.num_simple_edges = 0 + self.num_complex_edges = 0 self.target_edges = set() self.groundtruth_graph = self.init_immutable_graph() target_densegraph = DenseGraph(target_neurograph.path) @@ -308,7 +320,7 @@ def init_targets(self, target_neurograph): proj_xyz_j, d_j = target_neurograph.get_projection(xyz_j) # Check criteria - if d_i > 8 or d_j > 8: + if d_i > 7.5 or d_j > 7.5: continue elif self.check_cycle((i, j)): continue @@ -322,6 +334,7 @@ def init_targets(self, target_neurograph): continue if not target_densegraph.check_aligned(xyz_i, xyz_j): continue + self.num_complex_edges += 1 else: # Simple criteria inclusion_i = proj_xyz_i in site_to_site.keys() @@ -346,6 +359,8 @@ def init_targets(self, target_neurograph): site_to_site, pair_to_edge = self.remove_site( site_to_site, pair_to_edge, proj_xyz_i, proj_xyz_k ) + else: + self.num_simple_edges += 1 # Add site site_to_site, pair_to_edge = self.add_site( @@ -575,12 +590,11 @@ def is_nb(self, i, j): return True if i in self.neighbors(j) else False def is_contained(self, xyz): - if type(self.bbox) is dict: - for i in range(3): - lower_bool = xyz[i] < self.bbox["min"][i] - upper_bool = xyz[i] > self.bbox["max"][i] - if lower_bool or upper_bool: - return False + for i in range(3): + lower_bool = xyz[i] < self.bbox["min"][i] + upper_bool = xyz[i] >= self.bbox["max"][i] + if lower_bool or upper_bool: + return False return True def is_leaf(self, i): @@ -600,6 +614,11 @@ def get_edge_attr(self, key, i, j): attr_2 = self.nodes[j][key] return attr_1, attr_2 + def get_center(self): + return geometry_utils.compute_midpoint( + self.bbox["min"], self.bbox["max"] + ) + def to_line_graph(self): """ Converts graph to a line graph. diff --git a/src/deep_neurographs/utils.py b/src/deep_neurographs/utils.py index e63b2fb..23b79f7 100644 --- a/src/deep_neurographs/utils.py +++ b/src/deep_neurographs/utils.py @@ -8,7 +8,7 @@ """ - +import concurrent.futures import json import os import shutil @@ -19,6 +19,8 @@ import zarr from plotly.subplots import make_subplots +SUPPORTED_DRIVERS = ["neuroglancer_precomputed", "zarr"] + # --- dictionary utils --- def remove_item(my_set, item): @@ -103,7 +105,7 @@ def open_zarr(path): return zarr.open(n5store).s0 -def open_tensorstore(path): +def open_tensorstore(path, driver): """ Uploads segmentation mask stored as a directory of shard files. @@ -111,6 +113,8 @@ def open_tensorstore(path): ---------- path : str Path to directory containing shard files. + driver : str + Storage driver needed to read data at "path". Returns ------- @@ -118,9 +122,10 @@ def open_tensorstore(path): Sparse image volume. """ + assert driver in SUPPORTED_DRIVERS, "Error! Driver is not supported!" ts_arr = ts.open( { - "driver": "neuroglancer_precomputed", + "driver": driver, "kvstore": { "driver": "gcs", "bucket": "allen-nd-goog", @@ -128,7 +133,13 @@ def open_tensorstore(path): }, } ).result() - return ts_arr[ts.d["channel"][0]] + if driver == "neuroglancer_precomputed": + return ts_arr[ts.d["channel"][0]] + elif driver == "zarr": + ts_arr = ts_arr[0, 0, :, :, :] + ts_arr = ts_arr[ts.d[0].transpose[2]] + ts_arr = ts_arr[ts.d[0].transpose[1]] + return ts_arr def read_img_chunk(img, xyz, shape): @@ -139,8 +150,14 @@ def read_img_chunk(img, xyz, shape): ].transpose(2, 1, 0) +def get_chunk(arr, xyz, shape): + xyz_1 = [max(xyz[i] - shape[i] // 2, 0) for i in range(3)] + xyz_2 = [min(xyz[i] + shape[i] // 2, arr.shape[i] - 1) for i in range(3)] + return arr[xyz_1[0] : xyz_2[0], xyz_1[1] : xyz_2[1], xyz_1[2] : xyz_2[2]] + + def read_tensorstore(ts_arr, xyz, shape): - arr = ( + return ( ts_arr[ (xyz[0] - shape[0] // 2) : xyz[0] + shape[0] // 2, (xyz[1] - shape[1] // 2) : xyz[1] + shape[1] // 2, @@ -149,7 +166,24 @@ def read_tensorstore(ts_arr, xyz, shape): .read() .result() ) - return arr + + +def get_superchunks(img_path, label_path, xyz, shape): + with concurrent.futures.ThreadPoolExecutor() as executor: + img_job = executor.submit(get_superchunk, img_path, "zarr", xyz, shape) + label_job = executor.submit( + get_superchunk, label_path, "neuroglancer_precomputed", xyz, shape + ) + return img_job.result(), label_job.result() + + +def get_superchunk(path, driver, xyz, shape, from_center=True): + ts_arr = open_tensorstore(path, driver) + if from_center: + return read_tensorstore(ts_arr, xyz, shape) + else: + xyz = [xyz[i] + shape[i] // 2 for i in range(3)] + return read_tensorstore(ts_arr, xyz, shape) def read_json(path): @@ -288,7 +322,7 @@ def to_world(xyz, anisotropy, shift=[0, 0, 0]): def to_img(xyz, anisotropy, shift=[0, 0, 0]): - xyz = apply_anisotropy(xyz - shift, return_int=True) + xyz = apply_anisotropy(xyz - shift, anisotropy, return_int=True) return tuple(xyz)