Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: removed trim option, exact path length threshold, mark merges #259

Merged
merged 1 commit into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/deep_neurographs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ class GraphConfig:
smooth_bool : bool, optional
Indication of whether to smooth branches in graph. The default is
True.
trim_depth : float, optional
Maximum path length (in microns) to trim from all branches. The
default is 5.
trim_endpoints_bool : bool, optional
Indication of whether to endpoints of branches with exactly one
proposal. The default is True.
Expand All @@ -67,7 +64,6 @@ class GraphConfig:
remove_doubles_bool: bool = False
search_radius: float = 20.0
smooth_bool: bool = True
trim_depth: float = 5.0
trim_endpoints_bool: bool = True


Expand Down
102 changes: 42 additions & 60 deletions src/deep_neurographs/groundtruth_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
MIN_INTERSECTION = 10


def init_targets(pred_graph, target_graph, strict=True):
def init_targets(pred_graph, target_graph):
"""
Initializes ground truth for edge proposals.

Expand All @@ -33,62 +33,53 @@ def init_targets(pred_graph, target_graph, strict=True):
Graph built from ground truth swc files.
pred_graph : NeuroGraph
Graph build from predicted swc files.
strict : bool, optional
Indication if whether target edges should be determined by using
stricter criteria that checks if proposals are reasonably well
aligned. The default is True.

Returns
-------
target_edges : set
Edge proposals that machine learning model learns to accept.
set
Proposals that machine learning model learns to accept.

"""
# Initializations
valid_proposals = get_valid_proposals(target_graph, pred_graph)
lengths = [pred_graph.proposal_length(e) for e in valid_proposals]
proposals = get_valid_proposals(target_graph, pred_graph)
lengths = [pred_graph.proposal_length(p) for p in proposals]

# Add best simple edges
target_edges = set()
gt_accepts = set()
graph = pred_graph.copy_graph()
for i in np.argsort(lengths):
edge = valid_proposals[i]
created_cycle, _ = gutil.creates_cycle(graph, tuple(edge))
if not created_cycle:
graph.add_edges_from([edge])
target_edges.add(edge)
return target_edges
for idx in np.argsort(lengths):
i, j = tuple(proposals[idx])
if not nx.has_path(graph, i, j):
graph.add_edge(i, j)
gt_accepts.add(proposals[idx])
return gt_accepts


