-
Notifications
You must be signed in to change notification settings - Fork 1
/
process_doc.py
67 lines (49 loc) · 1.5 KB
/
process_doc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import argparse
import json
from datasets import load_dataset
DOC_DOMAIN_SPLIT = "train"
def text2line(text):
return text.replace("\n", " ").replace("\r", " ").strip()
def btag(tag, text): # tag the content
return "<{}>{}</{}>".format(tag, text2line(text), tag)
def process_doc(args):
doc_dataset = load_dataset(
"doc2dial.py",
name="document_domain",
split=DOC_DOMAIN_SPLIT,
cache_dir=args.cache_dir,
)
d_doc = {}
for ex in doc_dataset:
doc_id = ex["doc_id"]
d_doc[doc_id] = {}
doc_title = btag("title", ex["title"].split("#")[0])
spans_text = []
for d_span in ex["spans"]:
tag = d_span["tag"]
text_sp = d_span["text_sp"]
if tag != "u":
spans_text.append(btag(tag, text2line(text_sp)))
else:
spans_text.append(text2line(text_sp))
d_doc[doc_id]["text"] = " ".join([doc_title] + spans_text)
with open(os.path.join(args.output_dir, "docs.json"), "w") as f:
json.dump(d_doc, f, indent=4)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--cache_dir",
type=str,
help="Path for caching the downloaded data by HuggingFace Datasets",
)
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Path to the output file",
)
args = parser.parse_args()
process_doc(args)
if __name__ == "__main__":
main()