diff --git a/src/aind_exaspim_soma_detection/training/data_handling.py b/src/aind_exaspim_soma_detection/training/data_handling.py deleted file mode 100644 index f62d2c8..0000000 --- a/src/aind_exaspim_soma_detection/training/data_handling.py +++ /dev/null @@ -1,407 +0,0 @@ -""" -Created on Thu Dec 5 14:00:00 2024 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Routines for loading data and applying data augmentation (if applicable). - -""" - -from concurrent.futures import ThreadPoolExecutor, as_completed -from scipy.ndimage import rotate -from torch.utils.data import Dataset - -import numpy as np -import os -import random -import torch -import torchvision.transforms as transforms - -from aind_exaspim_soma_detection.utils import img_util, util - - -# --- Custom Dataset --- -class SomaDataset(Dataset): - """ - A custom dataset used to train a neural network to classify soma proposals - as either accepted or rejected. The dataset is initialized by providing - the following inputs: - (1) Path to a whole brain dataset stored in an S3 bucket - (2) List of voxel coordinates representing soma proposals - (3) Optionally, labels for each proposal (i.e. accept or reject) - - Note: This dataset supports inputs from multiple whole brain datasets. - - """ - - def __init__(self, patch_shape, transform=False): - # Initialize class attributes - self.examples = dict() # key: (brain_id, voxel), value: label - self.imgs = dict() # key: brain_id, value: image - self.img_paths = dict() - self.patch_shape = patch_shape - - # Data augmentation - if transform: - self.transform = transforms.Compose( - [ - RandomFlip3D(), - RandomNoise3D(), - RandomRotation3D(angles=(-30, 30)), - RandomContrast3D(factor_range=(0.7, 1.3)), - lambda x: torch.tensor(x, dtype=torch.float32).unsqueeze( - 0 - ), - ] - ) - else: - self.transform = transform - - def n_examples(self): - """ - Counts the number of examples in the dataset. - - Parameters - ---------- - None - - Returns - ------- - int - Number of examples in the dataset. - - """ - return len(self.examples) - - def n_positives(self): - """ - Counts the number of positive examples in the dataset. - - Parameters - ---------- - None - - Returns - ------- - int - Number of positive examples in the dataset. - - """ - return len(self.get_positives()) - - def n_negatives(self): - """ - Counts the number of negative examples in the dataset. - - Parameters - ---------- - None - - Returns - ------- - int - Number of negative examples in the dataset. - - """ - return len(self.get_negatives()) - - def get_positives(self): - """ - Gets all positive examples in the dataset. - - Parameters - ---------- - None - - Returns - ------- - dict - Positive examples in dataset. - - """ - return dict({k: v for k, v in self.examples.items() if v}) - - def get_negatives(self): - """ - Gets all negative examples in the dataset. - - Parameters - ---------- - None - - Returns - ------- - dict - Negetaive examples in dataset. - - """ - return dict({k: v for k, v in self.examples.items() if not v}) - - def __len__(self): - return len(self.examples) - - def __getitem__(self, key): - brain_id, voxel = key - img_patch = img_util.get_patch( - self.imgs[brain_id], voxel, self.patch_shape - ) - return img_patch / 2**15, self.examples[key] - - def ingest_examples(self, brain_id, img_prefix, proposals, labels=None): - # Load image - if brain_id not in self.imgs: - self.imgs[brain_id] = img_util.open_img(img_prefix) - self.img_paths[brain_id] = img_prefix - - # Check if labels are valid - if labels is not None: - assert len(proposals) == len(labels), "#proposals != #labels" - - # Load proposals - for i, voxel in enumerate(proposals): - key = (brain_id, tuple(voxel)) - self.examples[key] = labels[i] if labels else None - - def visualize_example(self, key): - img_patch, _ = self.__getitem__(key) - img_util.plot_mips(img_patch, clip_bool=True) - - def visualize_augmented_example(self, key): - # Get image patch - img_patch, _ = self.__getitem__(key) - img_util.plot_mips(img_patch, clip_bool=True) - - # Apply transforms - img_patch = np.array(self.transform(img_patch)) - img_util.plot_mips(img_patch[0, ...], clip_bool=True) - - -# --- Data Augmentation --- -class RandomFlip3D: - """ - Randomly flip a 3D image along one or more axes. - - """ - - def __init__(self, axes=(0, 1, 2)): - self.axes = axes - - def __call__(self, img): - for axis in self.axes: - if random.random() > 0.5: - img = np.flip(img, axis=axis) - return img - - -class RandomNoise3D: - """ - Adds random Gaussian noise to a 3D image. - - """ - - def __init__(self, mean=0.0, std=0.001): - self.mean = mean - self.std = std - - def __call__(self, img): - noise = np.random.normal(self.mean, self.std, img.shape) - return img + noise - - -class RandomRotation3D: - """ - Applies random rotation to a 3D image along a randomly chosen axis. - - """ - - def __init__(self, angles=(-20, 20), axes=((0, 1), (0, 2), (1, 2))): - self.angles = angles - self.axes = axes - - def __call__(self, img): - for _ in range(2): - angle = random.uniform(*self.angles) - axis = random.choice(self.axes) - img = rotate(img, angle, axes=axis, reshape=False, order=1) - return img - - -class RandomContrast3D: - """ - Adjusts the contrast of a 3D image by scaling voxel intensities. - - """ - - def __init__(self, factor_range=(0.7, 1.3)): - self.factor_range = factor_range - - def __call__(self, img): - factor = random.uniform(*self.factor_range) - return np.clip(img * factor, img.min(), img.max()) - - -# --- Custom Dataloader --- -class MultiThreadedDataLoader: - """ - DataLoader that uses multithreading to fetch image patches from the cloud - to form batches. - - """ - - def __init__(self, dataset, batch_size): - """ - Constructs a multithreaded data loading object. - - Parameters - ---------- - dataset : Dataset.SomaDataset - Instance of custom dataset. - batch_size : int - Number of samples per batch. - - Returns - ------- - None - - """ - self.dataset = dataset - self.batch_size = batch_size - - def __iter__(self): - return self.DataLoaderIterator(self) - - class DataLoaderIterator: - def __init__(self, dataloader): - self.dataloader = dataloader - self.dataset = dataloader.dataset - self.batch_size = dataloader.batch_size - self.keys = list(self.dataset.examples.keys()) - np.random.shuffle(self.keys) - self.current_index = 0 - - def __iter__(self): - return self - - def __next__(self): - # Check whether to stop - if self.current_index >= len(self.keys): - raise StopIteration - - # Get the next batch of keys - batch_keys = self.keys[ - self.current_index: self.current_index - + self.dataloader.batch_size - ] - self.current_index += self.dataloader.batch_size - - # Load image patches - with ThreadPoolExecutor() as executor: - # Assign threads - threads = { - executor.submit(self.dataset.__getitem__, idx): idx - for idx in batch_keys - } - - # Process results - patches = list() - labels = list() - for thread in as_completed(threads): - patch, label = thread.result() - patches.append(torch.tensor(patch, dtype=torch.float)) - labels.append(torch.tensor(label, dtype=torch.float)) - - # Reformat inputs - patches = torch.unsqueeze(torch.stack(patches), dim=1) - labels = torch.unsqueeze(torch.stack(labels), dim=1) - return patches, labels - - -# --- Fetch Training Data --- -def fetch_smartsheet_somas(smartsheet_path, img_prefixes_path, multiscale=0): - # Read data - soma_coords = util.extract_somas_from_smartsheet(smartsheet_path) - img_prefixes = util.read_json(img_prefixes_path) - - # Reformat data - data = list() - for brain_id, xyz_list in soma_coords.items(): - data.append( - reformat_data(brain_id, img_prefixes, multiscale, xyz_list, 1) - ) - return data - - -def fetch_exaspim_somas_2024(dataset_path, img_prefixes_path, multiscale=0): - """ - Fetches and formats soma data for training from the exaSPIM dataset. - - Parameters - ---------- - dataset_path : str - Path to the dataset directory containing brain-specific subdirectories - with "accepts" and "rejects" folders. - img_prefixes_path : str - Path to a JSON file that maps brain IDs to image S3 prefixes. - multiscale : int, optional - Level in the image pyramid that the voxel coordinates must index into. - - Returns - ------- - List[tuple] - List of tuples where each tuple contains the following: - - "brain_id" (str): The unique identifier for the brain. - - "img_path" (str): Path to image stored in S3 image. - - "voxels" (list): Voxel coordinates of proposed somas. - - "labels" (list): Labels corresponding to voxels. - - """ - data = list() - img_prefixes = util.read_json(img_prefixes_path) - for brain_id in util.list_subdirectory_names(dataset_path): - # Read data - accepts_dir = os.path.join(dataset_path, brain_id, "accepts") - accepts_xyz = util.read_swc_dir(accepts_dir) - - rejects_dir = os.path.join(dataset_path, brain_id, "rejects") - rejects_xyz = util.read_swc_dir(rejects_dir) - - # Reformat data - data.append( - reformat_data(brain_id, img_prefixes, multiscale, accepts_xyz, 1) - ) - data.append( - reformat_data(brain_id, img_prefixes, multiscale, rejects_xyz, 0) - ) - return data - - -def reformat_data(brain_id, img_prefixes, multiscale, xyz_list, label): - """ - Reformats data for training or inference by converting xyz to voxel - coordinates and associates them with a brain id, image path, and labels. - - Parameters - ---------- - brain_id : str - Unique identifier for the whole brain dataset. - img_prefixes : dict - A dictionary mapping brain IDs to image S3 prefixes. - multiscale : int - Level in the image pyramid that the voxel coordinates must index into. - xyz_list : List[ArrayLike] - List 3D coordinates (x, y, z). - label : int - Label associated with the given coordinates (i.e. 1 for "accepts" and - 0 for "rejects"). - - Returns - ------- - tuple - Tuple containing the "brain_id", "image_path", "voxels", and "labels". - - """ - img_path = img_prefixes[brain_id] + str(multiscale) - voxels = [img_util.to_voxels(xyz, multiscale) for xyz in xyz_list] - labels = len(voxels) * [label] - return (brain_id, img_path, voxels, labels) diff --git a/src/aind_exaspim_soma_detection/training/models.py b/src/aind_exaspim_soma_detection/training/models.py deleted file mode 100644 index 9189e8e..0000000 --- a/src/aind_exaspim_soma_detection/training/models.py +++ /dev/null @@ -1,156 +0,0 @@ -""" -Created on Mon Dec 9 13:00:00 2024 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Neural network architectures that classifies 3d image patches with a -soma proposal as accept or reject. - -""" - -import torch -import torch.nn as nn - - -class Fast3dCNN(nn.Module): - """ - Fast 3d convolutional neural network that utilizes 2.5d convolutional - layers to improve the computational complexity. - """ - - def __init__(self): - """ - Constructs the neural network architecture. - - Parameters - ---------- - None - - Returns - ------- - None - - """ - super(Fast3dCNN, self).__init__() - - # Convolutional Layer 1 - self.layer1 = nn.Sequential( - FastConvLayer(1, 32), - nn.BatchNorm3d(32), - nn.ReLU(), - nn.MaxPool3d(kernel_size=2, stride=2), - ) - - # Convolutional Layer 2 - self.layer2 = nn.Sequential( - FastConvLayer(32, 64), - nn.BatchNorm3d(64), - nn.ReLU(), - nn.MaxPool3d(kernel_size=2, stride=2), - ) - - # Convolutional Layer 3 - self.layer3 = nn.Sequential( - FastConvLayer(64, 128), - nn.BatchNorm3d(128), - nn.ReLU(), - nn.MaxPool3d(kernel_size=2, stride=2), - ) - - # Final fully connected layers - self.output = nn.Sequential( - nn.Linear(128 * 8**3, 256), - nn.ReLU(), - nn.Dropout(0.5), - nn.Linear(256, 1), - ) - - def forward(self, x): - """ - Forward pass of the 2.5D convolutional neural network. - - Parameters - ---------- - x : torch.Tensor - Input of shape (batch_size, in_channels, height, width, depth). - - Returns - ------- - torch.Tensor - Output with shape (batch_size, 1). - - """ - # Convolutional Layers - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - - # Output layer - x = torch.flatten(x, start_dim=1) - x = self.output(x) - return x - - -class FastConvLayer(nn.Module): - """ - Class that performs a single layer of 2.5 convolution. - - """ - - def __init__(self, in_channels, out_channels): - """ - Constructs a 2.5D convolutional layer. - - Parameters - ---------- - in_channels : int - Number of input channels. - int_channels : int - Number of intermediate channels for the 2D convolutions. - - Returns - ------- - None - - """ - super(FastConvLayer, self).__init__() - - self.conv_2d = nn.Conv2d(in_channels, in_channels, 3, padding=1) - self.conv_3d = nn.Conv3d(3 * in_channels, out_channels, 3, padding=1) - - def forward(self, x): - """ - Forward pass of a single 2.5D convolution block. - - Parameters - ---------- - x : torch.Tensor - Input of shape (batch_size, in_channels, height, width, depth). - - Returns - ------- - torch.Tensor - Output of shape (batch_size, out_channels, height, width, depth). - """ - B, C, D, H, W = x.shape - - # Process XY slices - xy_slices = x.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) - xy_features = self.conv_2d(xy_slices).reshape(B, D, -1, H, W) - xy_features = xy_features.permute(0, 2, 1, 3, 4) - - # Process XZ slices - xz_slices = x.permute(0, 3, 1, 2, 4).reshape(B * H, C, D, W) - xz_features = self.conv_2d(xz_slices).reshape(B, H, -1, D, W) - xz_features = xz_features.permute(0, 2, 3, 1, 4) - - # Process YZ slices - yz_slices = x.permute(0, 4, 1, 2, 3).reshape(B * W, C, D, H) - yz_features = self.conv_2d(yz_slices).reshape(B, W, -1, D, H) - yz_features = yz_features.permute(0, 2, 3, 4, 1) - - # Fuse features using 3D convolution - combined = torch.cat([xy_features, xz_features, yz_features], dim=1) - output = self.conv_3d(combined) - return output diff --git a/src/aind_exaspim_soma_detection/training/trainer.py b/src/aind_exaspim_soma_detection/training/trainer.py deleted file mode 100644 index d4ff792..0000000 --- a/src/aind_exaspim_soma_detection/training/trainer.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -Created on Fri Jan 3 12:30:00 2025 - -@author: Anna Grim -@email: anna.grim@alleninstitute.org - -Code used to train neural network to classify somas proposals. - -""" - -from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score - - -def evaluation_metrics(epoch, writer, y, hat_y, prefix=""): - """ - Computes and logs various evaluation metrics to a TensorBoard. - - Parameters - ---------- - epoch : int - Current training epoch. Used as the x-axis value for logging metrics - in the TensorBoard. - writer : torch.utils.tensorboard.SummaryWriter - TensorBoard writer object to log the metrics. - y : ArrayLike - True labels or ground truth values. - hat_y : ArrayLike - Predicted labels from a model. - prefix : str, optional - String prefix to prepend to the metric names when logging to - TensorBoard. Default is an empty string. - - Returns - ------- - float - F1 score for the given epoch. - - """ - # Compute metrics - accuracy = accuracy_score(y, hat_y) - accuracy_dif = accuracy - np.sum(y) / len(y) - f1 = f1_score(y, hat_y) - precision = precision_score(y, hat_y) - recall = recall_score(y, hat_y) - - # Write results to tensorboard - writer.add_scalar(prefix + "_accuracy", accuracy, epoch) - writer.add_scalar(prefix + "_accuracy_df", accuracy_dif, epoch) - writer.add_scalar(prefix + "_precision:", precision, epoch) - writer.add_scalar(prefix + "_recall:", recall, epoch) - writer.add_scalar(prefix + "_f1:", f1, epoch) - return f1