From 5c9c7877ea99e386f645db361409855c7d6583de Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 9 Jan 2024 19:08:33 +0000 Subject: [PATCH] major upd : added gcs support --- src/deep_neurographs/geometry_utils.py | 3 +- src/deep_neurographs/graph_utils.py | 22 +-- src/deep_neurographs/intake.py | 262 +++++++++++++++++++------ src/deep_neurographs/neurograph.py | 65 ++++-- src/deep_neurographs/swc_utils.py | 71 +++++-- src/deep_neurographs/utils.py | 20 ++ 6 files changed, 329 insertions(+), 114 deletions(-) diff --git a/src/deep_neurographs/geometry_utils.py b/src/deep_neurographs/geometry_utils.py index 68406fb..db7ea80 100644 --- a/src/deep_neurographs/geometry_utils.py +++ b/src/deep_neurographs/geometry_utils.py @@ -60,7 +60,7 @@ def get_midpoint(xyz_1, xyz_2): # Smoothing def smooth_branch(xyz, s=None): - if xyz.shape[0] > 5: + if xyz.shape[0] > 8: t = np.linspace(0, 1, xyz.shape[0]) spline_x, spline_y, spline_z = fit_spline(xyz, s=s) xyz = np.column_stack((spline_x(t), spline_y(t), spline_z(t))) @@ -95,7 +95,6 @@ def fill_path(img, path, val=-1): for xyz in path: x, y, z = tuple(np.floor(xyz).astype(int)) img[x - 1 : x + 2, y - 1 : y + 2, z - 1 : z + 2] = val - # img[x, y, z] = val return img diff --git a/src/deep_neurographs/graph_utils.py b/src/deep_neurographs/graph_utils.py index 1f9a236..e0926a6 100644 --- a/src/deep_neurographs/graph_utils.py +++ b/src/deep_neurographs/graph_utils.py @@ -14,17 +14,6 @@ from deep_neurographs import swc_utils, utils -def get_irreducibles(graph): - leafs = [] - junctions = [] - for i in graph.nodes: - if graph.degree[i] == 1: - leafs.append(i) - elif graph.degree[i] > 2: - junctions.append(i) - return leafs, junctions - - def extract_irreducible_graph(swc_dict, prune=True, prune_depth=16): graph = swc_utils.file_to_graph(swc_dict) leafs, junctions = get_irreducibles(graph) @@ -40,6 +29,17 @@ def extract_irreducible_graph(swc_dict, prune=True, prune_depth=16): return leafs, junctions, irreducible_edges +def get_irreducibles(graph): + leafs = [] + junctions = [] + for i in graph.nodes: + if graph.degree[i] == 1: + leafs.append(i) + elif graph.degree[i] > 2: + junctions.append(i) + return leafs, junctions + + def extract_irreducible_edges( graph, leafs, junctions, swc_dict, prune=True, prune_depth=16 ): diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index 18dfac9..40a84c9 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -9,46 +9,54 @@ """ import os +import concurrent.futures +from concurrent.futures import ThreadPoolExecutor, as_completed +from google.cloud import storage +from io import BytesIO from deep_neurographs import swc_utils, utils from deep_neurographs.neurograph import NeuroGraph +from time import time +from zipfile import ZipFile + +N_PROPOSALS_PER_LEAF = 3 +OPTIMIZE_ALIGNMENT = False +OPTIMIZE_DEPTH = 15 +PRUNE = True +PRUNE_DEPTH = 16 +SEARCH_RADIUS = 0 +SIZE_THRESHOLD = 100 +SMOOTH = False # --- Build graph --- -def build_neurograph( - swc_dir, - anisotropy=[1.0, 1.0, 1.0], +def build_neurograph_from_local( + swc_dir=None, + swc_paths=None, + img_patch_shape=None, + img_patch_origin=None, img_path=None, - size_threshold=40, - num_proposals=3, - search_radius=25.0, - prune=True, - prune_depth=16, - optimize_depth=15, - optimize_alignment=True, - optimize_path=False, - origin=None, - shape=None, - smooth=True, + n_proposals_per_leaf=N_PROPOSALS_PER_LEAF, + prune=PRUNE, + prune_depth=PRUNE_DEPTH, + optimize_alignment=OPTIMIZE_ALIGNMENT, + optimize_depth=OPTIMIZE_DEPTH, + search_radius=SEARCH_RADIUS, + size_threshold=SIZE_THRESHOLD, + smooth=SMOOTH, ): - """ - Builds a neurograph from a directory of swc files, where each swc - represents a neuron and these neurons are assumed to be near each - other. - - """ + assert utils.xor(swc_dir, swc_list), "Error: provide swc_dir or swc_paths" neurograph = NeuroGraph( - swc_dir, img_path=img_path, optimize_depth=optimize_depth, optimize_alignment=optimize_alignment, - optimize_path=optimize_path, - origin=origin, - shape=shape, + origin=img_patch_origin, + shape=img_patch_shape, ) - neurograph = init_immutables( + neurograph = init_immutables_from_local( neurograph, - anisotropy=anisotropy, + swc_dir=swc_dir, + swc_paths=swc_paths, prune=prune, prune_depth=prune_depth, size_threshold=size_threshold, @@ -56,51 +64,187 @@ def build_neurograph( ) if search_radius > 0: neurograph.generate_proposals( - num_proposals=num_proposals, search_radius=search_radius + n_proposals_per_leaf=n_proposals_per_leaf, + search_radius=search_radius ) return neurograph -def init_immutables( +def build_neurograph_from_gcs_zips( + bucket_name, + cloud_path, + img_path=None, + size_threshold=SIZE_THRESHOLD, + n_proposals_per_leaf=N_PROPOSALS_PER_LEAF, + search_radius=SEARCH_RADIUS, + prune=PRUNE, + prune_depth=PRUNE_DEPTH, + optimize_alignment=OPTIMIZE_ALIGNMENT, + optimize_depth=OPTIMIZE_DEPTH, + smooth=SMOOTH, +): + neurograph = NeuroGraph( + img_path=img_path, + optimize_alignment=optimize_alignment, + optimize_depth=optimize_depth, + ) + neurograph = init_immutables_from_gcs_zips( + neurograph, + bucket_name, + cloud_path, + prune=prune, + prune_depth=prune_depth, + size_threshold=size_threshold, + smooth=smooth, + ) + if search_radius > 0: + neurograph.generate_proposals( + n_proposals_per_leaf=n_proposals_per_leaf, + search_radius=search_radius + ) + return neurograph + + +def init_immutables_from_local( neurograph, - anisotropy=[1.0, 1.0, 1.0], - prune=True, - prune_depth=16, - size_threshold=40, - smooth=True, + swc_dir=None, + swc_paths=None, + prune=PRUNE, + prune_depth=PRUNE_DEPTH, + size_threshold=SIZE_THRESHOLD, + smooth=SMOOTH, ): - """ - To do... - """ + swc_paths = get_paths(swc_dir) if swc_dir else swc_paths + for path in swc_paths: + neurograph.ingest_swc_from_local( + path, + prune=True, + prune_depth=16, + smooth=smooth, + ) + return neurograph + - for path in get_paths(neurograph.path): - swc_id = get_id(path) - swc_dict = swc_utils.parse( +def get_paths(swc_dir): + swc_paths = [] + for f in utils.listdir(swc_dir, ext=".swc"): + paths.append(os.path.join(swc_dir, f)) + return paths + + +def init_immutables_from_gcs_zips( + neurograph, + bucket_name, + cloud_path, + prune=PRUNE, + prune_depth=PRUNE_DEPTH, + size_threshold=SIZE_THRESHOLD, + smooth=SMOOTH, +): + # Initializations + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) + zip_paths = list_gcs_filenames(bucket, cloud_path, ".zip") + n_swc_files = 2080791 #count_files_in_zips(bucket, zip_paths) + chunk_size = int(n_swc_files * 0.05) + print("# zip files:", len(zip_paths)) + print(f"# swc files: {utils.reformat_number(n_swc_files)} \n\n") + + # Parse + cnt = 1 + t0 = time() + t1 = time() + n_files_completed = 0 + print(f"-- Starting Multithread Reads with chunk_size={chunk_size} -- \n") + for path in zip_paths: + # Add to neurograph + swc_dicts = process_gcs_zip( + bucket, path, - anisotropy=anisotropy, - bbox=neurograph.bbox, - img_shape=neurograph.shape, ) - if len(swc_dict["xyz"]) < size_threshold: - continue if smooth: - swc_dict = swc_utils.smooth(swc_dict) - neurograph.generate_immutables( - swc_id, swc_dict, prune=prune, prune_depth=prune_depth - ) + with concurrent.futures.ProcessPoolExecutor() as executor: + swc_dicts = list(executor.map(swc_utils.smooth, swc_dicts)) + + # Readout progress + n_files_completed += len(swc_dicts) + if n_files_completed > cnt * chunk_size: + report_runtimes( + n_swc_files, + n_files_completed, + chunk_size, + time() - t1, + time() - t0, + ) + cnt += 1 + t1 = time() + t, unit = utils.time_writer(time() - t0) + print(f"Total Runtime: {round(t, 4)} {unit}") return neurograph -def get_paths(path_or_list): - if type(path_or_list) == str: - paths = [] - for f in utils.listdir(path_or_list, ext=".swc"): - paths.append(os.path.join(path_or_list, f)) - return paths - elif type(path_or_list) == list: - return path_or_list +def count_files_in_zips(bucket, zip_paths): + t0 = time() + file_cnt = 0 + for zip_path in zip_paths: + zip_blob = bucket.blob(zip_path) + zip_content = zip_blob.download_as_bytes() + file_paths = list_files_in_gcs_zip(zip_content) + file_cnt += len(file_paths) + return file_cnt + +def list_files_in_gcs_zip(zip_content): + """ + Lists all files in a zip file stored in a GCS bucket. -def get_id(path): - filename = path.split("/")[-1] - return filename.replace(".0.swc", "") + """ + with ZipFile(BytesIO(zip_content), 'r') as zip_file: + return zip_file.namelist() + + +def list_gcs_filenames(bucket, cloud_path, extension): + """ + Lists all files in a GCS bucket with the given extension. + + """ + blobs = bucket.list_blobs(prefix=cloud_path) + return [blob.name for blob in blobs if extension in blob.name] + + +def process_gcs_zip(bucket, zip_path): + # Get filenames + zip_blob = bucket.blob(zip_path) + zip_content = zip_blob.download_as_bytes() + swc_paths = list_files_in_gcs_zip(zip_content) + + # Read files + t0 = time() + swc_dicts = [None] * len(swc_paths) + with ZipFile(BytesIO(zip_content)) as zip_file: + with ThreadPoolExecutor() as executor: + results = [ + executor.submit(swc_utils.parse_gcs_zip, zip_file, path) + for path in swc_paths + ] + for i, result_i in enumerate(as_completed(results)): + swc_dicts[i] = result_i.result() + return swc_dicts + + +def report_runtimes( + n_files, + n_files_completed, + chunk_size, + chunk_runtime, + total_runtime, +): + n_files_remaining = n_files - n_files_completed + file_rate = chunk_runtime / chunk_size + eta = (total_runtime + n_files_remaining * file_rate) / 60 + files_processed = f"{n_files_completed - chunk_size}-{n_files_completed}" + print(f"Completed: {round(100 * n_files_completed / n_files, 2)}%") + print(f"Runtime for Files : {files_processed} {round(chunk_runtime, 4)} seconds") + print(f"File Processing Rate: {file_rate} seconds") + print(f"Approximate Total Runtime: {round(eta, 4)} minutes") + print("") diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index e2567ef..2d4048c 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -17,6 +17,7 @@ from deep_neurographs import geometry_utils from deep_neurographs import graph_utils as gutils +from deep_neurographs import swc_utils from deep_neurographs import utils from deep_neurographs.densegraph import DenseGraph from deep_neurographs.geometry_utils import dist as get_dist @@ -34,7 +35,7 @@ class NeuroGraph(nx.Graph): def __init__( self, - swc_path, + swc_dir=None, img_path=None, label_mask=None, optimize_depth=10, @@ -42,6 +43,7 @@ def __init__( optimize_path=False, origin=None, shape=None, + size_threshold=30, ): """ Parameters @@ -54,10 +56,11 @@ def __init__( """ super(NeuroGraph, self).__init__() - self.path = swc_path + self.path = swc_dir self.label_mask = label_mask self.leafs = set() self.junctions = set() + self.size_threshold = size_threshold self.immutable_edges = set() self.mutable_edges = set() @@ -98,7 +101,20 @@ def init_densegraph(self): self.densegraph = DenseGraph(self.path) # --- Add nodes or edges --- - def generate_immutables( + def ingest_swc_from_local(self, path, prune=True, prune_depth=16, smooth=True): + swc_id = utils.get_id(path) + swc_dict = swc_utils.parse_local_swc( + path, + bbox=self.bbox, + img_shape=self.shape, + ) + if len(swc_dict["xyz"]) > self.size_threshold: + swc_dict = swc_utils.smooth(swc_dict) if smooth else swc_dict + self.add_immutables( + swc_id, swc_dict, prune=prune, prune_depth=prune_depth + ) + + def add_immutables( self, swc_id, swc_dict, prune=True, prune_depth=16 ): """ @@ -109,7 +125,7 @@ def generate_immutables( node_id : int Node id. swc_dict : dict - Dictionary generated from an swc where the keys are swc type + Dictionary generated from an swc where the keys are swc attributes. Returns @@ -156,13 +172,10 @@ def generate_immutables( self.leafs.add(node_id[l]) for j in junctions: - self.junctions.add(node_id[j]) - - # Build kdtree - self._init_kdtree() + self.junctions.add(node_id[j]) # --- Proposal Generation --- - def generate_proposals(self, num_proposals=3, search_radius=25.0): + def generate_proposals(self, n_proposals_per_leaf=3, search_radius=25.0): """ Generates edges for the graph. @@ -171,13 +184,14 @@ def generate_proposals(self, num_proposals=3, search_radius=25.0): None """ + self._init_kdtree() self.mutable_edges = set() for leaf in self.leafs: if not self.is_contained(leaf): continue xyz_leaf = self.nodes[leaf]["xyz"] - proposals = self._get_proposals( - leaf, xyz_leaf, num_proposals, search_radius + proposals = self.__get_proposals( + leaf, xyz_leaf, n_proposals_per_leaf, search_radius ) for xyz in proposals: # Extract info on mutable connection @@ -205,20 +219,29 @@ def generate_proposals(self, num_proposals=3, search_radius=25.0): if self.optimize_alignment or self.optimize_path: self.run_optimization() - def _get_proposals( - self, query_id, query_xyz, num_proposals, search_radius + def __get_proposals( + self, query_id, query_xyz, n_proposals_per_leaf, search_radius ): """ + Generates edge proposals for node "query_id" by finding points on + distinct connected components near "query_xyz". + Parameters ---------- query_id : int Node id of the query node. query_xyz : tuple[float] - The (x,y,z) coordinates of the query node. + (x,y,z) coordinates of the query node. + n_proposals_per_leaf : int + Number of proposals generated from node "query_id". + search_radius : float + Maximum Euclidean length of edge proposal. Returns ------- - None. + list + List of "n_proposals_per_leaf" best edge proposals generated from + "query_node". """ best_xyz = dict() @@ -238,17 +261,17 @@ def _get_proposals( elif d < best_dist[swc_id]: best_xyz[swc_id] = xyz best_dist[swc_id] = d - return self._get_best_edges(best_dist, best_xyz, num_proposals) + return self._get_best_edges(best_dist, best_xyz, n_proposals_per_leaf) - def _get_best_edges(self, dists, xyz, num_proposals): + def _get_best_edges(self, dists, xyz, n_proposals_per_leaf): """ - Gets the at most "num_proposals" nodes that are closest to the - target node. + Gets the at most "n_proposals_per_leaf" nodes that are closest to + "xyz". """ - if len(dists.keys()) > num_proposals: + if len(dists.keys()) > n_proposals_per_leaf: keys = sorted(dists, key=dists.__getitem__) - return [xyz[key] for key in keys[0:num_proposals]] + return [xyz[key] for key in keys[0:n_proposals_per_leaf]] else: return list(xyz.values()) diff --git a/src/deep_neurographs/swc_utils.py b/src/deep_neurographs/swc_utils.py index b12bfb1..3e8522b 100644 --- a/src/deep_neurographs/swc_utils.py +++ b/src/deep_neurographs/swc_utils.py @@ -22,7 +22,23 @@ # -- io utils -- -def parse(path, anisotropy=[1.0, 1.0, 1.0], bbox=None, img_shape=None): +def parse_local_swc(path, bbox=None, img_shape=None): + return parse( + read_from_local(path), + bbox=bbox, + img_shape=img_shape, + ) + + +def parse_gcs_zip(zip_file, path, bbox=None, img_shape=None): + return parse( + read_from_gcs_zip(zip_file, path), + bbox=bbox, + img_shape=img_shape, + ) + + +def parse(swc_contents, bbox=None, img_shape=None): """ Parses an swc file to extract the contents which is stored in a dict. Note that node_ids from swc are refactored to index from 0 to n-1 where n is @@ -32,28 +48,23 @@ def parse(path, anisotropy=[1.0, 1.0, 1.0], bbox=None, img_shape=None): ---------- path : str Path to an swc file. - anisotropy : list[float] - Image to real-world coordinates scaling factors for [x, y, z] due to - anistropy of the microscope. + ... Returns ------- ... """ - # Initialize swc - swc_dict = {"id": [], "radius": [], "pid": [], "xyz": []} - - # Parse raw data min_id = np.inf offset = [0, 0, 0] - for line in read(path): + swc_dict = {"id": [], "radius": [], "pid": [], "xyz": []} + for line in swc_contents: if line.startswith("# OFFSET"): parts = line.split() offset = read_xyz(parts[2:5]) if not line.startswith("#") and len(line) > 0: parts = line.split() - xyz = read_xyz(parts[2:5], anisotropy=anisotropy, offset=offset) + xyz = read_xyz(parts[2:5], offset=offset) if bbox: if not utils.is_contained(bbox, img_shape, xyz): break @@ -73,13 +84,35 @@ def parse(path, anisotropy=[1.0, 1.0, 1.0], bbox=None, img_shape=None): return swc_dict -def read(path): +def read_from_local(path): + """ + Reads swc file stored at "path" on local machine. + + Parameters + ---------- + Path : str + Path to swc file to be read. + + Returns + ------- + list + List such that each entry is a line from the swc file. + + """ with open(path, "r") as file: - contents = file.readlines() - return contents + return file.readlines() -def read_xyz(xyz, anisotropy=[1.0, 1.0, 1.0], offset=[0, 0, 0]): +def read_from_gcs_zip(zip_file, path): + """ + Reads the content of an swc file from a zip file in a GCS bucket. + + """ + with zip_file.open(path) as text_file: + return text_file.read().decode('utf-8').splitlines() + + +def read_xyz(xyz, offset=[0, 0, 0]): """ Reads the (z,y,x) coordinates from an swc file, then reverses and scales them. @@ -88,18 +121,14 @@ def read_xyz(xyz, anisotropy=[1.0, 1.0, 1.0], offset=[0, 0, 0]): ---------- zyx : str (z,y,x) coordinates. - anisotropy : list[float] - Image to real-world coordinates scaling factors for [x, y, z] due to - anistropy of the microscope. Returns ------- - list + tuple The (x,y,z) coordinates from an swc file. """ - xyz = [float(xyz[i]) * anisotropy[i] + offset[i] for i in range(3)] - return tuple(xyz) + return tuple([float(xyz[i]) + offset[i] for i in range(3)]) def write_swc(path, contents): @@ -282,7 +311,7 @@ def smooth(swc_dict): def upd_edge(xyz, idxs): idxs = np.array(idxs) - xyz[idxs] = geometry_utils.smooth_branch(xyz[idxs].copy()) + xyz[idxs] = geometry_utils.smooth_branch(xyz[idxs], s=10) return xyz diff --git a/src/deep_neurographs/utils.py b/src/deep_neurographs/utils.py index 7cd3fae..4be8b62 100644 --- a/src/deep_neurographs/utils.py +++ b/src/deep_neurographs/utils.py @@ -422,6 +422,15 @@ def is_contained(bbox, img_shape, xyz): # --- miscellaneous --- +def get_id(path): + """ + Gets segment id of the swc file at "path". + + """ + filename = path.split("/")[-1] + return filename.replace(".0.swc", "") + + def get_img_mip(img, axis=0): return np.max(img, axis=axis) @@ -431,6 +440,10 @@ def normalize_img(img): return img / np.max(img) +def reformat_number(number): + return f"{number:,}" + + def time_writer(t, unit="seconds"): assert unit in ["seconds", "minutes", "hours"] upd_unit = {"seconds": "minutes", "minutes": "hours"} @@ -449,3 +462,10 @@ def progress_bar(current, total, bar_length=50): f"[{'=' * progress}{' ' * (bar_length - progress)}] {current}/{total}" ) print(f"\r{bar}", end="", flush=True) + +def xor(a, b): + if (a and b) or (not a and not b): + return False + else: + return True + \ No newline at end of file