-
Notifications
You must be signed in to change notification settings - Fork 0
/
loading_saving_utils.py
108 lines (70 loc) · 3.28 KB
/
loading_saving_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import checklist
import copy
import torch
import random
import numpy as np
import pickle
import wandb
from textattack import metrics
from math import floor
from checklist.test_types import MFT, INV, DIR
from checklist.test_suite import TestSuite
from sst_model import *
from config import *
from textattack.coverage import neuronMultiSectionCoverage
from textattack.datasets import HuggingFaceDataset
from datasets import load_dataset
from coverage_args import *
from coverage_utils import *
from repr_utils import *
def load_model_dataset(args):
if args.suite == 'sentiment':
# pretrained BERT model on SST-2
args.model_name_or_path = 'textattack/'+args.base_model+'-SST-2'
print('='*5, 'Loading ', args.model_name_or_path, '='*5)
print()
print()
model = SSTModel(args.model_name_or_path, args.max_seq_len)
elif args.suite == 'qqp':
# pretrained BERT model on QQP
args.model_name_or_path = 'textattack/'+args.base_model+'-MRPC'
print('='*5, 'Loading ', args.model_name_or_path, '='*5)
print()
print()
model = SSTModel(args.model_name_or_path, args.max_seq_len)
else:
quit()
trainset_str, validset_str = [], []
if args.suite == 'sentiment':
text_key = 'sentence'
trainset = HuggingFaceDataset('glue', 'sst2', 'train', shuffle = True)
validset = HuggingFaceDataset('glue', 'sst2', 'validation', shuffle = False)
trainset_str = [example[0][text_key] for example in trainset]
testset_str = [example[0][text_key] for example in validset]
validset_str = trainset_str[floor(0.8*len(trainset_str)):]
trainset_str = trainset_str[:floor(0.8*len(trainset_str))]
if args.suite == 'qqp':
'''
trainset = load_dataset('quora')['train']['questions']
print('total examples in quora dataset:', len(trainset))
validset_str = trainset[floor(0.8*len(trainset)):]
trainset_str = trainset[:floor(0.8*len(trainset))]
print('Trainset in QQP dataset:', len(trainset))
'''
trainset = HuggingFaceDataset('glue', 'qqp', 'train', shuffle = True)
validset = HuggingFaceDataset('glue', 'qqp', 'validation', shuffle = False)
testset = HuggingFaceDataset('glue', 'qqp', 'test', shuffle = False)
trainset_str = [(example[0]['question1'], example[0]['question2']) for example in trainset]
validset_str = [(example[0]['question1'], example[0]['question2']) for example in validset]
testset_str = [(example[0]['question1'], example[0]['question2']) for example in testset]
if args.mask == 'imask':
interaction_importance = torch.from_numpy( np.load('masks/'+args.suite+'_'+args.base_model + '-interaction.npy') )
word_importance = torch.from_numpy( np.load('masks/'+args.suite+'_'+args.base_model + '-word.npy') )
word_importance = torch.where(word_importance>=args.word_importance_threshold, torch.ones_like(word_importance),\
torch.zeros_like(word_importance))
interaction_importance = torch.where(interaction_importance>=args.interaction_importance_threshold, torch.ones_like(interaction_importance),\
torch.zeros_like(interaction_importance))
else:
word_importance = None
interaction_importance = None
return model, trainset_str, validset_str, testset_str, word_importance, interaction_importance