Skip to content

Commit

Permalink
bug: ground truth generation and proposal alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Jan 12, 2024
1 parent 1263793 commit 1832f4a
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 56 deletions.
25 changes: 12 additions & 13 deletions src/deep_neurographs/densegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ class DenseGraph:
"""

def __init__(self, swc_dir):
def __init__(self, swc_paths):
"""
Constructs a DenseGraph object from a directory of swc files.
Parameters
----------
swc_dir : path
Path to directory of swc files which are used to construct a hash
swc_paths : list[str]
List of paths to swc files which are used to construct a hash
table in which the entries are filename-graph pairs.
Returns
Expand All @@ -43,18 +43,18 @@ def __init__(self, swc_dir):
"""
self.xyz_to_node = dict()
self.xyz_to_swc = dict()
self.init_graphs(swc_dir)
self.init_graphs(swc_paths)
self.init_kdtree()

def init_graphs(self, swc_dir):
def init_graphs(self, swc_paths):
"""
Initializes graphs by reading swc files in "swc_dir". Graphs are
Initializes graphs by reading swc files in "swc_paths". Graphs are
stored in a hash table where the entries are filename-graph pairs.
Parameters
----------
swc_dir : path
Path to directory of swc files which are used to construct a hash
swc_paths : list[str]
List of paths to swc files which are used to construct a hash
table in which the entries are filename-graph pairs.
Returns
Expand All @@ -63,16 +63,15 @@ def init_graphs(self, swc_dir):
"""
self.graphs = dict()
for f in utils.listdir(swc_dir, ext=".swc"):
for path in swc_paths:
# Construct Graph
path = os.path.join(swc_dir, f)
swc_dict = swc_utils.parse_local_swc(path)
graph, xyz_to_node = swc_utils.to_graph(swc_dict, set_attrs=True)

# Store
xyz_to_id = dict(zip_broadcast(swc_dict["xyz"], f))
self.graphs[f] = graph
self.xyz_to_node[f] = xyz_to_node
xyz_to_id = dict(zip_broadcast(swc_dict["xyz"], path))
self.graphs[path] = graph
self.xyz_to_node[path] = xyz_to_node
self.xyz_to_swc.update(xyz_to_id)

def init_kdtree(self):
Expand Down
30 changes: 3 additions & 27 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def build_neurograph_from_local(
swc_dicts,
bbox=bbox,
img_path=img_path,
swc_paths=paths,
prune=prune,
prune_depth=prune_depth,
smooth=smooth,
Expand Down Expand Up @@ -245,36 +246,11 @@ def list_gcs_filenames(bucket, cloud_path, extension):


# -- Build neurograph ---
def build_neurograph_old(
swc_dicts,
bbox=None,
img_path=None,
prune=PRUNE,
prune_depth=PRUNE_DEPTH,
smooth=SMOOTH,
):
# Extract irreducibles
t0 = time()
irreducibles = dict()
for key in swc_dicts.keys():
irreducibles[key] = gutils.get_irreducibles(
swc_dicts[key], prune=prune, depth=prune_depth, smooth=smooth
)
print(f" --> get_irreducibles(): {time() - t0} seconds")

# Build neurograph
t0 = time()
neurograph = NeuroGraph(bbox=bbox, img_path=img_path)
for key in swc_dicts.keys():
neurograph.add_immutables(swc_dicts[key], irreducibles[key])
print(f" --> add_irreducibles(): {time() - t0} seconds")
return neurograph


def build_neurograph(
swc_dicts,
bbox=None,
img_path=None,
swc_paths=None,
prune=PRUNE,
prune_depth=PRUNE_DEPTH,
smooth=SMOOTH,
Expand Down Expand Up @@ -310,7 +286,7 @@ def build_neurograph(

# Build neurograph
t0 = time()
neurograph = NeuroGraph(bbox=bbox, img_path=img_path)
neurograph = NeuroGraph(bbox=bbox, img_path=img_path, swc_paths=swc_paths)
for key in swc_dicts.keys():
neurograph.add_immutables(irreducibles[key], swc_dicts[key], key)
print(f" --> add_irreducibles(): {time() - t0} seconds")
Expand Down
14 changes: 7 additions & 7 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ class NeuroGraph(nx.Graph):
"""

