Skip to content

Commit

Permalink
bug: fixed graph traversal (#49)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Feb 6, 2024
1 parent a420156 commit 1d46003
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 14 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dynamic = ["version"]

dependencies = [
'boto3',
'fastremap',
'lightning',
'more_itertools',
'networkx',
Expand Down
19 changes: 12 additions & 7 deletions src/deep_neurographs/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,13 @@ def get_irreducibles(swc_dict, swc_id=None, prune=True, depth=16, smooth=True):
"""
# Build dense graph
swc_dict["idx"] = dict(zip(swc_dict["id"], range(len(swc_dict["id"]))))
dense_graph = swc_utils.to_graph(swc_dict)
if prune:
dense_graph = prune_short_branches(dense_graph, depth)

# Extract nodes
leafs, junctions = get_irreducible_nodes(dense_graph, swc_dict)
leafs, junctions = get_irreducible_nodes(dense_graph)
if len(leafs) == 0:
return False, None

Expand Down Expand Up @@ -100,7 +101,7 @@ def get_irreducibles(swc_dict, swc_id=None, prune=True, depth=16, smooth=True):
return swc_id, irreducibles


def get_irreducible_nodes(graph, swc_dict):
def get_irreducible_nodes(graph):
"""
Gets irreducible nodes (i.e. leafs and junctions) of a graph.
Expand Down Expand Up @@ -315,7 +316,8 @@ def init_edge_attrs(swc_dict, i):
Edge attribute dictionary.
"""
return {"radius": [swc_dict["radius"][i]], "xyz": [swc_dict["xyz"][i]]}
j = swc_dict["idx"][i]
return {"radius": [swc_dict["radius"][j]], "xyz": [swc_dict["xyz"][j]]}


def upd_edge_attrs(swc_dict, attrs, i):
Expand All @@ -337,8 +339,9 @@ def upd_edge_attrs(swc_dict, attrs, i):
Edge attribute dictionary.
"""
attrs["radius"].append(swc_dict["radius"][i])
attrs["xyz"].append(swc_dict["xyz"][i])
j = swc_dict["idx"][i]
attrs["radius"].append(swc_dict["radius"][j])
attrs["xyz"].append(swc_dict["xyz"][j])
return attrs


Expand Down Expand Up @@ -403,7 +406,8 @@ def set_node_attrs(swc_dict, nodes):
"""
attrs = dict()
for i in nodes:
attrs[i] = {"radius": swc_dict["radius"][i], "xyz": swc_dict["xyz"][i]}
j = swc_dict["idx"][i]
attrs[i] = {"radius": swc_dict["radius"][j], "xyz": swc_dict["xyz"][j]}
return attrs


Expand Down Expand Up @@ -434,7 +438,8 @@ def upd_node_attrs(swc_dict, leafs, junctions, i):
Updated dictionary if "i" was contained in "junctions.keys()".
"""
upd_attrs = {"radius": swc_dict["radius"][i], "xyz": swc_dict["xyz"][i]}
j = swc_dict["idx"][i]
upd_attrs = {"radius": swc_dict["radius"][j], "xyz": swc_dict["xyz"][j]}
if i in leafs:
leafs[i] = upd_attrs
else:
Expand Down
2 changes: 2 additions & 0 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def download_gcs_zips(bucket_name, cloud_path, min_size):
cnt, t1 = report_progress(
i, len(zip_paths), chunk_size, cnt, t0, t1
)
if len(swc_dicts) > 2000:
stop
return swc_dicts


Expand Down
20 changes: 13 additions & 7 deletions src/deep_neurographs/swc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@ def parse_local_swc(path, bbox=None, min_size=0):

def parse_gcs_zip(zip_file, path, min_size=0):
contents = read_from_gcs_zip(zip_file, path)
swc_dict = (
fast_parse(contents) if len(contents) > min_size else {"id": [-1]}
)
parse_bool = len(contents) > min_size
swc_dict = fast_parse(contents) if parse_bool else {"id": [-1]}
return utils.get_swc_id(path), swc_dict


Expand Down Expand Up @@ -143,9 +142,14 @@ def fast_parse(contents):
min_id = np.min(swc_dict["id"])
swc_dict["id"] -= min_id
swc_dict["pid"] -= min_id
swc_dict["radius"] /= 1000.0
return swc_dict


def reindex(arr, idxs):
return arr[idxs]


def get_contents(swc_contents):
offset = [0, 0, 0]
for i, line in enumerate(swc_contents):
Expand Down Expand Up @@ -333,10 +337,12 @@ def to_graph(swc_dict, graph_id=None, set_attrs=False):


def __add_attributes(swc_dict, graph):
xyz = swc_dict["xyz"]
radii = swc_dict["radius"]
attrs = [{"xyz": xyz[i], "radius": radii[i]} for i in graph.nodes]
nx.set_node_attributes(graph, dict(zip(swc_dict["id"], attrs)))
attrs = dict()
for idx, node_id in enumerate(swc_dict["id"]):
attrs[node_id] = {
"xyz": swc_dict["xyz"][idx], "radius": swc_dict["radius"][idx]
}
nx.set_node_attributes(graph, attrs)
return graph


Expand Down
12 changes: 12 additions & 0 deletions src/deep_neurographs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import shutil
from copy import deepcopy
from io import BytesIO
from skimage.color import label2rgb
from time import time
from zipfile import ZipFile

Expand Down Expand Up @@ -425,6 +426,11 @@ def write_json(path, contents):
json.dump(contents, f)


def write_txt(path, contents):
f = open(path, "w")
f.write(contents)
f.close()

# --- coordinate conversions ---
def world_to_img(neurograph, node_or_xyz):
if type(node_or_xyz) == int:
Expand Down Expand Up @@ -500,6 +506,12 @@ def get_img_mip(img, axis=0):
return np.max(img, axis=axis)


def get_labels_mip(img, axis=0):
mip = np.max(img, axis=axis)
mip = label2rgb(mip)
return (255 * mip).astype(np.uint8)


def normalize_img(img):
img -= np.min(img)
return img / np.max(img)
Expand Down

0 comments on commit 1d46003

Please sign in to comment.