diff --git a/src/deep_neurographs/feature_extraction.py b/src/deep_neurographs/feature_extraction.py index 3f7e049..e71a8db 100644 --- a/src/deep_neurographs/feature_extraction.py +++ b/src/deep_neurographs/feature_extraction.py @@ -15,6 +15,7 @@ from deep_neurographs import geometry_utils, utils +CHUNK_SIZE = [32, 32, 32] NUM_POINTS = 5 WINDOW_SIZE = [5, 5, 5] @@ -24,7 +25,7 @@ # -- Wrappers -- def generate_mutable_features( - neurograph, anisotropy=[1.0, 1.0, 1.0], img_path=None + neurograph, anisotropy=[1.0, 1.0, 1.0], img_path=None, img_profile=True ): """ Generates feature vectors for every mutable edge in a neurograph. @@ -46,15 +47,32 @@ def generate_mutable_features( """ features = {"skel": generate_mutable_skel_features(neurograph)} - if img_path is not None: - features["img"] = generate_mutable_img_features( + if img_path and img_profile: + features["img"] = generate_mutable_img_profile_features( neurograph, img_path, anisotropy=anisotropy ) - return combine_feature_vecs(features) + elif img_path and not img_profile: + features["img"] = generate_mutable_img_chunk_features( + neurograph, img_path, anisotropy=anisotropy + ) + return features # -- Edge feature extraction -- -def generate_mutable_img_features( +def generate_mutable_img_chunk_features( + neurograph, path, anisotropy=[1.0, 1.0, 1.0] +): + img = utils.open_zarr(path) + features = dict() + for edge in neurograph.mutable_edges: + xyz_edge = neurograph.edges[edge]["xyz"] + xyz = geometry_utils.compute_midpoint(xyz_edge[0], xyz_edge[1]) + xyz = geometry_utils.get_coord(xyz, anisotropy=anisotropy) + features[edge] = utils.read_img_chunk(img, xyz, CHUNK_SIZE) + return features + + +def generate_mutable_img_profile_features( neurograph, path, anisotropy=[1.0, 1.0, 1.0] ): img = utils.open_zarr(path) @@ -131,9 +149,10 @@ def get_radii(neurograph, edge): # -- Combine feature vectors -def build_feature_matrix(neurographs, features, blocks): +def build_feature_matrix(neurographs, features, blocks, img_chunks=False): # Initialize X = None + y = None block_to_idxs = dict() idx_to_edge = dict() @@ -141,9 +160,14 @@ def build_feature_matrix(neurographs, features, blocks): for block_id in blocks: # Get features idx_shift = 0 if X is None else X.shape[0] - X_i, y_i, idx_to_edge_i = build_feature_submatrix( - neurographs[block_id], features[block_id], idx_shift - ) + if img_chunks: + X_i, y_i, idx_to_edge_i = build_img_chunk_submatrix( + neurographs[block_id], features[block_id], idx_shift + ) + else: + X_i, y_i, idx_to_edge_i = build_feature_submatrix( + neurographs[block_id], features[block_id], idx_shift + ) # Concatenate if X is None: @@ -160,33 +184,53 @@ def build_feature_matrix(neurographs, features, blocks): return X, y, block_to_idxs, idx_to_edge -def build_feature_submatrix(neurograph, feat_dict, shift): +def build_feature_submatrix(neurograph, features, shift): # Extract info - key = sample(list(feat_dict.keys()), 1)[0] + features = combine_features(features) + key = sample(list(features.keys()), 1)[0] num_edges = neurograph.num_mutables() - num_features = len(feat_dict[key]) + num_features = len(features[key]) # Build idx_to_edge = dict() X = np.zeros((num_edges, num_features)) y = np.zeros((num_edges)) - for i, edge in enumerate(feat_dict.keys()): + for i, edge in enumerate(features.keys()): + idx_to_edge[i + shift] = edge + X[i, :] = features[edge] + y[i] = 1 if edge in neurograph.target_edges else 0 + return X, y, idx_to_edge + + +def build_img_chunk_submatrix(neurograph, features, shift): + # Extract info + key = sample(list(features.keys()), 1)[0] + num_edges = neurograph.num_mutables() + num_features = len(features[key]) + + # Build + idx_to_edge = dict() + X = np.zeros(((num_edges,) + tuple(CHUNK_SIZE))) + y = np.zeros((num_edges)) + for i, edge in enumerate(features["img"].keys()): idx_to_edge[i + shift] = edge - X[i, :] = feat_dict[edge] + X[i, :] = features["img"][edge] y[i] = 1 if edge in neurograph.target_edges else 0 return X, y, idx_to_edge # -- Utils -- -def compute_num_features(): - return NUM_SKEL_FEATURES # NUM_IMG_FEATURES + +def compute_num_features(skel_features=True, img_features=True): + num_features = NUM_SKEL_FEATURES if skel_features else 0 + num_features += NUM_IMG_FEATURES if img_features else 0 + return num_features -def combine_feature_vecs(features): +def combine_features(features): for edge in features["skel"].keys(): - for feat_key in [key for key in features.keys() if key != "skel"]: + for key in [key for key in features.keys() if key != "skel"]: features["skel"][edge] = np.concatenate( - (features["skel"][edge], features[feat_key][edge]) + (features["skel"][edge], features[key][edge]) ) return features["skel"] diff --git a/src/deep_neurographs/geometry_utils.py b/src/deep_neurographs/geometry_utils.py index 3788599..8a85d78 100644 --- a/src/deep_neurographs/geometry_utils.py +++ b/src/deep_neurographs/geometry_utils.py @@ -94,6 +94,10 @@ def compute_normal(xyz): return normal / np.linalg.norm(normal) +def compute_midpoint(xyz1, xyz2): + return np.mean([xyz1, xyz2], axis=0) + + # Smoothing def smooth_branch(xyz): if xyz.shape[0] > 5: @@ -104,7 +108,7 @@ def smooth_branch(xyz): def fit_spline(xyz): - s = xyz.shape[0] / 10 + s = xyz.shape[0] / 5 t = np.arange(xyz.shape[0]) cs_x = UnivariateSpline(t, xyz[:, 0], s=s, k=3) cs_y = UnivariateSpline(t, xyz[:, 1], s=s, k=3) @@ -142,6 +146,10 @@ def get_coords(xyz_arr, anisotropy=[1.0, 1.0, 1.0]): return xyz_arr.astype(int) +def get_coord(xyz, anisotropy=[1.0, 1.0, 1.0]): + return [int(xyz[i] / anisotropy[i]) for i in range(3)] + + # Miscellaneous def compare_edges(xyx_i, xyz_j, xyz_k): dist_ij = dist(xyx_i, xyz_j) diff --git a/src/deep_neurographs/neural_networks.py b/src/deep_neurographs/neural_networks.py index 15d182b..ebea718 100644 --- a/src/deep_neurographs/neural_networks.py +++ b/src/deep_neurographs/neural_networks.py @@ -1,43 +1,70 @@ +import torch from torch import nn class FeedFowardNet(nn.Module): - def __init__(self, num_features, depth=3): + def __init__(self, num_features): nn.Module.__init__(self) + self.fc1 = self._init_fc_layer(num_features, num_features) + self.fc2 = self._init_fc_layer(num_features, num_features // 2) + self.output = nn.Sequential(nn.Linear(num_features // 2, 1)) - # Parameters - assert depth < num_features - self.depth = depth - self.num_features = num_features - - # Layers - print("Network Architecture...") - self.activation = nn.ELU() - self.dropout = nn.Dropout(p=0.2) - for d in range(self.depth): - D_in = num_features // max(d, 1) - D_out = num_features // (d + 1) - self.add_fc_layer(d, D_in, D_out) - self.last_fc = nn.Linear(D_out, 1) - self.sigmoid = nn.Sigmoid() + def _init_fc_layer(self, D_in, D_out): + fc_layer = nn.Sequential( + nn.Linear(D_in, D_out), nn.LeakyReLU(), nn.Dropout(p=0.25) + ) + return fc_layer def forward(self, x): - for d in range(self.depth): - fc_d = getattr(self, "fc{}".format(d)) - x = self.activation(self.dropout(fc_d(x))) - x = self.last_fc(x) - return self.sigmoid(x) - - def add_fc_layer(self, d, D_in, D_out): - setattr(self, "fc{}".format(d), nn.Linear(D_in, D_out)) - print(" {} --> {}".format(D_in, D_out)) + x = self.fc1(x) + x = self.fc2(x) + x = self.output(x) + return x class ConvNet(nn.Module): - def __init__(self, input_dims, depth=3): - pass + def __init__(self): + nn.Module.__init__(self) + self.conv1 = self._init_conv_layer(1, 4) + self.conv2 = self._init_conv_layer(4, 8) + self.output = nn.Sequential( + nn.Linear(8*6*6*6, 64), + nn.LeakyReLU(), + nn.Linear(64, 1) + ) + + def _init_conv_layer(self, in_channels, out_channels): + conv_layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=1, + padding=0, + ), + nn.BatchNorm3d(out_channels), + nn.LeakyReLU(), + nn.Dropout(p=0.25), + nn.MaxPool3d(kernel_size=(2, 2, 2), stride=2), + ) + return conv_layer + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = x.view(x.size(0), -1) + x = self.output(x) + return x class MultiModalNet(nn.Module): def __init__(self, feature_vec_shape, img_patch_shape): pass + + +def weights_init(net): + for module in net.modules(): + if isinstance(module, nn.Conv3d): + torch.nn.init.xavier_normal_(module.weight) + elif isinstance(module, nn.Linear): + torch.nn.init.xavier_normal_(module.weight) diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index b53eba0..19180eb 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -301,7 +301,7 @@ def init_targets(self, target_neurograph, target_densegraph): xyz_j = self.nodes[j]["xyz"] proj_xyz_i, d_i = target_neurograph.get_projection(xyz_i) proj_xyz_j, d_j = target_neurograph.get_projection(xyz_j) - if d_i > 10 or d_j > 10: + if d_i > 7 or d_j > 7: continue # Get corresponding edges on target diff --git a/src/deep_neurographs/train.py b/src/deep_neurographs/train.py index fa81861..7b09f7e 100644 --- a/src/deep_neurographs/train.py +++ b/src/deep_neurographs/train.py @@ -1,12 +1,16 @@ from random import sample - +from deep_neurographs import utils import lightning.pytorch as pl +from lightning.pytorch.callbacks import ModelCheckpoint import numpy as np from sklearn.metrics import roc_auc_score +import torchio as tio + import torch import torch.nn.functional as F +import torch.utils.data as torch_data +from torch.utils.data import Dataset, DataLoader from torch_geometric.utils import negative_sampling -from torch.utils.data import Dataset from torcheval.metrics.functional import ( binary_accuracy, binary_f1_score, @@ -14,69 +18,163 @@ binary_recall, ) +BATCH_SIZE = 32 +NUM_WORKERS = 0 +SHUFFLE = True + -# Cross Validation +# Training def get_kfolds(train_data, k): folds = [] samples = set(train_data) num_samples = int(np.floor(len(train_data) / k)) assert num_samples > 0, "Sample size is too small for {}-folds".format(k) for i in range(k): - if i < k - 1: - samples_i = sample(samples, num_samples) - samples = samples.difference(samples_i) - folds.append(set(samples_i)) - else: - folds.append(samples) + samples_i = sample(samples, num_samples) + samples = samples.difference(samples_i) + folds.append(samples_i) + if num_samples > len(samples): + break return folds -# Neural Network Training -class EdgeDataset(Dataset): - def __init__(self, data, labels): - self.data = data.astype(np.float32) +def train_network(dataset, net, max_epochs=100): + # Load data + train_set, valid_set = random_split(dataset) + train_loader = DataLoader( + train_set, + num_workers=NUM_WORKERS, + batch_size=BATCH_SIZE, + shuffle=SHUFFLE, + ) + valid_loader = DataLoader( + valid_set, num_workers=NUM_WORKERS, batch_size=BATCH_SIZE + ) + + # Fit model + model = LitNeuralNet(net) + checkpoint_callback = ModelCheckpoint( + save_top_k=1, monitor="val_f1", mode="max" + ) + trainer = pl.Trainer( + accelerator="gpu", + devices=1, + max_epochs=max_epochs, + callbacks=[checkpoint_callback], + ) + trainer.fit(model, train_loader, valid_loader) + return model + + +def random_split(train_set, train_ratio=0.8): + train_set_size = int(len(train_set) * train_ratio) + valid_set_size = len(train_set) - train_set_size + return torch_data.random_split(train_set, [train_set_size, valid_set_size]) + + +def eval_network(X, model, threshold=0.5): + model.eval() + X = torch.tensor(X, dtype=torch.float32) + y_pred = model.net(X) + return np.array(y_pred > threshold, dtype=int) + + +# Custom Datasets +class ProposalDataset(Dataset): + def __init__(self, inputs, labels, transform=None, target_transform=None): + self.inputs = inputs.astype(np.float32) self.labels = labels.astype(np.float32) def __len__(self): - return len(self.data) + return len(self.labels) def __getitem__(self, idx): - return {"data": self.data[idx], "label": self.labels[idx]} + return {"inputs": self.inputs[idx], "labels": self.labels[idx]} + +class ImgProposalDataset(Dataset): + def __init__(self, inputs, labels, transform=True): + self.inputs = self.reformat(inputs) + self.labels = self.reformat(labels) + if transform: + self.transform = Augmentator() + self.transform_bool = True if transform else False + def __len__(self): + return len(self.labels) + + def __getitem__(self, idx): + if self.transform_bool: + inputs = utils.normalize(self.inputs[idx]) + inputs = self.transform.run(inputs) + else: + inputs = self.inputs[idx] + return {"inputs": inputs, "labels": self.labels[idx]} + + def reformat(self, x): + return np.expand_dims(x, axis=1).astype(np.float32) + + +class Augmentator: + def __init__(self): + self.blur = tio.RandomBlur(std=(0, 0.5)) # 1 + self.noise = tio.RandomNoise(std=(0, 0.03)) + self.elastic = tio.RandomElasticDeformation(max_displacement=10) + self.apply_geometric = tio.Compose({ + #tio.RandomFlip(axes=(0, 1, 2)), + tio.RandomAffine(degrees=30, scales=(0.8, 1)), + }) + + def run(self, arr): + arr = self.blur(arr) + arr = self.noise(arr) + #arr = self.elastic(arr) + arr = self.apply_geometric(arr) + return arr + + +# Neural Network Training class LitNeuralNet(pl.LightningModule): def __init__(self, net): super().__init__() self.net = net + def forward(self, batch): + x = self.get_example(batch, "inputs") + return self.net(x) + def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + optimizer = torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-5) return optimizer def training_step(self, batch, batch_idx): - x = self.get_example(batch, "data") - y = self.get_example(batch, "label") - y_hat = self.net(x) - return F.mse_loss(y_hat, y) - - def test_step(self, batch, batch_idx): - x = self.get_example(batch, "data") - y = self.get_example(batch, "label") - y_hat = self.net(x) - self.compute_stats(y_hat, y) - - def compute_stats(self, y_hat, y): + X = self.get_example(batch, "inputs") + y = self.get_example(batch, "labels") + y_hat = self.net(X) + loss = F.binary_cross_entropy_with_logits(y_hat, y) + self.log("train_loss", loss) + self.compute_stats(y_hat, y, prefix="train_") + return loss + + def validation_step(self, batch, batch_idx): + X = self.get_example(batch, "inputs") + y = self.get_example(batch, "labels") + y_hat = self.net(X) + self.compute_stats(y_hat, y, prefix="val_") + + def compute_stats(self, y_hat, y, prefix=""): y_hat = torch.flatten(y_hat) y = torch.flatten(y).to(torch.int) - self.log("accuracy", binary_accuracy(y_hat, y)) - self.log("precision", binary_precision(y_hat, y)) - self.log("recall", binary_recall(y_hat, y)) - self.log("f1", binary_f1_score(y_hat, y)) + self.log(prefix + "accuracy", binary_accuracy(y_hat, y)) + self.log(prefix + "precision", binary_precision(y_hat, y)) + self.log(prefix + "recall", binary_recall(y_hat, y)) + self.log(prefix + "f1", binary_f1_score(y_hat, y)) def get_example(self, batch, key): - return batch[key].view(batch[key].size(0), -1) + return batch[key] +""" def train(model, optimizer, criterion, train_data): model.train() optimizer.zero_grad() @@ -116,7 +214,7 @@ def test(model, data): return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy()) -""" + from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCNConv import torch_geometric.transforms as T diff --git a/src/deep_neurographs/utils.py b/src/deep_neurographs/utils.py index d2269e8..21a6685 100644 --- a/src/deep_neurographs/utils.py +++ b/src/deep_neurographs/utils.py @@ -235,6 +235,12 @@ def subplot(data1, data2, title): # --- miscellaneous --- +def normalize(img): + img -= np.min(img) + img = img / np.max(img) + return img + + def to_world(xyz, anisotropy, shift=[0, 0, 0]): return tuple([int((xyz[i] - shift[i]) * anisotropy[i]) for i in range(3)])