def __init__(
self, bbox=None, swc_dir=None, img_path=None, label_mask=None
self, bbox=None, swc_paths=None, img_path=None, label_mask=None
):
super(NeuroGraph, self).__init__()
# Initialize paths
self.img_path = img_path
self.label_mask = label_mask
self.swc_paths = swc_dir
self.swc_paths = swc_paths

# Initialize node and edge sets
self.leafs = set()
Expand Down Expand Up @@ -204,7 +204,7 @@ def __get_proposals(
best_dist = dict()
query_swc_id = self.nodes[query_id]["swc_id"]
for xyz in self._query_kdtree(query_xyz, search_radius):
if not self.is_contained(xyz):
if not self.is_contained(xyz, buffer=36):
continue
xyz = tuple(xyz)
edge = self.xyz_to_edge[xyz]
Expand Down Expand Up @@ -336,8 +336,8 @@ def orient_edge(self, edge, i):
# --- Ground Truth Generation ---
def init_targets(self, target_neurograph):
# Initializations
msg = "Error: Provide swc_dir/swc_paths to initialize target edges!"
assert target_neurograph.swc_path, msg
msg = "Provide swc_dir/swc_paths to initialize target edges!"
assert target_neurograph.swc_paths, msg
target_neurograph.init_densegraph()
target_neurograph.init_kdtree()
self.target_edges = set()
Expand Down Expand Up @@ -534,11 +534,11 @@ def get_projection(self, xyz):
def is_nb(self, i, j):
return True if i in self.neighbors(j) else False

def is_contained(self, node_or_xyz):
def is_contained(self, node_or_xyz, buffer=0):
if self.bbox:
if type(node_or_xyz) == int:
node_or_xyz = deepcopy(self.nodes[node_or_xyz]["xyz"])
return utils.is_contained(self.bbox, node_or_xyz)
return utils.is_contained(self.bbox, node_or_xyz, buffer=buffer)
else:
return True

Expand Down
15 changes: 8 additions & 7 deletions src/deep_neurographs/swc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@
def process_local_paths(paths, min_size, bbox=None):
swc_dicts = dict()
for path in paths:
swc_dicts.update(parse_local_swc(path, bbox=bbox, min_size=min_size))
swc_dict = parse_local_swc(path, bbox=bbox, min_size=min_size)
if len(swc_dict):
swc_id = utils.get_swc_id(path)
swc_dicts[swc_id] = swc_dict
return swc_dicts


def parse_local_swc(path, bbox=None, min_size=25):
swc_id = utils.get_swc_id(path)
swc_contents = read_from_local(path)
swc_dict = parse(swc_contents, bbox=bbox)
return {swc_id: swc_dict} if len(swc_dict["id"]) > min_size else dict()
def parse_local_swc(path, bbox=None, min_size=0):
contents = read_from_local(path)
return parse(contents, bbox=bbox) if len(contents) > min_size else []


def parse_gcs_zip(zip_file, path, min_size=0):
Expand Down Expand Up @@ -80,7 +81,7 @@ def parse(swc_contents, bbox=None):
swc_dict["id"][i] -= min_id
swc_dict["pid"][i] -= min_id

return swc_dict
return swc_dict if len(swc_dict["id"]) > 1 else []


def read_from_local(path):
Expand Down
4 changes: 2 additions & 2 deletions src/deep_neurographs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,10 @@ def get_avg_std(data, weights=None):
return avg, math.sqrt(var)


def is_contained(bbox, xyz):
def is_contained(bbox, xyz, buffer=0):
xyz = apply_anisotropy(xyz - bbox["min"])
shape = bbox["max"] - bbox["min"]
if any(xyz < 0) or any(xyz >= shape):
if any(xyz < buffer) or any(xyz >= shape - buffer):
return False
else:
return True
Expand Down

0 comments on commit 1832f4a

Please sign in to comment.