def get_valid_proposals(target_graph, pred_graph):
# Initializations
valid_proposals = list()
kdtree = target_graph.get_kdtree()
invalid_ids, node_to_target = unaligned_components(
aligned_fragment_ids, node_to_target = find_aligned_fragments(
target_graph, pred_graph, kdtree
)

# Check whether aligned to same/adjacent target edges (i.e. valid)
for edge in pred_graph.proposals:
# Filter invalid and proposals btw different components
i, j = tuple(edge)
invalid_i = pred_graph.nodes[i]["swc_id"] in invalid_ids
invalid_j = pred_graph.nodes[j]["swc_id"] in invalid_ids
if invalid_i or invalid_j:
continue
elif node_to_target[i] != node_to_target[j]:
continue

# Check whether proposal is valid
target_id = node_to_target[i]
if is_valid(target_graph, pred_graph, kdtree, target_id, edge):
valid_proposals.append(edge)
# Check whether aligned to same/adjacent target edges
valid_proposals = list()
for p in pred_graph.proposals:
i, j = tuple(p)
is_aligned_i = pred_graph.nodes[i]["swc_id"] in aligned_fragment_ids
is_aligned_j = pred_graph.nodes[j]["swc_id"] in aligned_fragment_ids
if is_aligned_i and is_aligned_j:
if node_to_target[i] == node_to_target[j]:
# Check whether proposal is valid
target_id = node_to_target[i]
if is_valid(target_graph, pred_graph, kdtree, target_id, p):
valid_proposals.append(p)
return valid_proposals


def unaligned_components(target_graph, pred_graph, kdtree):
def find_aligned_fragments(target_graph, pred_graph, kdtree):
"""
Detects connected components in "pred_graph" that are unaligned to a
Detects connected components in "pred_graph" that are aligned to some
connected component in "target_graph".

Parameters
Expand All @@ -100,31 +91,30 @@ def unaligned_components(target_graph, pred_graph, kdtree):

Returns
-------
invalid_ids : set
valid_ids : set
IDs in ""pred_graph" that correspond to connected components that
are unaligned to a connected component in "target_graph".
are aligned to some connected component in "target_graph".
node_to_target : dict
Mapping between nodes and target ids.

"""
invalid_ids = set()
valid_ids = set()
node_to_target = dict()
for component in nx.connected_components(pred_graph):
for nodes in nx.connected_components(pred_graph):
aligned, target_id = is_component_aligned(
target_graph, pred_graph, component, kdtree
target_graph, pred_graph, nodes, kdtree
)
if not aligned:
i = util.sample_once(component)
invalid_ids.add(pred_graph.nodes[i]["swc_id"])
else:
node_to_target = upd_dict(node_to_target, component, target_id)
return invalid_ids, node_to_target
if aligned:
i = util.sample_once(nodes)
valid_ids.add(pred_graph.nodes[i]["swc_id"])
node_to_target = upd_dict(node_to_target, nodes, target_id)
return valid_ids, node_to_target


def is_component_aligned(target_graph, pred_graph, component, kdtree):
def is_component_aligned(target_graph, pred_graph, nodes, kdtree):
"""
Determines whether the connected component defined by "node_subset" is
close to a component in "target_graph". This routine iterates over
Determines whether the connected component "nodes" is
close to some component in "target_graph". This routine iterates over
"node_subset" and projects each node onto "target_graph", then
computes the projection distance. If (on average) each node in
"node_subset" is less 3.5 microns from a component in the ground truth,
Expand All @@ -142,13 +132,13 @@ def is_component_aligned(target_graph, pred_graph, component, kdtree):
Returns
-------
bool
Indication of whether "component" is aligned to a connected
Indication of whether connected component "nodes" is aligned to a connected
component in "target_graph".

"""
# Compute distances
dists = defaultdict(list)
for edge in pred_graph.subgraph(component).edges:
for edge in pred_graph.subgraph(nodes).edges:
for xyz in pred_graph.edges[edge]["xyz"]:
hat_xyz = geometry.kdtree_query(kdtree, xyz)
hat_swc_id = target_graph.xyz_to_swc(hat_xyz)
Expand Down Expand Up @@ -276,14 +266,6 @@ def is_adjacent_aligned(hat_branch_i, hat_branch_j, xyz_i, xyz_j):


# -- util --
def upd_dict_cnts(my_dict, key):
if key in my_dict.keys():
my_dict[key] += 1
else:
my_dict[key] = 1
return my_dict


def orient_branch(branch_i, branch_j):
"""
Flips branches so that "all(branch_i[0] == branch_j[0])" is True.
Expand Down
13 changes: 5 additions & 8 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,23 @@

"""

import os
from datetime import datetime
from time import time
from torch.nn.functional import sigmoid
from tqdm import tqdm

import networkx as nx
import numpy as np
import os
import torch
from torch.nn.functional import sigmoid
from tqdm import tqdm

from deep_neurographs.graph_artifact_removal import remove_doubles
from deep_neurographs.machine_learning import feature_generation
from deep_neurographs.utils import gnn_util
from deep_neurographs.utils import graph_util as gutil
from deep_neurographs.utils import img_util, ml_util
from deep_neurographs.utils import util
from deep_neurographs.utils.graph_util import GraphLoader
from deep_neurographs.utils import img_util, ml_util, util
from deep_neurographs.utils.gnn_util import toCPU
from deep_neurographs.utils.graph_util import GraphLoader

BATCH_SIZE = 2000
CONFIDENCE_THRESHOLD = 0.7
Expand Down Expand Up @@ -190,7 +189,6 @@ def build_graph(self, fragments_pointer):
anisotropy=self.graph_config.anisotropy,
min_size=self.graph_config.min_size,
node_spacing=self.graph_config.node_spacing,
trim_depth=self.graph_config.trim_depth,
)
self.graph = graph_builder.run(fragments_pointer)

