Skip to content

Commit

Permalink
trafo: add FLashAttention-2 for fused SoftMax
Browse files Browse the repository at this point in the history
  • Loading branch information
Viviane Potocnik committed Sep 15, 2023
1 parent ba247a5 commit 6a5feb7
Showing 1 changed file with 178 additions and 27 deletions.
205 changes: 178 additions & 27 deletions sw/apps/transformer/src/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ typedef struct transformer_layer_struct {
precision_t dtype;
} transformer_layer_t;

/* Debugging Variables*/
dump_float(debug, 8);
dump_uint(idx, 7);
dump_float(ifmap, 6);
Expand Down Expand Up @@ -451,15 +452,17 @@ static inline void transformer_layer(transformer_layer_t *const l) {
snrt_cluster_hw_barrier();

/////////////////////////////////////////////////
//// Fused Attention ////
//// FlashAttention-2 ////
/////////////////////////////////////////////////

int32_t RowBlock_Size = 64;
int32_t ColBlock_Size = 64;
int32_t B_r = 64; // Row block size
int32_t B_c = 64; // Column block size

// Compute the storage size of the matrices
uint32_t RowBlock = RowBlock_Size * l->positional_embeddings * sizeof(float);
uint32_t ColBlock = ColBlock_Size * l->positional_embeddings * sizeof(float);
uint32_t RowBlock = B_r * l->positional_embeddings * sizeof(float);
uint32_t ColBlock = B_c * l->positional_embeddings * sizeof(float);
uint32_t out_size = l->seq_len * l->positional_embeddings * sizeof(float);
float output_matrix[out_size];

// Reset the TCDM pointer
tcdm_ptr = (float *)snrt_l1_next();
Expand All @@ -470,20 +473,26 @@ static inline void transformer_layer(transformer_layer_t *const l) {
float *V_tile = tcdm_ptr;
tcdm_ptr += ColBlock;
float *A_tile = tcdm_ptr;
tcdm_ptr += RowBlock_Size * ColBlock_Size * sizeof(float);
float *P_tile = tcdm_ptr;
tcdm_ptr += RowBlock_Size * ColBlock_Size * sizeof(float);
tcdm_ptr += B_r * B_c * sizeof(float);
float *P_tilde = tcdm_ptr;
tcdm_ptr += B_r * B_c * sizeof(float);
float *O_tile = tcdm_ptr;
tcdm_ptr += B_r * l->positional_embeddings * sizeof(float);
float *m_i = tcdm_ptr;
tcdm_ptr += RowBlock_Size;
tcdm_ptr += B_r;
float *m_i_prev = tcdm_ptr;
tcdm_ptr += RowBlock_Size;
tcdm_ptr += B_r;
float *l_j = tcdm_ptr;
tcdm_ptr += RowBlock_Size;
tcdm_ptr += B_r;
float *shifted_exp = tcdm_ptr;
tcdm_ptr += RowBlock_Size;
tcdm_ptr += B_r;

int32_t T_r = l->seq_len / RowBlock_Size;
int32_t T_c = l->seq_len / ColBlock_Size;
// compute memory usage
used_memory_kB = (float)((uint64_t)tcdm_ptr - (uint64_t)snrt_l1_next()) / 1024.0f;
dump_debug(used_memory_kB);

int32_t T_r = l->seq_len / B_r; // Number of row blocks
int32_t T_c = l->seq_len / B_c; // Number of column blocks

for (int i = 0; i < T_r; i++) {
dump_idx(i);
Expand All @@ -493,18 +502,21 @@ static inline void transformer_layer(transformer_layer_t *const l) {
snrt_dma_txid_t txid_Q_tile =
snrt_dma_start_2d(
Q_tile, /* dst */
q_lin + i * RowBlock_Size * l->positional_embeddings,/* src */
q_lin + i * B_r * l->positional_embeddings, /* src */
RowBlock, /* size */
l->positional_embeddings * sizeof(float), /* dst_stride */
RowBlock_Size * l->positional_embeddings * sizeof(float),/* src_stride */
RowBlock_Size); /* repetitions */
l->positional_embeddings * sizeof(float), /* dst_stride */
B_r * l->positional_embeddings * sizeof(float), /* src_stride */
B_r); /* repetitions */

snrt_dma_wait_all();
}

// Initialize the SoftMax statistics
l_j[i] = 0.0f; // row sum
m_i[i] = -INFINITY; // row max
for (int r = 0; r < B_r; r++) {
m_i[r] = -INFINITY;
m_i_prev[r] = -INFINITY;
l_j[r] = 0.0f;
}

snrt_cluster_hw_barrier();

Expand All @@ -516,27 +528,166 @@ static inline void transformer_layer(transformer_layer_t *const l) {
snrt_dma_txid_t txid_K_tile =
snrt_dma_start_2d(
K_tile, /* dst */
k_lin + j * ColBlock_Size * l->positional_embeddings,/* src */
k_lin + j * B_c * l->positional_embeddings,/* src */
ColBlock, /* size */
l->positional_embeddings * sizeof(float), /* dst_stride */
ColBlock_Size * l->positional_embeddings * sizeof(float),/* src_stride */
ColBlock_Size); /* repetitions */
B_c * l->positional_embeddings * sizeof(float),/* src_stride */
B_c); /* repetitions */

snrt_dma_txid_t txid_V_tile =
snrt_dma_start_2d(
V_tile, /* dst */
v_lin + j * ColBlock_Size * l->positional_embeddings,/* src */
v_lin + j * B_c * l->positional_embeddings,/* src */
ColBlock, /* size */
l->positional_embeddings * sizeof(float), /* dst_stride */
ColBlock_Size * l->positional_embeddings * sizeof(float),/* src_stride */
ColBlock_Size); /* repetitions */
B_c * l->positional_embeddings * sizeof(float),/* src_stride */
B_c); /* repetitions */

snrt_dma_wait_all();
}

