forked from tspeterkim/flash-attention-minimal
-
Notifications
You must be signed in to change notification settings - Fork 1
/
flash_attention_2.cu
352 lines (309 loc) · 12.8 KB
/
flash_attention_2.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
#include <torch/types.h>
#include <cuda.h>
#include <cuda_runtime.h>
__global__
void flash_attention_2_forward_kernel(
const float* Q,
const float* K,
const float* V,
const int N,
const int d,
const int Tc,
const int Tr,
const int Bc,
const int Br,
const float softmax_scale,
float* L,
float* O
) {
int tx = threadIdx.x;
int bx = blockIdx.x; int by = blockIdx.y; // batch and head index
// Offset into Q,K,V,O - different for each batch and head
int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = nh
int lm_offset = (bx * gridDim.y * N) + (by * N); // offset for L
// Define SRAM for Q,K,V,S
extern __shared__ float sram[];
int tile_size = Bc * d; // size of Qi, Kj, Vj
float* Qi = sram;
float* Kj = &sram[tile_size];
float* Vj = &sram[tile_size * 2];
float* S = &sram[tile_size * 3];
for (int i = 0; i < Tr; ++i) {
if (i * Br + tx >= N)
break; // break if we are done with the sequence
// Load Qi from HBM to SRAM, l and m to registers
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
}
float row_m_prev = -INFINITY;
float row_l_prev = 0;
// Causal mask: j <= i
for (int j = 0; j <= i; ++j) {
__syncthreads();
// Load Kj, Vj from HBM to SRAM
for (int x = 0; x < d; x++) {
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
}
__syncthreads();
// S_i^j = softmax_scale * QiKj^T
// S_i^j[tx][y] = softmax_scale * Sum_{x = 0}^{d-1} Qi[tx][x] * Kj[y][x]
float row_m = -INFINITY;
for (int y = 0; y < Bc; y++) {
if (j * Bc + y >= N)
break; // break if we are done with the sequence
if (i * Br + tx < j * Bc + y)
break;
float sum = 0;
for (int x = 0; x < d; x++)
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
sum *= softmax_scale;
S[(Bc * tx) + y] = sum;
if (sum > row_m)
row_m = sum;
}
// m_i^j = max(m_i^j-1, row_max(S_i^j))
float new_row_m = max(row_m_prev, row_m);
// P_i^j = exp(S_i^j - m_i^j)
// P_i^j[tx][y] = exp(S_i^j[tx][y] - m_i^j)
float row_l = 0;
for (int y = 0; y < Bc; y++) {
if (j * Bc + y >= N)
break; // break if we are done with the sequence
if (i * Br + tx < j * Bc + y)
break;
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - new_row_m);
row_l += S[(Bc * tx) + y];
}
// l_i^j = (exp(m_i^j-1 - m_i^j) * l_i^j-1) + row_sum(P_i^j)
float row_m_exp = __expf(row_m_prev - new_row_m);
float new_row_l = (row_m_exp * row_l_prev) + row_l;
// O_i^j = diag(exp(m_i^j-1 - m_i^j))^-1 * O_i^j-1 + P_i^jVj
for (int x = 0; x < d; x++) {
float pv = 0; // Pij * Vj
for (int y = 0; y < Bc; y++) {
if (j * Bc + y >= N)
break; // break if we are done with the sequence
if (i * Br + tx < j * Bc + y)
break;
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
}
O[qkv_offset + (tile_size * i) + (tx * d) + x] = \
row_m_exp * O[qkv_offset + (tile_size * i) + (tx * d) + x] + pv;
}
// Update m and l
row_m_prev = new_row_m;
row_l_prev = new_row_l;
}
// O_i = diag(l_i^{Tc})^-1 * O_i^{Tc}
for (int x = 0; x < d; x++)
O[qkv_offset + (tile_size * i) + (tx * d) + x] /= row_l_prev;
// L_i = m_i^{Tc} + log(l_i^{Tc})
L[lm_offset + (Br * i) + tx] = row_m_prev + __logf(row_l_prev);
}
}
__global__
void flash_attention_2_backward_kernel(
const float* Q,
const float* K,
const float* V,
const float* O,
const float* dO,
const float* L,
const int N,
const int d,
const int Tc,
const int Tr,
const int Bc,
const int Br,
const float softmax_scale,
float* dQ,
float* dK,
float* dV
) {
int tx = threadIdx.x;
int bx = blockIdx.x; int by = blockIdx.y; // batch and head index
// Offset into Q,K,V,O - different for each batch and head
int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = nh
int lm_offset = (bx * gridDim.y * N) + (by * N); // offset for L
// Define SRAM for Q,K,V,S
extern __shared__ float sram[];
int col_tile_size = Bc * d; // size of Kj, Vj
int row_tile_size = Br * d; // size of Qi
float* Kj = sram;
float* Vj = &sram[col_tile_size];
float* dKj = &sram[col_tile_size * 2];
float* dVj = &sram[col_tile_size * 3];
float* Qi = &sram[col_tile_size * 4];
float* Oi = &sram[col_tile_size * 4 + row_tile_size];
float* dOi = &sram[col_tile_size * 4 + row_tile_size * 2];
// We also use S for P. Likewise, we use dS for dP.
// We can reuse the same memory because we don't need S and P at the same time.
// We also don't need dS and dP at the same time.
float* S = &sram[col_tile_size * 4 + row_tile_size * 3];
float* dS = &sram[col_tile_size * 4 + row_tile_size * 3 + Bc * Br];
for (int j = 0; j < Tc; j++) {
// Load Kj, Vj to SRAM
for (int x = 0; x < d; x++) {
Kj[(tx * d) + x] = K[qkv_offset + (col_tile_size * j) + (tx * d) + x];
Vj[(tx * d) + x] = V[qkv_offset + (col_tile_size * j) + (tx * d) + x];
}
// Initialize dKj, dVj to 0
for (int x = 0; x < d; x++) {
dKj[(tx * d) + x] = 0;
dVj[(tx * d) + x] = 0;
}
for (int i = j; i < Tr; i++) {
__syncthreads();
// Load Qi, Oi, dOi, dQi, li, mi to SRAM
// Also load l, m to registers
float Di = 0;
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = Q[qkv_offset + (row_tile_size * i) + (tx * d) + x];
Oi[(tx * d) + x] = O[qkv_offset + (row_tile_size * i) + (tx * d) + x];
dOi[(tx * d) + x] = dO[qkv_offset + (row_tile_size * i) + (tx * d) + x];
Di += dOi[(tx * d) + x] * Oi[(tx * d) + x];
}
float l_curr = L[lm_offset + (Br * i) + tx];
// Sij = softmax_scale * QiKj^T
// Sij[tx][y] = softmax_scale * Sum_{y = 0}^{Bc-1} Qi[tx][x] * Kj[y][x]
for (int y = 0; y < Bc; y++) {
float sum = 0;
for (int x = 0; x < d; x++) {
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
}
sum *= softmax_scale;
if (i * Br + tx < j * Bc + y)
sum = -INFINITY;
S[(Bc * tx) + y] = sum;
}
// Pij = diag(li)^-1 * exp(Sij - mi)
// Pij[tx][y] = (1 / li[tx]) * exp(Sij[tx][y] - mi[tx])
for (int y = 0; y < Bc; y++) {
if (i * Br + tx < j * Bc + y)
S[(Bc * tx) + y] = 0;
else
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - l_curr);
}
__syncthreads();
// dVj <- dVj + Pij^T * dOi
// dVj[tx][x] = dVj[tx][x] + Sum_{y = 0}^{Br-1} Pij[y][tx] * dOi[tx][x]
for (int x = 0; x < d; x++) {
float sum = 0;
for (int y = 0; y < Br; y++) {
sum += S[(Bc * y) + tx] * dOi[(tx * d) + x];
}
atomicAdd(&dVj[(tx * d) + x], sum);
}
// dPij <- dOi * Vj^T
// dPij[tx][y] = Sum_{x = 0}^{d-1} dOi[tx][x] * Vj[y][x]
for (int y = 0; y < Bc; y++) {
float sum = 0;
for (int x = 0; x < d; x++) {
sum += dOi[(tx * d) + x] * Vj[(y * d) + x];
}
dS[(Bc * tx) + y] = sum;
}
// dSij <- Pij * (dPij - Di)
// dSij[tx][y] = Pij[tx][y] * (dPij[tx][y] - Di[tx])
for (int y = 0; y < Bc; ++y) {
dS[(Bc * tx) + y] = S[(Bc * tx) + y] * (dS[(Bc * tx) + y] - Di);
}
// dQi <- dQi + softmax_scale * dSijKj
// dQ[tx][x] = dQ[tx][x] + softmax_scale * Sum_{y = 0}^{Bc-1} dSij[tx][y] * Kj[y][x]
for (int x = 0; x < d; x++) {
float sum = 0;
for (int y = 0; y < Bc; y++) {
sum += dS[(Bc * tx) + y] * Kj[(y * d) + x];
}
sum *= softmax_scale;
atomicAdd(&dQ[qkv_offset + (row_tile_size * i) + (tx * d) + x], sum);
}
__syncthreads();
// dKj <- dKj + softmax_scale * dSij^TQi
// dKj[tx][x] = dKj[tx][x] + softmax_scale * Sum_{y = 0}^{Br-1} dSij[y][tx] * Qi[y][x]
for (int x = 0; x < d; x++) {
float sum = 0;
for (int y = 0; y < Br; y++) {
sum += dS[(Bc * y) + tx] * Qi[(y * d) + x];
}
sum *= softmax_scale;
atomicAdd(&dKj[(tx * d) + x], sum);
}
}
// Upload Kj, Vj to HRAM
for (int x = 0; x < d; x++) {
dK[qkv_offset + (row_tile_size * j) + (tx * d) + x] = dKj[(tx * d) + x];
dV[qkv_offset + (row_tile_size * j) + (tx * d) + x] = dVj[(tx * d) + x];
}
}
}
std::vector<torch::Tensor> flash_attention_2_forward(torch::Tensor Q, torch::Tensor K, torch::Tensor V) {
// TODO: determine Bc, Br dynamically
const int Bc = 32; const int Br = 32;
const int B = Q.size(0); const int nh = Q.size(1);
const int N = Q.size(2); const int d = Q.size(3);
const int Tc = ceil((float) N / Bc); const int Tr = ceil((float) N / Br);
const float softmax_scale = 1.0 / sqrt(d);
// Initialize O, L to HBM
auto O = torch::zeros_like(Q);
auto L = torch::zeros({B, nh, N});
torch::Device device(torch::kCUDA);
L = L.to(device);
// Calculate SRAM size needed per block
int col_tile_size = Bc * d; // size of Kj, Vj
int row_tile_size = Br * d; // size of Qi
const int sram_size =
(2 * col_tile_size * sizeof(float)) // SRAM size for Kj, Vj
+ (row_tile_size * sizeof(float)) // SRAM size for Qi
+ (Bc * Br * sizeof(float)); // SRAM size for S
int max_sram_size;
cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0);
printf("Max shared memory: %d, requested shared memory: %d \n", max_sram_size, sram_size);
dim3 grid_dim(B, nh); // batch_size x num_heads
dim3 block_dim(Br); // Br threads per block
flash_attention_2_forward_kernel<<<grid_dim, block_dim, sram_size>>>(
Q.data_ptr<float>(), K.data_ptr<float>(), V.data_ptr<float>(),
N, d, Tc, Tr, Bc, Br, softmax_scale,
L.data_ptr<float>(), O.data_ptr<float>()
);
return {O, L};
}
std::vector<torch::Tensor> flash_attention_2_backward(
torch::Tensor Q,
torch::Tensor K,
torch::Tensor V,
torch::Tensor O,
torch::Tensor dO,
torch::Tensor L
) {
// TODO: determine Bc, Br dynamically
const int Bc = 16; const int Br = 16;
const int B = Q.size(0); const int nh = Q.size(1);
const int N = Q.size(2); const int d = Q.size(3);
const int Tc = ceil((float) N / Bc); const int Tr = ceil((float) N / Br);
const float softmax_scale = 1.0 / sqrt(d);
auto dQ = torch::zeros_like(Q);
auto dK = torch::zeros_like(K);
auto dV = torch::zeros_like(V);
// Calculate SRAM size needed per block
int col_tile_size = Bc * d; // size of Kj, Vj
int row_tile_size = Br * d; // size of Qi, Oi, dOi
const int sram_size =
(4 * col_tile_size * sizeof(float)) // SRAM size for Kj, Vj, dKj, dVj
+ (3 * row_tile_size * sizeof(float)) // SRAM size for Qi, Oi, dOi
+ (2 * Br * Bc * sizeof(float)); // SRAM size for S, dS
int max_sram_size;
cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0);
printf("Max shared memory: %d, requested shared memory: %d \n", max_sram_size, sram_size);
dim3 grid_dim(B, nh); // batch_size x num_heads
dim3 block_dim(Br); // Bc threads per block
flash_attention_2_backward_kernel<<<grid_dim, block_dim, sram_size>>>(
Q.data_ptr<float>(), K.data_ptr<float>(), V.data_ptr<float>(),
O.data_ptr<float>(), dO.data_ptr<float>(),
L.data_ptr<float>(),
N, d, Tc, Tr, Bc, Br, softmax_scale,
dQ.data_ptr<float>(), dK.data_ptr<float>(), dV.data_ptr<float>()
);
return {dQ, dK, dV};
}