-
Notifications
You must be signed in to change notification settings - Fork 0
/
rosette_to_conll2003.py
executable file
·380 lines (340 loc) · 12.6 KB
/
rosette_to_conll2003.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
#!/usr/bin/env python3
"""Get Rosette API named entity results in CoNLL 2003-style BIO format"""
import csv
import os
import sys
from getpass import getpass
from rosette.api import API, DocumentParameters
# CoNLL 2003-style B(eginning) I(nside) O(utside) tags
B, I, O = 'B-{}', 'I-{}', 'O'
# CoNLL 2003 fields
CONLL2003 = [
'word-token',
'part-of-speech-tag',
'chunk-tag',
'named-entity-tag'
]
# Default Rosette API URL
DEFAULT_ROSETTE_API_URL = 'https://api.rosette.com/rest/v1/'
def extent(obj):
"""Get the start and end offset attributes of a dict-like object
a = {'startOffset': 0, 'endOffset': 5}
b = {'startOffset': 0, 'endOffset': 10}
c = {'startOffset': 5, 'endOffset': 10}
extent(a) -> (0, 5)
extent(b) -> (0, 10)
extent(c) -> (5, 10)
extent({}) -> (-1, -1)
"""
return obj.get('startOffset', -1), obj.get('endOffset', -1)
def overlaps(*objs):
"""Find character offsets that overlap between objects
a = {'startOffset': 0, 'endOffset': 5}
b = {'startOffset': 0, 'endOffset': 10}
c = {'startOffset': 5, 'endOffset': 10}
overlaps(a, b) -> {0, 1, 2, 3, 4}
bool(overlaps(a, b)) -> True
overlaps(b, c) -> {5, 6, 7, 8, 9}
bool(overlaps(b, c)) -> True
overlaps(a, c) -> set()
bool(overlaps(a, c)) -> False
"""
return set.intersection(*(set(range(*extent(obj))) for obj in objs))
# slice text according to UTF-16 character offests
def get_text(string, start, end, bom=True):
"""This method correctly accesses slices of strings using character
start/end offsets referring to UTF-16 encoded bytes. This allows
for using character offsets generated by Rosette (and other softwares)
that use UTF-16 native string representations under Pythons with UCS-4
support, such as Python 3.3+ (refer to https://www.python.org/dev/peps/pep-0393/).
The offsets are adjusted to account for a UTF-16 byte order mark (BOM)
(2 bytes) and also that each UTF-16 logical character consumes 2 bytes.
'character' in this context refers to logical characters for the purpose of
character offsets; an individual character can consume up to 4 bytes (32
bits for so-called 'wide' characters) and graphemes can consume even more.
"""
import codecs
if not isinstance(string, str):
raise ValueError('expected string to be of type str')
if not any(((start is None), isinstance(start, int))):
raise ValueError('expected start to be of type int or NoneType')
if not any(((end is None), isinstance(end, int))):
raise ValueError('expected end to be of type int or NoneType')
if start is not None:
start *= 2
if bom:
start += 2
if end is not None:
end *= 2
if bom:
end += 2
utf_16, _ = codecs.utf_16_encode(string)
sliced, _ = codecs.utf_16_decode(utf_16[start:end])
return sliced
def load_content(txtfile):
"""Load data from a plain-text file"""
with open(txtfile, mode='r') as f:
return f.read()
def get_entities(content, key, url, language=None):
"""Get Rosette API named entity results for the given content
This method gets results of the "entities" endpoint of the Rosette API.
The result is an A(nnotated) D(ata) M(odel) or ADM that is a Python dict
representing a document, annotations of the document content, and metadata.
content: the textual content of a document for the Rosette API to process
key: your Rosette user key
url: the URL of the Rosette API
language: an optional ISO 639-2 T language code (the Rosette API will
automatically detect the language of the content by default)
"""
api = API(user_key=key, service_url=url)
# Request result as Annotated Data Model (ADM)
api.setUrlParameter("output", "rosette")
parameters = DocumentParameters()
parameters['content'] = content
parameters['language'] = language
adm = api.entities(parameters)
return adm
def entity_mentions(adm):
"""Generate named entity mentions from an ADM (Annotated Data Model)
The ADM contains an "entities" attribute that groups mentions of the
same entity together in a "mentions" attribute per entity. Each entity has
a head mention index, an entity type, an entity identifier, a confidence
score, and a list of mentions. Each entity mention contains additional
information including its start and end character offsets referring to the
array of characters in the document content (i.e., the adm["data"]).
Consider an ADM with the following content:
adm["data"] == "New York City or NYC is the most populous city in the United States."
Then the "entities" attribute would be:
adm["attributes"]["entities"] == {
"items": [
{
"headMentionIndex": 0,
"mentions": [
{
"source": "gazetteer",
"subsource": "/data/roots/rex/data/gazetteer/eng/accept/gaz-LE.bin",
"normalized": "New York City",
"startOffset": 0,
"endOffset": 13
},
{
"source": "gazetteer",
"subsource": "/data/roots/rex/data/gazetteer/eng/accept/gaz-LE.bin",
"normalized": "NYC",
"startOffset": 17,
"endOffset": 20
}
],
"confidence": 0.501718114501715,
"type": "LOCATION",
"entityId": "Q60"
},
{
"headMentionIndex": 0,
"mentions": [
{
"source": "gazetteer",
"subsource": "/data/roots/rex/data/gazetteer/eng/accept/gaz-LE.bin",
"normalized": "United States",
"startOffset": 55,
"endOffset": 68
}
],
"confidence": 0.08375498050536179,
"type": "LOCATION",
"entityId": "Q30"
}
],
"type": "list",
"itemType": "entities"
}
This method generates a list of all named entity mentions augmented with
the named entity type of the the entity it refers to.
entity_mentions(adm) -> <generator object entity_mentions at 0xXXXXXXXXX>
Since the mentions are grouped by the entity they refer to, it is useful to
get a list of the mentions in the order they appear in the document. We can
do this by sorting the mentions by their extent, i.e., their start and end
character offsets:
sorted(entity_mentions(adm), key=extent) -> [
{
"source": "gazetteer",
"normalized": "New York City",
"startOffset": 0,
"endOffset": 13,
"type": "LOCATION",
"subsource": "/data/roots/rex/data/gazetteer/eng/accept/gaz-LE.bin"
},
{
"source": "gazetteer",
"normalized": "NYC",
"startOffset": 17,
"endOffset": 20,
"type": "LOCATION",
"subsource": "/data/roots/rex/data/gazetteer/eng/accept/gaz-LE.bin"
},
{
"source": "gazetteer",
"normalized": "United States",
"startOffset": 55,
"endOffset": 68,
"type": "LOCATION",
"subsource": "/data/roots/rex/data/gazetteer/eng/accept/gaz-LE.bin"
}
]
"""
for entity in adm['attributes']['entities']['items']:
for mention in entity['mentions']:
# Augment mentions with the entity type of the entity they refer to
mention['type'] = entity['type']
yield mention
def conll2003(adm, use_conll_ne_tags=True):
"""Generate CoNLL 2003-style named entity rows from a Rosette API result
Taking an example ADM:
adm["data"] == "New York City or NYC is the most populous city in the United States."
Then the output would be:
conll2003(adm) -> <generator object conll2003 at 0xXXXXXXXXX>
list(conll2003(adm)) ->
[
{
"chunk-tag": "",
"part-of-speech-tag": "",
"named-entity-tag": "B-LOC",
"word-token": "New"
},
{
"chunk-tag": "",
"part-of-speech-tag": "",
"named-entity-tag": "I-LOC",
"word-token": "York"
},
{
"chunk-tag": "",
"part-of-speech-tag": "",
"named-entity-tag": "I-LOC",
"word-token": "City"
},
...
{
"chunk-tag": "",
"part-of-speech-tag": "",
"named-entity-tag": "B-LOC",
"word-token": "United"
},
{
"chunk-tag": "",
"part-of-speech-tag": "",
"named-entity-tag": "I-LOC",
"word-token": "States"
},
{
"chunk-tag": "",
"part-of-speech-tag": "",
"named-entity-tag": "O",
"word-token": "."
}
]
"""
# Map Rosette named entity types to CoNLL 2003 named entity types
CONLL2003_NE_TYPES = {
'PERSON': 'PER',
'LOCATION': 'LOC',
'ORGANIZATION': 'ORG'
}
# Access the entity mentions, sentences, and tokens from the ADM
mentions = sorted(entity_mentions(adm), key=extent)
sentences = adm['attributes']['sentence']['items']
tokens = adm['attributes']['token']['items']
# Assign a CoNLL2003-style named entity tag to each token
for token in tokens:
sentence = sentences[0] if sentences else {}
mention = mentions[0] if mentions else {}
# Add empty rows between sentences
if min(extent(token)) == min(extent(sentence)):
yield {k : '' for k in CONLL2003}
sentences.pop(0)
bio = O
if min(extent(token)) == min(extent(mention)):
bio = B
elif overlaps(token, mention):
bio = I
if max(extent(token)) == max(extent(mention)):
mentions.pop(0)
if use_conll_ne_tags:
entity_type = CONLL2003_NE_TYPES.get(mention.get('type'), 'MISC')
else:
entity_type = mention.get('type')
start, end = extent(token)
yield {
'word-token': get_text(adm['data'], start, end),
'part-of-speech-tag': '', # we can't get this with a single API call
'chunk-tag': '', # Rosette doesn't currently do syntactic chunking
'named-entity-tag': bio.format(entity_type)
}
def main(adm, use_conll_ne_tags, delimiter):
"""Given an ADM, write CoNLL 2003-style named entity rows to stdout"""
writer = csv.DictWriter(
sys.stdout,
fieldnames=CONLL2003,
delimiter=delimiter
)
# document header row
writer.writerow({
'word-token': '-DOCSTART-',
'part-of-speech-tag': '-X-',
'chunk-tag': O,
'named-entity-tag': O
})
writer.writerows(conll2003(adm, use_conll_ne_tags))
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description=__doc__
)
parser.add_argument(
'input',
help='A plain-text document to process'
)
parser.add_argument(
'-k',
'--key',
help='Rosette API Key',
default=None
),
parser.add_argument(
'-u',
'--url',
help='Alternative API URL',
default=DEFAULT_ROSETTE_API_URL
)
parser.add_argument(
'-l',
'--language',
help=(
'A three-letter (ISO 639-2 T) code that will override automatic '
'language detection'),
default=None
)
parser.add_argument(
'-d', '--delimiter',
default=' ',
help='A delimiter to separate the token and BIO tag output columns'
)
parser.add_argument(
'--use-conll-ne-tags',
action='store_true',
help=(
"Use CoNLL 2003 named entity tags (instead of Rosette API's named "
'entity tags)'
)
)
args = parser.parse_args()
key = (
os.environ['ROSETTE_USER_KEY']
or args.key
or getpass(prompt='Enter your Rosette API key: ')
)
content = load_content(args.input)
# Get ADM (Annotated Data Model) results from Rosette API
adm = get_entities(content, key, args.url, args.language)
main(adm, args.use_conll_ne_tags, args.delimiter)