-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
major upd : improve ground truth generation
- Loading branch information
anna-grim
committed
Dec 2, 2023
1 parent
acd0825
commit 19028a9
Showing
9 changed files
with
345 additions
and
269 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,35 +4,68 @@ | |
@author: Anna Grim | ||
@email: [email protected] | ||
Class of graphs that are built from swc files. | ||
Class of graphs built from swc files. Each swc file is stored as a distinct | ||
graph and each node in this graph. | ||
""" | ||
|
||
import os | ||
|
||
import networkx as nx | ||
import numpy as np | ||
from more_itertools import zip_broadcast | ||
from scipy.spatial import KDTree | ||
|
||
from deep_neurographs import swc_utils, utils | ||
from deep_neurographs.geometry_utils import dist | ||
from deep_neurographs.geometry_utils import dist as get_dist | ||
|
||
|
||
class DenseGraph: | ||
""" | ||
Class of graphs built from swc files. Each swc file is stored as a | ||
distinct graph and each node in this graph. | ||
""" | ||
|
||
def __init__(self, swc_dir): | ||
""" | ||
Constructs a DenseGraph object from a directory of swc files. | ||
Parameters | ||
---------- | ||
swc_dir : path | ||
Path to directory of swc files which are used to construct a hash | ||
table in which the entries are filename-graph pairs. | ||
Returns | ||
------- | ||
None | ||
""" | ||
self.xyz_to_node = dict() | ||
self.xyz_to_swc = dict() | ||
self.init_graphs(swc_dir) | ||
self.init_kdtree() | ||
|
||
def init_graphs(self, swc_dir): | ||
""" | ||
Initializes graphs by reading swc files in "swc_dir". Graphs are | ||
stored in a hash table where the entries are filename-graph pairs. | ||
Parameters | ||
---------- | ||
swc_dir : path | ||
Path to directory of swc files which are used to construct a hash | ||
table in which the entries are filename-graph pairs. | ||
Returns | ||
------- | ||
None | ||
""" | ||
self.graphs = dict() | ||
for f in utils.listdir(swc_dir, ext=".swc"): | ||
# Extract info | ||
path = os.path.join(swc_dir, f) | ||
|
||
# Construct Graph | ||
path = os.path.join(swc_dir, f) | ||
swc_dict = swc_utils.parse(swc_utils.read_swc(path)) | ||
graph, xyz_to_node = swc_utils.file_to_graph( | ||
swc_dict, set_attrs=True, return_dict=True | ||
|
@@ -45,49 +78,147 @@ def init_graphs(self, swc_dir): | |
self.xyz_to_swc.update(xyz_to_id) | ||
|
||
def init_kdtree(self): | ||
""" | ||
Initializes KDTree from all xyz coordinates contained in all | ||
graphs in "self.graphs". | ||
Parameters | ||
---------- | ||
None | ||
Returns | ||
------- | ||
None | ||
""" | ||
self.kdtree = KDTree(list(self.xyz_to_swc.keys())) | ||
|
||
def get_projection(self, xyz): | ||
_, idx = self.kdtree.query(xyz, k=1) | ||
proj_xyz = tuple(self.kdtree.data[idx]) | ||
proj_dist = dist(proj_xyz, xyz) | ||
return proj_xyz, proj_dist | ||
def query_kdtree(self, xyz): | ||
""" | ||
Queries "self.kdtree" for the nearest neighbor of "xyz". | ||
Parameters | ||
---------- | ||
xyz : tuple[float] | ||
Coordinate to be queried. | ||
def connect_nodes(self, graph_id, xyz_i, xyz_j, return_dist=True): | ||
i = self.xyz_to_node[graph_id][xyz_i] | ||
j = self.xyz_to_node[graph_id][xyz_j] | ||
Returns | ||
------- | ||
tuple[float] | ||
Result of query. | ||
""" | ||
_, idx = self.kdtree.query(xyz, k=1) | ||
return tuple(self.kdtree.data[idx]) | ||
|
||
def is_connected(self, xyz_1, xyz_2): | ||
""" | ||
Determines whether the points "xyz_1" and "xyz_2" belong to the same | ||
swc file (i.e. graph). | ||
Parameters | ||
---------- | ||
xyz_1 : tuple[float] | ||
Coordinate contained in some graph in "self.graph". | ||
xyz_2 : tuple[float] | ||
Coordinate contained in some graph in "self.graph". | ||
Returns | ||
------- | ||
bool | ||
Indication of whether "xyz_1" and "xyz_2" belong to the same swc | ||
file (i.e. graph). | ||
""" | ||
swc_identical = self.xyz_to_swc[xyz_1] == self.xyz_to_swc[xyz_2] | ||
return True if swc_identical else False | ||
|
||
def connect_nodes(self, xyz_1, xyz_2): | ||
""" | ||
Finds path connecting two points that belong to some graph in | ||
"self.graph". | ||
Parameters | ||
---------- | ||
xyz_1 : tuple[float] | ||
Source of path. | ||
xyz_2 : tuple[float] | ||
Target of path. | ||
Returns | ||
------- | ||
list[int] | ||
Path of nodes connecting source and target. | ||
float | ||
Length of path with respect to l2-metric. | ||
""" | ||
graph_id = self.xyz_to_swc[xyz_1] | ||
i = self.xyz_to_node[graph_id][xyz_1] | ||
j = self.xyz_to_node[graph_id][xyz_2] | ||
path = nx.shortest_path(self.graphs[graph_id], source=i, target=j) | ||
if return_dist: | ||
dist = self.compute_dist(graph_id, path) | ||
return path, dist | ||
else: | ||
return path | ||
|
||
def compute_dist(self, graph_id, path): | ||
d = 0 | ||
return path, self.path_length(graph_id, path) | ||
|
||
def path_length(self, graph_id, path): | ||
""" | ||
Computes length of path with respect to the l2-metric. | ||
Parameters | ||
---------- | ||
graph_id : str | ||
ID of graph that path belongs to. | ||
path : list[int] | ||
List of nodes that form a path. | ||
Returns | ||
------- | ||
float | ||
Length of path with respect to l2-metrics. | ||
""" | ||
path_length = 0 | ||
for i in range(1, len(path)): | ||
xyz_1 = self.graphs[graph_id].nodes[i]["xyz"] | ||
xyz_2 = self.graphs[graph_id].nodes[i - 1]["xyz"] | ||
d += dist(xyz_1, xyz_2) | ||
return d | ||
|
||
def check_aligned(self, pred_xyz_i, pred_xyz_j): | ||
# Get target graph | ||
xyz_i, _ = self.get_projection(pred_xyz_i) | ||
xyz_j, _ = self.get_projection(pred_xyz_j) | ||
graph_id = self.xyz_to_swc[xyz_i] | ||
if self.xyz_to_swc[xyz_i] != self.xyz_to_swc[xyz_j]: | ||
return False | ||
|
||
# Compute distances | ||
pred_xyz_i = np.array(pred_xyz_i) | ||
pred_xyz_j = np.array(pred_xyz_j) | ||
pred_dist = dist(pred_xyz_i, pred_xyz_j) | ||
|
||
# Check criteria | ||
target_path, target_dist = self.connect_nodes(graph_id, xyz_i, xyz_j) | ||
ratio = min(pred_dist, target_dist) / max(pred_dist, target_dist) | ||
if ratio < 0.5 and pred_dist > 10: | ||
return False | ||
else: | ||
return True | ||
path_length += get_dist(xyz_1, xyz_2) | ||
return path_length | ||
|
||
def is_aligned(self, xyz_1, xyz_2, ratio_threshold=0.5, exclude=10.0): | ||
""" | ||
Determines whether the edge proposal corresponding to "xyz_1" and | ||
"xyz_2" is aligned to the ground truth. This is determined by checking | ||
two conditions: (1) connectedness and (2) distance ratio. For (1), we | ||
project "xyz_1" and "xyz_2" onto "self.graph", then verify that they | ||
project to the same graph. For (2), we compute the ratio between the | ||
Euclidean distance "dist" from "xyz_1" to "xyz" and the path length | ||
between the corresponding projections. This ratio can be skewed if | ||
"dist" is small, so we skip this criteria if "dist" < "exclude". | ||
Parameters | ||
---------- | ||
xyz_1 : numpy.array | ||
Endpoint of edge proposal. | ||
xyz_2 : numpy.array | ||
Endpoint of edge proposal. | ||
ratio_threshold : float | ||
Lower bound on threshold used to compare similarity between "dist" | ||
and "path length". | ||
exclude : float | ||
Upper bound on threshold to ignore criteria 1. | ||
Returns | ||
------- | ||
bool | ||
Indication of whether edge proposal is aligned to ground truth. | ||
""" | ||
hat_xyz_1 = self.query_kdtree(xyz_1) | ||
hat_xyz_2 = self.query_kdtree(xyz_2) | ||
if self.is_connected(hat_xyz_1, hat_xyz_2): | ||
dist = get_dist(xyz_1, xyz_2) | ||
_, path_length = self.connect_nodes(hat_xyz_1, hat_xyz_2) | ||
ratio = min(dist, path_length) / max(dist, path_length) | ||
if ratio > ratio_threshold and dist > exclude: | ||
return True | ||
elif dist <= exclude: | ||
return True | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.