snrt_cluster_hw_barrier();

// Compute the A_tile
if (snrt_is_compute_core()) {
// printf("Computing A_tile %d/%d\n", i, T_r);
// determine the row offset for the current compute core
// every core computes B_r / num_cores rows
// and jumps by B_r rows in the next iteration
uint32_t row_offset = (i * B_r + (B_r / num_cores) * compute_id) * l->positional_embeddings;
// determine the column offset for the current compute core
// every core computes B_c / num_cores columns
// and jumps by B_c columns in the next iteration
uint32_t col_offset = (j * B_c + (B_c / num_cores) * compute_id) * l->positional_embeddings;

// compute a tile of the attention matrix
// A_tile = Q_tile * K_tile^T (B_r x B_c)
// we parallelize the rows of Q_tile
gemm_fp32_baseline(B_r / num_cores, B_c, l->positional_embeddings,
&Q_tile[row_offset], l->positional_embeddings, 0, K_tile,
l->positional_embeddings, 0, A_tile, B_c, 0.0f);

// print the current attention matrix for debugging
if (DEBUG_VERBOSE == 1) {
if (BANSHEE == 0) {
for (int i = 0; i < B_r * B_c; i++) {
dump_value(A_tile[i]);
}
} else {
for (int i = 0; i < B_r * B_c; i++) {
printf("A_tile[%d] = %f\n", i, A_tile[i]);
}
}
} else if (DEBUG == 1)
{
if (BANSHEE == 0) {
float attention_start = A_tile[0];
float attention_end = A_tile[B_r * B_c - 1];
dump_value(attention_start);
dump_value(attention_end);
} else {
printf("A_tile[0] = %f\n", A_tile[0]);
printf("A_tile[%d] = %f\n", B_r * B_c - 1, A_tile[B_r * B_c - 1]);
}
}

// determine the row max of the rows of A_tile (B_r x B_c)
// we parallelize over the rows of A_tile
int32_t A_offset = compute_id * (B_r / num_cores) * B_c;
for (int k = 0; k < (B_r / num_cores); k++) {
// store the previous row max
m_i_prev[k + compute_id * (B_r / num_cores)] = m_i[k + compute_id * (B_r / num_cores)];
for (int l = 0; l < B_c; l++) {
if (A_tile[A_offset + k * B_c + l] > m_i[k + compute_id * (B_r / num_cores)]) {
m_i[k + compute_id * (B_r / num_cores)] = A_tile[A_offset + k * B_c + l];
}
}
}

// next we determine the shifted exponential of A_tile
// shifted_exp = exp(m_i_prev - m_i)
for (int k = 0; k < (B_r / num_cores); k++) {
shifted_exp[k + compute_id * (B_r / num_cores)] =
expf(m_i_prev[k + compute_id * (B_r / num_cores)] - m_i[k + compute_id * (B_r / num_cores)]);
}

// we compute the P_tilde matrix (B_r x B_c)
// P_tilde = exp(A_tile - m_i)
int32_t P_tilde_offset = A_offset;
for (int k = 0; k < (B_r / num_cores); k++) {
for (int l = 0; l < B_c; l++) {
P_tilde[P_tilde_offset + k * B_c + l] =
expf(A_tile[A_offset + k * B_c + l] - m_i[k + compute_id * (B_r / num_cores)]);
}
}

// we scale the row sum by the shifted exponential
// l_i = shifted_exp * l_i
for (int k = 0; k < (B_r / num_cores); k++) {
l_j[k + compute_id * (B_r / num_cores)] =
shifted_exp[k + compute_id * (B_r / num_cores)] * l_j[k + compute_id * (B_r / num_cores)];
for (int l = 0; l < B_c; l++) {
l_j[k + compute_id * (B_r / num_cores)] += P_tilde[P_tilde_offset + k * B_c + l];
}
}

// finally we compute the fused softmax
// O_i = diag(shifted_exp)^-1 * O_i_prev + P_tilde * V_tile (B_r x E)
// we parallelize over the rows of O_tile
uint32_t O_offset = compute_id * (B_r / num_cores) * l->positional_embeddings;
for (int k = 0; k < B_r / num_cores; k++) {
for (int p = 0; p < l->positional_embeddings; p++) {
// the current row will be scaled by the shifted exponential
O_tile[O_offset + k * l->positional_embeddings + p] =
(1 / shifted_exp[k + compute_id * (B_r / num_cores)]) * O_tile[O_offset + k * l->positional_embeddings + p];
// we compute the matrix multiplication of P_tilde and V_tile
for (int o = 0; o < B_c; o++) {
O_tile[O_offset + k * (l->positional_embeddings) + p] +=
P_tilde[P_tilde_offset + k * B_c + o] * V_tile[o * (l->positional_embeddings) + p];
}
}
}

}
} // end of T_c loop

