Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat multimodal gnn #265

Merged
merged 8 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/deep_neurographs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class MLConfig:
batch_size: int = 2000
downsample_factor: int = 1
high_threshold: float = 0.9
lr: float = 1e-3
lr: float = 1e-4
threshold: float = 0.6
model_type: str = "GraphNeuralNet"
n_epochs: int = 1000
Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/groundtruth_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def is_valid(target_graph, pred_graph, kdtree, target_id, edge):
def proj_branch(target_graph, pred_graph, kdtree, target_id, i):
# Compute projections
hits = defaultdict(list)
for branch in pred_graph.get_branches(i):
for branch in pred_graph.branches(i):
for xyz in branch:
hat_xyz = geometry.kdtree_query(kdtree, xyz)
swc_id = target_graph.xyz_to_swc(hat_xyz)
Expand Down
4 changes: 3 additions & 1 deletion src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,9 @@ def get_batch_dataset(self, neurograph, batch):
...

"""
t0 = time()
features = self.feature_generator.run(neurograph, batch, self.radius)
print("Feature Generation:", time() - t0)
computation_graph = batch["graph"] if type(batch) is dict else None
dataset = ml_util.init_dataset(
neurograph,
Expand Down Expand Up @@ -590,7 +592,7 @@ def predict(self, dataset):
preds = np.array(self.model.predict_proba(dataset.data.x)[:, 1])

# Reformat prediction
idxs = dataset.idxs_proposals["idx_to_edge"]
idxs = dataset.idxs_proposals["idx_to_id"]
return {idxs[i]: p for i, p in enumerate(preds)}


Expand Down
13 changes: 7 additions & 6 deletions src/deep_neurographs/machine_learning/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self, proposals, x_proposals, y_proposals, idxs_proposals):
"""
# Conversion idxs
self.block_to_idxs = idxs_proposals["block_to_idxs"]
self.idxs_proposals = init_idxs(idxs_proposals)
self.idxs_proposals = init_idx_mapping(idxs_proposals)
self.proposals = proposals

# Features
Expand Down Expand Up @@ -293,7 +293,7 @@ def reformat(arr):
return np.expand_dims(arr, axis=1).astype(np.float32)


def init_idxs(idxs):
def init_idx_mapping(idx_to_id):
"""
Adds dictionary item called "edge_to_index" which maps a branch/proposal
in a neurograph to an idx that represents it's position in the feature
Expand All @@ -310,7 +310,8 @@ def init_idxs(idxs):
Updated dictionary.

"""
idxs["edge_to_idx"] = dict()
for idx, edge in idxs["idx_to_edge"].items():
idxs["edge_to_idx"][edge] = idx
return idxs
idx_mapping = {
"idx_to_id": idx_to_id,
"id_to_idx": {v: k for k, v in idx_to_id.items()}
}
return idx_mapping
Loading
Loading