Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

major upd : added gcs support #25

Merged
merged 1 commit into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/deep_neurographs/geometry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_midpoint(xyz_1, xyz_2):

# Smoothing
def smooth_branch(xyz, s=None):
if xyz.shape[0] > 5:
if xyz.shape[0] > 8:
t = np.linspace(0, 1, xyz.shape[0])
spline_x, spline_y, spline_z = fit_spline(xyz, s=s)
xyz = np.column_stack((spline_x(t), spline_y(t), spline_z(t)))
Expand Down Expand Up @@ -95,7 +95,6 @@ def fill_path(img, path, val=-1):
for xyz in path:
x, y, z = tuple(np.floor(xyz).astype(int))
img[x - 1 : x + 2, y - 1 : y + 2, z - 1 : z + 2] = val
# img[x, y, z] = val
return img


Expand Down
22 changes: 11 additions & 11 deletions src/deep_neurographs/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,6 @@
from deep_neurographs import swc_utils, utils


def get_irreducibles(graph):
leafs = []
junctions = []
for i in graph.nodes:
if graph.degree[i] == 1:
leafs.append(i)
elif graph.degree[i] > 2:
junctions.append(i)
return leafs, junctions


def extract_irreducible_graph(swc_dict, prune=True, prune_depth=16):
graph = swc_utils.file_to_graph(swc_dict)
leafs, junctions = get_irreducibles(graph)
Expand All @@ -40,6 +29,17 @@ def extract_irreducible_graph(swc_dict, prune=True, prune_depth=16):
return leafs, junctions, irreducible_edges


def get_irreducibles(graph):
leafs = []
junctions = []
for i in graph.nodes:
if graph.degree[i] == 1:
leafs.append(i)
elif graph.degree[i] > 2:
junctions.append(i)
return leafs, junctions


def extract_irreducible_edges(
graph, leafs, junctions, swc_dict, prune=True, prune_depth=16
):
Expand Down
262 changes: 203 additions & 59 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,98 +9,242 @@
"""

import os
import concurrent.futures

from concurrent.futures import ThreadPoolExecutor, as_completed
from google.cloud import storage
from io import BytesIO
from deep_neurographs import swc_utils, utils
from deep_neurographs.neurograph import NeuroGraph
from time import time
from zipfile import ZipFile

N_PROPOSALS_PER_LEAF = 3
OPTIMIZE_ALIGNMENT = False
OPTIMIZE_DEPTH = 15
PRUNE = True
PRUNE_DEPTH = 16
SEARCH_RADIUS = 0
SIZE_THRESHOLD = 100
SMOOTH = False


# --- Build graph ---
def build_neurograph(
swc_dir,
anisotropy=[1.0, 1.0, 1.0],
def build_neurograph_from_local(
swc_dir=None,
swc_paths=None,
img_patch_shape=None,
img_patch_origin=None,
img_path=None,
size_threshold=40,
num_proposals=3,
search_radius=25.0,
prune=True,
prune_depth=16,
optimize_depth=15,
optimize_alignment=True,
optimize_path=False,
origin=None,
shape=None,
smooth=True,
n_proposals_per_leaf=N_PROPOSALS_PER_LEAF,
prune=PRUNE,
prune_depth=PRUNE_DEPTH,
optimize_alignment=OPTIMIZE_ALIGNMENT,
optimize_depth=OPTIMIZE_DEPTH,
search_radius=SEARCH_RADIUS,
size_threshold=SIZE_THRESHOLD,
smooth=SMOOTH,
):
"""
Builds a neurograph from a directory of swc files, where each swc
represents a neuron and these neurons are assumed to be near each
other.