snrt_cluster_hw_barrier();

// we scale the O_tile by the final row sum
// O_i = diag(l_i)^-1 * O_i
if (snrt_is_compute_core()) {
// printf("Scaling O_tile %d/%d\n", i, T_r);
// we parallelize over the rows of O_tile
uint32_t O_offset = compute_id * (B_r / num_cores) * l->positional_embeddings;
for (int k = 0; k < B_r / num_cores; k++) {
for (int p = 0; p < l->positional_embeddings; p++) {
O_tile[O_offset + k * l->positional_embeddings + p] =
(1 / l_j[k + compute_id * (B_r / num_cores)]) * O_tile[O_offset + k * l->positional_embeddings + p];
}
}
}
}

snrt_cluster_hw_barrier();

// we write back O_i as the i-th block row of the output matrix O to DRAM
if (!snrt_is_compute_core()) {
// printf("Writing back O_tile %d/%d\n", i, T_r);
// determine the current offset in the output matrix
uint32_t output_offset = i * B_r * l->positional_embeddings;

// write back the output matrix
snrt_dma_txid_t txid_output =
snrt_dma_start_2d(
output_matrix + output_offset, /* dst */
O_tile, /* src */
B_r * l->positional_embeddings * sizeof(float), /* size */
l->positional_embeddings * sizeof(float), /* dst_stride */
B_r * l->positional_embeddings * sizeof(float), /* src_stride */
1); /* repetitions */

snrt_dma_wait_all();
}
} // end of T_r loop

snrt_global_barrier();
}

0 comments on commit 6a5feb7

Please sign in to comment.