Skip to content

Commit

Permalink
Feat multimodal gnn (#265)
Browse files Browse the repository at this point in the history
* refactor: feature generation

* refactor: simplified feature generation

* refactor: chunk extraction is functional

* refactor: heterognn simplified

* refactor with issue

---------

Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Oct 10, 2024
1 parent 0d8c6ef commit 13a1a7f
Show file tree
Hide file tree
Showing 12 changed files with 223 additions and 282 deletions.
2 changes: 1 addition & 1 deletion src/deep_neurographs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class MLConfig:
batch_size: int = 2000
downsample_factor: int = 1
high_threshold: float = 0.9
lr: float = 1e-3
lr: float = 1e-4
threshold: float = 0.6
model_type: str = "GraphNeuralNet"
n_epochs: int = 1000
Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/groundtruth_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def is_valid(target_graph, pred_graph, kdtree, target_id, edge):
def proj_branch(target_graph, pred_graph, kdtree, target_id, i):
# Compute projections
hits = defaultdict(list)
for branch in pred_graph.get_branches(i):
for branch in pred_graph.branches(i):
for xyz in branch:
hat_xyz = geometry.kdtree_query(kdtree, xyz)
swc_id = target_graph.xyz_to_swc(hat_xyz)
Expand Down
4 changes: 3 additions & 1 deletion src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,9 @@ def get_batch_dataset(self, neurograph, batch):
...
"""
t0 = time()
features = self.feature_generator.run(neurograph, batch, self.radius)
print("Feature Generation:", time() - t0)
computation_graph = batch["graph"] if type(batch) is dict else None
dataset = ml_util.init_dataset(
neurograph,
Expand Down Expand Up @@ -590,7 +592,7 @@ def predict(self, dataset):
preds = np.array(self.model.predict_proba(dataset.data.x)[:, 1])

# Reformat prediction
idxs = dataset.idxs_proposals["idx_to_edge"]
idxs = dataset.idxs_proposals["idx_to_id"]
return {idxs[i]: p for i, p in enumerate(preds)}


Expand Down
13 changes: 7 additions & 6 deletions src/deep_neurographs/machine_learning/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self, proposals, x_proposals, y_proposals, idxs_proposals):
"""
# Conversion idxs
self.block_to_idxs = idxs_proposals["block_to_idxs"]
self.idxs_proposals = init_idxs(idxs_proposals)
self.idxs_proposals = init_idx_mapping(idxs_proposals)
self.proposals = proposals

# Features
Expand Down Expand Up @@ -293,7 +293,7 @@ def reformat(arr):
return np.expand_dims(arr, axis=1).astype(np.float32)


def init_idxs(idxs):
def init_idx_mapping(idx_to_id):
"""
Adds dictionary item called "edge_to_index" which maps a branch/proposal
in a neurograph to an idx that represents it's position in the feature
Expand All @@ -310,7 +310,8 @@ def init_idxs(idxs):
Updated dictionary.
"""
idxs["edge_to_idx"] = dict()
for idx, edge in idxs["idx_to_edge"].items():
idxs["edge_to_idx"][edge] = idx
return idxs
idx_mapping = {
"idx_to_id": idx_to_id,
"id_to_idx": {v: k for k, v in idx_to_id.items()}
}
return idx_mapping
Loading

0 comments on commit 13a1a7f

Please sign in to comment.