-
Notifications
You must be signed in to change notification settings - Fork 2
/
concat_data.py
53 lines (40 loc) · 1.37 KB
/
concat_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import argparse
import numpy as np
import os
import random
from data_model import Dialog
from xtrack_data2 import XTrackData2
def concat(files):
data = []
for file_name in files:
data.append(XTrackData2.load(file_name))
vocab = data.vocab
classes = data.classes
dialogs = []
for slot, slot_clss in classes.iteritems():
for cls_name in slot_clss:
if cls_name == XTrackData2.null_class:
continue
d_id = 'dummy_%s_%s' % (slot, cls_name)
d = Dialog(d_id, d_id)
msg = '%s %s' % (cls_name, slot)
d.add_message([(msg, 1.0)], {slot: cls_name},
Dialog.ACTOR_USER)
dialogs.append(d)
print '> Data built.'
xt = XTrackData2()
xt.build(dialogs, slots=['food'], slot_groups={0: ['food']},
based_on=based_on,
oov_ins_p=0.0,
include_system_utterances=False, n_nbest_samples=1, n_best_order=[0],
score_mean=0.0, dump_text='/dev/null', replace_entities=False,
split_dialogs=False)
print '> Saving.'
xt.save(out_file)
if __name__ == '__main__':
from utils import init_logging
init_logging('ConcatData')
parser = argparse.ArgumentParser()
parser.add_argument('file', nargs='*', action='append')
args = parser.parse_args()
concat(**vars(args))