This repository has been archived by the owner on Jul 6, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 29
/
bert_sim.py
168 lines (138 loc) · 5.82 KB
/
bert_sim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# -*- coding: utf-8 -*-
#!/usr/bin/python3
import os
from termcolor import colored
from helper import import_tf, set_logger
__all__ = ['BertSim']
class BertSim(object):
def __init__(self, gpu_no, log_dir, bert_sim_dir, verbose=False):
self.bert_sim_dir = bert_sim_dir
self.logger = set_logger(colored('BS', 'cyan'), log_dir, verbose)
self.tf = import_tf(gpu_no, verbose)
# add tokenizer
from bert import tokenization
self.tokenizer = tokenization.FullTokenizer(os.path.join(bert_sim_dir, 'vocab.txt'))
# add placeholder
self.input_ids = self.tf.placeholder(self.tf.int32, (None, 45), 'input_ids')
self.input_mask = self.tf.placeholder(self.tf.int32, (None, 45), 'input_mask')
self.input_type_ids = self.tf.placeholder(self.tf.int32, (None, 45), 'input_type_ids')
# init graph
self._init_graph()
def _init_graph(self):
"""
init bert graph
"""
try:
from bert import modeling
bert_config = modeling.BertConfig.from_json_file(os.path.join(self.bert_sim_dir, 'bert_config.json'))
self.model = modeling.BertModel(config=bert_config,
is_training=False,
input_ids=self.input_ids,
input_mask=self.input_mask,
token_type_ids=self.input_type_ids,
use_one_hot_embeddings=False)
# get output weights and output bias
ckpt = self.tf.train.get_checkpoint_state(self.bert_sim_dir).all_model_checkpoint_paths[-1]
reader = self.tf.train.NewCheckpointReader(ckpt)
output_weights = reader.get_tensor('output_weights')
output_bias = reader.get_tensor('output_bias')
# get result op
output_layer = self.model.get_pooled_output()
logits = self.tf.matmul(output_layer, output_weights, transpose_b=True)
logits = self.tf.nn.bias_add(logits, output_bias)
self.probabilities = self.tf.nn.softmax(logits, axis=-1)
sess_config = self.tf.ConfigProto(allow_soft_placement=True)
sess_config.gpu_options.allow_growth = True
graph = self.probabilities.graph
saver = self.tf.train.Saver()
self.sess = self.tf.Session(config=sess_config, graph=graph)
self.sess.run(self.tf.global_variables_initializer())
self.tf.reset_default_graph()
saver.restore(self.sess, ckpt)
except Exception as e:
self.logger.error(e)
def predict(self, request_list):
"""
bert model predict
:return: label, similarity
:param request_list: request list, each element is text_a and text_b
"""
# with self.sess.as_default():
input_ids = []
input_masks = []
segment_ids = []
for d in request_list:
text_a = d[0]
text_b = d[1]
input_id, input_mask, segment_id = self._convert_single_example(text_a, text_b)
input_ids.append(input_id)
input_masks.append(input_mask)
segment_ids.append(segment_id)
predict_result = None
try:
predict_result = self.sess.run(self.probabilities, feed_dict={self.input_ids: input_ids,
self.input_mask: input_masks,
self.input_type_ids: segment_ids})
except Exception as e:
self.logger.error(e)
finally:
return predict_result
def _convert_single_example(self, text_a, text_b):
"""
convert text a and text b to id, padding [CLS] [SEP]
:param text_a: text a
:param text_b: text b
:return: input ids, input mask, segment ids
"""
tokens = []
input_ids = []
segment_ids = []
input_mask = []
try:
text_a = self.tokenizer.tokenize(text_a)
text_b = self.tokenizer.tokenize(text_b)
self._truncate_seq_pair(text_a, text_b)
tokens.append("[CLS]")
segment_ids.append(0)
for token in text_a:
tokens.append(token)
segment_ids.append(0)
segment_ids.append(0)
tokens.append("[SEP]")
for token in text_b:
tokens.append(token)
segment_ids.append(1)
tokens.append('[SEP]')
segment_ids.append(1)
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < 45:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
except:
self.logger.error()
finally:
return input_ids, input_mask, segment_ids
def _truncate_seq_pair(self, tokens_a, tokens_b):
"""
Truncates a sequence pair in place to the maximum length.
:param tokens_a: text a
:param tokens_b: text b
"""
try:
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= 45 - 3:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
except:
self.logger.error()
if __name__ == "__main__":
bs = BertSim(gpu_no=0, log_dir='log', bert_sim_dir='bert_sim_model\\', verbose=True)
text_a = '华为还准备起诉美国政府'
text_b = '飞机出现后货舱火警信息'
print(bs.predict([[text_a, text_b]]))