From 132269534f35e081fd14dbf420f5e01c805ba851 Mon Sep 17 00:00:00 2001 From: Anna Grim <108307071+anna-grim@users.noreply.github.com> Date: Mon, 6 Jan 2025 13:32:50 -0800 Subject: [PATCH] major updates --- .../machine_learning/data_handling.py | 358 ++++++++++++++++++ .../machine_learning/models.py | 159 ++++++++ .../machine_learning/trainer.py | 54 +++ 3 files changed, 571 insertions(+) create mode 100644 src/aind_exaspim_soma_detection/machine_learning/data_handling.py create mode 100644 src/aind_exaspim_soma_detection/machine_learning/models.py create mode 100644 src/aind_exaspim_soma_detection/machine_learning/trainer.py diff --git a/src/aind_exaspim_soma_detection/machine_learning/data_handling.py b/src/aind_exaspim_soma_detection/machine_learning/data_handling.py new file mode 100644 index 0000000..b050d5b --- /dev/null +++ b/src/aind_exaspim_soma_detection/machine_learning/data_handling.py @@ -0,0 +1,358 @@ +""" +Created on Thu Dec 5 14:00:00 2024 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Routines for loading data and applying data augmentation (if applicable). + +""" + +from concurrent.futures import ThreadPoolExecutor, as_completed +from scipy.ndimage import rotate +from torch.utils.data import Dataset + +import ast +import numpy as np +import os +import random +import torch +import torchvision.transforms as transforms + +from aind_exaspim_soma_detection.utils import img_util, util + + +# --- Custom Dataset --- +class ProposalDataset(Dataset): + """ + A custom dataset used to train a neural network to classify soma proposals + as either accept or reject. The dataset is initialized by providing + the following inputs: + (1) Path to a whole brain dataset stored in an S3 bucket + (2) List of voxel coordinates representing soma proposals + (3) Optionally, labels for each proposal (i.e. 0 or 1) + + Note: This dataset supports inputs from multiple whole brain datasets. + + """ + + def __init__(self, patch_shape, transform=False): + # Initialize class attributes + self.examples = dict() # key: (brain_id, voxel), value: label + self.imgs = dict() # key: brain_id, value: image + self.img_paths = dict() + self.patch_shape = patch_shape + + # Data augmentation + if transform: + self.transform = transforms.Compose( + [ + RandomFlip3D(), + RandomNoise3D(), + RandomRotation3D(angles=(-30, 30)), + RandomContrast3D(factor_range=(0.7, 1.3)), + lambda x: torch.tensor(x, dtype=torch.float32).unsqueeze( + 0 + ), + ] + ) + else: + self.transform = transform + + def __len__(self): + return len(self.examples) + + def __getitem__(self, key): + brain_id, voxel = key + img_patch = img_util.get_patch( + self.imgs[brain_id], voxel, self.patch_shape + ) + return key, img_patch / 2**15, self.examples[key] + + def get_positives(self): + """ + Gets all positive examples in the dataset. + + Parameters + ---------- + None + + Returns + ------- + dict + Positive examples in dataset. + + """ + return dict({k: v for k, v in self.examples.items() if v}) + + def get_negatives(self): + """ + Gets all negative examples in the dataset. + + Parameters + ---------- + None + + Returns + ------- + dict + Negetaive examples in dataset. + + """ + return dict({k: v for k, v in self.examples.items() if not v}) + + def n_positives(self): + """ + Counts the number of positive examples in the dataset. + + Parameters + ---------- + None + + Returns + ------- + int + Number of positive examples in the dataset. + + """ + return len(self.get_positives()) + + def n_negatives(self): + """ + Counts the number of negative examples in the dataset. + + Parameters + ---------- + None + + Returns + ------- + int + Number of negative examples in the dataset. + + """ + return len(self.get_negatives()) + + def ingest_examples(self, brain_id, img_prefix, proposals, labels=None): + # Sanity check + if labels is not None: + assert len(proposals) == len(labels), "#proposals != #labels" + + # Load image (if applicable) + if brain_id not in self.imgs: + self.imgs[brain_id] = img_util.open_img(img_prefix) + self.img_paths[brain_id] = img_prefix + + # Load proposal voxel coordinates + for i, voxel in enumerate(proposals): + key = (brain_id, tuple(voxel)) + self.examples[key] = labels[i] if labels else None + + def remove_example(self, key): + del self.examples[key] + + def visualize_example(self, key): + _, img_patch, _ = self.__getitem__(key) + img_util.plot_mips(img_patch, clip_bool=True) + + def visualize_augmented_example(self, key): + # Get image patch + _, img_patch, _ = self.__getitem__(key) + img_util.plot_mips(img_patch, clip_bool=True) + + # Apply transforms + img_patch = np.array(self.transform(img_patch)) + img_util.plot_mips(img_patch[0, ...], clip_bool=True) + + +# --- Data Augmentation --- +class RandomFlip3D: + """ + Randomly flip a 3D image along one or more axes. + + """ + + def __init__(self, axes=(0, 1, 2)): + self.axes = axes + + def __call__(self, img): + for axis in self.axes: + if random.random() > 0.5: + img = np.flip(img, axis=axis) + return img + + +class RandomNoise3D: + """ + Adds random Gaussian noise to a 3D image. + + """ + + def __init__(self, mean=0.0, std=0.001): + self.mean = mean + self.std = std + + def __call__(self, img): + noise = np.random.normal(self.mean, self.std, img.shape) + return img + noise + + +class RandomRotation3D: + """ + Applies random rotation to a 3D image along a randomly chosen axis. + + """ + + def __init__(self, angles=(-45, 45), axes=((0, 1), (0, 2), (1, 2))): + self.angles = angles + self.axes = axes + + def __call__(self, img): + for _ in range(2): + angle = random.uniform(*self.angles) + axis = random.choice(self.axes) + img = rotate(img, angle, axes=axis, reshape=False, order=1) + return img + + +class RandomContrast3D: + """ + Adjusts the contrast of a 3D image by scaling voxel intensities. + + """ + + def __init__(self, factor_range=(0.8, 1.2)): + self.factor_range = factor_range + + def __call__(self, img): + factor = random.uniform(*self.factor_range) + return np.clip(img * factor, img.min(), img.max()) + + +class RandomScale3D: + """ + Randomly scale a 3D image along all axes. + + """ + + def __init__(self, scale_range=(0.8, 1.2), axes=(0, 1, 2)): + self.scale_range = scale_range + self.axes = axes + + def __call__(self, img): + # Generate scaling factors + scales = list() + for _ in self.axes: + scales.append( + random.uniform(self.scale_range[0], self.scale_range[1]) + ) + + # Create a new shape by scaling the dimensions of the image + new_shape = list(img.shape) + for i, axis in enumerate(self.axes): + new_shape[axis] = int(img.shape[axis] * scales[i]) + return np.resize(img, new_shape) + + +# --- Custom Dataloader --- +class MultiThreadedDataLoader: + """ + DataLoader that uses multithreading to fetch image patches from the cloud + to form batches. + + """ + + def __init__(self, dataset, batch_size, return_keys=False): + """ + Constructs a multithreaded data loading object. + + Parameters + ---------- + dataset : Dataset.ProposalDataset + Instance of custom dataset. + batch_size : int + Number of samples per batch. + return_keys : bool, optional + Indication of whether to return example ids. The default is False. + + Returns + ------- + None + + """ + self.dataset = dataset + self.batch_size = batch_size + self.return_keys = return_keys + + def __iter__(self): + return self.DataLoaderIterator(self) + + class DataLoaderIterator: + def __init__(self, dataloader): + self.current_index = 0 + self.batch_size = dataloader.batch_size + self.dataloader = dataloader + self.dataset = dataloader.dataset + self.keys = list(self.dataset.examples.keys()) + self.return_keys = dataloader.return_keys + np.random.shuffle(self.keys) + + def __iter__(self): + return self + + def __next__(self): + # Check whether to stop + if self.current_index >= len(self.keys): + raise StopIteration + + # Get the next batch of keys + batch_keys = self.keys[ + self.current_index: self.current_index + + self.dataloader.batch_size + ] + self.current_index += self.dataloader.batch_size + + # Load image patches + with ThreadPoolExecutor() as executor: + # Assign threads + threads = { + executor.submit(self.dataset.__getitem__, idx): idx + for idx in batch_keys + } + + # Process results + keys = list() + patches = list() + labels = list() + for thread in as_completed(threads): + key, patch, label = thread.result() + keys.append(key) + patches.append(torch.tensor(patch, dtype=torch.float)) + labels.append(torch.tensor(label, dtype=torch.float)) + + # Reformat inputs + patches = torch.unsqueeze(torch.stack(patches), dim=1) + labels = torch.unsqueeze(torch.stack(labels), dim=1) + if self.return_keys: + return keys, patches, labels + else: + return patches, labels + + +# --- utils --- +def init_subdataset(dataset, positives, negatives, patch_shape, transform): + subdataset = ProposalDataset(patch_shape, transform=transform) + for example_tuple in merge_examples(dataset, positives, negatives): + subdataset.ingest_examples(*example_tuple) + return subdataset + + +def merge_examples(soma_dataset, positives, negatives): + examples = list() + combined_dict = positives.copy() + combined_dict.update(negatives) + for key, value in combined_dict.items(): + brain_id, voxel = key + img_path = soma_dataset.img_paths[brain_id] + examples.append((brain_id, img_path, [voxel], [value])) + return examples diff --git a/src/aind_exaspim_soma_detection/machine_learning/models.py b/src/aind_exaspim_soma_detection/machine_learning/models.py new file mode 100644 index 0000000..7a5dc95 --- /dev/null +++ b/src/aind_exaspim_soma_detection/machine_learning/models.py @@ -0,0 +1,159 @@ +""" +Created on Mon Dec 9 13:00:00 2024 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Neural network architectures that classifies 3d image patches with a +soma proposal as accept or reject. + +""" + +import torch +import torch.nn as nn + + +class Fast3dCNN(nn.Module): + """ + Fast 3d convolutional neural network that utilizes 2.5d convolutional + layers to improve the computational complexity. + """ + + def __init__(self, patch_shape): + """ + Constructs the neural network architecture. + + Parameters + ---------- + patch_shape : Tuple[int] + Shape of image patches to be run through network. + + Returns + ------- + None + + """ + super(Fast3dCNN, self).__init__() + self.patch_shape = patch_shape + + # Convolutional Layer 1 + self.layer1 = nn.Sequential( + FastConvLayer(1, 32), + nn.BatchNorm3d(32), + nn.ReLU(), + nn.MaxPool3d(kernel_size=2, stride=2), + ) + + # Convolutional Layer 2 + self.layer2 = nn.Sequential( + FastConvLayer(32, 64), + nn.BatchNorm3d(64), + nn.ReLU(), + nn.MaxPool3d(kernel_size=2, stride=2), + ) + + # Convolutional Layer 3 + self.layer3 = nn.Sequential( + FastConvLayer(64, 128), + nn.BatchNorm3d(128), + nn.ReLU(), + nn.MaxPool3d(kernel_size=2, stride=2), + ) + + # Final fully connected layers + self.output = nn.Sequential( + nn.Linear(128 * (self.patch_shape[0] // 8) ** 3, 256), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(256, 1), + ) + + def forward(self, x): + """ + Forward pass of the 2.5D convolutional neural network. + + Parameters + ---------- + x : torch.Tensor + Input of shape (batch_size, in_channels, height, width, depth). + + Returns + ------- + torch.Tensor + Output with shape (batch_size, 1). + + """ + # Convolutional Layers + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + # Output layer + x = torch.flatten(x, start_dim=1) + x = self.output(x) + return x + + +class FastConvLayer(nn.Module): + """ + Class that performs a single layer of 2.5 convolution. + + """ + + def __init__(self, in_channels, out_channels): + """ + Constructs a 2.5D convolutional layer. + + Parameters + ---------- + in_channels : int + Number of input channels. + int_channels : int + Number of intermediate channels for the 2D convolutions. + + Returns + ------- + None + + """ + super(FastConvLayer, self).__init__() + + self.conv_2d = nn.Conv2d(in_channels, in_channels, 3, padding=1) + self.conv_3d = nn.Conv3d(3 * in_channels, out_channels, 3, padding=1) + + def forward(self, x): + """ + Forward pass of a single 2.5D convolution block. + + Parameters + ---------- + x : torch.Tensor + Input of shape (batch_size, in_channels, height, width, depth). + + Returns + ------- + torch.Tensor + Output of shape (batch_size, out_channels, height, width, depth). + + """ + B, C, D, H, W = x.shape + + # Process XY slices + xy_slices = x.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) + xy_features = self.conv_2d(xy_slices).reshape(B, D, -1, H, W) + xy_features = xy_features.permute(0, 2, 1, 3, 4) + + # Process XZ slices + xz_slices = x.permute(0, 3, 1, 2, 4).reshape(B * H, C, D, W) + xz_features = self.conv_2d(xz_slices).reshape(B, H, -1, D, W) + xz_features = xz_features.permute(0, 2, 3, 1, 4) + + # Process YZ slices + yz_slices = x.permute(0, 4, 1, 2, 3).reshape(B * W, C, D, H) + yz_features = self.conv_2d(yz_slices).reshape(B, W, -1, D, H) + yz_features = yz_features.permute(0, 2, 3, 4, 1) + + # Fuse features using 3D convolution + combined = torch.cat([xy_features, xz_features, yz_features], dim=1) + output = self.conv_3d(combined) + return output diff --git a/src/aind_exaspim_soma_detection/machine_learning/trainer.py b/src/aind_exaspim_soma_detection/machine_learning/trainer.py new file mode 100644 index 0000000..0bec8ce --- /dev/null +++ b/src/aind_exaspim_soma_detection/machine_learning/trainer.py @@ -0,0 +1,54 @@ +""" +Created on Fri Jan 3 12:30:00 2025 + +@author: Anna Grim +@email: anna.grim@alleninstitute.org + +Code used to train neural network to classify somas proposals. + +""" + +from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score + +import numpy as np + + +def evaluation_metrics(epoch, writer, y, hat_y, prefix=""): + """ + Computes and logs various evaluation metrics to a TensorBoard. + + Parameters + ---------- + epoch : int + Current training epoch. Used as the x-axis value for logging metrics + in the TensorBoard. + writer : torch.utils.tensorboard.SummaryWriter + TensorBoard writer object to log the metrics. + y : ArrayLike + True labels or ground truth values. + hat_y : ArrayLike + Predicted labels from a model. + prefix : str, optional + String prefix to prepend to the metric names when logging to + TensorBoard. Default is an empty string. + + Returns + ------- + float + F1 score for the given epoch. + + """ + # Compute metrics + accuracy = accuracy_score(y, hat_y) + accuracy_dif = accuracy - np.sum(y) / len(y) + f1 = f1_score(y, hat_y) + precision = precision_score(y, hat_y) + recall = recall_score(y, hat_y) + + # Write results to tensorboard + writer.add_scalar(prefix + "_accuracy", accuracy, epoch) + writer.add_scalar(prefix + "_accuracy_df", accuracy_dif, epoch) + writer.add_scalar(prefix + "_precision:", precision, epoch) + writer.add_scalar(prefix + "_recall:", recall, epoch) + writer.add_scalar(prefix + "_f1:", f1, epoch) + return f1