-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_utils.py
147 lines (121 loc) · 5.26 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
'''
@Description: low-level func for reading data from files
@Author: bairongz ([email protected])
@Date: 2020-01-03 23:05:17
@LastEditTime : 2020-01-03 23:53:07
'''
from model import Transformer
import random
import torch
import constants as C
import os
import re
def clean(s):
# this pattern are defined for cleaning the dailydialog dataset
s = s.strip().lower()
s = re.sub(r'(\w+)\.(\w+)', r'\1 . \2', s)
s = re.sub(r'(\w+)-(\w+)', r'\1 \2', s)
# s = re.sub(r'[0-9]+(\.[0-9]+)?', r'1', s)
s = s.replace('。', '.')
# s = s.replace(';', ',')
s = s.replace('...', ',')
s = s.replace(' p . m . ', ' pm ')
s = s.replace(' P . m . ', ' pm ')
s = s.replace(' a . m . ', ' am ')
# this pattern are defined for cleaning the ubuntu dataset
# ....
return s
def _read_raw_separate(src_path, trg_path):
"""read data from paired file with source and targat sentence in seperated file
Arguments:
src_path {str} -- path for file containing source sentence
trg_path {str} -- path for file containing target sentence
Returns:
list -- list where each element is a src-trg word sequence pair e.g. [(src1,trg1),[src2,trg2],...]
"""
data = []
with open(src_path) as src_f, open(trg_path) as trg_f:
for src, trg in zip(src_f, trg_f):
# NOTE 512 tokens
src = clean(src)
trg = clean(trg)
src, trg = src.replace('__eou__', '').strip().split(), trg.replace('__eou__', '').strip().split()
if len(src) > 300:
src = src[-300:]
if len(trg) > 300:
trg = trg[:50]
# pair = (src.strip().split(), trg.strip().split())
pair = (src, trg)
data.append(pair)
return data
def read_raw_weibo_data(src_path, trg_path):
dialogs = _read_raw_separate(src_path, trg_path)
# for backward compatibale use
dialogs = list(map(lambda x: [x], dialogs))
return dialogs
def read_raw_iwslt14_data(path, test_only=False):
train_path_de = os.path.join(path, "src-train.txt")
train_path_en = os.path.join(path, "tgt-train.txt")
valid_path_de = os.path.join(path, "src-dev.txt")
valid_path_en = os.path.join(path, "tgt-dev.txt")
test_path_de = os.path.join(path, "src-test.txt")
test_path_en = os.path.join(path, "tgt-test.txt")
'''
train_path_en = os.path.join(path, "train.en")
train_path_de = os.path.join(path, "train.de")
valid_path_en = os.path.join(path, "valid.en")
valid_path_de = os.path.join(path, "valid.de")
test_path_en = os.path.join(path, "test.en")
test_path_de = os.path.join(path, "test.de")'''
if test_only:
test_data = read_raw_weibo_data(test_path_de, test_path_en)
return test_data
else:
train_data = _read_raw_separate(train_path_de, train_path_en)
valid_data = _read_raw_separate(valid_path_de, valid_path_en)
test_data = _read_raw_separate(test_path_de, test_path_en)
return train_data, valid_data, test_data
def read_raw_wmt14_data(path):
train_path_en = os.path.join(path, "train.en")
train_path_de = os.path.join(path, "train.de")
valid_path_en = os.path.join(path, "valid.en")
valid_path_de = os.path.join(path, "valid.de")
test_path_en = os.path.join(path, "test.en")
test_path_de = os.path.join(path, "test.de")
train_data = _read_raw_separate(train_path_en, train_path_de)
valid_data = _read_raw_separate(valid_path_en, valid_path_de)
test_data = _read_raw_separate(test_path_en, test_path_de)
return train_data, valid_data, test_data
def read_raw_weibo1m_data(path, n_valid_samples=10000, n_test_samples=10000):
src_file_path = os.path.join(path, "weibo_train_1m.src")
trg_file_path = os.path.join(path, "weibo_train_1m.trg")
dialogs = _read_raw_separate(src_file_path, trg_file_path)
shuffle(dialogs)
valid_dialogs = dialogs[:n_valid_samples]
test_dialogs = dialogs[n_valid_samples:n_valid_samples+n_test_samples]
train_dialogs = dialogs[n_valid_samples+n_test_samples:]
return train_dialogs, valid_dialogs, test_dialogs
def read_raw_rand_data(n_samples,n_vocab=5000):
train_data = []
for _ in range(n_samples):
trg = [random.randint(0,n_vocab) for _ in range(random.randint(10,30))]
src = trg.copy()
train_data.append((trg, src))
valid_data = []
for _ in range(min(int(0.05*n_samples), 5000)):
trg = [random.randint(0,n_vocab) for _ in range(random.randint(10,30))]
src = trg.copy()
valid_data.append((trg, src))
test_data = []
for _ in range(min(int(0.05*n_samples), 5000)):
trg = [random.randint(0,n_vocab) for _ in range(random.randint(10,30))]
src = trg.copy()
test_data.append((trg, src))
return train_data, valid_data, test_data
def padding_for_trs(batch):
items = zip(*batch)
padded_src, padded_trg, src_pos, trg_pos = list(
map(lambda x: torch.nn.utils.rnn.pad_sequence(x, padding_value=C.PAD), items))
trg_mask, src_key_padding_mask, trg_key_padding_mask, memory_key_padding_mask = Transformer.get_masks(
padded_src, padded_trg[:-1], PAD=C.PAD)
return padded_src, padded_trg, src_pos, trg_pos, trg_mask, src_key_padding_mask, trg_key_padding_mask, memory_key_padding_mask