-
Notifications
You must be signed in to change notification settings - Fork 2
/
finetune.py
181 lines (145 loc) · 5.97 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#!/usr/bin/env python
# coding: utf-8
"""
Sample program to finetune the SMLM model for marketplace <mkt> for the
task of classification on cuda device <cuda_id>
Run as: python finetune.py <mkt> <cuda_id>
We can parallely run the finetuning for different marketplaces by using multiple devices.
Recommend not using cuda_id = 0 because that will be used for pytorch's reserve memory.
"""
import pandas as pd
import numpy as np
import torch
from torch.nn import BCEWithLogitsLoss
import pickle
from sklearn.metrics import f1_score, accuracy_score
from transformers import *
from tqdm import tqdm
import sys
import itertools
from model import Classifier
script, mkt, cuda_id = sys.argv
mode = "e_s_ci"#classes
def chunked(it, size):
"""
Function to build iterator for the graph dataloaders.
Args:
it:iterable object - iterable list of graphs
size:int - size of each iteration yield (=batch_size)
"""
it = iter(it)
while True:
p = list(itertools.islice(it, size))
if not p:
break
yield p
# Define parameters for loading dataloaders
split_mark = "train"
device = "cuda"
batch_size = 8
num_labels = 3
# Loading Dataloaders
print("Loading Dataloaders")
train_dataloader = torch.load(f'dataloader/test/{mkt}_{split_mark}_data_loader_text_{mode}')
validation_dataloader = torch.load(f'dataloader/test/{mkt}_validation_data_loader_text_{mode}')
train1_graph = pickle.load(open(f"./dataloader/test/{mkt}_{split_mark}_data_loader_node1_{mode}","rb"))
train2_graph = pickle.load(open(f"./dataloader/test/{mkt}_{split_mark}_data_loader_node2_{mode}","rb"))
validation1_graph = pickle.load(open(f"./dataloader/test/{mkt}_validation_data_loader_node1_{mode}","rb"))
validation2_graph = pickle.load(open(f"./dataloader/test/{mkt}_validation_data_loader_node2_{mode}","rb"))
print("Loaded Dataloaders")
# Define device
device = torch.device(f"cuda:{cuda_id}")
# Load Classifier SMLM model
model = Classifier(lm="xlm", mode=mode, graph=1)
model.to(device)
# Load weights trained on marketplace-agnostic dataset
model.load_state_dict(torch.load(f'e_s_ci_mlm_structurelm_model_amazon_ckpt'))
# Setting custom optimization parameters.
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
'weight_decay_rate': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters,lr=2e-5,correct_bias=True)
# Store our loss and accuracy for plotting
train_loss_set = []
# Number of training epochs
epochs = 6
for _ in range(epochs):
# Training
# Set our model to training mode (as opposed to evaluation mode)
model.train()
# Tracking variables
tr_loss = 0 #running loss
nb_tr_examples, nb_tr_steps = 0, 0
# Train the data for one epoch
train1_loader = chunked(train1_graph,batch_size)
train2_loader = chunked(train2_graph,batch_size)
for step, batch in enumerate(tqdm(train_dataloader)):
# Add batch to GPU
batch = tuple(t.to(device) for t in batch)
# Unpack the inputs from our dataloader
b_input_ids, b_input_mask, b_labels = batch
# Clear out the gradients (by default they accumulate)
optimizer.zero_grad()
# Forward pass for multilabel classification
graph1 = next(train1_loader)
graph2 = next(train2_loader)
outputs = model(b_input_ids, attention_mask=b_input_mask, data1=graph1, data2=graph2)
logits = outputs
loss_func = BCEWithLogitsLoss()
loss = loss_func(logits.view(-1,num_labels),b_labels.type_as(logits).view(-1,num_labels)) #convert labels to float for calculation
train_loss_set.append(loss.item())
# Backward pass
loss.backward()
# Update parameters and take a step using the computed gradient
optimizer.step()
# Update tracking variables
tr_loss += loss.item()
nb_tr_examples += b_input_ids.size(0)
nb_tr_steps += 1
print("Train loss: {}".format(tr_loss/nb_tr_steps))
###############################################################################
# Validation
# Put model in evaluation mode to evaluate loss on the validation set
model.eval()
validation1_loader = chunked(validation1_graph,batch_size)
validation2_loader = chunked(validation2_graph,batch_size)
# Variables to gather full output
logit_preds,true_labels,pred_labels,tokenized_texts = [],[],[],[]
# Predict
for i, batch in enumerate(validation_dataloader):
batch = tuple(t.to(device) for t in batch)
# Unpack the inputs from our dataloader
b_input_ids, b_input_mask, b_labels = batch
with torch.no_grad():
# Forward pass
graph1 = next(validation1_loader)
graph2 = next(validation2_loader)
print(b_input_ids.shape,b_input_mask.shape)
outs = model(b_input_ids, attention_mask=b_input_mask, data1=graph1, data2=graph2)
b_logit_pred = outs
pred_label = torch.sigmoid(b_logit_pred)
b_logit_pred = b_logit_pred.detach().cpu().numpy()
pred_label = pred_label.to('cpu').numpy()
b_labels = b_labels.to('cpu').numpy()
tokenized_texts.append(b_input_ids)
logit_preds.append(b_logit_pred)
true_labels.append(b_labels)
pred_labels.append(pred_label)
# Flatten outputs
pred_labels = [item for sublist in pred_labels for item in sublist]
true_labels = [item for sublist in true_labels for item in sublist]
# Calculate Accuracy
threshold = 0.50
pred_bools = [pl>threshold for pl in pred_labels]
true_bools = [tl==1 for tl in true_labels]
val_f1_accuracy = f1_score(true_bools,pred_bools,average='micro')*100
val_flat_accuracy = accuracy_score(true_bools, pred_bools)*100
print('F1 Validation Accuracy: ', val_f1_accuracy)
print('Flat Validation Accuracy: ', val_flat_accuracy)
# Save the fine-tuned model for marketplace {mkt} and model {mode}
torch.save(model.state_dict(), f'./finetuned_models/{mkt}_{mode}_mlm_structurelm_new_model_amazon_ckpt')