Skip to content

Commit

Permalink
major upd : improve ground truth generation
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Dec 2, 2023
1 parent acd0825 commit 19028a9
Show file tree
Hide file tree
Showing 9 changed files with 345 additions and 269 deletions.
4 changes: 1 addition & 3 deletions src/deep_neurographs/deep_learning/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import torchio as tio
from torch.utils.data import Dataset

from deep_neurographs import utils


# Custom datasets
class ProposalDataset(Dataset):
Expand Down Expand Up @@ -250,7 +248,7 @@ def __init__(self):
tio.RandomFlip(axes=(0, 1, 2)),
tio.RandomAffine(
degrees=20, scales=(0.8, 1), image_interpolation="nearest"
)
),
]
)

Expand Down
15 changes: 6 additions & 9 deletions src/deep_neurographs/deep_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.profilers import PyTorchProfiler
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier
from torch.utils.data import DataLoader
from torch.nn.functional import sigmoid
from torch.utils.data import DataLoader
from torcheval.metrics.functional import (
binary_accuracy,
binary_f1_score,
binary_precision,
binary_recall,
Expand Down Expand Up @@ -113,9 +112,7 @@ def train_network(

# Configure trainer
model = LitNeuralNet(net=net, lr=lr)
ckpt_callback = ModelCheckpoint(
save_top_k=1, monitor="val_f1", mode="max"
)
ckpt_callback = ModelCheckpoint(save_top_k=1, monitor="val_f1", mode="max")
profiler = PyTorchProfiler() if profile else None

# Fit model
Expand All @@ -131,7 +128,7 @@ def train_network(
profiler=profiler,
)
trainer.fit(model, train_loader, valid_loader)

# Return best model
ckpt = torch.load(ckpt_callback.best_model_path)
model.net.load_state_dict(ckpt["state_dict"])
Expand All @@ -158,7 +155,7 @@ def __init__(self, net=None, lr=10e-3):
super().__init__()
self.net = net
self.lr = lr

def forward(self, batch):
x = self.get_example(batch, "inputs")
return self.net(x)
Expand Down Expand Up @@ -192,5 +189,5 @@ def compute_stats(self, y_hat, y, prefix=""):
def get_example(self, batch, key):
return batch[key]

def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.net.state_dict(destination, prefix + '', keep_vars)
def state_dict(self, destination=None, prefix="", keep_vars=False):
return self.net.state_dict(destination, prefix + "", keep_vars)
221 changes: 176 additions & 45 deletions src/deep_neurographs/densegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,68 @@
@author: Anna Grim
@email: [email protected]
Class of graphs that are built from swc files.
Class of graphs built from swc files. Each swc file is stored as a distinct
graph and each node in this graph.
"""

import os

import networkx as nx
import numpy as np
from more_itertools import zip_broadcast
from scipy.spatial import KDTree

from deep_neurographs import swc_utils, utils
from deep_neurographs.geometry_utils import dist
from deep_neurographs.geometry_utils import dist as get_dist


class DenseGraph:
"""
Class of graphs built from swc files. Each swc file is stored as a
distinct graph and each node in this graph.
"""

def __init__(self, swc_dir):
"""
Constructs a DenseGraph object from a directory of swc files.
Parameters
----------
swc_dir : path
Path to directory of swc files which are used to construct a hash
table in which the entries are filename-graph pairs.
Returns
-------
None
"""
self.xyz_to_node = dict()
self.xyz_to_swc = dict()
self.init_graphs(swc_dir)
self.init_kdtree()

def init_graphs(self, swc_dir):
"""
Initializes graphs by reading swc files in "swc_dir". Graphs are
stored in a hash table where the entries are filename-graph pairs.
Parameters
----------
swc_dir : path
Path to directory of swc files which are used to construct a hash
table in which the entries are filename-graph pairs.
Returns
-------
None
"""
self.graphs = dict()
for f in utils.listdir(swc_dir, ext=".swc"):
# Extract info
path = os.path.join(swc_dir, f)

# Construct Graph
path = os.path.join(swc_dir, f)
swc_dict = swc_utils.parse(swc_utils.read_swc(path))
graph, xyz_to_node = swc_utils.file_to_graph(
swc_dict, set_attrs=True, return_dict=True
Expand All @@ -45,49 +78,147 @@ def init_graphs(self, swc_dir):
self.xyz_to_swc.update(xyz_to_id)

def init_kdtree(self):
"""
Initializes KDTree from all xyz coordinates contained in all
graphs in "self.graphs".
Parameters
----------
None
Returns
-------
None
"""
self.kdtree = KDTree(list(self.xyz_to_swc.keys()))

def get_projection(self, xyz):
_, idx = self.kdtree.query(xyz, k=1)
proj_xyz = tuple(self.kdtree.data[idx])
proj_dist = dist(proj_xyz, xyz)
return proj_xyz, proj_dist
def query_kdtree(self, xyz):
"""
Queries "self.kdtree" for the nearest neighbor of "xyz".
Parameters
----------
xyz : tuple[float]
Coordinate to be queried.
def connect_nodes(self, graph_id, xyz_i, xyz_j, return_dist=True):
i = self.xyz_to_node[graph_id][xyz_i]
j = self.xyz_to_node[graph_id][xyz_j]
Returns
-------
tuple[float]
Result of query.
"""
_, idx = self.kdtree.query(xyz, k=1)
return tuple(self.kdtree.data[idx])

def is_connected(self, xyz_1, xyz_2):
"""
Determines whether the points "xyz_1" and "xyz_2" belong to the same
swc file (i.e. graph).
Parameters
----------
xyz_1 : tuple[float]
Coordinate contained in some graph in "self.graph".
xyz_2 : tuple[float]
Coordinate contained in some graph in "self.graph".
Returns
-------
bool
Indication of whether "xyz_1" and "xyz_2" belong to the same swc
file (i.e. graph).
"""
swc_identical = self.xyz_to_swc[xyz_1] == self.xyz_to_swc[xyz_2]
return True if swc_identical else False

def connect_nodes(self, xyz_1, xyz_2):
"""
Finds path connecting two points that belong to some graph in
"self.graph".
Parameters
----------
xyz_1 : tuple[float]
Source of path.
xyz_2 : tuple[float]
Target of path.
Returns
-------
list[int]
Path of nodes connecting source and target.
float
Length of path with respect to l2-metric.
"""
graph_id = self.xyz_to_swc[xyz_1]
i = self.xyz_to_node[graph_id][xyz_1]
j = self.xyz_to_node[graph_id][xyz_2]
path = nx.shortest_path(self.graphs[graph_id], source=i, target=j)
if return_dist:
dist = self.compute_dist(graph_id, path)
return path, dist
else:
return path

def compute_dist(self, graph_id, path):
d = 0
return path, self.path_length(graph_id, path)

def path_length(self, graph_id, path):
"""
Computes length of path with respect to the l2-metric.
Parameters
----------
graph_id : str
ID of graph that path belongs to.
path : list[int]
List of nodes that form a path.
Returns
-------
float
Length of path with respect to l2-metrics.
"""
path_length = 0
for i in range(1, len(path)):
xyz_1 = self.graphs[graph_id].nodes[i]["xyz"]
xyz_2 = self.graphs[graph_id].nodes[i - 1]["xyz"]
d += dist(xyz_1, xyz_2)
return d

def check_aligned(self, pred_xyz_i, pred_xyz_j):
# Get target graph
xyz_i, _ = self.get_projection(pred_xyz_i)
xyz_j, _ = self.get_projection(pred_xyz_j)
graph_id = self.xyz_to_swc[xyz_i]
if self.xyz_to_swc[xyz_i] != self.xyz_to_swc[xyz_j]:
return False

# Compute distances
pred_xyz_i = np.array(pred_xyz_i)
pred_xyz_j = np.array(pred_xyz_j)
pred_dist = dist(pred_xyz_i, pred_xyz_j)

# Check criteria
target_path, target_dist = self.connect_nodes(graph_id, xyz_i, xyz_j)
ratio = min(pred_dist, target_dist) / max(pred_dist, target_dist)
if ratio < 0.5 and pred_dist > 10:
return False
else:
return True
path_length += get_dist(xyz_1, xyz_2)
return path_length

def is_aligned(self, xyz_1, xyz_2, ratio_threshold=0.5, exclude=10.0):
"""
Determines whether the edge proposal corresponding to "xyz_1" and
"xyz_2" is aligned to the ground truth. This is determined by checking
two conditions: (1) connectedness and (2) distance ratio. For (1), we
project "xyz_1" and "xyz_2" onto "self.graph", then verify that they
project to the same graph. For (2), we compute the ratio between the
Euclidean distance "dist" from "xyz_1" to "xyz" and the path length
between the corresponding projections. This ratio can be skewed if
"dist" is small, so we skip this criteria if "dist" < "exclude".
Parameters
----------
xyz_1 : numpy.array
Endpoint of edge proposal.
xyz_2 : numpy.array
Endpoint of edge proposal.
ratio_threshold : float
Lower bound on threshold used to compare similarity between "dist"
and "path length".
exclude : float
Upper bound on threshold to ignore criteria 1.
Returns
-------
bool
Indication of whether edge proposal is aligned to ground truth.
"""
hat_xyz_1 = self.query_kdtree(xyz_1)
hat_xyz_2 = self.query_kdtree(xyz_2)
if self.is_connected(hat_xyz_1, hat_xyz_2):
dist = get_dist(xyz_1, xyz_2)
_, path_length = self.connect_nodes(hat_xyz_1, hat_xyz_2)
ratio = min(dist, path_length) / max(dist, path_length)
if ratio > ratio_threshold and dist > exclude:
return True
elif dist <= exclude:
return True
return False
5 changes: 2 additions & 3 deletions src/deep_neurographs/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ def generate_mutable_img_chunk_features(
img, labels = utils.get_superchunks(
img_path, labels_path, origin, neurograph.shape, from_center=False
)

#img = utils.normalize_img(img)
img = utils.normalize_img(img)
for edge in neurograph.mutable_edges:
# Compute image coordinates
i, j = tuple(edge)
Expand All @@ -88,7 +87,7 @@ def generate_mutable_img_chunk_features(
# Mark path
d = int(geometry_utils.dist(xyz_i, xyz_j) + 5)
img_coords_i = np.round(xyz_i - midpoint + HALF_CHUNK_SIZE).astype(int)
img_coords_j = np.round(xyz_j - midpoint + HALF_CHUNK_SIZE).astype(int)
img_coords_j = np.round(xyz_j - midpoint + HALF_CHUNK_SIZE).astype(int)
path = geometry_utils.make_line(img_coords_i, img_coords_j, d)

img_chunk = utils.normalize_img(img_chunk)
Expand Down
Loading

0 comments on commit 19028a9

Please sign in to comment.