-
Notifications
You must be signed in to change notification settings - Fork 8
/
demo1.py
120 lines (95 loc) · 3.37 KB
/
demo1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from transformers import RobertaTokenizer, RobertaConfig, RobertaModel, AdamW
import numpy as np
from dataSet import PhpDataset
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from pathlib import Path
from pretreatment.code_pre import code_pre
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import time
import NNModel
import os
def walkFile(file):
file_list = []
for root, dirs, files in os.walk(file):
# root 表示当前正在访问的文件夹路径
# dirs 表示该文件夹下的子目录名list
# files 表示该文件夹下的文件list
# 遍历文件
for f in files:
if os.path.join(root, f)[-3:] == 'php':
file_list.append(os.path.join(root, f))
# 遍历所有的文件夹
for d in dirs:
os.path.join(root, d)
return file_list
class testDataset(Dataset):
def __init__(self):
self.tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
self.model = RobertaModel.from_pretrained("microsoft/codebert-base")
self.df = walkFile('./')
def __getitem__(self, item):
try:
rf = open(self.df[item], 'r', encoding='utf-8', errors='ignore')
data = rf.read()
finally:
# print(data)
# print(self.df[item])
rf.close()
data = code_pre(data)[:10000]
# print(data)
# data = data
# print(len(data))
# print(len(data))
inputs = self.tokenizer.encode_plus(
data,
None,
add_special_tokens=True,
max_length=512,
padding='max_length',
return_token_type_ids=True,
truncation=True,
)
ids = inputs['input_ids']
mask = inputs['attention_mask']
token_type_ids = inputs["token_type_ids"]
return {
'ids': torch.tensor(ids, dtype=torch.long),
'mask': torch.tensor(mask, dtype=torch.long),
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
'filename': self.df[item]
}
#
# outputs = self.model(torch.tensor([inputs['input_ids']]))
#
# label = 1 if self.df[item][40:43] == 'bla' else 0
# return outputs, label
def __len__(self):
return len(self.df)
if __name__ == '__main__':
data_set = testDataset()
data_loading = DataLoader(dataset=data_set, batch_size=32, shuffle=True, num_workers=2, drop_last=False)
model = NNModel.TextCNNClassifer().cpu()
model.load_state_dict(torch.load('model/cls_model_0.pth'))
for _, data in enumerate(data_loading, 0):
# print(data['filename'])
time_start = time.time()
ids = data['ids'].cpu()
mask = data['mask'].cpu()
token_type_ids = data['token_type_ids'].cpu()
outputs = model(ids, mask, token_type_ids)
pred_choice = outputs.max(1)[1]
index = torch.where(pred_choice == 0)
for i in index:
index = i.numpy().tolist()
for i in index:
print(outputs[i])
print(data['filename'][i])
# print(outputs)
# print(pred_choice)
# print(ids.size())
time_end = time.time()
print('totally cost', time_end - time_start)
# break