-
Notifications
You must be signed in to change notification settings - Fork 1
/
padding.py
106 lines (80 loc) · 3.82 KB
/
padding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import math
def unpad_input(out_, in_, indices):
out_[:,:] = in_[indices[:],:]
def pad_input(out_, in_, indices):
out_[indices[:],:] = in_[:,:]
def unpad_mask(out_, in_, indices):
out_[:] = in_.flatten()[indices[:]]
def generate_mask(attention_mask, heads, pad=False, fuse_mask=True, unpad_fmha=False):
if unpad_fmha:
seqlen = attention_mask.sum(dim=1).to(dtype=torch.int32).flatten()
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
maxseqlen = seqlen.max().item()
b = attention_mask.shape[0]
cu_seqlens = torch.zeros(b+1, device=attention_mask.device, dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlen, dim=0)
ntokens = cu_seqlens[-1].item()
return indices, attention_mask, seqlen, ntokens, cu_seqlens, seqlen, maxseqlen
seqlen = attention_mask.sum(dim=1).float().cpu()
if pad == False:
seqlen[:] = ((seqlen[:] + 16 - 1) / 16).floor()*16
seqlen[seqlen < 16] = 16
seqlen = seqlen.int()
ntokens = seqlen.sum().item()
else:
batch = attention_mask.shape[0]
maxseqlen = attention_mask.shape[1]
seqlen.fill_(maxseqlen)
seqlen = seqlen.int()
ntokens = batch * maxseqlen
padded_mask = attention_mask.clone()
for i in range(len(seqlen)):
padded_mask[i,:seqlen[i]] = 1
indices = torch.nonzero(padded_mask.flatten(), as_tuple=False).flatten()
if pad == False and fuse_mask == True:
mask = torch.zeros([ntokens], device="cuda", dtype=torch.float16)
unpad_mask(mask, attention_mask, indices)
mask = (1 - mask) * -10000.0
elif pad == False and fuse_mask == False:
padded_mask = (padded_mask.unsqueeze(1) * padded_mask.unsqueeze(2)).unsqueeze(1).half().repeat(1, heads, 1, 1)
indices_mask = torch.nonzero(padded_mask.flatten(), as_tuple=False).flatten()
mask = torch.zeros([len(indices_mask)], device="cuda", dtype=torch.float16)
unpad_mask(mask, padded_mask, indices_mask)
mask = (1 - mask) * -10000.0
elif pad == True and fuse_mask == True:
mask = -10000.0 * (1 - attention_mask).half().view(-1)
elif pad == True and fuse_mask == False:
mask = -10000.0 * (1 - (attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2))).unsqueeze(1).half().repeat(1, heads, 1, 1).view(-1)
return indices, mask, seqlen, ntokens, None, None, None
class PadInput(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices, batch, maxseqlen, hidden, ntokens):
ctx.save_for_backward(indices)
ctx.hidden = hidden
ctx.ntokens = ntokens
ntokens = batch*maxseqlen
output = torch.zeros([ntokens,hidden], device="cuda", dtype=torch.float16)
pad_input(output, input, indices)
return output[:ntokens]
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
grad_input = torch.zeros([ctx.ntokens,ctx.hidden], device="cuda", dtype=torch.float16)
unpad_input(grad_input, grad_output, indices)
return grad_input[:ctx.ntokens], None, None, None, None, None
class UnpadInput(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices, batch, maxseqlen, hidden, ntokens):
ctx.save_for_backward(indices)
ctx.hidden = hidden
ctx.ntokens = batch*maxseqlen
output = torch.zeros([ntokens, hidden], device="cuda", dtype=torch.float16)
unpad_input(output, input, indices)
return output[:ntokens]
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
grad_input = torch.zeros([ctx.ntokens,ctx.hidden], device="cuda", dtype=torch.float16)
pad_input(grad_input, grad_output, indices)
return grad_input[:ctx.ntokens], None, None, None, None, None