-
Notifications
You must be signed in to change notification settings - Fork 2
/
toy_gen.py
80 lines (62 loc) · 2.2 KB
/
toy_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import argparse
import numpy as np
import os
import random
from data_model import Dialog
from xtrack_data2 import XTrackData2
def build(out_file, based_on, input_len, dialog_len, dialog_cnt):
vocab_size = 500
vocab = [str(i) for i in range(vocab_size)]
slot_vocab = ["val%d" % i for i in range(70)]
dialogs = []
#d.add_message('hi', None, Dialog.ACTOR_SYSTEM)
for d_id in range(dialog_cnt):
d = Dialog(str(d_id), str(d_id))
goal = None
for m in range(dialog_len):
txt = []
for t in range(random.randint(1, input_len)):
word = random.choice(vocab)
txt.append(word)
val = random.choice(slot_vocab)
ndx = random.randint(1, len(txt))
if random.random() > 0.65:
txt.insert(ndx, 'X')
goal = val
txt.insert(ndx, val)
#goal = goal[0]
#if random.random() < 0.001:
#print ' '.join(txt), '--', goal
score = 1.0 #random.random()
d.add_message([(" ".join(txt), np.log(score))],
{'food': goal},
Dialog.ACTOR_USER)
dialogs.append(d)
print '> Data built.'
xt = XTrackData2()
xt.build(dialogs,
slots=['food'],
slot_groups={0: ['food']},
based_on=based_on,
oov_ins_p=0.0,
include_system_utterances=False,
n_nbest_samples=1,
n_best_order=[0],
score_mean=0.0,
dump_text='/dev/null',
dump_cca='/dev/null',
score_bins=[0.2, 0.4, 0.6, 0.8, 1.0],
word_drop_p=0.0)
print '> Saving.'
xt.save(out_file)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--out_file',
required=True,
help="Output file.")
parser.add_argument('--based_on', default=None)
parser.add_argument('--input_len', default=15, type=int)
parser.add_argument('--dialog_len', default=5, type=int)
parser.add_argument('--dialog_cnt', default=10000, type=int)
args = parser.parse_args()
build(**vars(args))