-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
178 lines (153 loc) · 6.51 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import torch
from torch.nn import BCEWithLogitsLoss
from sklearn.metrics import f1_score, accuracy_score
import pickle
from transformers import *
from tqdm import tqdm
from model import Classifier
"""
Sample program to train the SMLM model on marketplace agnostic dataset for the
task of "e vs s vs ci" classification on cuda device
Run as: python train.py
Marketplace-specific finetuning of the model recommended prior to use on test datasets.
"""
import itertools
def chunked(it, size):
"""
Function to build iterator for the graph dataloaders.
Args:
it:iterable object - iterable list of graphs
size:int - size of each iteration yield (=batch_size)
"""
it = iter(it)
while True:
p = list(itertools.islice(it, size))
if not p:
break
yield p
# Define parameters for loading dataloaders
split_mark = "train"
mode = "e_s_ci"
epochs = 4
device = "cuda"
batch_size = 8
# Loading Dataloaders
print("Loading Dataloaders")
train_dataloader = torch.load(f'dataloader/{split_mark}_data_loader_text_{mode}')
validation_dataloader = torch.load(f'dataloader/validation_data_loader_text_{mode}')
train1_graph = pickle.load(open(f"./dataloader/{split_mark}_data_loader_node1_{mode}","rb"))
train2_graph = pickle.load(open(f"./dataloader/{split_mark}_data_loader_node2_{mode}","rb"))
validation1_graph = pickle.load(open(f"./dataloader/validation_data_loader_node1_{mode}","rb"))
validation2_graph = pickle.load(open(f"./dataloader/validation_data_loader_node2_{mode}","rb"))
print("Loaded Dataloaders")
# Load Classifier SMLM model and move it to cuda device
model = Classifier(lm="xlm", mode="e_s_ci", graph=1)
model.cuda()
# Setting custom optimization parameters.
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
'weight_decay_rate': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters,lr=2e-5,correct_bias=True)
# Store our loss and accuracy for plotting
train_loss_set = []
# Set number of labels based on the task
num_labels = len(mode.split("_"))
# Set the class weights (calculated offline using sklearn)
weights = {"e_s_c_i":[0.4115, 1.3260, 1.6203, 5.0389],"e_s_ci":[0.4115, 1.3260, 3.3296], "e_sci":[0.4115,2.6617]}
for _ in range(epochs):
# Training
# Set our model to training mode
model.train()
# Tracking variables
tr_loss = 0 #running loss
nb_tr_examples, nb_tr_steps = 0, 0
train1_loader = chunked(train1_graph,batch_size)
train2_loader = chunked(train2_graph,batch_size)
# Train the data for one epoch
for step, batch in enumerate(tqdm(train_dataloader)):
# Apd batch to GPU
batch = tuple(t.to(device) for t in batch)
# Unpack the inputs from our dataloader
b_input_ids, b_input_mask, b_labels = batch
# Clear out the gradients (by default they accumulate)
optimizer.zero_grad()
graph1 = next(train1_loader)
graph2 = next(train2_loader)
for g_i in range(len(graph1)):
if graph1[g_i] == None: continue
graph1[g_i].x = graph1[g_i].x.float()
graph1[g_i] = graph1[g_i].cuda()
for g_i in range(len(graph2)):
if graph2[g_i] == None: continue
graph2[g_i].x = graph2[g_i].x.float()
graph2[g_i] = graph2[g_i].cuda()
# Forward pass for multilabel classification
logits = model(b_input_ids, attention_mask=b_input_mask, data1=graph1, data2=graph2)
#logits = outputs[0]
loss_func = BCEWithLogitsLoss(weight = torch.tensor(weights[mode],dtype=torch.float32,device=device))
loss = loss_func(logits.view(-1,num_labels),b_labels.type_as(logits).view(-1,num_labels)) #convert labels to float for calculation
# loss_func = BCELoss()
# loss = loss_func(torch.sigmoid(logits.view(-1,num_labels)),b_labels.type_as(logits).view(-1,num_labels)) #convert labels to float for calculation
train_loss_set.append(loss.item())
# Backward pass
loss.backward()
# Update parameters and take a step using the computed gradient
optimizer.step()
# Update tracking variables
tr_loss += loss.item()
nb_tr_examples += b_input_ids.size(0)
nb_tr_steps += 1
print("Train loss: {}".format(tr_loss/nb_tr_steps))
###############################################################################
# Validation
# Put model in evaluation mode to evaluate loss on the validation set
model.eval()
# Variables to gather full output
logit_preds,true_labels,pred_labels,tokenized_texts = [],[],[],[]
validation1_loader = chunked(validation1_graph,batch_size)
validation2_loader = chunked(validation2_graph,batch_size)
# Predict
for i, batch in enumerate(validation_dataloader):
batch = tuple(t.to(device) for t in batch)
# Unpack the inputs from our dataloader
b_input_ids, b_input_mask, b_labels = batch
with torch.no_grad():
# Forward pass
graph1 = next(validation1_loader)
graph2 = next(validation2_loader)
for g_i in range(len(graph1)):
if graph1[g_i] == None: continue
graph1[g_i].x = graph1[g_i].x.float()
graph1[g_i] = graph1[g_i].cuda()
for g_i in range(len(graph2)):
if graph2[g_i] == None: continue
graph2[g_i].x = graph2[g_i].x.float()
graph2[g_i] = graph2[g_i].cuda()
outs = model(b_input_ids, attention_mask=b_input_mask, data1=graph1, data2=graph2)
b_logit_pred = outs
pred_label = torch.sigmoid(b_logit_pred)
b_logit_pred = b_logit_pred.detach().cpu().numpy()
pred_label = pred_label.to('cpu').numpy()
b_labels = b_labels.to('cpu').numpy()
tokenized_texts.append(b_input_ids)
logit_preds.append(b_logit_pred)
true_labels.append(b_labels)
pred_labels.append(pred_label)
# Flatten outputs
pred_labels = [item for sublist in pred_labels for item in sublist]
true_labels = [item for sublist in true_labels for item in sublist]
# Calculate Accuracy
threshold = 0.50
pred_bools = [pl>threshold for pl in pred_labels]
true_bools = [tl==1 for tl in true_labels]
val_f1_accuracy = f1_score(true_bools,pred_bools,average='micro')*100
val_flat_accuracy = accuracy_score(true_bools, pred_bools)*100
print('F1 Validation Accuracy: ', val_f1_accuracy)
print('Flat Validation Accuracy: ', val_flat_accuracy)
# Save the trained marketplace-agnostic model for marketplace {mkt} and model {mode}
torch.save(model.state_dict(), f'{mode}_mlm_structurelm_new_model_amazon_ckpt')