From 46f57456dc2a8e6df0e968e2405a8e9d2761bc35 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 28 Nov 2023 19:57:54 +0000 Subject: [PATCH] optimize feature generation --- .../deep_learning/datasets.py | 2 +- src/deep_neurographs/deep_learning/models.py | 4 +- src/deep_neurographs/deep_learning/train.py | 10 +- src/deep_neurographs/evaluation.py | 60 ++++++++---- src/deep_neurographs/feature_extraction.py | 96 ++++++++----------- src/deep_neurographs/geometry_utils.py | 45 ++------- src/deep_neurographs/neurograph.py | 19 ++-- src/deep_neurographs/utils.py | 45 +++++---- 8 files changed, 134 insertions(+), 147 deletions(-) diff --git a/src/deep_neurographs/deep_learning/datasets.py b/src/deep_neurographs/deep_learning/datasets.py index f53df06..1d12e67 100644 --- a/src/deep_neurographs/deep_learning/datasets.py +++ b/src/deep_neurographs/deep_learning/datasets.py @@ -251,7 +251,7 @@ def __init__(self): tio.RandomNoise(std=(0, 0.0125)), tio.RandomFlip(axes=(0, 1, 2)), # tio.RandomAffine( - # degrees=20, scales=(0.8, 1), image_interpolation="nearest" + # degrees=20, scales=(0.8, 1), image_interpolation="nearest" # ) ] ) diff --git a/src/deep_neurographs/deep_learning/models.py b/src/deep_neurographs/deep_learning/models.py index a3c9314..3f4e4ff 100644 --- a/src/deep_neurographs/deep_learning/models.py +++ b/src/deep_neurographs/deep_learning/models.py @@ -36,9 +36,7 @@ def __init__(self, num_features): nn.Module.__init__(self) self.fc1 = self._init_fc_layer(num_features, num_features // 2) self.fc2 = self._init_fc_layer(num_features // 2, num_features // 2) - self.output = nn.Sequential( - nn.Linear(num_features // 2, 1), nn.Sigmoid() - ) + self.output = nn.Linear(num_features // 2, 1) def _init_fc_layer(self, D_in, D_out): """ diff --git a/src/deep_neurographs/deep_learning/train.py b/src/deep_neurographs/deep_learning/train.py index ff9e819..017e891 100644 --- a/src/deep_neurographs/deep_learning/train.py +++ b/src/deep_neurographs/deep_learning/train.py @@ -20,6 +20,7 @@ from lightning.pytorch.profilers import PyTorchProfiler from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier from torch.utils.data import DataLoader +from torch.nn.functional import sigmoid from torcheval.metrics.functional import ( binary_accuracy, binary_f1_score, @@ -131,16 +132,16 @@ def train_network( return model -def random_split(train_set, train_ratio=0.8): +def random_split(train_set, train_ratio=0.85): train_set_size = int(len(train_set) * train_ratio) valid_set_size = len(train_set) - train_set_size return torch_data.random_split(train_set, [train_set_size, valid_set_size]) -def eval_network(X, model, threshold=0): +def eval_network(X, model, threshold=0.5): model.eval() X = torch.tensor(X, dtype=torch.float32) - y_pred = model.net(X) + y_pred = sigmoid(model.net(X)) return np.array(y_pred > threshold, dtype=int) @@ -174,9 +175,8 @@ def validation_step(self, batch, batch_idx): self.compute_stats(y_hat, y, prefix="val_") def compute_stats(self, y_hat, y, prefix=""): - y_hat = torch.flatten(y_hat) + y_hat = torch.flatten(sigmoid(y_hat)) y = torch.flatten(y).to(torch.int) - self.log(prefix + "accuracy", binary_accuracy(y_hat, y)) self.log(prefix + "precision", binary_precision(y_hat, y)) self.log(prefix + "recall", binary_recall(y_hat, y)) self.log(prefix + "f1", binary_f1_score(y_hat, y)) diff --git a/src/deep_neurographs/evaluation.py b/src/deep_neurographs/evaluation.py index 439e284..230f1f8 100644 --- a/src/deep_neurographs/evaluation.py +++ b/src/deep_neurographs/evaluation.py @@ -7,16 +7,25 @@ Evaluates performance of edge classifier. """ -from copy import deepcopy - import numpy as np +STATS_LIST = [ + "precision", + "recall", + "f1", + "# splits fixed", + "# merges created", +] + def run_evaluation( target_graphs, pred_graphs, y_pred, block_to_idxs, idx_to_edge, blocks ): - stats = init_counters() - stats_by_type = {"simple": init_counters(), "complex": init_counters()} + stats = dict([(s, []) for s in STATS_LIST]) + stats_by_type = { + "simple": dict([(s, []) for s in STATS_LIST]), + "complex": dict([(s, []) for s in STATS_LIST]), + } for block_id in blocks: # Get predicted edges pred_edges = get_predictions( @@ -34,16 +43,12 @@ def run_evaluation( simple_stats, complex_stats = __reconstruction_type_stats( target_graphs[block_id], pred_graphs[block_id], pred_edges ) - for key in stats.keys(): + for key in STATS_LIST: stats_by_type["simple"][key].append(simple_stats[key]) stats_by_type["complex"][key].append(complex_stats[key]) return stats, stats_by_type -def init_counters(val=[]): - return {"# splits fixed": deepcopy(val), "# merges created": deepcopy(val)} - - def get_predictions(idxs, idx_to_edge, y_pred): edge_idxs = set(np.where(y_pred > 0)[0]).intersection(idxs) return set([idx_to_edge[idx] for idx in edge_idxs]) @@ -61,8 +66,8 @@ def __reconstruction_stats(target_graph, pred_graph, pred_edges): def __reconstruction_type_stats(target_graph, pred_graph, pred_edges): - simple_stats = init_counters(val=0) - complex_stats = init_counters(val=0) + simple_stats = dict([(s, 0) for s in STATS_LIST]) + complex_stats = dict([(s, 0) for s in STATS_LIST]) for edge in pred_edges: i, j = tuple(edge) deg_i = pred_graph.immutable_degree(i) @@ -77,13 +82,32 @@ def __reconstruction_type_stats(target_graph, pred_graph, pred_edges): simple_stats["# merges created"] += 1 else: complex_stats["# merges created"] += 1 - return simple_stats, complex_stats + num_simple, num_complex = compute_edge_type(pred_graph) + simple_stats = compute_accuracy(simple_stats, num_simple) + complex_stats = compute_accuracy(complex_stats, num_complex) + return simple_stats, complex_stats -def compute_accuracy(stats, type_key, num_edges): - tp = deepcopy(stats[type_key]["# splits fixed"]) - fp = deepcopy(stats[type_key]["# merges created"]) - recall = tp / num_edges - precision = tp / (tp + fp) - f1 = (2 * recall * precision) / (recall + precision) +def compute_edge_type(graph): + num_simple = 0 + num_complex = 0 + for edge in graph.target_edges: + i, j = tuple(edge) + deg_i = graph.immutable_degree(i) + deg_j = graph.immutable_degree(j) + if deg_i == 1 and deg_j == 1: + num_simple += 1 + else: + num_complex += 1 + return num_simple, num_complex + + +def compute_accuracy(stats, num_edges): + d = stats["# merges created"] + stats["# splits fixed"] + r = 1 if num_edges == 0 else stats["# splits fixed"] / num_edges + p = 1 if d == 0 else stats["# splits fixed"] / d + stats["f1"] = 0 if r + p == 0 else (2 * r * p) / (r + p) + stats["precision"] = p + stats["recall"] = r + return stats diff --git a/src/deep_neurographs/feature_extraction.py b/src/deep_neurographs/feature_extraction.py index 288fbd0..5fd2896 100644 --- a/src/deep_neurographs/feature_extraction.py +++ b/src/deep_neurographs/feature_extraction.py @@ -16,7 +16,6 @@ from deep_neurographs import geometry_utils, utils CHUNK_SIZE = [64, 64, 64] -BUFFER = 256 HALF_CHUNK_SIZE = [CHUNK_SIZE[i] // 2 for i in range(3)] WINDOW_SIZE = [5, 5, 5] @@ -69,44 +68,27 @@ def generate_mutable_img_chunk_features( neurograph, img_path, labels_path, anisotropy=[1.0, 1.0, 1.0] ): features = dict() - shape = neurograph.shape - origin = neurograph.bbox["min"] # world coordinates - origin = utils.apply_anisotropy( - origin, anisotropy, return_int=True - ) # global image coordinates + origin = utils.apply_anisotropy(neurograph.origin, return_int=True) img, labels = utils.get_superchunks( - img_path, labels_path, origin, shape, from_center=False + img_path, labels_path, origin, neurograph.shape, from_center=False ) for edge in neurograph.mutable_edges: # Compute image coordinates - edge_xyz = deepcopy(neurograph.edges[edge]["xyz"]) - edge_xyz = [ - utils.apply_anisotropy( - edge_xyz[0] - origin, anisotropy=anisotropy - ), - utils.apply_anisotropy( - edge_xyz[1] - origin, anisotropy=anisotropy - ), - ] + i, j = tuple(edge) + xyz_i = get_local_img_coords(neurograph, i) + xyz_j = get_local_img_coords(neurograph, j) # Extract chunks - midpoint = geometry_utils.compute_midpoint( - edge_xyz[0], edge_xyz[1] - ).astype(int) + midpoint = geometry_utils.compute_midpoint(xyz_i, xyz_j).astype(int) img_chunk = utils.get_chunk(img, midpoint, CHUNK_SIZE) labels_chunk = utils.get_chunk(labels, midpoint, CHUNK_SIZE) - # Compute path - d = int(geometry_utils.dist(edge_xyz[0], edge_xyz[1]) + 5) - img_coords_1 = np.round( - edge_xyz[0] - midpoint + HALF_CHUNK_SIZE - ).astype(int) - img_coords_2 = np.round( - edge_xyz[1] - midpoint + HALF_CHUNK_SIZE - ).astype(int) - path = geometry_utils.make_line(img_coords_1, img_coords_2, d) - - # Fill path + # Mark path + d = int(geometry_utils.dist(xyz_i, xyz_j) + 5) + img_coords_i = np.round(xyz_i - midpoint + HALF_CHUNK_SIZE).astype(int) + img_coords_j = np.round(xyz_j - midpoint + HALF_CHUNK_SIZE).astype(int) + path = geometry_utils.make_line(img_coords_i, img_coords_j, d) + labels_chunk[labels_chunk > 0] = 1 labels_chunk = geometry_utils.fill_path(labels_chunk, path, val=-1) features[edge] = np.stack([img_chunk, labels_chunk], axis=0) @@ -114,38 +96,27 @@ def generate_mutable_img_chunk_features( return features -def get_chunk(superchunk, xyz): - return deepcopy( - superchunk[ - (xyz[0] - CHUNK_SIZE[0] // 2) : xyz[0] + CHUNK_SIZE[0] // 2, - (xyz[1] - CHUNK_SIZE[1] // 2) : xyz[1] + CHUNK_SIZE[1] // 2, - (xyz[2] - CHUNK_SIZE[2] // 2) : xyz[2] + CHUNK_SIZE[2] // 2, - ] +def get_local_img_coords(neurograph, i): + global_xyz = deepcopy(neurograph.nodes[i]["xyz"]) + local_xyz = utils.apply_anisotropy( + global_xyz - np.array(neurograph.origin) ) + return local_xyz def generate_mutable_img_profile_features( neurograph, path, anisotropy=[1.0, 1.0, 1.0] ): features = dict() - origin = utils.apply_anisotropy( - neurograph.bbox["min"], anisotropy, return_int=True - ) - shape = [neurograph.shape[i] + BUFFER for i in range(3)] + origin = utils.apply_anisotropy(neurograph.origin, return_int=True) superchunk = utils.get_superchunk( - path, "zarr", origin, shape, from_center=False + path, "zarr", origin, neurograph.shape, from_center=False ) for edge in neurograph.mutable_edges: - edge_xyz = deepcopy(neurograph.edges[edge]["xyz"]) - edge_xyz = [ - utils.apply_anisotropy( - edge_xyz[0] - neurograph.origin, anisotropy=anisotropy - ), - utils.apply_anisotropy( - edge_xyz[1] - neurograph.origin, anisotropy=anisotropy - ), - ] - line = geometry_utils.make_line(edge_xyz[0], edge_xyz[1], NUM_POINTS) + i, j = tuple(edge) + xyz_i = get_local_img_coords(neurograph, i) + xyz_j = get_local_img_coords(neurograph, j) + line = geometry_utils.make_line(xyz_i, xyz_j, NUM_POINTS) features[edge] = geometry_utils.get_profile( superchunk, line, window_size=WINDOW_SIZE ) @@ -156,17 +127,14 @@ def generate_mutable_skel_features(neurograph): features = dict() for edge in neurograph.mutable_edges: i, j = tuple(edge) - deg_i = len(list(neurograph.neighbors(i))) - deg_j = len(list(neurograph.neighbors(j))) - length = compute_length(neurograph, edge) radius_i, radius_j = get_radii(neurograph, edge) dot1, dot2, dot3 = get_directionals(neurograph, edge, 5) ddot1, ddot2, ddot3 = get_directionals(neurograph, edge, 10) features[edge] = np.concatenate( ( - length, - deg_i, - deg_j, + compute_length(neurograph, edge), + neurograph.immutable_degree(i), + neurograph.immutable_degree(j), radius_i, radius_j, dot1, @@ -327,3 +295,15 @@ def combine_features(features): (combined[edge], features[key][edge]) ) return combined + + +""" +def get_chunk(superchunk, xyz): + return deepcopy( + superchunk[ + (xyz[0] - CHUNK_SIZE[0] // 2) : xyz[0] + CHUNK_SIZE[0] // 2, + (xyz[1] - CHUNK_SIZE[1] // 2) : xyz[1] + CHUNK_SIZE[1] // 2, + (xyz[2] - CHUNK_SIZE[2] // 2) : xyz[2] + CHUNK_SIZE[2] // 2, + ] + ) +""" diff --git a/src/deep_neurographs/geometry_utils.py b/src/deep_neurographs/geometry_utils.py index 12c6e4b..9000ca4 100644 --- a/src/deep_neurographs/geometry_utils.py +++ b/src/deep_neurographs/geometry_utils.py @@ -2,7 +2,7 @@ from scipy.interpolate import UnivariateSpline from scipy.linalg import svd -from deep_neurographs import utils +from deep_neurographs import utils, feature_extraction as extracter # Context Tangent Vectors @@ -116,6 +116,7 @@ def fit_spline(xyz): return cs_x, cs_y, cs_z +""" def smooth_end(branch_xyz, radii, ref_xyz, num_pts=8): smooth_bool = branch_xyz.shape[0] > 10 if all(branch_xyz[0] == ref_xyz) and smooth_bool: @@ -126,55 +127,21 @@ def smooth_end(branch_xyz, radii, ref_xyz, num_pts=8): return branch_xyz, radii, -1 else: return branch_xyz, radii, None - +""" # Image feature extraction def get_profile(img, xyz_arr, window_size=[5, 5, 5]): return [np.max(utils.get_chunk(img, xyz, window_size)) for xyz in xyz_arr] -""" -def get_profile_old( - img, xyz_arr, anisotropy=[1.0, 1.0, 1.0], window_size=[5, 5, 5] -): - #xyz_arr = get_coords(xyz_arr, anisotropy=anisotropy) - profile = [] - for xyz in xyz_arr: - xyz = xyz.astype(int) - img_chunk = utils.get_chunk(img, xyz, window_size) - profile.append(np.max(img_chunk)) - return np.array(profile) - - - xyz_arr = get_coords(xyz_arr, anisotropy=anisotropy) - profile = [] - for xyz in xyz_arr: - img_chunk = utils.read_img_chunk(img, xyz, window_size) - profile.append(np.max(img_chunk)) - return np.array(profile) -""" - - def fill_path(img, path, val=-1): for xyz in path: - x, y, z = tuple(np.round(xyz).astype(int)) - img[x, y, z] = val - # img[(x - 1) : x + 1, (y - 1) : y + 1, (z - 1) : z + 1] = val + x, y, z = tuple(np.floor(xyz).astype(int)) + img[x - 1 : x + 2, y - 1 : y + 2, z - 1 : z + 2] = val + # img[x,y,z] = val return img -def get_coords(xyz_arr, anisotropy=[1.0, 1.0, 1.0]): - for i in range(3): - xyz_arr[:, i] = xyz_arr[:, i] / anisotropy[i] - return xyz_arr.astype(int) - - -def get_coord(xyz, anisotropy=[1.0, 1.0, 1.0]): - return [int(xyz[i] / anisotropy[i]) for i in range(3)] - - -# Rotate image - # Miscellaneous def compare_edges(xyx_i, xyz_j, xyz_k): dist_ij = dist(xyx_i, xyz_j) diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index 4e29921..ce8b8d5 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -299,8 +299,6 @@ def _query_kdtree(self, query, dist): return self.kdtree.data[idxs] def init_targets(self, target_neurograph): - self.num_simple_edges = 0 - self.num_complex_edges = 0 self.target_edges = set() self.groundtruth_graph = self.init_immutable_graph() target_densegraph = DenseGraph(target_neurograph.path) @@ -334,7 +332,6 @@ def init_targets(self, target_neurograph): continue if not target_densegraph.check_aligned(xyz_i, xyz_j): continue - self.num_complex_edges += 1 else: # Simple criteria inclusion_i = proj_xyz_i in site_to_site.keys() @@ -359,8 +356,6 @@ def init_targets(self, target_neurograph): site_to_site, pair_to_edge = self.remove_site( site_to_site, pair_to_edge, proj_xyz_i, proj_xyz_k ) - else: - self.num_simple_edges += 1 # Add site site_to_site, pair_to_edge = self.add_site( @@ -590,9 +585,11 @@ def is_nb(self, i, j): return True if i in self.neighbors(j) else False def is_contained(self, xyz): + xyz = utils.apply_anisotropy(xyz - np.array(self.bbox["min"])) + img_shape = np.array(self.shape) for i in range(3): - lower_bool = xyz[i] < self.bbox["min"][i] - upper_bool = xyz[i] >= self.bbox["max"][i] + lower_bool = xyz[i] < 32 + upper_bool = xyz[i] > img_shape[i] - 32 if lower_bool or upper_bool: return False return True @@ -619,6 +616,14 @@ def get_center(self): self.bbox["min"], self.bbox["max"] ) + def get_simple_proposals(self): + simple_proposals = set() + for edge in self.mutable_edges: + i, j = tuple(edge) + if self.immutable_degree(i) == 1 and self.immutable_degree(j) == 1: + simple_proposals.add(edge) + return simple_proposals + def to_line_graph(self): """ Converts graph to a line graph. diff --git a/src/deep_neurographs/utils.py b/src/deep_neurographs/utils.py index 23b79f7..5e8ee57 100644 --- a/src/deep_neurographs/utils.py +++ b/src/deep_neurographs/utils.py @@ -19,6 +19,7 @@ import zarr from plotly.subplots import make_subplots +ANISOTROPY = [0.748, 0.748, 1.0] SUPPORTED_DRIVERS = ["neuroglancer_precomputed", "zarr"] @@ -144,9 +145,9 @@ def open_tensorstore(path, driver): def read_img_chunk(img, xyz, shape): return img[ - (xyz[2] - shape[2] // 2) : xyz[2] + shape[2] // 2, - (xyz[1] - shape[1] // 2) : xyz[1] + shape[1] // 2, - (xyz[0] - shape[0] // 2) : xyz[0] + shape[0] // 2, + xyz[2] - shape[2] // 2: xyz[2] + shape[2] // 2, + xyz[1] - shape[1] // 2: xyz[1] + shape[1] // 2, + xyz[0] - shape[0] // 2: xyz[0] + shape[0] // 2, ].transpose(2, 1, 0) @@ -159,20 +160,32 @@ def get_chunk(arr, xyz, shape): def read_tensorstore(ts_arr, xyz, shape): return ( ts_arr[ - (xyz[0] - shape[0] // 2) : xyz[0] + shape[0] // 2, - (xyz[1] - shape[1] // 2) : xyz[1] + shape[1] // 2, - (xyz[2] - shape[2] // 2) : xyz[2] + shape[2] // 2, + xyz[0] - shape[0] // 2: xyz[0] + shape[0] // 2, + xyz[1] - shape[1] // 2: xyz[1] + shape[1] // 2, + xyz[2] - shape[2] // 2: xyz[2] + shape[2] // 2, ] .read() .result() ) -def get_superchunks(img_path, label_path, xyz, shape): +def get_superchunks(img_path, label_path, xyz, shape, from_center=True): with concurrent.futures.ThreadPoolExecutor() as executor: - img_job = executor.submit(get_superchunk, img_path, "zarr", xyz, shape) + img_job = executor.submit( + get_superchunk, + img_path, + "zarr", + xyz, + shape, + from_center=from_center, + ) label_job = executor.submit( - get_superchunk, label_path, "neuroglancer_precomputed", xyz, shape + get_superchunk, + label_path, + "neuroglancer_precomputed", + xyz, + shape, + from_center=from_center, ) return img_job.result(), label_job.result() @@ -317,20 +330,20 @@ def normalize_img(img): return img -def to_world(xyz, anisotropy, shift=[0, 0, 0]): - return tuple([int((xyz[i] - shift[i]) * anisotropy[i]) for i in range(3)]) +def to_world(xyz, shift=[0, 0, 0]): + return tuple([int((xyz[i] - shift[i]) * ANISOTROPY[i]) for i in range(3)]) -def to_img(xyz, anisotropy, shift=[0, 0, 0]): - xyz = apply_anisotropy(xyz - shift, anisotropy, return_int=True) +def to_img(xyz, shift=[0, 0, 0]): + xyz = apply_anisotropy(xyz - shift, return_int=True) return tuple(xyz) -def apply_anisotropy(xyz, anisotropy, return_int=False): +def apply_anisotropy(xyz, return_int=False): if return_int: - return [int(xyz[i] / anisotropy[i]) for i in range(3)] + return [int(xyz[i] / ANISOTROPY[i]) for i in range(3)] else: - return [xyz[i] / anisotropy[i] for i in range(3)] + return [xyz[i] / ANISOTROPY[i] for i in range(3)] def time_writer(t, unit="seconds"):