Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: optimize swc parser #29

Merged
merged 1 commit into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/deep_neurographs/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def upd_branch_endpoint(edges, key, old_xyz, new_xyz):

# -- attribute utils --
def __init_edge_attrs(swc_dict, i):
#print(len(swc_dict["radius"]), i)
return {"radius": [swc_dict["radius"][i]], "xyz": [swc_dict["xyz"][i]]}


Expand Down
2 changes: 2 additions & 0 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def build_neurograph_from_local(
assert swc_dir or swc_paths, "Provide swc_dir or swc_paths!"
bbox = utils.get_bbox(img_patch_origin, img_patch_shape)
paths = get_paths(swc_dir) if swc_dir else swc_paths
t0 = time()
swc_dicts = process_local_paths(paths, min_size, bbox=bbox)
print(f"build_neurograph_from_local(): {time() - t0} seconds")

# Build neurograph
t0 = time()
Expand Down
97 changes: 75 additions & 22 deletions src/deep_neurographs/swc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import networkx as nx
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
from time import time
from zipfile import ZipFile

from deep_neurographs import geometry_utils
Expand Down Expand Up @@ -49,17 +50,21 @@ def process_gsc_zip(bucket, zip_path, min_size=0):

def parse_local_swc(path, bbox=None, min_size=0):
contents = read_from_local(path)
swc_dict = parse(contents) if len(contents) > min_size else {"id": [-1]}
parse_bool = len(contents) > min_size
if parse_bool:
swc_dict = parse(contents, bbox=bbox) if bbox else fast_parse(contents)
else:
swc_dict = {"id": [-1]}
return utils.get_swc_id(path), swc_dict


def parse_gcs_zip(zip_file, path, min_size=0):
contents = read_from_gcs_zip(zip_file, path)
swc_dict = parse(contents) if len(contents) > min_size else {"id": [-1]}
swc_dict = fast_parse(contents) if len(contents) > min_size else {"id": [-1]}
return utils.get_swc_id(path), swc_dict


def parse(swc_contents, bbox=None):
def parse(contents, bbox=None):
"""
Parses an swc file to extract the contents which is stored in a dict. Note
that node_ids from swc are refactored to index from 0 to n-1 where n is
Expand All @@ -79,33 +84,81 @@ def parse(swc_contents, bbox=None):
min_id = np.inf
offset = [0, 0, 0]
swc_dict = {"id": [], "radius": [], "pid": [], "xyz": []}
for line in swc_contents:
if line.startswith("# OFFSET"):
parts = line.split()
offset = read_xyz(parts[2:5])
if not line.startswith("#"):
parts = line.split()
xyz = read_xyz(parts[2:5], offset=offset)
if bbox:
if not utils.is_contained(bbox, xyz):
break

swc_dict["id"].append(int(parts[0]))
swc_dict["radius"].append(float(parts[-2]))
swc_dict["pid"].append(int(parts[-1]))
swc_dict["xyz"].append(xyz)
if swc_dict["id"][-1] < min_id:
min_id = swc_dict["id"][-1]
contents, swc_id = get_contents(contents)
for line in contents:
parts = line.split()
xyz = read_xyz(parts[2:5], offset=offset)
if bbox:
if not utils.is_contained(bbox, xyz):
break
swc_dict["id"].append(int(parts[0]))
swc_dict["radius"].append(float(parts[-2]))
swc_dict["pid"].append(int(parts[-1]))
swc_dict["xyz"].append(xyz)
if swc_dict["id"][-1] < min_id:
min_id = swc_dict["id"][-1]

# Reindex from zero
swc_dict["radius"] = np.array(swc_dict["radius"])
for i in range(len(swc_dict["id"])):
swc_dict["id"][i] -= min_id
swc_dict["pid"][i] -= min_id

return swc_dict if len(swc_dict["id"]) > 1 else {"id": [-1]}


def fast_parse(contents):
"""
Parses an swc file to extract the contents which is stored in a dict. Note
that node_ids from swc are refactored to index from 0 to n-1 where n is
the number of entries in the swc file.

Parameters
----------
path : str
Path to an swc file.
...

Returns
-------
...

"""
contents, offset = get_contents(contents)
dtype = np.int16 if len(contents) < 2 ** 16 else np.int32
swc_dict = {
"id": np.zeros((len(contents)), dtype=dtype),
"radius": np.zeros((len(contents)), dtype=np.float16),
"pid": np.zeros((len(contents)), dtype=dtype),
"xyz": []
}

min_id = np.inf
for i, line in enumerate(contents):
parts = line.split()
xyz = read_xyz(parts[2:5], offset=offset)
swc_dict["id"][i] = int(parts[0])
swc_dict["radius"][i] = float(parts[-2])
swc_dict["pid"][i] = int(parts[-1])
swc_dict["xyz"].append(xyz)

# Reindex from zero
min_id = np.min(swc_dict["id"])
swc_dict["id"] -= min_id
swc_dict["pid"] -= min_id
return swc_dict


def get_contents(swc_contents):
offset = [0, 0, 0]
for i, line in enumerate(swc_contents):
if line.startswith("# OFFSET"):
parts = line.split()
offset = read_xyz(parts[2:5])
if not line.startswith("#"):
break
return swc_contents[i:], offset


def read_from_local(path):
"""
Reads swc file stored at "path" on local machine.
Expand Down Expand Up @@ -150,7 +203,7 @@ def read_xyz(xyz, offset=[0, 0, 0]):
The (x,y,z) coordinates from an swc file.

"""
return tuple([float(xyz[i]) + offset[i] for i in range(3)])
return tuple([np.float32(xyz[i]) + offset[i] for i in range(3)])


def write(path, contents):
Expand Down
Loading