diff --git a/src/deep_neurographs/densegraph.py b/src/deep_neurographs/densegraph.py index fc92672..9839a68 100644 --- a/src/deep_neurographs/densegraph.py +++ b/src/deep_neurographs/densegraph.py @@ -26,14 +26,14 @@ class DenseGraph: """ - def __init__(self, swc_dir): + def __init__(self, swc_paths): """ Constructs a DenseGraph object from a directory of swc files. Parameters ---------- - swc_dir : path - Path to directory of swc files which are used to construct a hash + swc_paths : list[str] + List of paths to swc files which are used to construct a hash table in which the entries are filename-graph pairs. Returns @@ -43,18 +43,18 @@ def __init__(self, swc_dir): """ self.xyz_to_node = dict() self.xyz_to_swc = dict() - self.init_graphs(swc_dir) + self.init_graphs(swc_paths) self.init_kdtree() - def init_graphs(self, swc_dir): + def init_graphs(self, swc_paths): """ - Initializes graphs by reading swc files in "swc_dir". Graphs are + Initializes graphs by reading swc files in "swc_paths". Graphs are stored in a hash table where the entries are filename-graph pairs. Parameters ---------- - swc_dir : path - Path to directory of swc files which are used to construct a hash + swc_paths : list[str] + List of paths to swc files which are used to construct a hash table in which the entries are filename-graph pairs. Returns @@ -63,16 +63,15 @@ def init_graphs(self, swc_dir): """ self.graphs = dict() - for f in utils.listdir(swc_dir, ext=".swc"): + for path in swc_paths: # Construct Graph - path = os.path.join(swc_dir, f) swc_dict = swc_utils.parse_local_swc(path) graph, xyz_to_node = swc_utils.to_graph(swc_dict, set_attrs=True) # Store - xyz_to_id = dict(zip_broadcast(swc_dict["xyz"], f)) - self.graphs[f] = graph - self.xyz_to_node[f] = xyz_to_node + xyz_to_id = dict(zip_broadcast(swc_dict["xyz"], path)) + self.graphs[path] = graph + self.xyz_to_node[path] = xyz_to_node self.xyz_to_swc.update(xyz_to_id) def init_kdtree(self): diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index 8f9e5b8..2716b92 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -65,6 +65,7 @@ def build_neurograph_from_local( swc_dicts, bbox=bbox, img_path=img_path, + swc_paths=paths, prune=prune, prune_depth=prune_depth, smooth=smooth, @@ -245,36 +246,11 @@ def list_gcs_filenames(bucket, cloud_path, extension): # -- Build neurograph --- -def build_neurograph_old( - swc_dicts, - bbox=None, - img_path=None, - prune=PRUNE, - prune_depth=PRUNE_DEPTH, - smooth=SMOOTH, -): - # Extract irreducibles - t0 = time() - irreducibles = dict() - for key in swc_dicts.keys(): - irreducibles[key] = gutils.get_irreducibles( - swc_dicts[key], prune=prune, depth=prune_depth, smooth=smooth - ) - print(f" --> get_irreducibles(): {time() - t0} seconds") - - # Build neurograph - t0 = time() - neurograph = NeuroGraph(bbox=bbox, img_path=img_path) - for key in swc_dicts.keys(): - neurograph.add_immutables(swc_dicts[key], irreducibles[key]) - print(f" --> add_irreducibles(): {time() - t0} seconds") - return neurograph - - def build_neurograph( swc_dicts, bbox=None, img_path=None, + swc_paths=None, prune=PRUNE, prune_depth=PRUNE_DEPTH, smooth=SMOOTH, @@ -310,7 +286,7 @@ def build_neurograph( # Build neurograph t0 = time() - neurograph = NeuroGraph(bbox=bbox, img_path=img_path) + neurograph = NeuroGraph(bbox=bbox, img_path=img_path, swc_paths=swc_paths) for key in swc_dicts.keys(): neurograph.add_immutables(irreducibles[key], swc_dicts[key], key) print(f" --> add_irreducibles(): {time() - t0} seconds") diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index afbc893..1f530c5 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -33,13 +33,13 @@ class NeuroGraph(nx.Graph): """ def __init__( - self, bbox=None, swc_dir=None, img_path=None, label_mask=None + self, bbox=None, swc_paths=None, img_path=None, label_mask=None ): super(NeuroGraph, self).__init__() # Initialize paths self.img_path = img_path self.label_mask = label_mask - self.swc_paths = swc_dir + self.swc_paths = swc_paths # Initialize node and edge sets self.leafs = set() @@ -204,7 +204,7 @@ def __get_proposals( best_dist = dict() query_swc_id = self.nodes[query_id]["swc_id"] for xyz in self._query_kdtree(query_xyz, search_radius): - if not self.is_contained(xyz): + if not self.is_contained(xyz, buffer=36): continue xyz = tuple(xyz) edge = self.xyz_to_edge[xyz] @@ -336,8 +336,8 @@ def orient_edge(self, edge, i): # --- Ground Truth Generation --- def init_targets(self, target_neurograph): # Initializations - msg = "Error: Provide swc_dir/swc_paths to initialize target edges!" - assert target_neurograph.swc_path, msg + msg = "Provide swc_dir/swc_paths to initialize target edges!" + assert target_neurograph.swc_paths, msg target_neurograph.init_densegraph() target_neurograph.init_kdtree() self.target_edges = set() @@ -534,11 +534,11 @@ def get_projection(self, xyz): def is_nb(self, i, j): return True if i in self.neighbors(j) else False - def is_contained(self, node_or_xyz): + def is_contained(self, node_or_xyz, buffer=0): if self.bbox: if type(node_or_xyz) == int: node_or_xyz = deepcopy(self.nodes[node_or_xyz]["xyz"]) - return utils.is_contained(self.bbox, node_or_xyz) + return utils.is_contained(self.bbox, node_or_xyz, buffer=buffer) else: return True diff --git a/src/deep_neurographs/swc_utils.py b/src/deep_neurographs/swc_utils.py index ea58f94..19464a7 100644 --- a/src/deep_neurographs/swc_utils.py +++ b/src/deep_neurographs/swc_utils.py @@ -21,15 +21,16 @@ def process_local_paths(paths, min_size, bbox=None): swc_dicts = dict() for path in paths: - swc_dicts.update(parse_local_swc(path, bbox=bbox, min_size=min_size)) + swc_dict = parse_local_swc(path, bbox=bbox, min_size=min_size) + if len(swc_dict): + swc_id = utils.get_swc_id(path) + swc_dicts[swc_id] = swc_dict return swc_dicts -def parse_local_swc(path, bbox=None, min_size=25): - swc_id = utils.get_swc_id(path) - swc_contents = read_from_local(path) - swc_dict = parse(swc_contents, bbox=bbox) - return {swc_id: swc_dict} if len(swc_dict["id"]) > min_size else dict() +def parse_local_swc(path, bbox=None, min_size=0): + contents = read_from_local(path) + return parse(contents, bbox=bbox) if len(contents) > min_size else [] def parse_gcs_zip(zip_file, path, min_size=0): @@ -80,7 +81,7 @@ def parse(swc_contents, bbox=None): swc_dict["id"][i] -= min_id swc_dict["pid"][i] -= min_id - return swc_dict + return swc_dict if len(swc_dict["id"]) > 1 else [] def read_from_local(path): diff --git a/src/deep_neurographs/utils.py b/src/deep_neurographs/utils.py index 17b8602..cf9ac54 100644 --- a/src/deep_neurographs/utils.py +++ b/src/deep_neurographs/utils.py @@ -411,10 +411,10 @@ def get_avg_std(data, weights=None): return avg, math.sqrt(var) -def is_contained(bbox, xyz): +def is_contained(bbox, xyz, buffer=0): xyz = apply_anisotropy(xyz - bbox["min"]) shape = bbox["max"] - bbox["min"] - if any(xyz < 0) or any(xyz >= shape): + if any(xyz < buffer) or any(xyz >= shape - buffer): return False else: return True