diff --git a/src/deep_neurographs/feature_extraction.py b/src/deep_neurographs/feature_extraction.py index 2a5e708..33aba7b 100644 --- a/src/deep_neurographs/feature_extraction.py +++ b/src/deep_neurographs/feature_extraction.py @@ -9,51 +9,25 @@ """ import numpy as np -from deep_neurographs import utils +from copy import deepcopy +from deep_neurographs import utils, geometry_utils +from random import sample +from scipy.linalg import svd NUM_IMG_FEATURES = 0 -NUM_SKEL_FEATURES = 4 +NUM_SKEL_FEATURES = 9 NUM_PC_FEATURES = 0 # -- Wrappers -- -def generate_node_features(neurograph, img=True, pointcloud=True, skel=True): - features = dict() - if img: - features["img"] = generate_img_features(neurograph) - - if skel: - features["skel"] = generate_skel_features(neurograph) - - if pointcloud: - features["pointcloud"] = generate_pointcloud_features(neurograph) - return extract_feature_vec(features) - - -def generate_immutable_features(neurograph, img=True, pointcloud=True, skel=True): - features = dict() - if img: - features["img"] = generate_img_features(neurograph) - - if skel: - features["skel"] = generate_immutable_skel_features(neurograph) - - if pointcloud: - features["pointcloud"] = generate_pointcloud_features(neurograph) - return extract_feature_vec(features) - - def generate_mutable_features(neurograph, img=True, pointcloud=True, skel=True): features = dict() if img: features["img"] = generate_img_features(neurograph) - if skel: features["skel"] = generate_mutable_skel_features(neurograph) - - if pointcloud: - features["pointcloud"] = generate_pointcloud_features(neurograph) - return extract_feature_vec(features) + features = combine_feature_vecs(features) + return features # -- Node feature extraction -- @@ -71,7 +45,6 @@ def _generate_node_img_features(): def generate_skel_features(neurograph): skel_features = np.zeros((neurograph.num_nodes(), NUM_SKEL_FEATURES)) for node in neurograph.nodes: - output = _generate_node_skel_features(neurograph, node) skel_features[node] = _generate_node_skel_features(neurograph, node) return skel_features @@ -94,31 +67,94 @@ def _generate_pointcloud_node_features(): # -- Edge feature extraction -- -def generate_immutable_skel_features(neurograph): +def generate_mutable_skel_features(neurograph): features = dict() - for edge in neurograph.immutable_edges: - features[edge] = _generate_immutable_skel_features(neurograph, edge) + for edge in neurograph.mutable_edges: + length = compute_length(neurograph, edge) + radius_i, radius_j = get_radii(neurograph, edge) + + dot1, dot2, dot3 = get_directionals(neurograph, edge, 5) + ddot1, ddot2, ddot3 = get_directionals(neurograph, edge, 10) + features[edge] = np.concatenate((length, dot1, dot2, dot3), axis=None) return features -def _generate_immutable_skel_features(neurograph, edge): - mean_xyz = np.mean(neurograph.edges[edge]["xyz"], axis=0) - mean_radius = np.mean(neurograph.edges[edge]["radius"], axis=0) - path_length = len(neurograph.edges[edge]["radius"]) - return np.concatenate((mean_xyz, mean_radius, path_length), axis=None) +def compute_length(neurograph, edge, metric="l2"): + i, j = tuple(edge) + xyz_1, xyz_2 = neurograph.get_edge_attr("xyz", i, j) + return utils.dist(xyz_1, xyz_2, metric=metric) + + +def get_directionals(neurograph, edge, window_size): + # Compute tangent vectors + i, j = tuple(edge) + mutable_xyz_i, mutable_xyz_j = neurograph.get_edge_attr("xyz", i, j) + mutable_xyz = np.array([mutable_xyz_i, mutable_xyz_j]) + mutable_tangent = geometry_utils.compute_tangent(mutable_xyz) + context_tangent_1 = geometry_utils.compute_context_vec(neurograph, i, mutable_tangent, window_size=window_size) + context_tangent_2 = geometry_utils.compute_context_vec(neurograph, j, mutable_tangent, window_size=window_size) + + # Compute features + inner_product_1 = abs(np.dot(mutable_tangent, context_tangent_1)) + inner_product_2 = abs(np.dot(mutable_tangent, context_tangent_2)) + inner_product_3 = np.dot(context_tangent_1, context_tangent_2) + return inner_product_1, inner_product_2, inner_product_3 + + +def get_radii(neurograph, edge): + i, j = tuple(edge) + radius_i = neurograph.nodes[i]["radius"] + radius_j = neurograph.nodes[j]["radius"] + return radius_i, radius_j + + +# -- Combine feature vectors +def build_feature_matrix(neurographs, features, blocks): + # Initialize + X = None + block_to_idxs = dict() + idx_to_edge = dict() + + # Feature extraction + for block_id in blocks: + # Get features + idx_shift = 0 if X is None else X.shape[0] + X_i, y_i, idx_to_edge_i = build_feature_submatrix( + neurographs[block_id], + features[block_id], + idx_shift, + ) + + # Concatenate + if X is None: + X = deepcopy(X_i) + y = deepcopy(y_i) + else: + X = np.concatenate((X, X_i), axis=0) + y = np.concatenate((y, y_i), axis=0) + # Update dicts + idxs = set(np.arange(idx_shift, idx_shift + len(idx_to_edge_i))) + block_to_idxs[block_id] = idxs + idx_to_edge.update(idx_to_edge_i) + return X, y, block_to_idxs, idx_to_edge -def generate_mutable_skel_features(neurograph): - features = dict() - for edge in neurograph.mutable_edges: - features[edge] = _generate_mutable_skel_features(neurograph, edge) - return features +def build_feature_submatrix(neurograph, feat_dict, shift): + # Extract info + key = sample(list(feat_dict.keys()), 1)[0] + num_edges = neurograph.num_mutables() + num_features = len(feat_dict[key]) -def _generate_mutable_skel_features(neurograph, edge): - mean_xyz = np.mean(neurograph.edges[edge]["xyz"], axis=0) - edge_length = compute_length(neurograph, edge) - return np.concatenate((mean_xyz, edge_length), axis=None) + # Build + idx_to_edge = dict() + X = np.zeros((num_edges, num_features)) + y = np.zeros((num_edges)) + for i, edge in enumerate(feat_dict.keys()): + idx_to_edge[i + shift] = edge + X[i, :] = feat_dict[edge] + y[i] = 1 if edge in neurograph.target_edges else 0 + return X, y, idx_to_edge # -- Utils -- @@ -129,17 +165,54 @@ def compute_num_features(features): return num_features -def extract_feature_vec(features,): - feature_vec = None +def combine_feature_vecs(features): + vec = None for key in features.keys(): - if feature_vec is None: - feature_vec = features[key] + if vec is None: + vec = features[key] else: - feature_vec = np.concatenate((feature_vec, features[key]), axis=1) - return feature_vec + vec = np.concatenate((vec, features[key]), axis=1) + return vec + + + +""" + +def generate_node_features(neurograph, img=True, pointcloud=True, skel=True): + features = dict() + if img: + features["img"] = generate_img_features(neurograph) + + if skel: + features["skel"] = generate_skel_features(neurograph) + + if pointcloud: + features["pointcloud"] = generate_pointcloud_features(neurograph) + return extract_feature_vec(features) -def compute_length(neurograph, edge): - xyz_1 = neurograph.edges[edge]["xyz"][0] - xyz_2 = neurograph.edges[edge]["xyz"][1] - return utils.dist(xyz_1, xyz_2) \ No newline at end of file +def generate_immutable_features( + neurograph, img=True, pointcloud=True, skel=True +): + features = dict() + if img: + features["img"] = generate_img_features(neurograph) + + if skel: + features["skel"] = generate_immutable_skel_features(neurograph) + + if pointcloud: + features["pointcloud"] = generate_pointcloud_features(neurograph) + return extract_feature_vec(features) + +def generate_immutable_skel_features(neurograph): + features = dict() + for edge in neurograph.immutable_edges: + features[edge] = _generate_immutable_skel_features(neurograph, edge) + return features + + +def _generate_immutable_skel_features(neurograph, edge): + mean_radius = np.mean(neurograph.edges[edge]["radius"], axis=0) + return np.concatenate((mean_radius), axis=None) +""" \ No newline at end of file diff --git a/src/deep_neurographs/geometry_utils.py b/src/deep_neurographs/geometry_utils.py new file mode 100644 index 0000000..582c088 --- /dev/null +++ b/src/deep_neurographs/geometry_utils.py @@ -0,0 +1,78 @@ +import numpy as np +from deep_neurographs import utils +from scipy.linalg import svd + + +# Context Tangent Vectors +def compute_context_vec(neurograph, i, mutable_tangent, window_size=5, return_pts=False, vec_type="tangent"): + # Compute context vecs + branches = get_branches(neurograph, i) + context_vec_list = [] + xyz_list = [] + ref_xyz = neurograph.nodes[i]["xyz"] + for branch in branches: + context_vec, xyz = _compute_context_vec(branch, ref_xyz, window_size, vec_type) + context_vec_list.append(context_vec) + xyz_list.append(xyz) + + # Determine best + max_dot_prod = 0 + arg_max = -1 + for k in range(len(context_vec_list)): + dot_prod = abs(np.dot(mutable_tangent, context_vec_list[k])) + if dot_prod >= max_dot_prod: + max_dot_prod = dot_prod + arg_max = k + + # Compute normal + if return_pts: + return context_vec_list, branches, xyz_list, arg_max + else: + return context_vec_list[arg_max] + + +def _compute_context_vec(all_xyz, ref_xyz, window_size, vec_type): + from_start = orient_pts(all_xyz, ref_xyz) + xyz = get_pts(all_xyz, window_size, from_start) + if vec_type == "normal": + vec = compute_normal(xyz) + else: + vec = compute_tangent(xyz) + return vec, np.mean(xyz, axis=0).reshape(1, 3) + + +def get_branches(neurograph, i): + nbs = [] + for j in list(neurograph.neighbors(i)): + if frozenset((i, j)) in neurograph.immutable_edges: + nbs.append(j) + return [neurograph.edges[i, j]["xyz"] for j in nbs] + + +def orient_pts(xyz, ref_xyz): + return True if all(xyz[0] == ref_xyz) else False + + +def get_pts(xyz, window_size, from_start): + if len(xyz) > window_size and from_start: + return xyz[0:window_size] + elif len(xyz) > window_size and not from_start: + return xyz[-window_size:] + else: + return xyz + + +def compute_svd(xyz): + xyz = xyz - np.mean(xyz, axis=0) + return svd(xyz) + + +def compute_tangent(xyz): + if xyz.shape[0] == 2: + tangent = (xyz[1] - xyz[0]) / utils.dist(xyz[1], xyz[0]) + else: + U, S, VT = compute_svd(xyz) + tangent = VT[0] + return tangent / np.linalg.norm(tangent) + + diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index a3b163f..9363f61 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -100,7 +100,7 @@ def init_immutables_from_local( anisotropy=[1.0, 1.0, 1.0], prune=True, prune_depth=16, - smooth=True, + smooth=False, ): """ To do... @@ -109,8 +109,8 @@ def init_immutables_from_local( raw_swc = swc_utils.read_swc(os.path.join(swc_dir, swc_id)) swc_id = swc_id.replace(".0.swc", "") swc_dict = swc_utils.parse(raw_swc, anisotropy=anisotropy) - #if smooth: - # swc_dict = swc_utils.smooth(swc_dict) + if smooth: + swc_dict = swc_utils.smooth(swc_dict) neurograph.generate_immutables( swc_id, swc_dict, prune=prune, prune_depth=prune_depth ) diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index b32a6fa..8a3ae79 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -14,6 +14,7 @@ import numpy as np import plotly.graph_objects as go import plotly.tools as tls +import tensorstore as ts from more_itertools import zip_broadcast from plotly.subplots import make_subplots from scipy.spatial import KDTree @@ -21,8 +22,11 @@ from deep_neurographs import graph_utils as gutils from deep_neurographs import swc_utils, utils +from tifffile import imwrite + COLORS = list(mcolors.TABLEAU_COLORS.keys()) nCOLORS = len(COLORS) +SUPPORTED_LABEL_MASK_TYPES = [dict, np.array, ts.TensorStore] class NeuroGraph(nx.Graph): @@ -33,7 +37,7 @@ class NeuroGraph(nx.Graph): """ - def __init__(self): + def __init__(self, label_mask=None): """ Parameters ---------- @@ -45,6 +49,7 @@ def __init__(self): """ super(NeuroGraph, self).__init__() + self.label_mask = label_mask self.leafs = set() self.junctions = set() self.immutable_edges = set() @@ -135,10 +140,10 @@ def generate_mutables(self, max_degree=5, max_dist=100.0): attrs = self.get_edge_data(i, j) # Get connecting node - if utils.dist(xyz, attrs["xyz"][0]) < 8: + if utils.dist(xyz, attrs["xyz"][0]) < 5: node = i xyz = self.nodes[node]["xyz"] - elif utils.dist(xyz, attrs["xyz"][-1]) < 8: + elif utils.dist(xyz, attrs["xyz"][-1]) < 5: node = j xyz = self.nodes[node]["xyz"] if node == leaf: @@ -170,7 +175,7 @@ def _get_mutables(self, query_id, query_xyz, max_degree, max_dist): best_dist = dict() query_swc_id = self.nodes[query_id]["swc_id"] for xyz in self._query_kdtree(query_xyz, max_dist): - xyz = tuple(xyz.astype(int)) + xyz = tuple(xyz) #.astype(int)) edge = self.xyz_to_edge[xyz] swc_id = gutils.get_edge_attr(self, edge, "swc_id") if swc_id != query_swc_id: @@ -269,34 +274,52 @@ def _query_kdtree(self, query, max_dist): idxs = self.kdtree.query_ball_point(query, max_dist) return self.kdtree.data[idxs] - def init_targets(self, log_path, dist_threshold): - # Initializations - splits_log = utils.read_mistake_log(log_path) - split_edges = set(splits_log.keys()) - target_edges = set() - # Parse mutables + def init_targets(self, target_labels, log_path, anisotropy=[1.0, 1.0, 1.0], shift=[0, 0, 0]): + labels_to_edge = dict() + mistakes_xyz = utils.read_mistake_coords(log_path, anisotropy=anisotropy, shift=shift) + mistakes_kdtree = KDTree(mistakes_xyz) for edge in self.mutable_edges: + # Get edge i, j = tuple(edge) - key = frozenset(self.get_edge_attr("swc_id", i, j)) - if key in split_edges: - k, l = list(edge) - mutable_xyz_1 = self.nodes[k]["xyz"] - mutable_xyz_2 = self.nodes[l]["xyz"] - log_xyz_1 = splits_log[key]["xyz"][0] - log_xyz_2 = splits_log[key]["xyz"][1] - - pair_1 = [mutable_xyz_1, mutable_xyz_2] - pair_2 = [log_xyz_1, log_xyz_2] - d = utils.pair_dist(pair_1, pair_2) - if d < dist_threshold: - target_edges.add(frozenset((i, j))) - self.target_edges = target_edges - print("% target edges in mistake log:", len(target_edges) / len(split_edges)) - print("% target edges in mutable:", len(target_edges) / len(self.mutable_edges)) + xyz_1 = utils.to_img(self.nodes[i]["xyz"], anisotropy, shift=shift) + xyz_2 = utils.to_img(self.nodes[j]["xyz"], anisotropy, shift=shift) + + # Check coords in bounds + bounds_1 = [xyz_1[i] >= target_labels.shape[i] for i in range(3)] + bounds_2 = [xyz_2[i] >= target_labels.shape[i] for i in range(3)] + if any(bounds_1) or any(bounds_2): + continue + + # Check img + cond_1 = utils.check_img_path(target_labels, xyz_1, xyz_2) + + # Check mistake log + d_1, _ = mistakes_kdtree.query(xyz_1, k=1) + d_2, _ = mistakes_kdtree.query(xyz_2, k=1) + cond_2 = d_1 < 10 and d_2 < 10 + if cond_1 and cond_2: + key = frozenset([self.nodes[i]["swc_id"], self.nodes[j]["swc_id"]]) + if key in labels_to_edge.keys(): + # Check whether to update target edge + cur_dist = self.compute_length(labels_to_edge[key]) + new_dist = self.compute_length(edge) + if cur_dist < new_dist: + continue + else: + self.target_edges.remove(labels_to_edge[key]) + + # Add new target edge + labels_to_edge[key] = edge + self.target_edges.add(edge) + print("# target edges:", len(self.target_edges)) + print( + "% target edges in mutable:", + len(self.target_edges) / len(self.mutable_edges), + ) # --- Visualization --- - def visualize_immutables(self, return_data=False, title="Immutable Graph"): + def visualize_immutables(self, title="Immutable Graph", return_data=False): """ Parameters ---------- @@ -317,16 +340,31 @@ def visualize_immutables(self, return_data=False, title="Immutable Graph"): else: utils.plot(data, title) - def visualize_mutables(self, title="Mutable Graph"): + def visualize_mutables(self, title="Mutable Graph", return_data=False): data = [self._plot_nodes()] data.extend(self._plot_edges(self.immutable_edges, color="black")) data.extend(self._plot_edges(self.mutable_edges)) - utils.plot(data, title) + if return_data: + return data + else: + utils.plot(data, title) + - def visualize_targets(self, target_edges, title="Target Edges"): + def visualize_targets(self, target_graph=None, title="Target Edges", return_data=False): data = [self._plot_nodes()] data.extend(self._plot_edges(self.immutable_edges, color="black")) - data.extend(self._plot_edges(target_edges)) + data.extend(self._plot_edges(self.target_edges)) + if target_graph is not None: + data.extend(target_graph._plot_edges(target_graph.immutable_edges, color="blue")) + if return_data: + return data + else: + utils.plot(data, title) + + def visualize_subset(self, edges, title=""): + data = [self._plot_nodes()] + data.extend(self._plot_edges(self.immutable_edges, color="black")) + data.extend(self._plot_edges(edges)) utils.plot(data, title) def _plot_nodes(self): @@ -389,6 +427,49 @@ def num_edges(self): """ return self.number_of_edges() + def num_immutables(self): + """ + Computes number of immutable edges in the graph. + + Parameters + ---------- + None + + Returns + ------- + int + Number of immutable edges in the graph. + + """ + return len(self.immutable_edges) + + def num_mutables(self): + """ + Computes number of mutable edges in the graph. + + Parameters + ---------- + None + + Returns + ------- + int + Number of mutable edges in the graph. + + """ + return len(self.mutable_edges) + + def compute_length(self, edge, metric="l2"): + i, j = tuple(edge) + xyz_1, xyz_2 = self.get_edge_attr("xyz", i, j) + return utils.dist(xyz_1, xyz_2, metric=metric) + + def path_length(self, metric="l2"): + length = 0 + for edge in self.immutable_edges: + length += self.compute_length(edge, metric=metric) + return length + def get_edge_attr(self, key, i, j): attr_1 = self.nodes[i][key] attr_2 = self.nodes[j][key] @@ -411,4 +492,38 @@ def to_line_graph(self): graph = nx.Graph() graph.add_nodes_from(self.nodes) graph.add_edges_from(self.edges) - return nx.line_graph(graph) + return nx.line_graph(graph) + + + + def init_targets_old(self, log_path, dist_threshold): + # Initializations + splits_log = utils.read_mistake_log(log_path) + split_edges = set(splits_log.keys()) + target_edges = set() + + # Parse mutables + for edge in self.mutable_edges: + i, j = tuple(edge) + key = frozenset(self.get_edge_attr("swc_id", i, j)) + if key in split_edges: + k, l = list(edge) + mutable_xyz_1 = self.nodes[k]["xyz"] + mutable_xyz_2 = self.nodes[l]["xyz"] + log_xyz_1 = splits_log[key]["xyz"][0] + log_xyz_2 = splits_log[key]["xyz"][1] + + pair_1 = [mutable_xyz_1, mutable_xyz_2] + pair_2 = [log_xyz_1, log_xyz_2] + d = utils.pair_dist(pair_1, pair_2) + if d < dist_threshold: + target_edges.add(frozenset((i, j))) + self.target_edges = target_edges + print( + "% target edges in mistake log:", + len(target_edges) / len(split_edges), + ) + print( + "% target edges in mutable:", + len(target_edges) / len(self.mutable_edges), + ) diff --git a/src/deep_neurographs/swc_utils.py b/src/deep_neurographs/swc_utils.py index cad0f2a..9738cf5 100644 --- a/src/deep_neurographs/swc_utils.py +++ b/src/deep_neurographs/swc_utils.py @@ -21,13 +21,14 @@ from deep_neurographs import utils +# -- io utils -- def read_swc(path): with open(path, "r") as file: contents = file.readlines() return contents -def parse(raw_swc, anisotropy=[1.0, 1.0, 1.0]): +def parse(raw_swc, anisotropy=[1.0, 1.0, 1.0], idx=False): """ Parses a raw swc file to extract the (x,y,z) coordinates and radii. Note that node_ids from swc are refactored to index from 0 to n-1 where n is @@ -48,7 +49,11 @@ def parse(raw_swc, anisotropy=[1.0, 1.0, 1.0]): """ # Initialize swc - swc_dict = {"id": [], "xyz": [], "radius": [], "pid": []} + swc_dict = {"id": [], "radius": [], "pid": []} + if idx: + swc_dict["idx"] = [] + else: + swc_dict["xyz"] = [] # Parse raw data min_id = np.inf @@ -62,9 +67,14 @@ def parse(raw_swc, anisotropy=[1.0, 1.0, 1.0]): swc_dict["id"].append(int(parts[0])) swc_dict["radius"].append(float(parts[-2])) swc_dict["pid"].append(int(parts[-1])) - swc_dict["xyz"].append( - read_xyz(parts[2:5], anisotropy=anisotropy, offset=offset) - ) + if idx: + swc_dict["idx"].append( + read_idx(parts[2:5], anisotropy=anisotropy, offset=offset) + ) + else: + swc_dict["xyz"].append( + read_xyz(parts[2:5], anisotropy=anisotropy, offset=offset) + ) if swc_dict["id"][-1] < min_id: min_id = swc_dict["id"][-1] @@ -95,11 +105,45 @@ def read_xyz(xyz, anisotropy=[1.0, 1.0, 1.0], offset=[0, 0, 0]): The (x,y,z) coordinates from an swc file. """ - xyz = [int(float(xyz[i]) * anisotropy[i] + offset[i]) for i in range(3)] + xyz = [float(xyz[i]) * anisotropy[i] + offset[i] for i in range(3)] return tuple(xyz) -def write_swc(path, list_of_entries, color=None): +def read_idx(xyz, anisotropy=[1.0, 1.0, 1.0], offset=[0, 0, 0]): + """ + Reads the (z,y,x) coordinates from an swc file, then reverses and scales + them. + + Parameters + ---------- + zyx : str + (z,y,x) coordinates. + anisotropy : list[float] + Image to real-world coordinates scaling factors for [x, y, z] due to + anistropy of the microscope. + + Returns + ------- + list + The (x,y,z) coordinates from an swc file. + + """ + xyz = [int(float(xyz[i]) * anisotropy[i]) + offset[i] for i in range(3)] + return xyz + + +def write_swc(path, contents): + if type(content) is list: + write_list(path, contents) + elif type(content) is dict: + write_dict(path, contents) + elif type(content) is nx.Graph: + write_graph(path, contents) + else: + assert True, "Unable to write {} to swc".format(type(content)) + + +def write_list(path, entry_list, color=None): """ Writes an swc file. @@ -123,13 +167,13 @@ def write_swc(path, list_of_entries, color=None): else: f.write("# id, type, z, y, x, r, pid") f.write("\n") - for i, entry in enumerate(list_of_entries): + for i, entry in enumerate(entry_list): for x in entry: f.write(str(x) + " ") f.write("\n") -def write_swc_dict(path, swc_dict, color=None): +def write_dict(path, swc_dict, color=None): with open(path, "w") as f: if color is not None: f.write("# COLOR" + color) @@ -149,6 +193,57 @@ def write_swc_dict(path, swc_dict, color=None): first = False +def write_graph(path, graph): + """ + Makes a list of entries to be written in an swc file. + + Parameters + ---------- + graph : networkx.Graph + Graph that edges in "edge_list" belong to. + edge_list : list[tuple[int]] + List of edges to be written to an swc file. + + Returns + ------- + list[str] + List of swc file entries to be written. + + """ + # loop through connected components + + reindex = dict() + edges = graph.edges if edge_list is None else edge_list + for i, j in edges: + if len(reindex) < 1: + entry, reindex = make_entry(graph, i, -1, reindex, anisotropy) + entry_list = [entry] + entry, reindex = make_entry(graph, j, reindex[i], reindex, anisotropy) + entry_list.append(entry) + return entry_list + + +def make_entry(graph, i, parent, r, reindex): + """ + Makes an entry to be written in an swc file. + + Parameters + ---------- + graph : networkx.Graph + Graph that "i" and "parent" belong to. + i : int + Node that entry corresponds to. + parent : int + Parent of node "i". + + """ + reindex[i] = len(reindex) + 1 + r = graph.nodes[i]["radius"] + x, y, z = tuple(map(str, graph.nodes[i]["xyz"])) + return [x, y, z, r, parent], reindex + + +# -- Conversions -- def file_to_graph(swc_dict, graph_id=None, set_attrs=False): graph = nx.Graph(graph_id=graph_id) graph.add_edges_from(zip(swc_dict["id"][1:], swc_dict["pid"][1:])) @@ -173,7 +268,7 @@ def dir_to_graphs(swc_dir, anisotropy=[1.0, 1.0, 1.0]): def file_to_volume(swc_dict, sparse=False, vid=None, radius_plus=0): volume = [] for i in swc_dict["id"]: - r = max(3 * int(np.round(swc_dict["radius"][i] + radius_plus)), 5) + r = max(int(np.round(swc_dict["radius"][i] + radius_plus)), 5) xyz = cp(swc_dict["xyz"][i]) volume.extend(generate_coords(xyz, r)) return dict(zip_broadcast(volume, vid)) if sparse else np.array(volume) @@ -191,9 +286,10 @@ def dir_to_volume(swc_dir, radius_plus=0): return volume +# -- miscellaneous -- def smooth(swc_dict): if len(swc_dict["xyz"]) > 10: - xyz = np.array(swc_dict["xyz"], dtype=int) + xyz = np.array(swc_dict["xyz"]) graph = file_to_graph(swc_dict) leafs, junctions = gutils.get_irreducibles(graph) if len(junctions) == 0: diff --git a/src/deep_neurographs/utils.py b/src/deep_neurographs/utils.py index d0b9190..de9c6c4 100644 --- a/src/deep_neurographs/utils.py +++ b/src/deep_neurographs/utils.py @@ -12,11 +12,13 @@ import json import os import shutil +import zarr import numpy as np import plotly.graph_objects as go from plotly.subplots import make_subplots -from scipy.interpolate import SmoothBivariateSpline, UnivariateSpline +from scipy.interpolate import UnivariateSpline, CubicSpline +from scipy.linalg import svd # --- dictionary utils --- @@ -65,9 +67,11 @@ def remove_key(my_dict, key): # --- os utils --- -def mkdir(path_to_dir): - if not os.path.exists(path_to_dir): - os.mkdir(path_to_dir) +def mkdir(path, delete=False): + if os.path.exists(path) and delete: + shutil.rmtree(path) + if not os.path.exists(path): + os.mkdir(path) def rmdir(path): @@ -94,6 +98,22 @@ def list_subdirs(path, keyword=None): # --- io utils --- +def read_n5(path): + """ + Reads n5 file at "path". + + Parameters + ---------- + path : str + Path to n5. + + Returns + ------- + np.array + Image volume. + """ + return zarr.open(zarr.N5FSStore(path), "r").volume + def read_json(path): """ Reads json file stored at "path". @@ -118,6 +138,19 @@ def read_txt(path): return f.read() +def read_mistake_coords(path, anisotropy=[1.0, 1.0, 1.0], shift=[0, 0, 0]): + xyz = [] + with open(path, "r") as file: + for line in file: + if not line.startswith("#") and len(line) > 0: + parts = line.split() + xyz_1 = extract_coords(parts[0:3]) + xyz_2 = extract_coords(parts[3:6]) + xyz.append(to_img(xyz_1, anisotropy, shift=shift)) + xyz.append(to_img(xyz_2, anisotropy, shift=shift)) + return np.array(xyz) + + def read_mistake_log(path): splits_log = dict() with open(path, "r") as file: @@ -129,7 +162,10 @@ def read_mistake_log(path): swc_1 = parts[6].replace(",", "") swc_2 = parts[7].replace(",", "") key = frozenset([swc_1, swc_2]) - splits_log[key] = {"swc": [swc_1, swc_2], "xyz": [xyz_1, xyz_2]} + splits_log[key] = { + "swc": [swc_1, swc_2], + "xyz": [xyz_1, xyz_2], + } return splits_log @@ -220,7 +256,7 @@ def dist(x, y, metric="l2"): else: return np.linalg.norm(np.subtract(x, y), ord=2) - + def pair_dist(pair_1, pair_2, metric="l2"): pair_1.reverse() d1 = _pair_dist(pair_1, pair_2) @@ -230,22 +266,75 @@ def pair_dist(pair_1, pair_2, metric="l2"): return min(d1, d2) -def _pair_dist(pair_1, pair_2, metric="l2"): +def _pair_dist(pair_1, pair_2, metric="l2"): d1 = dist(pair_1[0], pair_2[0], metric=metric) d2 = dist(pair_1[1], pair_2[1], metric=metric) return max(d1, d2) -def smooth_branch(xyz, k=3): - t = np.arange(len(xyz[:, 0])) - s = len(t) / 4 - cs_x = UnivariateSpline(t, xyz[:, 0], k=k, s=s) - cs_y = UnivariateSpline(t, xyz[:, 1], k=k, s=s) - cs_z = UnivariateSpline(t, xyz[:, 2], k=k, s=s) - smoothed_x = cs_x(t) - smoothed_y = cs_y(t) - smoothed_z = cs_z(t) - return np.column_stack((smoothed_x, smoothed_y, smoothed_z)).astype(int) +def smooth_branch(xyz, round=True): + t = np.arange(len(xyz[:, 0]) + 12) + s = len(t) / 10 + cs_x = UnivariateSpline(t, extend_boundary(xyz[:, 0]), s=s, k=3) + cs_y = UnivariateSpline(t, extend_boundary(xyz[:, 1]), s=s, k=3) + cs_z = UnivariateSpline(t, extend_boundary(xyz[:, 2]), s=s, k=3) + smoothed_x = trim_boundary(cs_x(t)) + smoothed_y = trim_boundary(cs_y(t)) + smoothed_z = trim_boundary(cs_z(t)) + smoothed = np.column_stack((smoothed_x, smoothed_y, smoothed_z)) + if round: + return smoothed #np.round(smoothed).astype(int) + else: + return smoothed + + +def extend_boundary(x, num_boundary_points=6): + extended_x = np.concatenate(( + np.linspace(x[0], x[1], num_boundary_points, endpoint=False), + x, + np.linspace(x[-2], x[-1], num_boundary_points, endpoint=False) + )) + return extended_x + + +def trim_boundary(x, num_boundary_points=6): + return x[num_boundary_points:-num_boundary_points] + + +def check_img_path(target_labels, xyz_1, xyz_2): + d = dist(xyz_1, xyz_2) + t_steps = np.arange(0, 1, 1 / d) + num_steps = len(t_steps) + labels = set() + collisions = set() + for t in t_steps: + xyz = tuple([int(line(xyz_1[i], xyz_2[i], t)) for i in range(3)]) + if target_labels[xyz] != 0: + # Check for repeat collisions + if xyz in collisions: + num_steps -= 1 + else: + collisions.add(xyz) + + # Check for collision with multiple labels + labels.add(target_labels[xyz]) + if len(labels) > 1: + return False + ratio = len(collisions) / len(t_steps) + return True if ratio > 1 / 3 else False + + + +def line(xyz_1, xyz_2, t): + return np.round((1 - t) * xyz_1 + t * xyz_2) + + +def to_world(xyz, anisotropy, shift=[0, 0, 0]): + return tuple([int((xyz[i] - shift[i]) * anisotropy[i]) for i in range(3)]) + + +def to_img(xyz, anisotropy, shift=[0, 0, 0]): + return tuple([int((xyz[i] - shift[i]) / anisotropy[i]) for i in range(3)]) def time_writer(t, unit="seconds"):