"""
assert utils.xor(swc_dir, swc_list), "Error: provide swc_dir or swc_paths"
neurograph = NeuroGraph(
swc_dir,
img_path=img_path,
optimize_depth=optimize_depth,
optimize_alignment=optimize_alignment,
optimize_path=optimize_path,
origin=origin,
shape=shape,
origin=img_patch_origin,
shape=img_patch_shape,
)
neurograph = init_immutables(
neurograph = init_immutables_from_local(
neurograph,
anisotropy=anisotropy,
swc_dir=swc_dir,
swc_paths=swc_paths,
prune=prune,
prune_depth=prune_depth,
size_threshold=size_threshold,
smooth=smooth,
)
if search_radius > 0:
neurograph.generate_proposals(
num_proposals=num_proposals, search_radius=search_radius
n_proposals_per_leaf=n_proposals_per_leaf,
search_radius=search_radius
)
return neurograph


def init_immutables(
def build_neurograph_from_gcs_zips(
bucket_name,
cloud_path,
img_path=None,
size_threshold=SIZE_THRESHOLD,
n_proposals_per_leaf=N_PROPOSALS_PER_LEAF,
search_radius=SEARCH_RADIUS,
prune=PRUNE,
prune_depth=PRUNE_DEPTH,
optimize_alignment=OPTIMIZE_ALIGNMENT,
optimize_depth=OPTIMIZE_DEPTH,
smooth=SMOOTH,
):
neurograph = NeuroGraph(
img_path=img_path,
optimize_alignment=optimize_alignment,
optimize_depth=optimize_depth,
)
neurograph = init_immutables_from_gcs_zips(
neurograph,
bucket_name,
cloud_path,
prune=prune,
prune_depth=prune_depth,
size_threshold=size_threshold,
smooth=smooth,
)
if search_radius > 0:
neurograph.generate_proposals(
n_proposals_per_leaf=n_proposals_per_leaf,
search_radius=search_radius
)
return neurograph


def init_immutables_from_local(
neurograph,
anisotropy=[1.0, 1.0, 1.0],
prune=True,
prune_depth=16,
size_threshold=40,
smooth=True,
swc_dir=None,
swc_paths=None,
prune=PRUNE,
prune_depth=PRUNE_DEPTH,
size_threshold=SIZE_THRESHOLD,
smooth=SMOOTH,
):
"""
To do...
"""
swc_paths = get_paths(swc_dir) if swc_dir else swc_paths
for path in swc_paths:
neurograph.ingest_swc_from_local(
path,
prune=True,
prune_depth=16,
smooth=smooth,
)
return neurograph


for path in get_paths(neurograph.path):
swc_id = get_id(path)
swc_dict = swc_utils.parse(
def get_paths(swc_dir):
swc_paths = []
for f in utils.listdir(swc_dir, ext=".swc"):
paths.append(os.path.join(swc_dir, f))
return paths


def init_immutables_from_gcs_zips(
neurograph,
bucket_name,
cloud_path,
prune=PRUNE,
prune_depth=PRUNE_DEPTH,
size_threshold=SIZE_THRESHOLD,
smooth=SMOOTH,
):
# Initializations
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
zip_paths = list_gcs_filenames(bucket, cloud_path, ".zip")
n_swc_files = 2080791 #count_files_in_zips(bucket, zip_paths)
chunk_size = int(n_swc_files * 0.05)
print("# zip files:", len(zip_paths))
print(f"# swc files: {utils.reformat_number(n_swc_files)} \n\n")

# Parse
cnt = 1
t0 = time()
t1 = time()
n_files_completed = 0
print(f"-- Starting Multithread Reads with chunk_size={chunk_size} -- \n")
for path in zip_paths:
# Add to neurograph
swc_dicts = process_gcs_zip(
bucket,
path,
anisotropy=anisotropy,
bbox=neurograph.bbox,
img_shape=neurograph.shape,
)
if len(swc_dict["xyz"]) < size_threshold:
continue
if smooth:
swc_dict = swc_utils.smooth(swc_dict)
neurograph.generate_immutables(
swc_id, swc_dict, prune=prune, prune_depth=prune_depth
)
with concurrent.futures.ProcessPoolExecutor() as executor:
swc_dicts = list(executor.map(swc_utils.smooth, swc_dicts))

# Readout progress
n_files_completed += len(swc_dicts)
if n_files_completed > cnt * chunk_size:
report_runtimes(
n_swc_files,
n_files_completed,
chunk_size,
time() - t1,
time() - t0,
)
cnt += 1
t1 = time()
t, unit = utils.time_writer(time() - t0)
print(f"Total Runtime: {round(t, 4)} {unit}")
return neurograph


def get_paths(path_or_list):
if type(path_or_list) == str:
paths = []
for f in utils.listdir(path_or_list, ext=".swc"):
paths.append(os.path.join(path_or_list, f))
return paths
elif type(path_or_list) == list:
return path_or_list
def count_files_in_zips(bucket, zip_paths):
t0 = time()
file_cnt = 0
for zip_path in zip_paths:
zip_blob = bucket.blob(zip_path)
zip_content = zip_blob.download_as_bytes()
file_paths = list_files_in_gcs_zip(zip_content)
file_cnt += len(file_paths)
return file_cnt


def list_files_in_gcs_zip(zip_content):
"""
Lists all files in a zip file stored in a GCS bucket.

def get_id(path):
filename = path.split("/")[-1]
return filename.replace(".0.swc", "")
"""
with ZipFile(BytesIO(zip_content), 'r') as zip_file:
return zip_file.namelist()


def list_gcs_filenames(bucket, cloud_path, extension):
"""
Lists all files in a GCS bucket with the given extension.

"""
blobs = bucket.list_blobs(prefix=cloud_path)
return [blob.name for blob in blobs if extension in blob.name]


def process_gcs_zip(bucket, zip_path):
# Get filenames
zip_blob = bucket.blob(zip_path)
zip_content = zip_blob.download_as_bytes()
swc_paths = list_files_in_gcs_zip(zip_content)

# Read files
t0 = time()
swc_dicts = [None] * len(swc_paths)
with ZipFile(BytesIO(zip_content)) as zip_file:
with ThreadPoolExecutor() as executor:
results = [
executor.submit(swc_utils.parse_gcs_zip, zip_file, path)
for path in swc_paths
]
for i, result_i in enumerate(as_completed(results)):
swc_dicts[i] = result_i.result()
return swc_dicts


def report_runtimes(
n_files,
n_files_completed,
chunk_size,
chunk_runtime,
total_runtime,
):
n_files_remaining = n_files - n_files_completed
file_rate = chunk_runtime / chunk_size
eta = (total_runtime + n_files_remaining * file_rate) / 60
files_processed = f"{n_files_completed - chunk_size}-{n_files_completed}"
print(f"Completed: {round(100 * n_files_completed / n_files, 2)}%")
print(f"Runtime for Files : {files_processed} {round(chunk_runtime, 4)} seconds")
print(f"File Processing Rate: {file_rate} seconds")
print(f"Approximate Total Runtime: {round(eta, 4)} minutes")
print("")
Loading
Loading