Skip to content

Commit

Permalink
added documentation and reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Nov 6, 2023
1 parent 4b70d8f commit 54b3654
Show file tree
Hide file tree
Showing 8 changed files with 431 additions and 33 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta"

[project]
name = "deep-neurographs"
description = "Generated from aind-library-template"
description = "Neuron reconstruction framework that detects and corrects false splits in a predicted segmentation."
license = {text = "MIT"}
requires-python = ">=3.7"
authors = [
{name = "Allen Institute for Neural Dynamics"}
{name = "Anna Grim"}
]
classifiers = [
"Programming Language :: Python :: 3"
Expand Down
193 changes: 190 additions & 3 deletions src/deep_neurographs/deep_learning/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import torchio as tio
from torch.utils.data import Dataset

from deep_neurographs import utils


Expand All @@ -21,14 +22,59 @@ class ProposalDataset(Dataset):
proposals.
"""

def __init__(self, inputs, labels):
"""
Constructs ProposalDataset object.
Parameters
----------
inputs : np.array
Feature matrix where each row corresponds to the feature vector of
an edge proposal.
labels : np.array
Binary vector where each entry indicates whether an edge proposal
should be added or omitted from a reconstruction.
Returns
-------
None
"""
self.inputs = inputs.astype(np.float32)
self.labels = reformat(labels)

def __len__(self):
"""
Computes number of examples in dataset.
Parameters
----------
None
Returns
-------
int
Number of examples in dataset.
"""
return len(self.labels)

def __getitem__(self, idx):
"""
Gets example (i.e. input and label) corresponding to "idx".
Parameters
----------
idx : int
Index of example to be returned.
Returns
-------
dict
Example corresponding to "idx".
"""
return {"inputs": self.inputs[idx], "labels": self.labels[idx]}


Expand All @@ -38,15 +84,64 @@ class ImgProposalDataset(Dataset):
proposals.
"""

def __init__(self, inputs, labels, transform=True):
"""
Constructs ImgProposalDataset object.
Parameters
----------
inputs : numpy.array
Feature tensor where each submatrix corresponds to an image chunk
that contains an edge proposal. Note that the midpoint of the edge
proposal is the center point of the chunk.
labels : np.array
Binary vector where each entry indicates whether an edge proposal
should be added or omitted from a reconstruction.
transform : bool, optional
Indication of whether to apply data augmentation to the inputs.
The default is True.
Returns
-------
None
"""
self.inputs = inputs.astype(np.float32)
self.labels = reformat(labels)
self.transform = Augmentator() if transform else None

def __len__(self):
"""
Computes number of examples in dataset.
Parameters
----------
None
Returns
-------
int
Number of examples in dataset.
"""
return len(self.labels)

def __getitem__(self, idx):
"""
Gets example (i.e. input and label) corresponding to "idx".
Parameters
----------
idx : int
Index of example to be returned.
Returns
-------
dict
Example corresponding to "idx".
"""
if self.transform:
inputs = utils.normalize_img(self.inputs[idx])
inputs = self.transform.run(inputs)
Expand All @@ -61,16 +156,66 @@ class MultiModalDataset(Dataset):
chunks that correspond to edge proposals.
"""

def __init__(self, inputs, labels, transform=True):
"""
Constructs MultiModalDataset object.
Parameters
----------
inputs : dict
Feature dictionary where each key-value is the type of feature and
corresponding value. The keys of this dictionary are (1) "imgs" and
(2) "features" which correspond to a (1) feature tensor containing
image chunks and (2) feature vector.
labels : np.array
Binary vector where each entry indicates whether an edge proposal
should be added or omitted from a reconstruction.
transform : bool, optional
Indication of whether to apply augmentation to "inputs["imgs"]".
The default is True.
Returns
-------
None
"""
self.img_inputs = inputs["imgs"].astype(np.float32)
self.feature_inputs = inputs["features"].astype(np.float32)
self.labels = reformat(labels)
self.transform = Augmentator() if transform else None

def __len__(self):
"""
Computes number of examples in dataset.
Parameters
----------
None
Returns
-------
int
Number of examples in dataset.
"""
return len(self.labels)

def __getitem__(self, idx):
"""
Gets example (i.e. input and label) corresponding to "idx".
Parameters
----------
idx : int
Index of example to be returned.
Returns
-------
dict
Example corresponding to "idx".
"""
if self.transform:
img_inputs = utils.normalize_img(self.img_inputs[idx])
img_inputs = self.transform.run(img_inputs)
Expand All @@ -86,24 +231,66 @@ class Augmentator:
Applies augmentation to an image chunk.
"""

def __init__(self):
"""
Constructs an Augmentator object.
Parameters
----------
None
Returns
-------
None
"""
self.blur = tio.RandomBlur(std=(0, 0.4))
self.noise = tio.RandomNoise(std=(0, 0.03))
self.apply_geometric = tio.Compose(
{
#tio.RandomFlip(axes=(0, 1, 2)),
# tio.RandomFlip(axes=(0, 1, 2)),
tio.RandomAffine(
degrees=20, scales=(0.8, 1), image_interpolation="nearest"
)
}
)

def run(self, arr):
"""
Runs an image through the data augmentation pipeline.
Parameters
----------
arr : torch.array
Array that contains an image chunk.
Returns
-------
torch.array
Transformed array after being run through augmentation pipeline.
"""
arr = self.blur(arr)
arr = self.noise(arr)
arr = self.apply_geometric(arr)
return arr


def reformat(x):
return np.expand_dims(x, axis=1).astype(np.float32)
def reformat(arr):
"""
Reformats a label vector for training by adding a dimension and casting it
to float32.
Parameters
----------
arr : numpy.array
Label vector.
Returns
-------
numpy.arr
Reformatted label vector.
"""
return np.expand_dims(arr, axis=1).astype(np.float32)
Loading

0 comments on commit 54b3654

Please sign in to comment.