Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Nov 1, 2023
1 parent 5a8c606 commit 2b2a424
Show file tree
Hide file tree
Showing 6 changed files with 265 additions and 82 deletions.
82 changes: 63 additions & 19 deletions src/deep_neurographs/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from deep_neurographs import geometry_utils, utils

CHUNK_SIZE = [32, 32, 32]
NUM_POINTS = 5
WINDOW_SIZE = [5, 5, 5]

Expand All @@ -24,7 +25,7 @@

# -- Wrappers --
def generate_mutable_features(
neurograph, anisotropy=[1.0, 1.0, 1.0], img_path=None
neurograph, anisotropy=[1.0, 1.0, 1.0], img_path=None, img_profile=True
):
"""
Generates feature vectors for every mutable edge in a neurograph.
Expand All @@ -46,15 +47,32 @@ def generate_mutable_features(
"""
features = {"skel": generate_mutable_skel_features(neurograph)}
if img_path is not None:
features["img"] = generate_mutable_img_features(
if img_path and img_profile:
features["img"] = generate_mutable_img_profile_features(
neurograph, img_path, anisotropy=anisotropy
)
return combine_feature_vecs(features)
elif img_path and not img_profile:
features["img"] = generate_mutable_img_chunk_features(
neurograph, img_path, anisotropy=anisotropy
)
return features


# -- Edge feature extraction --
def generate_mutable_img_features(
def generate_mutable_img_chunk_features(
neurograph, path, anisotropy=[1.0, 1.0, 1.0]
):
img = utils.open_zarr(path)
features = dict()
for edge in neurograph.mutable_edges:
xyz_edge = neurograph.edges[edge]["xyz"]
xyz = geometry_utils.compute_midpoint(xyz_edge[0], xyz_edge[1])
xyz = geometry_utils.get_coord(xyz, anisotropy=anisotropy)
features[edge] = utils.read_img_chunk(img, xyz, CHUNK_SIZE)
return features


def generate_mutable_img_profile_features(
neurograph, path, anisotropy=[1.0, 1.0, 1.0]
):
img = utils.open_zarr(path)
Expand Down Expand Up @@ -131,19 +149,25 @@ def get_radii(neurograph, edge):


# -- Combine feature vectors
def build_feature_matrix(neurographs, features, blocks):
def build_feature_matrix(neurographs, features, blocks, img_chunks=False):
# Initialize
X = None
y = None
block_to_idxs = dict()
idx_to_edge = dict()

# Feature extraction
for block_id in blocks:
# Get features
idx_shift = 0 if X is None else X.shape[0]
X_i, y_i, idx_to_edge_i = build_feature_submatrix(
neurographs[block_id], features[block_id], idx_shift
)
if img_chunks:
X_i, y_i, idx_to_edge_i = build_img_chunk_submatrix(
neurographs[block_id], features[block_id], idx_shift
)
else:
X_i, y_i, idx_to_edge_i = build_feature_submatrix(
neurographs[block_id], features[block_id], idx_shift
)

# Concatenate
if X is None:
Expand All @@ -160,33 +184,53 @@ def build_feature_matrix(neurographs, features, blocks):
return X, y, block_to_idxs, idx_to_edge


def build_feature_submatrix(neurograph, feat_dict, shift):
def build_feature_submatrix(neurograph, features, shift):
# Extract info
key = sample(list(feat_dict.keys()), 1)[0]
features = combine_features(features)
key = sample(list(features.keys()), 1)[0]
num_edges = neurograph.num_mutables()
num_features = len(feat_dict[key])
num_features = len(features[key])

# Build
idx_to_edge = dict()
X = np.zeros((num_edges, num_features))
y = np.zeros((num_edges))
for i, edge in enumerate(feat_dict.keys()):
for i, edge in enumerate(features.keys()):
idx_to_edge[i + shift] = edge
X[i, :] = features[edge]
y[i] = 1 if edge in neurograph.target_edges else 0
return X, y, idx_to_edge


def build_img_chunk_submatrix(neurograph, features, shift):
# Extract info
key = sample(list(features.keys()), 1)[0]
num_edges = neurograph.num_mutables()
num_features = len(features[key])

# Build
idx_to_edge = dict()
X = np.zeros(((num_edges,) + tuple(CHUNK_SIZE)))
y = np.zeros((num_edges))
for i, edge in enumerate(features["img"].keys()):
idx_to_edge[i + shift] = edge
X[i, :] = feat_dict[edge]
X[i, :] = features["img"][edge]
y[i] = 1 if edge in neurograph.target_edges else 0
return X, y, idx_to_edge


# -- Utils --
def compute_num_features():
return NUM_SKEL_FEATURES # NUM_IMG_FEATURES +
def compute_num_features(skel_features=True, img_features=True):
num_features = NUM_SKEL_FEATURES if skel_features else 0
num_features += NUM_IMG_FEATURES if img_features else 0
return num_features


def combine_feature_vecs(features):
def combine_features(features):
for edge in features["skel"].keys():
for feat_key in [key for key in features.keys() if key != "skel"]:
for key in [key for key in features.keys() if key != "skel"]:
features["skel"][edge] = np.concatenate(
(features["skel"][edge], features[feat_key][edge])
(features["skel"][edge], features[key][edge])
)
return features["skel"]

Expand Down
10 changes: 9 additions & 1 deletion src/deep_neurographs/geometry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def compute_normal(xyz):
return normal / np.linalg.norm(normal)


def compute_midpoint(xyz1, xyz2):
return np.mean([xyz1, xyz2], axis=0)


# Smoothing
def smooth_branch(xyz):
if xyz.shape[0] > 5:
Expand All @@ -104,7 +108,7 @@ def smooth_branch(xyz):


def fit_spline(xyz):
s = xyz.shape[0] / 10
s = xyz.shape[0] / 5
t = np.arange(xyz.shape[0])
cs_x = UnivariateSpline(t, xyz[:, 0], s=s, k=3)
cs_y = UnivariateSpline(t, xyz[:, 1], s=s, k=3)
Expand Down Expand Up @@ -142,6 +146,10 @@ def get_coords(xyz_arr, anisotropy=[1.0, 1.0, 1.0]):
return xyz_arr.astype(int)


def get_coord(xyz, anisotropy=[1.0, 1.0, 1.0]):
return [int(xyz[i] / anisotropy[i]) for i in range(3)]


# Miscellaneous
def compare_edges(xyx_i, xyz_j, xyz_k):
dist_ij = dist(xyx_i, xyz_j)
Expand Down
81 changes: 54 additions & 27 deletions src/deep_neurographs/neural_networks.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,70 @@
import torch
from torch import nn


class FeedFowardNet(nn.Module):
def __init__(self, num_features, depth=3):
def __init__(self, num_features):
nn.Module.__init__(self)
self.fc1 = self._init_fc_layer(num_features, num_features)
self.fc2 = self._init_fc_layer(num_features, num_features // 2)
self.output = nn.Sequential(nn.Linear(num_features // 2, 1))

# Parameters
assert depth < num_features
self.depth = depth
self.num_features = num_features

# Layers
print("Network Architecture...")
self.activation = nn.ELU()
self.dropout = nn.Dropout(p=0.2)
for d in range(self.depth):
D_in = num_features // max(d, 1)
D_out = num_features // (d + 1)
self.add_fc_layer(d, D_in, D_out)
self.last_fc = nn.Linear(D_out, 1)
self.sigmoid = nn.Sigmoid()
def _init_fc_layer(self, D_in, D_out):
fc_layer = nn.Sequential(
nn.Linear(D_in, D_out), nn.LeakyReLU(), nn.Dropout(p=0.25)
)
return fc_layer

def forward(self, x):
for d in range(self.depth):
fc_d = getattr(self, "fc{}".format(d))
x = self.activation(self.dropout(fc_d(x)))
x = self.last_fc(x)
return self.sigmoid(x)

def add_fc_layer(self, d, D_in, D_out):
setattr(self, "fc{}".format(d), nn.Linear(D_in, D_out))
print(" {} --> {}".format(D_in, D_out))
x = self.fc1(x)
x = self.fc2(x)
x = self.output(x)
return x


class ConvNet(nn.Module):
def __init__(self, input_dims, depth=3):
pass
def __init__(self):
nn.Module.__init__(self)
self.conv1 = self._init_conv_layer(1, 4)
self.conv2 = self._init_conv_layer(4, 8)
self.output = nn.Sequential(
nn.Linear(8*6*6*6, 64),
nn.LeakyReLU(),
nn.Linear(64, 1)
)

def _init_conv_layer(self, in_channels, out_channels):
conv_layer = nn.Sequential(
nn.Conv3d(
in_channels,
out_channels,
kernel_size=(3, 3, 3),
stride=1,
padding=0,
),
nn.BatchNorm3d(out_channels),
nn.LeakyReLU(),
nn.Dropout(p=0.25),
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=2),
)
return conv_layer

def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = self.output(x)
return x


class MultiModalNet(nn.Module):
def __init__(self, feature_vec_shape, img_patch_shape):
pass


def weights_init(net):
for module in net.modules():
if isinstance(module, nn.Conv3d):
torch.nn.init.xavier_normal_(module.weight)
elif isinstance(module, nn.Linear):
torch.nn.init.xavier_normal_(module.weight)
2 changes: 1 addition & 1 deletion src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def init_targets(self, target_neurograph, target_densegraph):
xyz_j = self.nodes[j]["xyz"]
proj_xyz_i, d_i = target_neurograph.get_projection(xyz_i)
proj_xyz_j, d_j = target_neurograph.get_projection(xyz_j)
if d_i > 10 or d_j > 10:
if d_i > 7 or d_j > 7:
continue

# Get corresponding edges on target
Expand Down
Loading

0 comments on commit 2b2a424

Please sign in to comment.