-
Notifications
You must be signed in to change notification settings - Fork 2
/
evaluate.py
115 lines (95 loc) · 3.96 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/usr/bin/env python
# coding: utf-8
import pandas as pd
import numpy as np
import torch
from sklearn.metrics import classification_report
import pickle
from transformers import *
from tqdm import tqdm, trange
from ast import literal_eval
import sys
import itertools
from model import Classifier
"""
Sample program to evaluate the SMLM model on marketplace <mkt> dataset for the
task of classification on cuda device <cuda_id>
Run as: python evaluate.py <mkt> <cuda_id>
We can parallely run the evaluation for different marketplaces by using multiple devices.
Recommend not using cuda_id = 0 because that will be used for pytorch's reserve memory.
"""
script, mkt, cuda_id = sys.argv
mode = "e_s_ci" #classes
label_cols = mode.split("_")
def chunked(it, size):
"""
Function to build iterator for the graph dataloaders.
Args:
it:iterable object - iterable list of graphs
size:int - size of each iteration yield (=batch_size)
"""
it = iter(it)
while True:
p = list(itertools.islice(it, size))
if not p:
break
yield p
# Define parameters for loading dataloaders
split_mark = "train"
epochs = 4
device = "cuda"
batch_size = 8
num_labels = 3
# Loading dataloaders
print("Loading Dataloaders")
train_dataloader = torch.load(f'dataloader/monthly/{mkt}_{split_mark}_data_loader_text_{mode}')
validation_dataloader = torch.load(f'dataloader/monthly/{mkt}_validation_data_loader_text_{mode}')
train1_graph = pickle.load(open(f"./dataloader/monthly/{mkt}_{split_mark}_data_loader_node1_{mode}","rb"))
train2_graph = pickle.load(open(f"./dataloader/test/{mkt}_{split_mark}_data_loader_node2_{mode}","rb"))
validation1_graph = pickle.load(open(f"./dataloader/test/{mkt}_validation_data_loader_node1_{mode}","rb"))
validation2_graph = pickle.load(open(f"./dataloader/test/{mkt}_validation_data_loader_node2_{mode}","rb"))
print("Loaded Dataloaders")
# Define device
device = torch.device(f"cuda:{cuda_id}")
# Load Classifier SMLM model
model = Classifier(lm="xlm", mode=mode, graph=1)
model.to(device)
# Load weights trained on marketplace-specific dataset
model.load_state_dict(torch.load(f'finetuned_models/{mkt}_e_s_ci_mlm_structurelm_new_model_amazon_ckpt'))
# Run evaluation over the validation datasets.
model.eval()
logit_preds,true_labels,pred_labels,tokenized_texts = [],[],[],[]
validation1_loader = chunked(validation1_graph,batch_size)
validation2_loader = chunked(validation2_graph,batch_size)
# Predict
for i, batch in tqdm(enumerate(validation_dataloader),total=len(validation_dataloader)):
batch = tuple(t.to(device) for t in batch)
# Unpack the inputs from our dataloader
b_input_ids, b_input_mask, b_labels = batch
with torch.no_grad():
# Forward pass
graph1 = next(validation1_loader)
graph2 = next(validation2_loader)
outs = model(b_input_ids, attention_mask=b_input_mask, data1=graph1, data2=graph2)
b_logit_pred = outs
pred_label = torch.sigmoid(b_logit_pred)
b_logit_pred = b_logit_pred.detach().cpu().numpy()
pred_label = pred_label.to('cpu').numpy()
b_labels = b_labels.to('cpu').numpy()
tokenized_texts.append(b_input_ids)
logit_preds.append(b_logit_pred)
true_labels.append(b_labels)
pred_labels.append(pred_label)
# Flatten outputs
tokenized_texts = [item for sublist in tokenized_texts for item in sublist]
pred_labels = [item for sublist in pred_labels for item in sublist]
true_labels = [item for sublist in true_labels for item in sublist]
# Converting flattened binary values to boolean values
true_bools = [tl==1 for tl in true_labels]
# Generate classification reports for predictions vs ground truth
clf_report_optimized = classification_report(np.argmax(true_bools,axis=1),np.argmax(pred_labels,axis=1), target_names=label_cols, digits=4)
# Save the classification reports in the "reports/" directory.
pickle.dump(clf_report_optimized, open(f'reports/{mkt}_{mode}_mlm_classification_report_optimized.txt','wb'))
# Print for immediate user check.
print("MKT:",mkt)
print(clf_report_optimized)