Skip to content

Commit

Permalink
fixed performance bug (#266)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Oct 11, 2024
1 parent 13a1a7f commit 8389bab
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 69 deletions.
2 changes: 1 addition & 1 deletion src/deep_neurographs/generate_proposals.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,4 +484,4 @@ def tangent(branch, idx, depth):
"""
end = min(idx + depth, len(branch))
return geometry.tangent(branch[idx:end])
return geometry.tangent(branch[idx:end])
2 changes: 1 addition & 1 deletion src/deep_neurographs/groundtruth_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,4 +300,4 @@ def orient_branch(branch_i, branch_j):
def upd_dict(node_to_target_id, nodes, target_id):
for node in nodes:
node_to_target_id[node] = target_id
return node_to_target_id
return node_to_target_id
Original file line number Diff line number Diff line change
Expand Up @@ -469,4 +469,4 @@ def n_edge_features(x):
"""
key = sample(list(x.keys()), 1)[0]
return x[key].shape[0]
return x[key].shape[0]
65 changes: 34 additions & 31 deletions src/deep_neurographs/machine_learning/heterograph_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,16 @@ def __init__(
self.dropout = dropout

# Feature vector sizes
hidden_dim = scale_hidden_dim * np.max(list(node_dict.values()))
hidden_dim = scale_hidden_dim* np.max(list(node_dict.values()))
output_dim = heads_1 * heads_2 * hidden_dim

# Linear layers
self.input_nodes = self.init_linear_layer(hidden_dim, node_dict)
self.input_edges = self.init_linear_layer(hidden_dim, edge_dict)
self.input_nodes = nn.ModuleDict()
self.input_edges = dict()
for key, d in node_dict.items():
self.input_nodes[key] = nn.Linear(d, hidden_dim, device=device)
for key, d in edge_dict.items():
self.input_edges[key] = nn.Linear(d, hidden_dim, device=device)
self.output = Linear(output_dim, 1).to(device)

# Message passing layers
Expand All @@ -84,7 +88,7 @@ def __init__(
def get_relation_types(cls):
return cls.relation_types

# --- Initialize architecture ---
# --- Architecture ---
def init_linear_layer(self, hidden_dim, my_dict):
linear_layer = dict()
for key, dim in my_dict.items():
Expand Down Expand Up @@ -132,14 +136,32 @@ def init_weights(self):
None
"""
# Output layer
for params in self.output.parameters():
if len(params.shape) > 1:
init.kaiming_normal_(params)
else:
init.zeros_(params)

# --- Generate prediction ---
for layer in [self.output, self.input_nodes]:
for param in layer.parameters():
if len(param.shape) > 1:
init.kaiming_normal_(param)
else:
init.zeros_(param)

def activation(self, x_dict):
"""
Applies nonlinear activation
Parameters
----------
x_dict : dict
Dictionary that maps node/edge types to feature matrices.
Returns
-------
dict
Feature matrices with activation applied.
"""
x_dict = {key: self.leaky_relu(x) for key, x in x_dict.items()}
x_dict = {key: self.dropout(x) for key, x in x_dict.items()}
return x_dict

def forward(self, x_dict, edge_index_dict, edge_attr_dict):
# Input - Nodes
x_dict = {key: f(x_dict[key]) for key, f in self.input_nodes.items()}
Expand All @@ -162,25 +184,6 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict):
x_dict = self.output(x_dict["proposal"])
return x_dict

def activation(self, x_dict):
"""
Applies nonlinear activation
Parameters
----------
x_dict : dict
Dictionary that maps node/edge types to feature matrices.
Returns
-------
dict
Feature matrices with activation applied.
"""
x_dict = {key: self.leaky_relu(x) for key, x in x_dict.items()}
x_dict = {key: self.dropout(x) for key, x in x_dict.items()}
return x_dict


class MultiModalHGAT(HeteroGNN):
pass
2 changes: 1 addition & 1 deletion src/deep_neurographs/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,4 +470,4 @@ def get_predictions(hat_y, threshold=0.5):
Binary predictions based on the given threshold.
"""
return (ml_util.sigmoid(np.array(hat_y)) > threshold).tolist()
return (ml_util.sigmoid(np.array(hat_y)) > threshold).tolist()
16 changes: 6 additions & 10 deletions src/deep_neurographs/utils/gnn_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,16 @@

# --- Tensor Operations ---
def get_inputs(data, device=None):
# Extract inputs
x = data.x_dict
edge_index = data.edge_index_dict
edge_attr = data.edge_attr_dict
if device and torch.cuda.is_available():
return toGPU(x), toGPU(edge_index), toGPU(edge_attr)
else:
return x, edge_index, edge_attr

# Move to gpu (if applicable)
if "cuda" in device and torch.cuda.is_available():
x = toGPU(x, device)
edge_index = toGPU(edge_index, device)
edge_attr = toGPU(edge_attr, device)
return x, edge_index, edge_attr


def toGPU(tensor_dict, device):
def toGPU(tensor_dict):
"""
Moves dictionary of tensors from CPU to GPU.
Expand Down Expand Up @@ -301,4 +297,4 @@ def init_line_graph(edges):
"""
graph = nx.Graph()
graph.add_edges_from(edges)
return nx.line_graph(graph)
return nx.line_graph(graph)
23 changes: 2 additions & 21 deletions src/deep_neurographs/utils/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def run(
"""
# Load fragments and extract irreducibles
self.init_img_bbox(img_patch_origin, img_patch_shape)
self.img_bbox = img_util.init_bbox(img_patch_origin, img_patch_shape)
swc_dicts = self.reader.load(fragments_pointer)
irreducibles = get_irreducibles(
swc_dicts,
Expand All @@ -139,25 +139,6 @@ def run(
neurograph.add_component(irreducible_set)
return neurograph

def init_img_bbox(self, img_patch_origin, img_patch_shape):
"""
Sets the bounding box of an image patch as a class attriubte.
Parameters
----------
img_patch_origin : tuple[int]
Origin of bounding box which is assumed to be top, front, left
corner.
img_patch_shape : tuple[int]
Shape of bounding box.
Returns
-------
None
"""
self.img_bbox = img_util.init_bbox(img_patch_origin, img_patch_shape)


# --- Graph structure extraction ---
def get_irreducibles(
Expand Down Expand Up @@ -877,4 +858,4 @@ def largest_components(neurograph, k):
node_ids.pop(-1)
break
i += 1
return node_ids
return node_ids
2 changes: 1 addition & 1 deletion src/deep_neurographs/utils/img_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,4 +412,4 @@ def find_img_path(bucket_name, img_root, dataset_name):
for subdir in util.list_gcs_subdirectories(bucket_name, img_root):
if dataset_name in subdir:
return subdir + "whole-brain/fused.zarr/"
raise f"Dataset not found in {bucket_name} - {img_root}"
raise f"Dataset not found in {bucket_name} - {img_root}"
2 changes: 1 addition & 1 deletion src/deep_neurographs/utils/ml_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,4 @@ def get_kfolds(filenames, k):
folds.append(samples_i)
if n_samples > len(samples):
break
return folds
return folds
2 changes: 1 addition & 1 deletion src/deep_neurographs/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,4 +663,4 @@ def spaced_idxs(container, k):
idxs = np.arange(0, len(container) + k, k)[:-1]
if len(container) % 2 == 0:
idxs = np.append(idxs, len(container) - 1)
return idxs
return idxs

0 comments on commit 8389bab

Please sign in to comment.