Skip to content

Commit

Permalink
major upd : added gcs support (#25)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Jan 9, 2024
1 parent bd01b92 commit 5a535fd
Show file tree
Hide file tree
Showing 6 changed files with 329 additions and 114 deletions.
3 changes: 1 addition & 2 deletions src/deep_neurographs/geometry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_midpoint(xyz_1, xyz_2):

# Smoothing
def smooth_branch(xyz, s=None):
if xyz.shape[0] > 5:
if xyz.shape[0] > 8:
t = np.linspace(0, 1, xyz.shape[0])
spline_x, spline_y, spline_z = fit_spline(xyz, s=s)
xyz = np.column_stack((spline_x(t), spline_y(t), spline_z(t)))
Expand Down Expand Up @@ -95,7 +95,6 @@ def fill_path(img, path, val=-1):
for xyz in path:
x, y, z = tuple(np.floor(xyz).astype(int))
img[x - 1 : x + 2, y - 1 : y + 2, z - 1 : z + 2] = val
# img[x, y, z] = val
return img


Expand Down
22 changes: 11 additions & 11 deletions src/deep_neurographs/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,6 @@
from deep_neurographs import swc_utils, utils


def get_irreducibles(graph):
leafs = []
junctions = []
for i in graph.nodes:
if graph.degree[i] == 1:
leafs.append(i)
elif graph.degree[i] > 2:
junctions.append(i)
return leafs, junctions


def extract_irreducible_graph(swc_dict, prune=True, prune_depth=16):
graph = swc_utils.file_to_graph(swc_dict)
leafs, junctions = get_irreducibles(graph)
Expand All @@ -40,6 +29,17 @@ def extract_irreducible_graph(swc_dict, prune=True, prune_depth=16):
return leafs, junctions, irreducible_edges


def get_irreducibles(graph):
leafs = []
junctions = []
for i in graph.nodes:
if graph.degree[i] == 1:
leafs.append(i)
elif graph.degree[i] > 2:
junctions.append(i)
return leafs, junctions


def extract_irreducible_edges(
graph, leafs, junctions, swc_dict, prune=True, prune_depth=16
):
Expand Down
262 changes: 203 additions & 59 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,98 +9,242 @@
"""

import os
import concurrent.futures

from concurrent.futures import ThreadPoolExecutor, as_completed
from google.cloud import storage
from io import BytesIO
from deep_neurographs import swc_utils, utils
from deep_neurographs.neurograph import NeuroGraph
from time import time
from zipfile import ZipFile

N_PROPOSALS_PER_LEAF = 3
OPTIMIZE_ALIGNMENT = False
OPTIMIZE_DEPTH = 15
PRUNE = True
PRUNE_DEPTH = 16
SEARCH_RADIUS = 0
SIZE_THRESHOLD = 100
SMOOTH = False


# --- Build graph ---
def build_neurograph(
swc_dir,
anisotropy=[1.0, 1.0, 1.0],
def build_neurograph_from_local(
swc_dir=None,
swc_paths=None,
img_patch_shape=None,
img_patch_origin=None,
img_path=None,
size_threshold=40,
num_proposals=3,
search_radius=25.0,
prune=True,
prune_depth=16,
optimize_depth=15,
optimize_alignment=True,
optimize_path=False,
origin=None,
shape=None,
smooth=True,
n_proposals_per_leaf=N_PROPOSALS_PER_LEAF,
prune=PRUNE,
prune_depth=PRUNE_DEPTH,
optimize_alignment=OPTIMIZE_ALIGNMENT,
optimize_depth=OPTIMIZE_DEPTH,
search_radius=SEARCH_RADIUS,
size_threshold=SIZE_THRESHOLD,
smooth=SMOOTH,
):
"""
Builds a neurograph from a directory of swc files, where each swc
represents a neuron and these neurons are assumed to be near each
other.
"""
assert utils.xor(swc_dir, swc_list), "Error: provide swc_dir or swc_paths"
neurograph = NeuroGraph(
swc_dir,
img_path=img_path,
optimize_depth=optimize_depth,
optimize_alignment=optimize_alignment,
optimize_path=optimize_path,
origin=origin,
shape=shape,
origin=img_patch_origin,
shape=img_patch_shape,
)
neurograph = init_immutables(
neurograph = init_immutables_from_local(
neurograph,
anisotropy=anisotropy,
swc_dir=swc_dir,
swc_paths=swc_paths,
prune=prune,
prune_depth=prune_depth,
size_threshold=size_threshold,
smooth=smooth,
)
if search_radius > 0:
neurograph.generate_proposals(
num_proposals=num_proposals, search_radius=search_radius
n_proposals_per_leaf=n_proposals_per_leaf,
search_radius=search_radius
)
return neurograph


def init_immutables(
def build_neurograph_from_gcs_zips(
bucket_name,
cloud_path,
img_path=None,
size_threshold=SIZE_THRESHOLD,
n_proposals_per_leaf=N_PROPOSALS_PER_LEAF,
search_radius=SEARCH_RADIUS,
prune=PRUNE,
prune_depth=PRUNE_DEPTH,
optimize_alignment=OPTIMIZE_ALIGNMENT,
optimize_depth=OPTIMIZE_DEPTH,
smooth=SMOOTH,
):
neurograph = NeuroGraph(
img_path=img_path,
optimize_alignment=optimize_alignment,
optimize_depth=optimize_depth,
)
neurograph = init_immutables_from_gcs_zips(
neurograph,
bucket_name,
cloud_path,
prune=prune,
prune_depth=prune_depth,
size_threshold=size_threshold,
smooth=smooth,
)
if search_radius > 0:
neurograph.generate_proposals(
n_proposals_per_leaf=n_proposals_per_leaf,
search_radius=search_radius
)
return neurograph


def init_immutables_from_local(
neurograph,
anisotropy=[1.0, 1.0, 1.0],
prune=True,
prune_depth=16,
size_threshold=40,
smooth=True,
swc_dir=None,
swc_paths=None,
prune=PRUNE,
prune_depth=PRUNE_DEPTH,
size_threshold=SIZE_THRESHOLD,
smooth=SMOOTH,
):
"""
To do...
"""
swc_paths = get_paths(swc_dir) if swc_dir else swc_paths
for path in swc_paths:
neurograph.ingest_swc_from_local(
path,
prune=True,
prune_depth=16,
smooth=smooth,
)
return neurograph


for path in get_paths(neurograph.path):
swc_id = get_id(path)
swc_dict = swc_utils.parse(
def get_paths(swc_dir):
swc_paths = []
for f in utils.listdir(swc_dir, ext=".swc"):
paths.append(os.path.join(swc_dir, f))
return paths


def init_immutables_from_gcs_zips(
neurograph,
bucket_name,
cloud_path,
prune=PRUNE,
prune_depth=PRUNE_DEPTH,
size_threshold=SIZE_THRESHOLD,
smooth=SMOOTH,
):
# Initializations
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
zip_paths = list_gcs_filenames(bucket, cloud_path, ".zip")
n_swc_files = 2080791 #count_files_in_zips(bucket, zip_paths)
chunk_size = int(n_swc_files * 0.05)
print("# zip files:", len(zip_paths))
print(f"# swc files: {utils.reformat_number(n_swc_files)} \n\n")

# Parse
cnt = 1
t0 = time()
t1 = time()
n_files_completed = 0
print(f"-- Starting Multithread Reads with chunk_size={chunk_size} -- \n")
for path in zip_paths:
# Add to neurograph
swc_dicts = process_gcs_zip(
bucket,
path,
anisotropy=anisotropy,
bbox=neurograph.bbox,
img_shape=neurograph.shape,
)
if len(swc_dict["xyz"]) < size_threshold:
continue
if smooth:
swc_dict = swc_utils.smooth(swc_dict)
neurograph.generate_immutables(
swc_id, swc_dict, prune=prune, prune_depth=prune_depth
)
with concurrent.futures.ProcessPoolExecutor() as executor:
swc_dicts = list(executor.map(swc_utils.smooth, swc_dicts))

# Readout progress
n_files_completed += len(swc_dicts)
if n_files_completed > cnt * chunk_size:
report_runtimes(
n_swc_files,
n_files_completed,
chunk_size,
time() - t1,
time() - t0,
)
cnt += 1
t1 = time()
t, unit = utils.time_writer(time() - t0)
print(f"Total Runtime: {round(t, 4)} {unit}")
return neurograph


def get_paths(path_or_list):
if type(path_or_list) == str:
paths = []
for f in utils.listdir(path_or_list, ext=".swc"):
paths.append(os.path.join(path_or_list, f))
return paths
elif type(path_or_list) == list:
return path_or_list
def count_files_in_zips(bucket, zip_paths):
t0 = time()
file_cnt = 0
for zip_path in zip_paths:
zip_blob = bucket.blob(zip_path)
zip_content = zip_blob.download_as_bytes()
file_paths = list_files_in_gcs_zip(zip_content)
file_cnt += len(file_paths)
return file_cnt


def list_files_in_gcs_zip(zip_content):
"""
Lists all files in a zip file stored in a GCS bucket.
def get_id(path):
filename = path.split("/")[-1]
return filename.replace(".0.swc", "")
"""
with ZipFile(BytesIO(zip_content), 'r') as zip_file:
return zip_file.namelist()


def list_gcs_filenames(bucket, cloud_path, extension):
"""
Lists all files in a GCS bucket with the given extension.
"""
blobs = bucket.list_blobs(prefix=cloud_path)
return [blob.name for blob in blobs if extension in blob.name]


def process_gcs_zip(bucket, zip_path):
# Get filenames
zip_blob = bucket.blob(zip_path)
zip_content = zip_blob.download_as_bytes()
swc_paths = list_files_in_gcs_zip(zip_content)

# Read files
t0 = time()
swc_dicts = [None] * len(swc_paths)
with ZipFile(BytesIO(zip_content)) as zip_file:
with ThreadPoolExecutor() as executor:
results = [
executor.submit(swc_utils.parse_gcs_zip, zip_file, path)
for path in swc_paths
]
for i, result_i in enumerate(as_completed(results)):
swc_dicts[i] = result_i.result()
return swc_dicts


def report_runtimes(
n_files,
n_files_completed,
chunk_size,
chunk_runtime,
total_runtime,
):
n_files_remaining = n_files - n_files_completed
file_rate = chunk_runtime / chunk_size
eta = (total_runtime + n_files_remaining * file_rate) / 60
files_processed = f"{n_files_completed - chunk_size}-{n_files_completed}"
print(f"Completed: {round(100 * n_files_completed / n_files, 2)}%")
print(f"Runtime for Files : {files_processed} {round(chunk_runtime, 4)} seconds")
print(f"File Processing Rate: {file_rate} seconds")
print(f"Approximate Total Runtime: {round(eta, 4)} minutes")
print("")
Loading

0 comments on commit 5a535fd

Please sign in to comment.