-
Notifications
You must be signed in to change notification settings - Fork 0
/
value_generation_fusion.py
342 lines (289 loc) · 14.3 KB
/
value_generation_fusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
'''
Adding value appreciation to T5
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from logging import raiseExceptions
import os
import numpy as np
import copy
import math
import random
from typing import Optional, Tuple
from dataclasses import dataclass, field
from itertools import chain
import torch
import torch.utils.checkpoint
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import (
AutoConfig,
AutoTokenizer,
AutoModel,
AutoModelForSeq2SeqLM,
T5ForConditionalGeneration,
T5Model,
T5EncoderModel,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from transformers.modeling_outputs import Seq2SeqLMOutput,BaseModelOutput
from torch.optim import Optimizer
from utils.model_utils import Projection, LinearELU, SequenceMask
import torch.optim as optimizer_module
from transformers import logging
logging.set_verbosity_error()
SIM_DIM = 64
BOW_DIM = 1024
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class T5Pretrained(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_factor # Used for testing weights initialization
if isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
# Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
def _shift_right(self, input_ids):
decoder_start_token_id = self.model.config.decoder_start_token_id
pad_token_id = self.model.config.pad_token_id
assert (
decoder_start_token_id is not None
), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
return shifted_input_ids
class DSTGeneration(T5Pretrained):
def __init__(self, model_path, pretrain_config, knowledge_fusion, word_bow_loss, max_len, cache_dir, training_mode="pretrain"):
super(DSTGeneration, self).__init__(pretrain_config)
self.config=pretrain_config
self.model = T5ForConditionalGeneration.from_pretrained(
model_path,
from_tf=bool(".ckpt" in model_path),
config=pretrain_config,
cache_dir=cache_dir,
)
self.knowledge_fusion = knowledge_fusion
self.word_bow_loss = word_bow_loss
self.fact_layer = Projection(self.config.d_model, SIM_DIM)
self.prior_layer = Projection(self.config.d_model, SIM_DIM)
self.post_layer = Projection(self.config.d_model*2, SIM_DIM)
self.max_len=max_len
self.decoder_init_layer = nn.Linear(self.config.d_model, self.config.d_model, bias=False)
self.tanh = nn.Tanh()
self.lm_head = nn.Linear(self.config.d_model, self.config.vocab_size, bias=False)
self.dropout = nn.Dropout(0.5)
self.training_mode = training_mode
self.shared = nn.Embedding(self.config.vocab_size, self.config.d_model)
# Model parallel
self.model_parallel = False
self.device_map = None
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.model.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.model.encoder.block))
self.model.encoder.parallelize(self.device_map)
self.model.decoder.parallelize(self.device_map)
self.lm_head = self.lm_head.to(self.model.decoder.first_device)
self.model_parallel = True
def deparallelize(self):
self.model.encoder.deparallelize()
self.model.decoder.deparallelize()
self.model.encoder = self.model.encoder.to("cpu")
self.model.decoder = self.model.decoder.to("cpu")
self.lm_head = self.lm_head.to("cpu")
self.model_parallel = False
self.device_map = None
torch.cuda.empty_cache()
def get_encoder(self):
return self.model.encoder
def get_decoder(self):
return self.model.decoder
def sequence_mask(self, lengths, maxlen=None, dtype=torch.bool):
if maxlen is None:
maxlen = lengths.max()
row_vector = torch.arange(0, maxlen, 1, device=self.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = (row_vector < matrix).to(dtype=dtype, device=self.device)
return mask
def _safe_log(self,y):
return torch.log(torch.clamp(y, 1e-9))
def forward(
self,
input_ids=None,
decoder_input_ids=None,
value_candidate_embedding=None,
attention_mask=None,
decoder_attention_mask=None,
value_candidate_mask=None,
labels=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
max_candidate_num = value_candidate_embedding.size(1) if value_candidate_embedding is not None else None
golden_value_index = value_candidate_embedding
golden_encoder_states = None
if encoder_outputs is None:
# hidden_states of encoder outputs
encoder_outputs = self.model.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=return_dict,
)
# hidden_states of golden decoder outputs in training
golden_encoder_outputs = self.model.encoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
return_dict=return_dict,
)
golden_encoder_states = golden_encoder_outputs[0][:,0,:]
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
question_context_states, value_candidate_states = torch.split(encoder_outputs[0], encoder_outputs[0].size(1)//2, dim=1)
_, value_candidate_mask = torch.split(attention_mask, attention_mask.size(1)//2, dim=1)#
encoder_states = question_context_states[:,0,:]#
fact_projection = self.fact_layer(value_candidate_states)#
max_candidate_num = value_candidate_states.size(1) if value_candidate_states is not None else None#
# prior value distribution
prior_projection = self.prior_layer(encoder_states)
prior_projection = prior_projection.unsqueeze(1).repeat(1, max_candidate_num, 1)
prior_scores = torch.sum(prior_projection * fact_projection, -1)
fact_seq_mask = ~value_candidate_mask #[batch,value_seq_len]
unk_mask = self.sequence_mask(torch.ones(value_candidate_states.size(0),dtype=torch.float32, device=self.device), maxlen=max_candidate_num, dtype=torch.float32) #[batch,value_seq_len]
fact_mask = fact_seq_mask.masked_fill_(fact_seq_mask==1, -1e-10) + unk_mask.masked_fill_(unk_mask==1, -1e-10)
prior_scores += fact_mask
prior_distribution = F.softmax(prior_scores, dim=1)
knowledge_fusion = None
if golden_encoder_states is not None: # update post distribution if training
# post value distribution
post_inputs = torch.cat((golden_encoder_states,encoder_states), -1)
post_projection = self.post_layer(post_inputs)
post_projection = post_projection.unsqueeze(1).repeat(1, max_candidate_num, 1)
post_scores = torch.sum(post_projection * fact_projection, -1)
post_scores += fact_mask
post_distribution = F.softmax(post_scores, dim=1)
if self.knowledge_fusion== 'initDecoder':
knowledge_fusion = value_candidate_states * post_distribution.unsqueeze(-1)
knowledge_fusion = self.dropout(knowledge_fusion) #[batch, embedding_size]
else:
if self.knowledge_fusion== 'initDecoder':
knowledge_fusion = value_candidate_states * prior_distribution.unsqueeze(-1)#[batch,seq,embed]
knowledge_fusion = self.dropout(knowledge_fusion) #[batch, embedding_size]
if knowledge_fusion is None:
concatenated_encoder_states = encoder_outputs[0]
else:
concatenated_encoder_states = torch.cat((question_context_states, knowledge_fusion), 1)
if golden_encoder_states is not None:
# calc KLD and bag-of-word loss
post_prior_ratio = torch.div(post_distribution, torch.clamp(prior_distribution, 1e-9, 1.0))
kld_loss = post_distribution * self._safe_log(post_prior_ratio)
kld_loss = torch.mean(kld_loss, -1)
kld_loss = torch.sum(kld_loss) / value_candidate_states.size(0)
encoder_hidden_states = self.decoder_init_layer(concatenated_encoder_states)
encoder_hidden_states = self.tanh(encoder_hidden_states)
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
if labels is not None:# and decoder_input_ids is None and decoder_inputs_embeds is None:
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
hidden_states = hidden_states.to(self.decoder.first_device)
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.decoder.first_device)
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
decoder_outputs = self.model.decoder(input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_values=past_key_values,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = decoder_outputs[0]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.encoder.first_device)
self.lm_head = self.lm_head.to(self.encoder.first_device)
sequence_output = sequence_output.to(self.lm_head.weight.device)
lm_logits = self.lm_head(sequence_output)
#print("The size of lm_logits is: {}".format(lm_logits.shape))
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = kld_loss + loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
if not return_dict:
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {
"decoder_input_ids": input_ids,
"past_key_values": past,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
def prepare_decoder_input_ids_from_labels(self, labels):
return self._shift_right(labels)