Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: batch formation #283

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/deep_neurographs/fragments_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,39 @@ def list_proposals(self):
"""
return list(self.proposals)

def proposal_connected_component(self, proposal):
"""
Extracts the connected component that "proposal" belongs to in the
proposal induced subgraph.

Parameters
----------
proposal : frozenset
Proposal used to as the root to extract its connected component
in the proposal induced subgraph.

Returns
-------
List[frozenset]
List of proposals in the connected component that "proposal"
belongs to in the proposal induced subgraph.

"""
queue = [proposal]
visited = set()
while len(queue) > 0:
# Visit proposal
p = queue.pop()
visited.add(p)

# Update queue
for i in p:
for j in self.nodes[i]["proposals"]:
p_ij = frozenset({i, j})
if p_ij not in visited:
queue.append(p_ij)
return visited

# -- KDTree --
def init_kdtree(self, node_type):
"""
Expand Down
82 changes: 52 additions & 30 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
self.model_path,
self.ml_config.model_type,
self.graph_config.search_radius,
batch_size=self.ml_config.batch_size,
confidence_threshold=self.ml_config.threshold,
device=device,
downsample_factor=self.ml_config.downsample_factor,
Expand Down Expand Up @@ -535,62 +536,66 @@ def __init__(
if self.is_gnn and "cuda" in device:
self.model = self.model.to(self.device)

def run(self, neurograph, proposals):
def run(self, fragments_graph, proposals):
"""
Runs inference by forming batches of proposals, then performing the
following steps for each batch: (1) generate features, (2) classify
proposals by running model, and (3) adding each accepted proposal as
an edge to "neurograph" if it does not create a cycle.
an edge to "fragments_graph" if it does not create a cycle.

Parameters
----------
neurograph : NeuroGraph
fragments_graph : FragmentsGraph
Graph that inference will be performed on.
proposals : list
Proposals to be classified as accept or reject.

Returns
-------
NeuroGraph
FragmentsGraph
Updated graph with accepted proposals added as edges.
list
Accepted proposals.

"""
# Initializations
assert not gutil.cycle_exists(neurograph), "Graph contains cycle!"
assert not gutil.cycle_exists(fragments_graph), "Graph has cycle!"
if self.is_gnn:
proposals = set(proposals)
else:
proposals = sort_proposals(neurograph, proposals)
proposals = sort_proposals(fragments_graph, proposals)

# Main
flagged = get_large_proposal_components(fragments_graph, 4)
with tqdm(total=len(proposals), desc="Inference") as pbar:
accepts = list()
while len(proposals) > 0:
# Predict
batch = self.get_batch(neurograph, proposals)
dataset = self.get_batch_dataset(neurograph, batch)
batch = self.get_batch(fragments_graph, proposals, flagged)
dataset = self.get_batch_dataset(fragments_graph, batch)
preds = self.predict(dataset)

# Update graph
for p in get_accepts(neurograph, preds, self.threshold):
neurograph.merge_proposal(p)
for p in get_accepts(fragments_graph, preds, self.threshold):
fragments_graph.merge_proposal(p)
accepts.append(p)
pbar.update(len(batch["proposals"]))
neurograph.absorb_reducibles()
return neurograph, accepts
fragments_graph.absorb_reducibles()
return fragments_graph, accepts

def get_batch(self, neurograph, proposals):
def get_batch(self, fragments_graph, proposals, flagged_proposals):
"""
Generates a batch of proposals.

Parameters
----------
neurograph : NeuroGraph
fragments_graph : FragmentsGraph
Graph that proposals were generated from.
proposals : list
proposals : List[frozenset]
Proposals for which batch is to be generated from.
flagged_proposals : List[frozenset]
List of proposals that are part of a "large" connected component
in the proposal induced subgraph of "fragments_graph".

Returns
-------
Expand All @@ -600,20 +605,22 @@ def get_batch(self, neurograph, proposals):

"""
if self.is_gnn:
return gnn_util.get_batch(neurograph, proposals, self.batch_size)
return gnn_util.get_batch(
fragments_graph, proposals, self.batch_size, flagged_proposals
)
else:
batch = {"proposals": proposals[0:self.batch_size], "graph": None}
del proposals[0:self.batch_size]
return batch

def get_batch_dataset(self, neurograph, batch):
def get_batch_dataset(self, fragments_graph, batch):
"""
Generates features and initializes dataset that can be input to a
machine learning model.

Parameters
----------
neurograph : NeuroGraph
fragments_graph : FragmentsGraph
Graph that inference will be performed on.
batch : list
Proposals to be classified.
Expand All @@ -623,10 +630,12 @@ def get_batch_dataset(self, neurograph, batch):
...

"""
features = self.feature_generator.run(neurograph, batch, self.radius)
features = self.feature_generator.run(
fragments_graph, batch, self.radius
)
computation_graph = batch["graph"] if type(batch) is dict else None
dataset = ml_util.init_dataset(
neurograph,
fragments_graph,
features,
self.is_gnn,
computation_graph=computation_graph,
Expand Down Expand Up @@ -694,14 +703,14 @@ def predict_with_gnn(model, data, device=None):
return toCPU(preds[0:len(data["proposal"]["y"]), 0])


def get_accepts(neurograph, preds, threshold, high_threshold=0.9):
def get_accepts(fragments_graph, preds, threshold, high_threshold=0.9):
"""
Determines which proposals to accept based on prediction scores and the
specified threshold.

Parameters
----------
neurograph : NeuroGraph
fragments_graph : FragmentsGraph
Graph that proposals belong to.
preds : dict
Dictionary that maps proposal ids to probability generated from
Expand All @@ -713,20 +722,20 @@ def get_accepts(neurograph, preds, threshold, high_threshold=0.9):
Returns
-------
list
Proposals to be added as edges to "neurograph".
Proposals to be added as edges to "fragments_graph".

"""
# Partition proposals into best and the rest
preds = {k: v for k, v in preds.items() if v > threshold}
best_proposals, proposals = separate_best(
preds, neurograph.simple_proposals(), high_threshold
preds, fragments_graph.simple_proposals(), high_threshold
)

# Determine which proposals to accept
accepts = list()
accepts.extend(filter_proposals(neurograph, best_proposals))
accepts.extend(filter_proposals(neurograph, proposals))
neurograph.remove_edges_from(map(tuple, accepts))
accepts.extend(filter_proposals(fragments_graph, best_proposals))
accepts.extend(filter_proposals(fragments_graph, proposals))
fragments_graph.remove_edges_from(map(tuple, accepts))
return accepts


Expand Down Expand Up @@ -795,13 +804,13 @@ def filter_proposals(graph, proposals):
return accepts


def sort_proposals(neurograph, proposals):
def sort_proposals(fragments_graph, proposals):
"""
Sorts proposals by length.

Parameters
----------
neurograph : NeuroGraph
fragments_graph : FragmentsGraph
Graph that proposals were generated from.
proposals : list[frozenset]
List of proposals.
Expand All @@ -812,5 +821,18 @@ def sort_proposals(neurograph, proposals):
Sorted proposals.

"""
idxs = np.argsort([neurograph.proposal_length(p) for p in proposals])
idxs = np.argsort([fragments_graph.proposal_length(p) for p in proposals])
return [proposals[idx] for idx in idxs]


# --- Batch Formation ---
def get_large_proposal_components(fragments_graph, k):
flagged_proposals = set()
visited = set()
for p in fragments_graph.list_proposals():
if p not in visited:
component = fragments_graph.proposal_connected_component(p)
if len(component) > k:
flagged_proposals = flagged_proposals.union(component)
visited = visited.union(component)
return flagged_proposals
Loading
Loading