-
Notifications
You must be signed in to change notification settings - Fork 2
/
dataloader.py
259 lines (247 loc) · 12.5 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import pandas as pd
import torch
import pickle
from transformers import *
from tqdm import tqdm
import gc
import argparse
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0)
def print_data_statistics(df):
"""
Function to print certain statistics on the keywords and titles in the dataset
Args:
df:pd.DataFrame - Dataframe of the dataset for analysis
"""
print('Unique titles: ', df.title.nunique() == df.shape[0])
print('Unique keyword: ', df.keyword.nunique() == df.shape[0])
print('Null values: ', df.isnull().values.any())
print('average title sentence length: ', df.title.str.split().str.len().mean())
print('stdev title sentence length: ', df.title.str.split().str.len().std())
print('average kw sentence length: ', df.keyword.str.split().str.len().mean())
print('stdev kw sentence length: ', df.keyword.str.split().str.len().std())
def print_label_statistics(df, label_cols):
"""
Function to print certain statistics on the labels in the dataset
Args:
df:pd.DataFrame - Dataframe of the dataset for analysis
label_cols:list - list of labels in the dataset such as [e,s,c,i] or [e,s,ci]
"""
print('Count of 1 per label: \n', df[label_cols].sum(), '\n')
print('Count of 0 per label: \n', df[label_cols].eq(0).sum())
def get_mode_labels(df,mode="e_s_c_i"):
"""
Function to process label columns according to the classification mode.
Args:
df:pd.DataFrame - Dataframe of the dataset for analysis
mode:str - classification mode (e_s_c_i or e_s_ci or e_sci or es_ci)
Returns:
Processed dataframe column (can be appended to the dataset dataframe)
"""
if mode == "e_s_c_i":
return df['class'].apply(lambda x: 'e' if x=="exact" else 's' if x=="substitute" else 'c' if x=="complement" else "i")
if mode == "e_s_ci":
return df['class'].apply(lambda x: 'e' if x=="e" else 's' if x=="s" else 'ci')
if mode == "e_sci":
return df['class'].apply(lambda x: 'e' if x=="e" else 'sci')
if mode == "es_ci":
return df['class'].apply(lambda x: 'es' if (x=="exact" or x=="substitute") else 'ci')
def load_tokenizer(language_model = "xlm"):
"""
Function to load tokenizer according to the language model
Args:
language_model:str - decides the language model to load (xlm or bert)
Returns:
tokenizer object from transformers library
"""
if language_model == "xlm":
print("Load XLM tokenizer")
tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base', do_lower_case=True) # tokenizer
elif language_model == "bert":
print("Load RoBERTa tokenizer")
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
print("Loaded Tokenizer")
return tokenizer
def tokenize_inputs(text_kw_list, text_ti_list, tokenizer, num_embeddings=128):
"""
Function to tokenize the input text input into ids. Appends the appropriate special
characters to the end of the text to denote end of sentence. Truncate or pad
the appropriate sequence length.
Args:
text_kw_list:list - list of keywords
text_ti_list:list - list of titles
tokenizer:tokenizer object - tokenizer depending on the language model to be used.
num_embeddings:int - Maximum length of tokens to be considered
Returns:
(List of input ids:list, List of attention masks:tensor):tuple
"""
import warnings
warnings.filterwarnings("ignore")
text_list = zip(text_kw_list,text_ti_list)
input_ids, attention_masks = [], []
for x,y in tqdm(text_list,total=len(text_kw_list)):
input_id, attention_mask = tokenizer.encode_plus(x,y,add_special_tokens=True,padding='max_length',
max_length=num_embeddings,truncation=True).values()
input_ids.append(input_id)
attention_masks.append(attention_mask)
input_ids = pad_sequences(input_ids, maxlen=num_embeddings, dtype="long", truncating="post", padding="post")
return list(input_ids), attention_masks
def save_dataloader(path="new_dataset/trainer.tsv", mode="e_s_c_i", df_frac = 1,
language_model="xlm", max_length=128, split_mark="train",
create_validation=0.05, batch_size=8):
"""
Function to save the text dataloader after preprocessing the dataset with the tokenizers.
Files are stored in the directory: ./dataloader/
Args:
path:str - Path to the dataset tsv
mode:str - Classification mode (e_s_c_i or e_s_ci or es_ci or e_sci)
df_frac:float - Percentage of data to use (for large files)
language_model:str - decides the language model to load (xlm or bert)
max_length:int - Maximum length of tokens to be considered
split_mark:str - marker for the processed files (e.g., train/test/valid)
create_validation:float - use part of the dataset as validation dataset
batch_size: int - number of samples/batch in the dataloader
"""
df = pd.read_csv(path,sep="\t")
df = df.sample(frac=df_frac,random_state=42)
print_data_statistics(df)
df = df[["keyword","title","class"]]
df[mode] = get_mode_labels(df,mode)
y = pd.get_dummies(df[mode])
for label in mode.split("_"):
df[label] = y[label]
df = df.drop(columns="class")
df = df.drop(columns=mode)
cols = df.columns
label_cols = list(cols[2:])
print('Label columns: ', label_cols)
df['one_hot_labels'] = list(df[label_cols].values)
print_label_statistics(df,label_cols)
labels = list(df.one_hot_labels.values)
keywords = list(df.keyword.values)
titles = list(df.title.values)
tokenizer = load_tokenizer(language_model)
input_ids, attention_masks = tokenize_inputs(keywords,titles, tokenizer,num_embeddings=max_length)
if create_validation > 0:
train_inputs, validation_inputs, train_labels, validation_labels, train_masks, validation_masks = train_test_split(input_ids, labels,attention_masks, test_size=create_validation,shuffle=False)
else:
train_inputs, train_labels, train_masks= input_ids, labels, attention_masks
gc.collect()
train_inputs = torch.tensor(train_inputs)
train_labels = torch.tensor(train_labels)
train_masks = torch.tensor(train_masks)
if create_validation > 0:
validation_inputs = torch.tensor(input_ids)
validation_labels = torch.tensor(labels)
validation_masks = torch.tensor(attention_masks)
train_data = TensorDataset(train_inputs, train_masks, train_labels)
del(train_inputs,train_masks,train_labels)
train_dataloader = DataLoader(train_data, batch_size=batch_size)
torch.save(train_dataloader,f'./dataloader/{split_mark}_data_loader_text_{mode}')
print(f"{split_mark} dataloader stored in file ./dataloader/{split_mark}_data_loader_text_{mode}")
del(train_dataloader)
gc.collect()
if create_validation > 0:
validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels)
del(validation_inputs,validation_masks,validation_labels)
gc.collect()
validation_dataloader = DataLoader(validation_data, batch_size=batch_size)
torch.save(validation_dataloader,f'./dataloader/validation_data_loader_text_{mode}')
print(f"validation dataloader stored in file ./dataloader/validation_data_loader_text_{mode}")
del(validation_dataloader)
gc.collect()
def get_graph_inputs(keywords):
"""
Function to get graph inputs for a list of keywords/titles
Args:
keywords:list - list of text of keywords/titles.
Returns:
node1_list:list - list of graph neighborhoods from the saved graph pickle
"""
print("Loading graph dataset")
graph = pickle.load(open("two_hop_ngbrs.pkl","rb"))
print("Completed loading graph")
node1_list= []
for x in tqdm(keywords,total=len(keywords)):
first = graph.get(x,None)
node1_list.append(first)
del(graph)
gc.collect()
return node1_list
def save_graph_dataloader(path="new_dataset/trainer.tsv", mode="e_s_c_i", df_frac = 1,
split_mark="train",
create_validation=0.05, batch_size=8):
"""
Function to save the graph dataloader after preprocessing the dataset.
Files are stored in the directory: ./dataloader/
Args:
path:str - Path to the dataset tsv
mode:str - Classification mode (e_s_c_i or e_s_ci or es_ci or e_sci)
df_frac:float - Percentage of data to use (for large files)
split_mark:str - marker for the processed files (e.g., train/test/valid)
create_validation:float - use part of the dataset as validation dataset
batch_size: int - number of samples/batch in the dataloader
"""
df = pd.read_csv(path,sep="\t")
df = df.sample(frac=df_frac,random_state=42)
print_data_statistics(df)
df = df[["keyword","title"]]
keywords = list(df.keyword.values)
titles = list(df.title.values)
del(df)
gc.collect()
if create_validation > 0:
keywords_train, keywords_validation, titles_train, titles_validation = train_test_split(keywords, titles, test_size=create_validation, shuffle=False)
node1_train = get_graph_inputs(keywords_train)
node2_train = get_graph_inputs(titles_train)
node1_validation = get_graph_inputs(keywords)
node2_validation = get_graph_inputs(titles)
else:
node1_train = get_graph_inputs(keywords)
node2_train = get_graph_inputs(titles)
gc.collect()
with open(f'./dataloader/{split_mark}_data_loader_node1_{mode}',"wb") as file_handler:
pickle.dump(list(node1_train), file_handler)
del(node1_train)
gc.collect()
with open(f'./dataloader/{split_mark}_data_loader_node2_{mode}',"wb") as file_handler:
pickle.dump(list(node2_train), file_handler)
del(node2_train)
gc.collect()
print(f"{split_mark} dataloader stored in file ./dataloader/{split_mark}_data_loader_node1(2)_{mode}")
if create_validation > 0:
pickle.dump(list(node1_validation),open(f'./dataloader/validation_data_loader_node1_{mode}',"wb"))
del(node1_validation)
gc.collect()
pickle.dump(list(node2_validation),open(f'./dataloader/validation_data_loader_node2_{mode}',"wb"))
print(f"validation dataloader stored in file ./dataloader/validation_data_loader_node1(2)_{mode}")
del(node2_validation)
gc.collect()
if __name__=="__main__":
"""
Main script to be called for the text and graph dataloader construction for the SMLM model.
save_dataloader - Saves the text dataloader for the experiment.
save_graph_dataloader - Saves the graph dataloader for the experiment.
Directory created:
./dataloader/: Contains the graph and text dataloaders for the dataset.
"""
parser = argparse.ArgumentParser(description='Construct dataloader from query-asin dataset')
parser.add_argument('--path', metavar='P', type=str, help='path to the query-asin dataset')
parser.add_argument('--frac', metavar='F', type=float, default=1, help='incase you need to only process a fraction of the dataset')
parser.add_argument('--mode', metavar='M', type=str, default="e_s_c_i", help='classifier mode: e_s_c_i (or) e_s_ci (or) e_sci (or) es_ci')
parser.add_argument('--lm', metavar='LM', type=str, default="xlm", help='language model: xlm (or) bert')
parser.add_argument('--maxlen', metavar='ML', type=int, default=128, help='max length of the sentence to be considered')
parser.add_argument('--split_mark', metavar='SM', type=str, default="train", help='marks the type of split; train, test or valid')
parser.add_argument('--create_validation', metavar='CV', type=float, default=0.05, help='fraction of the train data to be used for validation')
parser.add_argument('--batch_size', metavar='BS', type=int, default=8, help='batch size of the dataloaders')
parser.add_argument('--mkt', metavar='MKT', type=str, default="AU", help='batch size of the dataloaders')
parser.add_argument('--month', metavar='mnth', type=str, default="01", help='batch size of the dataloaders')
args = parser.parse_args()
print(args)
save_dataloader(path=args.path, mode=args.mode, df_frac = args.frac,
language_model=args.lm, max_length=args.maxlen, split_mark=args.split_mark,
create_validation=args.create_validation, batch_size=args.batch_size, mkt=args.mkt,month=args.month)
save_graph_dataloader(path=args.path, mode=args.mode, df_frac = args.frac,
split_mark=args.split_mark,
create_validation=args.create_validation, batch_size=args.batch_size, mkt=args.mkt)