forked from suamin/ICD-BERT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
read_data.py
322 lines (267 loc) · 11.8 KB
/
read_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
# -*- coding: utf-8 -*-
import os
import logging
import spacy
import nltk
import pandas as pd
import pickle as pkl
from concurrent.futures import ProcessPoolExecutor
from sklearn.preprocessing import MultiLabelBinarizer
from collections import Counter
from parse_icd10 import ICD10Hierarchy
from ext.CharSplit import char_split
from google.cloud import translate
class DataReader:
def __init__(self):
self.data_dir = "data"
def read_doc(self, fname):
with open(fname, mode="r", encoding="utf-8", errors="ignore") as rf:
data = list()
for line in rf:
line = line.strip()
if line:
data.append(line)
return data
def read_docs(self, train_or_test):
if train_or_test == "train":
docs_dir = os.path.join(self.data_dir, "nts-icd", "docs-training")
else:
# test set not released yet
docs_dir = os.path.join(self.data_dir, "nts-icd", "test", "docs")
incompl_count = 0
for datafile in os.listdir(docs_dir):
if datafile == "id.txt":
continue
# filename is uid
doc_id = int(datafile.rstrip(".txt"))
data = self.read_doc(os.path.join(docs_dir, datafile))
# sanity check: each file must have 6 lines of text as in README
if len(data) != 6:
incompl_count += 1
# use special token to recover each text field (if needed)
data = "<SECTION>".join(data)
yield doc_id, data
# report if incompletes
if incompl_count > 0:
print("[INFO] %d docs do not have 6 text lines" % incompl_count)
def read_anns(self):
anns_file = os.path.join(self.data_dir, "nts-icd", "anns_train_dev.txt")
with open(anns_file) as rf:
for line in rf:
line = line.strip()
if line:
doc_id, icd10_codes = line.split("\t")
# sanity check: remove any duplicates, if there
yield int(doc_id), set(icd10_codes.split("|"))
def read_ids(self, ids_file):
ids_file = os.path.join(self.data_dir, "nts-icd", ids_file)
ids = set()
with open(ids_file, "r") as rf:
for line in rf:
line = line.strip()
if line:
if line == "id": # line 242 in train ids
continue
ids.add(int(line))
return ids
def read_data(self, train_test):
def read(docs, anns, ids):
id2anns = {a[0]:a[1] for a in anns if a[0] in ids}
# list of tuple (doc text, doc id, set of annotations)
data = [(d[1], d[0], id2anns[d[0]]) for d in docs if d[0] in id2anns]
return data
if train_test == "train":
# load training-dev common data
docs_train_dev = list(self.read_docs("train"))
anns_train_dev = list(self.read_anns())
print("[INFO] num of annotations in `anns_train_dev.txt`: %d" % len(anns_train_dev))
# train data
ids_train = self.read_ids("ids_training.txt")
data_train = read(docs_train_dev, anns_train_dev, ids_train)
# dev data
ids_dev = self.read_ids("ids_development.txt")
data_dev = read(docs_train_dev, anns_train_dev, ids_dev)
return data_train, data_dev
else:
# load test docs and annotations
data_test = list(self.read_docs("test"))
return data_test
class TextProcessor:
def __init__(self, en_translate=False, split_compound=False):
# spacy word tokenizers
self.word_tokenizers = {
"de": spacy.load('de_core_news_sm', disable=['tagger', 'parser', 'ner']).tokenizer,
"en": spacy.load('en_core_web_sm', disable=['tagger', 'parser', 'ner']).tokenizer
}
# nltk sent tokenizers
self.sent_tokenizers = {
"de": nltk.data.load('tokenizers/punkt/german.pickle').tokenize,
"en": nltk.data.load('tokenizers/punkt/english.pickle').tokenize
}
# special tokens
self.sent_sep_tok = "<SENT>"
self.section_sep_tok = "<SECTION>"
# google translator
self.en_translate = en_translate
if en_translate:
self.translate_client = translate.Client()
self.split_compound = split_compound
def process_doc(self, doc):
doc = doc.split(self.section_sep_tok) # returns each section
doc_de = list()
if self.en_translate:
doc_en = list()
for textfield in doc:
sents_de = list(self.sents_tokenize(textfield, "de"))
sents_tokens_de = list()
for sent in sents_de:
tokenized_text = " ".join(list(self.words_tokenize(sent, "de")))
sents_tokens_de.append(tokenized_text)
sents_tokens_de = self.sent_sep_tok.join(sents_tokens_de)
doc_de.append(sents_tokens_de)
if self.en_translate:
sents_en = list(self.sents_tokenize(self.translate(textfield), "en"))
sents_tokens_en = list()
for sent in sents_en:
tokenized_text = " ".join(list(self.words_tokenize(sent, "en")))
sents_tokens_en.append(tokenized_text)
sents_tokens_en = self.sent_sep_tok.join(sents_tokens_en)
doc_en.append(sents_tokens_en)
doc_de = self.section_sep_tok.join(doc_de)
if self.en_translate:
doc_en = self.section_sep_tok.join(doc_en)
return doc, doc_de, doc_en
return doc, doc_de
def sents_tokenize(self, text, lang):
for sent in self.sent_tokenizers[lang](text):
sent = sent.strip()
if sent:
yield sent
def words_tokenize(self, text, lang):
for token in self.word_tokenizers[lang](text):
token = token.text.strip()
if token:
yield token
def translate(self, text):
return self.translate_client.translate(text, target_language="en")['translatedText']
@staticmethod
def de_compounds_split(word, t=0.8):
res = char_split.split_compound(word)[0]
if res[0] >= t:
return res[1:]
else:
return word
def process_with_context(self, text_and_context):
text = text_and_context[0]
context = text_and_context[1:]
return tuple([self.process_doc(text)] + list(context))
def mp_process(self, data, max_workers=8, chunksize=512):
"""
data : tup(doc id, doc text, labels list)
"""
ret = list()
if max_workers <= 1:
for idx, item in enumerate(data):
if idx % 100 == 0 and idx != 0:
print("[INFO]: {} documents proceesed".format(idx))
ret.append(self.process_with_context(item))
else:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
emap = executor.map(self.process_with_context, data, chunksize=chunksize)
for idx, result in enumerate(emap):
if idx % 100 == 0 and idx != 0:
print("[INFO]: {} documents proceesed".format(idx))
ret.append(result)
return ret
def save(fname, data):
with open(fname, "wb") as wf:
pkl.dump(data, wf)
def prepare_processed():
# translation takes time and translated texts already in tmp/ dir
TRANSLATE = False
# uncomment if translation is required
if TRANSLATE:
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = "/home/mlt/saad/tmp/translate-de-to-en-eacaff80b066.json"
os.makedirs("tmp", exist_ok=True)
dr = DataReader()
train_data, dev_data = dr.read_data("train")
tp = TextProcessor(translate=TRANSLATE)
# each processed record is:
# [(doc_orig, doc_de, doc_en), doc_id, {doc_labels}]
# where:
# doc_orig : is "<SECTION>" separated original document in German
# doc_de : is "<SECTION" separated, where each section is tokenized to sentences
# by "<SENT>" separator and each sentence is further tokenized separated
# by whitespace. (again German).
# doc_en : same as doc_de expect the text is translated to English. This field is
# absent if translation is turned off
# doc_id : is unique document id as obtained from task's filenames
# doc_labels : is a set of original document labels
train_data = tp.mp_process(train_data)
dev_data = tp.mp_process(dev_data)
# save data
tfname = os.path.join("tmp", "train_data_de" + "_en.pkl" if TRANSLATE else ".pkl")
dfname = os.path.join("tmp", "dev_data_de" + "_en.pkl" if TRANSLATE else ".pkl")
save(tfname, train_data)
save(dfname, dev_data)
def read_train_file(processed_train_file, label_threshold=25):
with open(processed_train_file, "rb") as rf:
train_data = pkl.load(rf)
count_labels = Counter([label for val in train_data for label in val[-1]])
print("[INFO] top most frequent labels:", count_labels.most_common(10))
if label_threshold is None:
discard_labels = set()
else:
discard_labels = {k for k, v in count_labels.items() if v < label_threshold}
temp = []
for val in train_data:
val_labels = {label for label in val[-1] if label not in discard_labels}
if val_labels:
val[-1] = val_labels
temp.append(val)
if discard_labels:
print(
"[INFO] discarded %d labels with counts less than %d; remaining labels %d"
% (len(discard_labels), label_threshold, len(count_labels) - len(discard_labels))
)
print("[INFO] no. of data points removed %d" % (len(train_data) - len(temp)))
train_data = temp[:]
mlb = MultiLabelBinarizer()
temp = [val[-1] for val in train_data]
labels = mlb.fit_transform(temp)
train_data = [(val[0], val[1], labels[idx, :]) for idx, val in enumerate(train_data)]
return train_data, mlb, discard_labels
def read_dev_file(proceesed_dev_file, mlb, discard_labels):
with open(proceesed_dev_file, "rb") as rf:
dev_data = pkl.load(rf)
count_labels = Counter([label for val in dev_data for label in val[-1]])
print("[INFO] top most frequent labels:", count_labels.most_common(10))
temp = []
for val in dev_data:
# discard any labels and keep only ones seen in training
val_labels = {
label for label in val[-1]
if label not in discard_labels and label in set(mlb.classes_)
}
if val_labels:
val[-1] = val_labels
temp.append(val)
print("[INFO] no. of data points removed %d" % (len(dev_data) - len(temp)))
dev_data = temp[:]
temp = [val[-1] for val in dev_data]
labels = mlb.transform(temp)
dev_data = [(val[0], val[1], labels[idx, :]) for idx, val in enumerate(dev_data)]
return dev_data
def prepare_varying_train_data(train_file, dev_file):
thresholds = [None, 5, 10, 15, 20, 25, 50]
os.makedirs("tmp", exist_ok=True)
for t in thresholds:
train_data, mlb, discard_labels = read_train_file(train_file, t)
dev_data = read_dev_file(dev_file, mlb, discard_labels)
if t is None:
t = 0
suffix = "_t{}_c{}.pkl".format(t, len(mlb.classes_))
save(os.path.join("tmp", "train_data"+suffix), train_data)
save(os.path.join("tmp", "dev_data"+suffix), dev_data)
save(os.path.join("tmp", "mlb"+suffix), mlb)
save(os.path.join("tmp", "discarded"+suffix), discard_labels)