-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
anna-grim
committed
Dec 7, 2023
1 parent
78f495e
commit 8198f58
Showing
5 changed files
with
228 additions
and
136 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,12 +4,12 @@ | |
@author: Anna Grim | ||
@email: [email protected] | ||
Evaluates performance of edge classifier. | ||
Evaluates performance of edge classifiation model. | ||
""" | ||
import numpy as np | ||
|
||
STATS_LIST = [ | ||
METRICS_LIST = [ | ||
"precision", | ||
"recall", | ||
"f1", | ||
|
@@ -18,108 +18,193 @@ | |
] | ||
|
||
|
||
def run_evaluation( | ||
target_graphs, pred_graphs, y_pred, block_to_idxs, idx_to_edge, blocks | ||
): | ||
stats = dict([(s, []) for s in STATS_LIST]) | ||
stats_by_type = { | ||
"simple": dict([(s, []) for s in STATS_LIST]), | ||
"complex": dict([(s, []) for s in STATS_LIST]), | ||
def run_evaluation(neurographs, blocks, pred_edges): | ||
""" | ||
Runs an evaluation on the accuracy of the predictions generated by an edge | ||
classication model. | ||
Parameters | ||
---------- | ||
neurographs : list[NeuroGraph] | ||
Predicted neurographs. | ||
y_pred : numpy.ndarray | ||
Binary predictions of edges generated by classifcation model. | ||
blocks_to_idxs : dict | ||
Dictionary that stores which indices in "y_pred" correspond to edges | ||
in "neurographs[block_id]". | ||
idx_to_edge : dict | ||
Dictionary that stores the correspondence between an index from | ||
"y_pred" and edge from "neurographs[block_id]" for some "block_id". | ||
blocks : list[str] | ||
List of block_ids that indicate which predictions to evaluate. | ||
Returns | ||
------- | ||
stats : dict | ||
Dictionary that stores the accuracy of the edge classification model | ||
on all edges (i.e. "Overall"), simple edges, and complex edges. The | ||
metrics contained in this dictionary are identical to "METRICS_LIST"]. | ||
""" | ||
stats = { | ||
"Overall": dict([(metric, []) for metric in METRICS_LIST]), | ||
"Simple": dict([(metric, []) for metric in METRICS_LIST]), | ||
"Complex": dict([(metric, []) for metric in METRICS_LIST]), | ||
} | ||
print(blocks) | ||
for block_id in blocks: | ||
# Get predicted edges | ||
pred_edges = get_predictions( | ||
block_to_idxs[block_id], idx_to_edge, y_pred | ||
# Compute accuracy | ||
overall_stats_i = get_stats( | ||
neurographs[block_id], | ||
neurographs[block_id].mutable_edges, | ||
pred_edges[block_id] | ||
) | ||
|
||
# Overall performance | ||
num_fixes, num_mistakes = __reconstruction_stats( | ||
target_graphs[block_id], pred_graphs[block_id], pred_edges | ||
simple_stats_i = get_stats( | ||
neurographs[block_id], | ||
neurographs[block_id].get_simple_proposals(), | ||
pred_edges[block_id], | ||
) | ||
stats["# splits fixed"].append(num_fixes) | ||
stats["# merges created"].append(num_mistakes) | ||
|
||
# In-depth performance | ||
simple_stats, complex_stats = __reconstruction_type_stats( | ||
target_graphs[block_id], pred_graphs[block_id], pred_edges | ||
complex_stats_i = get_stats( | ||
neurographs[block_id], | ||
neurographs[block_id].get_complex_proposals(), | ||
pred_edges[block_id], | ||
) | ||
if True: | ||
print("simple stats:", simple_stats) | ||
print("complex stats:", complex_stats) | ||
print("") | ||
for key in STATS_LIST: | ||
stats_by_type["simple"][key].append(simple_stats[key]) | ||
stats_by_type["complex"][key].append(complex_stats[key]) | ||
return stats, stats_by_type | ||
|
||
# Store results | ||
for metric in METRICS_LIST: | ||
stats["Overall"][metric].append(overall_stats_i[metric]) | ||
stats["Simple"][metric].append(simple_stats_i[metric]) | ||
stats["Complex"][metric].append(complex_stats_i[metric]) | ||
|
||
return stats | ||
|
||
|
||
def get_predictions(idxs, idx_to_edge, y_pred): | ||
""" | ||
Gets edges that are predicted to be target edges for some "block_id". | ||
Parameters | ||
---------- | ||
idxs : set | ||
Indices of entries in "y_pred" that belong to a given block. | ||
idx_to_edge : dict | ||
Dictionary that stores the correspondence between an index from | ||
"y_pred" and edge from "neurographs[block_id]" for some "block_id". | ||
y_pred : numpy.ndarray | ||
Prediction of edge probabilities generated by classifcation model. | ||
Returns | ||
------- | ||
set | ||
Edges that are predicted to be target edges for some "block_id". | ||
""" | ||
edge_idxs = set(np.where(y_pred > 0)[0]).intersection(idxs) | ||
return set([idx_to_edge[idx] for idx in edge_idxs]) | ||
|
||
|
||
def __reconstruction_stats(target_graph, pred_graph, pred_edges): | ||
true_positives = 0 | ||
false_positives = 0 | ||
for edge in pred_edges: | ||
if edge in pred_graph.target_edges: | ||
true_positives += 1 | ||
else: | ||
false_positives += 1 | ||
return true_positives, false_positives | ||
|
||
|
||
def __reconstruction_type_stats(target_graph, pred_graph, pred_edges): | ||
simple_stats = dict([(s, 0) for s in STATS_LIST]) | ||
complex_stats = dict([(s, 0) for s in STATS_LIST]) | ||
for edge in pred_edges: | ||
i, j = tuple(edge) | ||
deg_i = pred_graph.immutable_degree(i) | ||
deg_j = pred_graph.immutable_degree(j) | ||
if edge in pred_graph.target_edges: | ||
if deg_i == 1 and deg_j == 1: | ||
simple_stats["# splits fixed"] += 1 | ||
else: | ||
complex_stats["# splits fixed"] += 1 | ||
else: | ||
if deg_i == 1 and deg_j == 1: | ||
simple_stats["# merges created"] += 1 | ||
else: | ||
complex_stats["# merges created"] += 1 | ||
|
||
num_simple, num_complex = compute_edge_type(pred_graph) | ||
simple_stats = compute_accuracy(simple_stats, num_simple) | ||
complex_stats = compute_accuracy(complex_stats, num_complex) | ||
|
||
if False: | ||
print("# simple edges:", num_simple) | ||
print("% simple edges:", num_simple / (num_complex + num_simple)) | ||
print("# complex edges:", num_complex) | ||
print("% complex edges:", num_complex / (num_complex + num_simple)) | ||
print("") | ||
return simple_stats, complex_stats | ||
|
||
|
||
def compute_edge_type(graph): | ||
num_simple = 0 | ||
num_complex = 0 | ||
for edge in graph.target_edges: | ||
i, j = tuple(edge) | ||
deg_i = graph.immutable_degree(i) | ||
deg_j = graph.immutable_degree(j) | ||
if deg_i == 1 and deg_j == 1: | ||
num_simple += 1 | ||
else: | ||
num_complex += 1 | ||
return num_simple, num_complex | ||
|
||
|
||
def compute_accuracy(stats, num_edges): | ||
d = stats["# merges created"] + stats["# splits fixed"] | ||
r = 1 if num_edges == 0 else stats["# splits fixed"] / num_edges | ||
p = 1 if d == 0 else stats["# splits fixed"] / d | ||
stats["f1"] = 0 if r + p == 0 else (2 * r * p) / (r + p) | ||
stats["precision"] = p | ||
stats["recall"] = r | ||
def get_stats(neurograph, proposals, pred_edges): | ||
""" | ||
Accuracy of the predictions generated by an edge classication model on a | ||
given block and "edge_type" (e.g. overall, simple, or complex). | ||
Parameters | ||
---------- | ||
neurograph : NeuroGraph | ||
Predicted neurograph | ||
proposals : set[frozenset] | ||
Set of edge proposals for a given "edge_type". | ||
y_pred : numpy.ndarray | ||
Binary predictions of edges generated by classifcation model. | ||
Returns | ||
------- | ||
dict | ||
Dictionary containing results of evaluation where the keys are | ||
"METRICS_LIST". | ||
""" | ||
tp, fp, p, r, f1 = get_accuracy(neurograph, proposals, pred_edges) | ||
stats = { | ||
"# splits fixed": tp, | ||
"# merges created": fp, | ||
"precision": p, | ||
"recall": r, | ||
"f1": f1, | ||
} | ||
return stats | ||
|
||
|
||
def get_accuracy(neurograph, proposals, pred_edges): | ||
""" | ||
Computes the following metrics for a given set of predicted edges: | ||
(1) true positives, (2) false positive, (3) precision, (4) recall, and | ||
(5) f1-score. | ||
Parameters | ||
---------- | ||
neurograph : NeuroGraph | ||
Predicted neurograph | ||
proposals : set[frozenset] | ||
Set of edge proposals for a given "edge_type". | ||
y_pred : numpy.ndarray | ||
Prediction of edge probabilities generated by classifcation model. | ||
Returns | ||
------- | ||
tp : float | ||
Number of true positives. | ||
fp : float | ||
Number of false positives. | ||
p : float | ||
Precision. | ||
r : float | ||
Recall. | ||
f1 : float | ||
F1-score. | ||
""" | ||
tp, fp, fn = get_accuracy_counts(neurograph, proposals, pred_edges) | ||
p = 1 if tp + fp == 0 else tp / (tp + fp) | ||
r = 1 if tp + fn == 0 else tp / (tp + fn) | ||
f1 = (2 * r * p) / max(r + p, 1e-3) | ||
return tp, fp, p, r, f1 | ||
|
||
|
||
def get_accuracy_counts(neurograph, proposals, pred_edges): | ||
""" | ||
Computes the following values: (1) true positives, (2) false positive, and | ||
(3) false negatives. | ||
Parameters | ||
---------- | ||
neurograph : NeuroGraph | ||
Predicted neurograph | ||
proposals : set[frozenset] | ||
Set of edge proposals for a given "edge_type". | ||
y_pred : numpy.ndarray | ||
Prediction of edge probabilities generated by classifcation model. | ||
Returns | ||
------- | ||
tp : float | ||
Number of true positives. | ||
fp : float | ||
Number of false positives. | ||
fn : float | ||
Number of false negatives. | ||
""" | ||
tp = 0 | ||
fp = 0 | ||
fn = 0 | ||
for edge in proposals: | ||
if edge in neurograph.target_edges: | ||
if edge in pred_edges: | ||
tp += 1 | ||
else: | ||
fn += 1 | ||
elif edge in pred_edges: | ||
fp += 1 | ||
return tp, fp, fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.