Skip to content

Commit

Permalink
upd : gcs intake (#27)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Jan 10, 2024
1 parent 62d0c90 commit 83c3e4e
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 104 deletions.
4 changes: 1 addition & 3 deletions src/deep_neurographs/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@


def extract_irreducible_graph(swc_dict, prune=True, prune_depth=16):
graph = swc_utils.file_to_graph(swc_dict)
graph = swc_utils.to_graph(swc_dict)
leafs, junctions = get_irreducibles(graph)
irreducible_edges, leafs = extract_irreducible_edges(
graph, leafs, junctions, swc_dict, prune=prune, prune_depth=prune_depth
)

# Check irreducility holds after pruning
if prune:
irreducible_edges, junctions = check_irreducibility(
junctions, irreducible_edges
Expand Down
220 changes: 149 additions & 71 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@

from google.cloud import storage

from deep_neurographs import swc_utils, utils
from deep_neurographs import graph_utils as gutils, swc_utils, utils
from deep_neurographs.neurograph import NeuroGraph
from deep_neurographs.swc_utils import parse_gcs_zip

N_PROPOSALS_PER_LEAF = 3
OPTIMIZE_ALIGNMENT = False
OPTIMIZE_DEPTH = 15
PRUNE = True
PRUNE_DEPTH = 16
SEARCH_RADIUS = 0
SIZE_THRESHOLD = 100
MIN_SIZE = 30
SMOOTH = False


Expand All @@ -43,7 +44,7 @@ def build_neurograph_from_local(
optimize_alignment=OPTIMIZE_ALIGNMENT,
optimize_depth=OPTIMIZE_DEPTH,
search_radius=SEARCH_RADIUS,
size_threshold=SIZE_THRESHOLD,
min_size=MIN_SIZE,
smooth=SMOOTH,
):
assert utils.xor(swc_dir, swc_paths), "Error: provide swc_dir or swc_paths"
Expand All @@ -61,7 +62,7 @@ def build_neurograph_from_local(
swc_paths=swc_paths,
prune=prune,
prune_depth=prune_depth,
size_threshold=size_threshold,
min_size=min_size,
smooth=smooth,
)
if search_radius > 0:
Expand All @@ -76,7 +77,7 @@ def build_neurograph_from_gcs_zips(
bucket_name,
cloud_path,
img_path=None,
size_threshold=SIZE_THRESHOLD,
min_size=MIN_SIZE,
n_proposals_per_leaf=N_PROPOSALS_PER_LEAF,
search_radius=SEARCH_RADIUS,
prune=PRUNE,
Expand All @@ -85,19 +86,52 @@ def build_neurograph_from_gcs_zips(
optimize_depth=OPTIMIZE_DEPTH,
smooth=SMOOTH,
):
neurograph = NeuroGraph(
"""
Builds a neurograph from a GCS bucket that contain of zips of swc files.
Parameters
----------
bucket_name : str
Name of GCS bucket where zips are stored.
cloud_path : str
Path within GCS bucket to directory containing zips.
img_path : str
Path to image stored GCS Bucket that swc files were generated from.
min_size : int
Minimum path length of swc files which are stored.
n_proposals_per_leaf : int
Number of edge proposals generated from each leaf node in an swc file.
search_radius : float
Maximum Euclidean length of an edge proposal.
prune : bool
Indication of whether to prune short branches.
prune_depth : int
Branches less than "prune_depth" microns are pruned if "prune" is
True.
optimize_alignment : bool
Indication of whether to optimize alignment of edge proposals to image
signal.
optimize_depth : int
Distance from each edge proposal end point that is search during
alignment optimization.
smooth : bool
Indication of whether to smooth branches from swc files.
Returns
-------
neurograph : NeuroGraph
Neurograph generated from zips of swc files stored in a GCS bucket.
"""
swc_dicts = download_gcs_zips(bucket_name, cloud_path, min_size=min_size)
neurograph = build_neurograph(
swc_dicts,
img_path=img_path,
optimize_alignment=optimize_alignment,
optimize_depth=optimize_depth,
)
neurograph = init_immutables_from_gcs_zips(
neurograph,
bucket_name,
cloud_path,
prune=prune,
prune_depth=prune_depth,
size_threshold=size_threshold,
smooth=smooth,
optimize_alignment=OPTIMIZE_ALIGNMENT,
optimize_depth=OPTIMIZE_DEPTH,
)
if search_radius > 0:
neurograph.generate_proposals(
Expand All @@ -113,7 +147,7 @@ def init_immutables_from_local(
swc_paths=None,
prune=PRUNE,
prune_depth=PRUNE_DEPTH,
size_threshold=SIZE_THRESHOLD,
min_size=MIN_SIZE,
smooth=SMOOTH,
):
swc_paths = get_paths(swc_dir) if swc_dir else swc_paths
Expand All @@ -131,52 +165,66 @@ def get_paths(swc_dir):
return paths


def init_immutables_from_gcs_zips(
neurograph,
def download_gcs_zips(
bucket_name,
cloud_path,
prune=PRUNE,
prune_depth=PRUNE_DEPTH,
size_threshold=SIZE_THRESHOLD,
smooth=SMOOTH,
min_size=0,
):
"""
Downloads swc files from zips stored in a GCS bucket.
Parameters
----------
bucket_name : str
Name of GCS bucket where zips are stored.
cloud_path : str
Path within GCS bucket to directory containing zips.
min_size : int
Minimum path length of swc files which are stored.
Returns
-------
swc_dicts : list
"""
# Initializations
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
zip_paths = list_gcs_filenames(bucket, cloud_path, ".zip")
n_swc_files = 2080791 # count_files_in_zips(bucket, zip_paths)
chunk_size = int(n_swc_files * 0.05)
print("# zip files:", len(zip_paths))
print(f"# swc files: {utils.reformat_number(n_swc_files)} \n\n")
chunk_size = int(len(zip_paths) * 0.1)
print(f"# zip files: {len(zip_paths)} \n\n", )

# Parse
cnt = 1
t0 = time()
t1 = time()
n_files_completed = 0
swc_dicts = []
print(f"-- Starting Multithread Reads with chunk_size={chunk_size} -- \n")
for path in zip_paths:
# Add to neurograph
swc_dicts = process_gcs_zip(bucket, path)
if smooth:
with concurrent.futures.ProcessPoolExecutor() as executor:
swc_dicts = list(executor.map(swc_utils.smooth, swc_dicts))

# Readout progress
n_files_completed += len(swc_dicts)
if n_files_completed > cnt * chunk_size:
report_runtimes(
n_swc_files,
n_files_completed,
chunk_size,
time() - t1,
time() - t0,
)
cnt += 1
for i, path in enumerate(zip_paths):
swc_dict_i = download_zip(bucket, path, min_size=min_size)
swc_dicts.extend(swc_dict_i)
if i > cnt * chunk_size:
report_runtimes(len(zip_paths), i, chunk_size, t0, t1)
t1 = time()
cnt += 1
break
t, unit = utils.time_writer(time() - t0)
print(f"Total Runtime: {round(t, 4)} {unit}")
return neurograph
print("# connected components:", len(swc_dicts))
print(f"Download Runtime: {round(t, 4)} {unit}")
return swc_dicts


def download_zip(bucket, zip_path, min_size=0):
zip_blob = bucket.blob(zip_path)
zip_content = zip_blob.download_as_bytes()
with ZipFile(BytesIO(zip_content)) as zip_file:
with ThreadPoolExecutor() as executor:
results = [
executor.submit(parse_gcs_zip, zip_file, path, min_size)
for path in list_files_in_gcs_zip(zip_content)
]
swc_dicts = [result.result() for result in as_completed(results)]
return swc_dicts


def count_files_in_zips(bucket, zip_paths):
Expand Down Expand Up @@ -208,37 +256,67 @@ def list_gcs_filenames(bucket, cloud_path, extension):
return [blob.name for blob in blobs if extension in blob.name]


def process_gcs_zip(bucket, zip_path):
# Get filenames
zip_blob = bucket.blob(zip_path)
zip_content = zip_blob.download_as_bytes()
swc_paths = list_files_in_gcs_zip(zip_content)

# Read files
t0 = time()
swc_dicts = [None] * len(swc_paths)
with ZipFile(BytesIO(zip_content)) as zip_file:
with ThreadPoolExecutor() as executor:
results = [
executor.submit(swc_utils.parse_gcs_zip, zip_file, path)
for path in swc_paths
]
for i, result_i in enumerate(as_completed(results)):
swc_dicts[i] = result_i.result()
return swc_dicts


def report_runtimes(
n_files, n_files_completed, chunk_size, chunk_runtime, total_runtime
n_files, n_files_completed, chunk_size, start, start_chunk,
):
runtime = time() - start
chunk_runtime = time() - start_chunk
n_files_remaining = n_files - n_files_completed
file_rate = chunk_runtime / chunk_size
eta = (total_runtime + n_files_remaining * file_rate) / 60
rate = chunk_runtime / chunk_size
eta = (runtime + n_files_remaining * rate) / 60
files_processed = f"{n_files_completed - chunk_size}-{n_files_completed}"
print(f"Completed: {round(100 * n_files_completed / n_files, 2)}%")
print(
f"Runtime for Files : {files_processed} {round(chunk_runtime, 4)} seconds"
f"Runtime for Zips {files_processed}: {round(chunk_runtime, 4)} seconds"
)
print(f"File Processing Rate: {file_rate} seconds")
print(f"Zip Processing Rate: {file_rate} seconds")
print(f"Approximate Total Runtime: {round(eta, 4)} minutes")
print("")


def build_neurograph(
swc_dicts,
img_path=None,
optimize_alignment=OPTIMIZE_ALIGNMENT,
optimize_depth=OPTIMIZE_DEPTH,
prune=PRUNE,
prune_depth=PRUNE_DEPTH,
smooth=SMOOTH
):
graph_list = build_graphs(swc_dicts, prune, prune_depth, smooth)
start_ids = get_start_ids(swc_dicts)
print("Total Runtime:", 1600 * t)
stop
neurograph = NeuroGraph(
img_path=img_path,
optimize_alignment=optimize_alignment,
optimize_depth=optimize_depth,
)

def build_graphs(swc_dicts, prune, prune_depth, smooth):
t0 = time()
graphs = [None] * len(swc_dicts)
for i, swc_dict in enumerate(swc_dicts):
graphs[i] = build_subgraph(swc_dict)
t = time() - t0
print(f"build_subgraphs(): {t} seconds")
return graphs


def build_subgraph(swc_dict):
graph = nx.Graph()
graph.add_edges_from(zip(swc_dict["id"][1:], swc_dict["pid"][1:]))
return graph


def get_start_ids(swc_dicts):
# runtime: ~ 1 minute
t0 = time()
node_ids = []
cnt = 0
for swc_dict in swc_dicts:
graph = swc_utils.to_graph(swc_dict)
leafs, junctions = gutils.get_irreducibles(graph)
node_ids.append(cnt)
cnt += len(leafs) + len(junctions)
return node_ids
Loading

0 comments on commit 83c3e4e

Please sign in to comment.