Expand Down Expand Up @@ -345,7 +343,6 @@ def write_metadata(self):
"confidence_threshold": self.ml_config.threshold,
"node_spacing": self.graph_config.node_spacing,
"remove_doubles": self.graph_config.remove_doubles_bool,
"trim_depth": self.graph_config.trim_depth,
}
path = os.path.join(self.output_dir, "metadata.json")
util.write_json(path, metadata)
Expand Down
6 changes: 3 additions & 3 deletions src/deep_neurographs/machine_learning/heterograph_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def __init__(
"""
super().__init__()
# Feature vector sizes
node_dict = ml.heterograph_feature_generation.n_node_features()
edge_dict = ml.heterograph_feature_generation.n_edge_features()
node_dict = ml.feature_generation_graphs.n_node_features()
edge_dict = ml.feature_generation_graphs.n_edge_features()
hidden = scale_hidden * np.max(list(node_dict.values()))

# Linear layers
Expand All @@ -56,7 +56,7 @@ def __init__(
self.input_nodes[key] = nn.Linear(d, hidden, device=device)
for key, d in edge_dict.items():
self.input_edges[key] = nn.Linear(d, hidden, device=device)
self.output = Linear(output_dim, 1, device=device)
self.output = Linear(output_dim, 1).to(device)

# Convolutional layers
self.conv1 = HeteroConv(
Expand Down
42 changes: 3 additions & 39 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@

from deep_neurographs import generate_proposals, geometry
from deep_neurographs.geometry import dist as get_dist
from deep_neurographs.groundtruth_generation import (
init_targets,
)
from deep_neurographs.groundtruth_generation import init_targets
from deep_neurographs.utils import graph_util as gutil
from deep_neurographs.utils import img_util, swc_util, util

Expand Down Expand Up @@ -658,6 +656,8 @@ def merge_proposal(self, proposal):
attrs = dict()
for k in ["xyz", "radius"]:
combine = np.vstack if k == "xyz" else np.array
self.nodes[i][k][-1] = 8.0
self.nodes[j][k][0] = 8.0
attrs[k] = combine([self.nodes[i][k], self.nodes[j][k]])

# Sparse attributes
Expand Down Expand Up @@ -843,23 +843,6 @@ def oriented_edge(self, edge, i, key="xyz"):
else:
return np.flip(self.edges[edge][key], axis=0)

def edge_length(self, edge):
"""
Computes length of path stored as xyz coordinates in "edge".

Parameters
----------
edge : tuple
Edge in self.

Returns
-------
float
Path length of edge.

"""
return geometry.path_length(self.edges[edge]["xyz"])

def is_contained(self, node_or_xyz, buffer=0):
if self.bbox:
coord = self.to_voxels(node_or_xyz)
Expand Down Expand Up @@ -936,25 +919,6 @@ def xyz_to_swc(self, xyz, return_node=False):
else:
return None

"""
def component_cardinality(self, root):
cardinality = 0
queue = [(-1, root)]
visited = set()
while len(queue):
# Visit
i, j = queue.pop()
visited.add(frozenset((i, j)))
if i != -1:
cardinality = len(self.edges[i, j]["xyz"])

# Add neighbors
for k in self.neighbors(j):
if frozenset((j, k)) not in visited:
queue.append((j, k))
return cardinality
"""

# --- write graph to swcs ---
def to_zipped_swcs(self, zip_path, color=None):
with zipfile.ZipFile(zip_path, "w") as zip_writer:
Expand Down
Loading
Loading