-
Notifications
You must be signed in to change notification settings - Fork 1
/
scorer.py
executable file
·156 lines (129 loc) · 6.25 KB
/
scorer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""The Official CONLL 2016 Shared Task Scorer
"""
import argparse
import json
from confusion_matrix import ConfusionMatrix, Alphabet
import validator
def evaluate(gold_list, predicted_list):
sense_cm = evaluate_sense(gold_list, predicted_list)
print('Sense classification--------------')
sense_cm.print_summary()
print('Overall parser performance --------------')
precision, recall, f1 = sense_cm.compute_micro_average_f1()
print('Precision %1.4f Recall %1.4f F1 %1.4f' % (precision, recall, f1))
#return connective_cm, arg1_cm, arg2_cm, rel_arg_cm, sense_cm, precision, recall, f1
return sense_cm, precision, recall, f1
def spans_exact_matching(gold_doc_id_spans, predicted_doc_id_spans):
"""Matching two lists of spans
Input:
gold_doc_id_spans : (DocID , a list of lists of tuples of token addresses)
predicted_doc_id_spans : (DocID , a list of lists of token indices)
Returns:
True if the spans match exactly
"""
exact_match = True
gold_docID = gold_doc_id_spans[0]
gold_spans = gold_doc_id_spans[1]
predicted_docID = predicted_doc_id_spans[0]
predicted_spans = predicted_doc_id_spans[1]
for gold_span, predicted_span in zip(gold_spans, predicted_spans):
exact_match = span_exact_matching((gold_docID,gold_span), (predicted_docID, predicted_span)) \
and exact_match
return exact_match
def span_exact_matching(gold_span, predicted_span):
"""Matching two spans
Input:
gold_span : a list of tuples :(DocID, list of tuples of token addresses)
predicted_span : a list of tuples :(DocID, list of token indices)
Returns:
True if the spans match exactly
"""
gold_docID = gold_span[0]
predicted_docID = predicted_span[0]
if gold_docID != predicted_docID:
return False
gold_token_indices = [x[2] for x in gold_span[1]]
predicted_token_indices = predicted_span[1]
return gold_docID == predicted_docID and gold_token_indices == predicted_token_indices
def evaluate_sense(gold_list, predicted_list):
"""Evaluate sense classifier
The label ConfusionMatrix.NEGATIVE_CLASS is for the relations
that are missed by the system
because the arguments don't match any of the gold relations.
"""
sense_alphabet = Alphabet()
valid_senses = validator.identify_valid_senses(gold_list)
for relation in gold_list:
sense = relation['Sense'][0]
if sense in valid_senses:
sense_alphabet.add(sense)
sense_alphabet.add(ConfusionMatrix.NEGATIVE_CLASS)
sense_cm = ConfusionMatrix(sense_alphabet)
gold_to_predicted_map, predicted_to_gold_map = \
_link_gold_predicted(gold_list, predicted_list, spans_exact_matching)
for i, gold_relation in enumerate(gold_list):
gold_sense = gold_relation['Sense'][0]
if gold_sense in valid_senses:
if i in gold_to_predicted_map:
predicted_sense = gold_to_predicted_map[i]['Sense'][0]
if predicted_sense in gold_relation['Sense']:
sense_cm.add(predicted_sense, predicted_sense)
else:
if not sense_cm.alphabet.has_label(predicted_sense):
predicted_sense = ConfusionMatrix.NEGATIVE_CLASS
sense_cm.add(predicted_sense, gold_sense)
else:
sense_cm.add(ConfusionMatrix.NEGATIVE_CLASS, gold_sense)
for i, predicted_relation in enumerate(predicted_list):
if i not in predicted_to_gold_map:
predicted_sense = predicted_relation['Sense'][0]
if not sense_cm.alphabet.has_label(predicted_sense):
predicted_sense = ConfusionMatrix.NEGATIVE_CLASS
sense_cm.add(predicted_sense, ConfusionMatrix.NEGATIVE_CLASS)
return sense_cm
def _link_gold_predicted(gold_list, predicted_list, matching_fn):
"""Link gold standard relations to the predicted relations
A pair of relations are linked when the arg1 and the arg2 match exactly.
We do this because we want to evaluate sense classification later.
Returns:
A tuple of two dictionaries:
1) mapping from gold relation index to predicted relation index
2) mapping from predicted relation index to gold relation index
"""
gold_to_predicted_map = {}
predicted_to_gold_map = {}
gold_arg12_list = [(x['DocID'], (x['Arg1']['TokenList'], x['Arg2']['TokenList']))
for x in gold_list]
predicted_arg12_list = [(x['DocID'], (x['Arg1']['TokenList'], x['Arg2']['TokenList']))
for x in predicted_list]
for gi, gold_span in enumerate(gold_arg12_list):
for pi, predicted_span in enumerate(predicted_arg12_list):
if matching_fn(gold_span, predicted_span):
gold_to_predicted_map[gi] = predicted_list[pi]
predicted_to_gold_map[pi] = gold_list[gi]
return gold_to_predicted_map, predicted_to_gold_map
def main():
parser = argparse.ArgumentParser(
description="Evaluate system's output against the gold standard")
parser.add_argument('gold', help='Gold standard file')
parser.add_argument('predicted', help='System output file')
args = parser.parse_args()
gold_list = [json.loads(x) for x in open(args.gold)]
predicted_list = [json.loads(x) for x in open(args.predicted)]
print('\n================================================')
print('Evaluation for all discourse relations')
evaluate(gold_list, predicted_list)
print('\n================================================')
print('Evaluation for explicit discourse relations only')
explicit_gold_list = [x for x in gold_list if x['Type'] == 'Explicit']
explicit_predicted_list = [x for x in predicted_list if x['Type'] == 'Explicit']
evaluate(explicit_gold_list, explicit_predicted_list)
print('\n================================================')
print('Evaluation for non-explicit discourse relations only (Implicit, EntRel, AltLex)')
non_explicit_gold_list = [x for x in gold_list if x['Type'] != 'Explicit']
non_explicit_predicted_list = [x for x in predicted_list if x['Type'] != 'Explicit']
evaluate(non_explicit_gold_list, non_explicit_predicted_list)
if __name__ == '__main__':
main()