diff --git a/src/deep_neurographs/graph_utils.py b/src/deep_neurographs/graph_utils.py index 3ad9e21..fae3b5b 100644 --- a/src/deep_neurographs/graph_utils.py +++ b/src/deep_neurographs/graph_utils.py @@ -230,6 +230,7 @@ def upd_branch_endpoint(edges, key, old_xyz, new_xyz): # -- attribute utils -- def __init_edge_attrs(swc_dict, i): + #print(len(swc_dict["radius"]), i) return {"radius": [swc_dict["radius"][i]], "xyz": [swc_dict["xyz"][i]]} diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index 94bf551..87581ce 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -55,7 +55,9 @@ def build_neurograph_from_local( assert swc_dir or swc_paths, "Provide swc_dir or swc_paths!" bbox = utils.get_bbox(img_patch_origin, img_patch_shape) paths = get_paths(swc_dir) if swc_dir else swc_paths + t0 = time() swc_dicts = process_local_paths(paths, min_size, bbox=bbox) + print(f"build_neurograph_from_local(): {time() - t0} seconds") # Build neurograph t0 = time() diff --git a/src/deep_neurographs/swc_utils.py b/src/deep_neurographs/swc_utils.py index c9b03e6..f468229 100644 --- a/src/deep_neurographs/swc_utils.py +++ b/src/deep_neurographs/swc_utils.py @@ -13,6 +13,7 @@ import networkx as nx import numpy as np from concurrent.futures import ThreadPoolExecutor, as_completed +from time import time from zipfile import ZipFile from deep_neurographs import geometry_utils @@ -49,17 +50,21 @@ def process_gsc_zip(bucket, zip_path, min_size=0): def parse_local_swc(path, bbox=None, min_size=0): contents = read_from_local(path) - swc_dict = parse(contents) if len(contents) > min_size else {"id": [-1]} + parse_bool = len(contents) > min_size + if parse_bool: + swc_dict = parse(contents, bbox=bbox) if bbox else fast_parse(contents) + else: + swc_dict = {"id": [-1]} return utils.get_swc_id(path), swc_dict def parse_gcs_zip(zip_file, path, min_size=0): contents = read_from_gcs_zip(zip_file, path) - swc_dict = parse(contents) if len(contents) > min_size else {"id": [-1]} + swc_dict = fast_parse(contents) if len(contents) > min_size else {"id": [-1]} return utils.get_swc_id(path), swc_dict -def parse(swc_contents, bbox=None): +def parse(contents, bbox=None): """ Parses an swc file to extract the contents which is stored in a dict. Note that node_ids from swc are refactored to index from 0 to n-1 where n is @@ -79,26 +84,21 @@ def parse(swc_contents, bbox=None): min_id = np.inf offset = [0, 0, 0] swc_dict = {"id": [], "radius": [], "pid": [], "xyz": []} - for line in swc_contents: - if line.startswith("# OFFSET"): - parts = line.split() - offset = read_xyz(parts[2:5]) - if not line.startswith("#"): - parts = line.split() - xyz = read_xyz(parts[2:5], offset=offset) - if bbox: - if not utils.is_contained(bbox, xyz): - break - - swc_dict["id"].append(int(parts[0])) - swc_dict["radius"].append(float(parts[-2])) - swc_dict["pid"].append(int(parts[-1])) - swc_dict["xyz"].append(xyz) - if swc_dict["id"][-1] < min_id: - min_id = swc_dict["id"][-1] + contents, swc_id = get_contents(contents) + for line in contents: + parts = line.split() + xyz = read_xyz(parts[2:5], offset=offset) + if bbox: + if not utils.is_contained(bbox, xyz): + break + swc_dict["id"].append(int(parts[0])) + swc_dict["radius"].append(float(parts[-2])) + swc_dict["pid"].append(int(parts[-1])) + swc_dict["xyz"].append(xyz) + if swc_dict["id"][-1] < min_id: + min_id = swc_dict["id"][-1] # Reindex from zero - swc_dict["radius"] = np.array(swc_dict["radius"]) for i in range(len(swc_dict["id"])): swc_dict["id"][i] -= min_id swc_dict["pid"][i] -= min_id @@ -106,6 +106,59 @@ def parse(swc_contents, bbox=None): return swc_dict if len(swc_dict["id"]) > 1 else {"id": [-1]} +def fast_parse(contents): + """ + Parses an swc file to extract the contents which is stored in a dict. Note + that node_ids from swc are refactored to index from 0 to n-1 where n is + the number of entries in the swc file. + + Parameters + ---------- + path : str + Path to an swc file. + ... + + Returns + ------- + ... + + """ + contents, offset = get_contents(contents) + dtype = np.int16 if len(contents) < 2 ** 16 else np.int32 + swc_dict = { + "id": np.zeros((len(contents)), dtype=dtype), + "radius": np.zeros((len(contents)), dtype=np.float16), + "pid": np.zeros((len(contents)), dtype=dtype), + "xyz": [] + } + + min_id = np.inf + for i, line in enumerate(contents): + parts = line.split() + xyz = read_xyz(parts[2:5], offset=offset) + swc_dict["id"][i] = int(parts[0]) + swc_dict["radius"][i] = float(parts[-2]) + swc_dict["pid"][i] = int(parts[-1]) + swc_dict["xyz"].append(xyz) + + # Reindex from zero + min_id = np.min(swc_dict["id"]) + swc_dict["id"] -= min_id + swc_dict["pid"] -= min_id + return swc_dict + + +def get_contents(swc_contents): + offset = [0, 0, 0] + for i, line in enumerate(swc_contents): + if line.startswith("# OFFSET"): + parts = line.split() + offset = read_xyz(parts[2:5]) + if not line.startswith("#"): + break + return swc_contents[i:], offset + + def read_from_local(path): """ Reads swc file stored at "path" on local machine. @@ -150,7 +203,7 @@ def read_xyz(xyz, offset=[0, 0, 0]): The (x,y,z) coordinates from an swc file. """ - return tuple([float(xyz[i]) + offset[i] for i in range(3)]) + return tuple([np.float32(xyz[i]) + offset[i] for i in range(3)]) def write(path, contents):