diff --git a/src/deep_neurographs/deep_learning/train.py b/src/deep_neurographs/deep_learning/train.py index 20702e4..3bc8998 100644 --- a/src/deep_neurographs/deep_learning/train.py +++ b/src/deep_neurographs/deep_learning/train.py @@ -223,8 +223,17 @@ def random_split(train_set, train_ratio=0.85): def eval_network(X, model): + # Prep data + if type(X) == dict: + X = [ + torch.tensor(X["features"], dtype=torch.float32), + torch.tensor(X["imgs"], dtype=torch.float32), + ] + else: + X = torch.tensor(X, dtype=torch.float32) + + # Run model model.eval() - X = torch.tensor(X, dtype=torch.float32) with torch.no_grad(): y_pred = sigmoid(model.net(X)) return np.array(y_pred) diff --git a/src/deep_neurographs/densegraph.py b/src/deep_neurographs/densegraph.py index e7e4297..0231cbe 100644 --- a/src/deep_neurographs/densegraph.py +++ b/src/deep_neurographs/densegraph.py @@ -66,7 +66,7 @@ def init_graphs(self, swc_dir): for f in utils.listdir(swc_dir, ext=".swc"): # Construct Graph path = os.path.join(swc_dir, f) - swc_dict = swc_utils.parse(path) + swc_dict = swc_utils.parse_local_swc(path) graph, xyz_to_node = swc_utils.file_to_graph( swc_dict, set_attrs=True, return_dict=True ) diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index 40a84c9..19660df 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -8,17 +8,18 @@ """ -import os import concurrent.futures - +import os from concurrent.futures import ThreadPoolExecutor, as_completed -from google.cloud import storage from io import BytesIO -from deep_neurographs import swc_utils, utils -from deep_neurographs.neurograph import NeuroGraph from time import time from zipfile import ZipFile +from google.cloud import storage + +from deep_neurographs import swc_utils, utils +from deep_neurographs.neurograph import NeuroGraph + N_PROPOSALS_PER_LEAF = 3 OPTIMIZE_ALIGNMENT = False OPTIMIZE_DEPTH = 15 @@ -45,8 +46,9 @@ def build_neurograph_from_local( size_threshold=SIZE_THRESHOLD, smooth=SMOOTH, ): - assert utils.xor(swc_dir, swc_list), "Error: provide swc_dir or swc_paths" + assert utils.xor(swc_dir, swc_paths), "Error: provide swc_dir or swc_paths" neurograph = NeuroGraph( + swc_dir=swc_dir, img_path=img_path, optimize_depth=optimize_depth, optimize_alignment=optimize_alignment, @@ -65,7 +67,7 @@ def build_neurograph_from_local( if search_radius > 0: neurograph.generate_proposals( n_proposals_per_leaf=n_proposals_per_leaf, - search_radius=search_radius + search_radius=search_radius, ) return neurograph @@ -100,7 +102,7 @@ def build_neurograph_from_gcs_zips( if search_radius > 0: neurograph.generate_proposals( n_proposals_per_leaf=n_proposals_per_leaf, - search_radius=search_radius + search_radius=search_radius, ) return neurograph @@ -117,16 +119,13 @@ def init_immutables_from_local( swc_paths = get_paths(swc_dir) if swc_dir else swc_paths for path in swc_paths: neurograph.ingest_swc_from_local( - path, - prune=True, - prune_depth=16, - smooth=smooth, + path, prune=True, prune_depth=16, smooth=smooth ) return neurograph def get_paths(swc_dir): - swc_paths = [] + paths = [] for f in utils.listdir(swc_dir, ext=".swc"): paths.append(os.path.join(swc_dir, f)) return paths @@ -145,7 +144,7 @@ def init_immutables_from_gcs_zips( storage_client = storage.Client() bucket = storage_client.bucket(bucket_name) zip_paths = list_gcs_filenames(bucket, cloud_path, ".zip") - n_swc_files = 2080791 #count_files_in_zips(bucket, zip_paths) + n_swc_files = 2080791 # count_files_in_zips(bucket, zip_paths) chunk_size = int(n_swc_files * 0.05) print("# zip files:", len(zip_paths)) print(f"# swc files: {utils.reformat_number(n_swc_files)} \n\n") @@ -158,10 +157,7 @@ def init_immutables_from_gcs_zips( print(f"-- Starting Multithread Reads with chunk_size={chunk_size} -- \n") for path in zip_paths: # Add to neurograph - swc_dicts = process_gcs_zip( - bucket, - path, - ) + swc_dicts = process_gcs_zip(bucket, path) if smooth: with concurrent.futures.ProcessPoolExecutor() as executor: swc_dicts = list(executor.map(swc_utils.smooth, swc_dicts)) @@ -199,7 +195,7 @@ def list_files_in_gcs_zip(zip_content): Lists all files in a zip file stored in a GCS bucket. """ - with ZipFile(BytesIO(zip_content), 'r') as zip_file: + with ZipFile(BytesIO(zip_content), "r") as zip_file: return zip_file.namelist() @@ -233,18 +229,16 @@ def process_gcs_zip(bucket, zip_path): def report_runtimes( - n_files, - n_files_completed, - chunk_size, - chunk_runtime, - total_runtime, + n_files, n_files_completed, chunk_size, chunk_runtime, total_runtime ): n_files_remaining = n_files - n_files_completed file_rate = chunk_runtime / chunk_size - eta = (total_runtime + n_files_remaining * file_rate) / 60 + eta = (total_runtime + n_files_remaining * file_rate) / 60 files_processed = f"{n_files_completed - chunk_size}-{n_files_completed}" print(f"Completed: {round(100 * n_files_completed / n_files, 2)}%") - print(f"Runtime for Files : {files_processed} {round(chunk_runtime, 4)} seconds") + print( + f"Runtime for Files : {files_processed} {round(chunk_runtime, 4)} seconds" + ) print(f"File Processing Rate: {file_rate} seconds") print(f"Approximate Total Runtime: {round(eta, 4)} minutes") print("") diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index 2d4048c..0010a7d 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -17,8 +17,7 @@ from deep_neurographs import geometry_utils from deep_neurographs import graph_utils as gutils -from deep_neurographs import swc_utils -from deep_neurographs import utils +from deep_neurographs import swc_utils, utils from deep_neurographs.densegraph import DenseGraph from deep_neurographs.geometry_utils import dist as get_dist @@ -66,6 +65,7 @@ def __init__( self.mutable_edges = set() self.target_edges = set() self.xyz_to_edge = dict() + self.kdtree = None self.img_path = img_path self.optimize_depth = optimize_depth @@ -101,12 +101,12 @@ def init_densegraph(self): self.densegraph = DenseGraph(self.path) # --- Add nodes or edges --- - def ingest_swc_from_local(self, path, prune=True, prune_depth=16, smooth=True): + def ingest_swc_from_local( + self, path, prune=True, prune_depth=16, smooth=True + ): swc_id = utils.get_id(path) swc_dict = swc_utils.parse_local_swc( - path, - bbox=self.bbox, - img_shape=self.shape, + path, bbox=self.bbox, img_shape=self.shape ) if len(swc_dict["xyz"]) > self.size_threshold: swc_dict = swc_utils.smooth(swc_dict) if smooth else swc_dict @@ -114,9 +114,7 @@ def ingest_swc_from_local(self, path, prune=True, prune_depth=16, smooth=True): swc_id, swc_dict, prune=prune, prune_depth=prune_depth ) - def add_immutables( - self, swc_id, swc_dict, prune=True, prune_depth=16 - ): + def add_immutables(self, swc_id, swc_dict, prune=True, prune_depth=16): """ Adds nodes to graph from a dictionary generated from an swc files. @@ -172,7 +170,7 @@ def add_immutables( self.leafs.add(node_id[l]) for j in junctions: - self.junctions.add(node_id[j]) + self.junctions.add(node_id[j]) # --- Proposal Generation --- def generate_proposals(self, n_proposals_per_leaf=3, search_radius=25.0): @@ -184,7 +182,7 @@ def generate_proposals(self, n_proposals_per_leaf=3, search_radius=25.0): None """ - self._init_kdtree() + self.init_kdtree() self.mutable_edges = set() for leaf in self.leafs: if not self.is_contained(leaf): @@ -305,10 +303,10 @@ def __add_edge(self, edge, attrs, idxs): self.xyz_to_edge[tuple(xyz)] = edge self.immutable_edges.add(frozenset(edge)) - def _init_kdtree(self): + def init_kdtree(self): """ Builds a KD-Tree from the (x,y,z) coordinates of the subnodes of - each node in the graph. + each connected component in the graph. Parameters ---------- @@ -319,7 +317,8 @@ def _init_kdtree(self): None """ - self.kdtree = KDTree(list(self.xyz_to_edge.keys())) + if not self.kdtree: + self.kdtree = KDTree(list(self.xyz_to_edge.keys())) def _query_kdtree(self, query, d): """ @@ -385,9 +384,12 @@ def orient_edge(self, edge, i): # --- Ground Truth Generation --- def init_targets(self, target_neurograph): # Initializations + msg = "Error: Provide swc_dir/swc_paths to initialize target edges!" + assert target_neurograph.path, msg + target_neurograph.init_densegraph() + target_neurograph.init_kdtree() self.target_edges = set() self.init_predicted_graph() - target_neurograph.init_densegraph() # Add best simple edges remaining_proposals = [] diff --git a/src/deep_neurographs/swc_utils.py b/src/deep_neurographs/swc_utils.py index 3e8522b..2662afa 100644 --- a/src/deep_neurographs/swc_utils.py +++ b/src/deep_neurographs/swc_utils.py @@ -23,18 +23,12 @@ # -- io utils -- def parse_local_swc(path, bbox=None, img_shape=None): - return parse( - read_from_local(path), - bbox=bbox, - img_shape=img_shape, - ) + return parse(read_from_local(path), bbox=bbox, img_shape=img_shape) def parse_gcs_zip(zip_file, path, bbox=None, img_shape=None): return parse( - read_from_gcs_zip(zip_file, path), - bbox=bbox, - img_shape=img_shape, + read_from_gcs_zip(zip_file, path), bbox=bbox, img_shape=img_shape ) @@ -109,7 +103,7 @@ def read_from_gcs_zip(zip_file, path): """ with zip_file.open(path) as text_file: - return text_file.read().decode('utf-8').splitlines() + return text_file.read().decode("utf-8").splitlines() def read_xyz(xyz, offset=[0, 0, 0]): diff --git a/src/deep_neurographs/utils.py b/src/deep_neurographs/utils.py index 4be8b62..8ca1eea 100644 --- a/src/deep_neurographs/utils.py +++ b/src/deep_neurographs/utils.py @@ -255,21 +255,21 @@ def open_tensorstore(path, driver): def read_img_chunk(img, xyz, shape): start, end = get_start_end(xyz, shape) return img[ - start[2]:end[2], start[1]:end[1], start[0]:end[0] + start[2] : end[2], start[1] : end[1], start[0] : end[0] ].transpose(2, 1, 0) def get_chunk(arr, xyz, shape): start, end = get_start_end(xyz, shape) return deepcopy( - arr[start[0]:end[0], start[1]:end[1], start[2]:end[2]] + arr[start[0] : end[0], start[1] : end[1], start[2] : end[2]] ) def read_tensorstore(ts_arr, xyz, shape): start, end = get_start_end(xyz, shape) return ( - ts_arr[start[0]:end[0], start[1]:end[1], start[2]:end[2]] + ts_arr[start[0] : end[0], start[1] : end[1], start[2] : end[2]] .read() .result() ) @@ -463,9 +463,9 @@ def progress_bar(current, total, bar_length=50): ) print(f"\r{bar}", end="", flush=True) + def xor(a, b): if (a and b) or (not a and not b): return False else: return True - \ No newline at end of file