From 6a5feb72af6eab991314bfa03e5bb2294b0731ac Mon Sep 17 00:00:00 2001 From: Viviane Potocnik Date: Fri, 15 Sep 2023 15:48:35 +0200 Subject: [PATCH] trafo: add FLashAttention-2 for fused SoftMax --- sw/apps/transformer/src/transformer.h | 205 ++++++++++++++++++++++---- 1 file changed, 178 insertions(+), 27 deletions(-) diff --git a/sw/apps/transformer/src/transformer.h b/sw/apps/transformer/src/transformer.h index 824bb3948c..d92b8d26fa 100644 --- a/sw/apps/transformer/src/transformer.h +++ b/sw/apps/transformer/src/transformer.h @@ -66,6 +66,7 @@ typedef struct transformer_layer_struct { precision_t dtype; } transformer_layer_t; +/* Debugging Variables*/ dump_float(debug, 8); dump_uint(idx, 7); dump_float(ifmap, 6); @@ -451,15 +452,17 @@ static inline void transformer_layer(transformer_layer_t *const l) { snrt_cluster_hw_barrier(); ///////////////////////////////////////////////// - //// Fused Attention //// + //// FlashAttention-2 //// ///////////////////////////////////////////////// - int32_t RowBlock_Size = 64; - int32_t ColBlock_Size = 64; + int32_t B_r = 64; // Row block size + int32_t B_c = 64; // Column block size // Compute the storage size of the matrices - uint32_t RowBlock = RowBlock_Size * l->positional_embeddings * sizeof(float); - uint32_t ColBlock = ColBlock_Size * l->positional_embeddings * sizeof(float); + uint32_t RowBlock = B_r * l->positional_embeddings * sizeof(float); + uint32_t ColBlock = B_c * l->positional_embeddings * sizeof(float); + uint32_t out_size = l->seq_len * l->positional_embeddings * sizeof(float); + float output_matrix[out_size]; // Reset the TCDM pointer tcdm_ptr = (float *)snrt_l1_next(); @@ -470,20 +473,26 @@ static inline void transformer_layer(transformer_layer_t *const l) { float *V_tile = tcdm_ptr; tcdm_ptr += ColBlock; float *A_tile = tcdm_ptr; - tcdm_ptr += RowBlock_Size * ColBlock_Size * sizeof(float); - float *P_tile = tcdm_ptr; - tcdm_ptr += RowBlock_Size * ColBlock_Size * sizeof(float); + tcdm_ptr += B_r * B_c * sizeof(float); + float *P_tilde = tcdm_ptr; + tcdm_ptr += B_r * B_c * sizeof(float); + float *O_tile = tcdm_ptr; + tcdm_ptr += B_r * l->positional_embeddings * sizeof(float); float *m_i = tcdm_ptr; - tcdm_ptr += RowBlock_Size; + tcdm_ptr += B_r; float *m_i_prev = tcdm_ptr; - tcdm_ptr += RowBlock_Size; + tcdm_ptr += B_r; float *l_j = tcdm_ptr; - tcdm_ptr += RowBlock_Size; + tcdm_ptr += B_r; float *shifted_exp = tcdm_ptr; - tcdm_ptr += RowBlock_Size; + tcdm_ptr += B_r; - int32_t T_r = l->seq_len / RowBlock_Size; - int32_t T_c = l->seq_len / ColBlock_Size; + // compute memory usage + used_memory_kB = (float)((uint64_t)tcdm_ptr - (uint64_t)snrt_l1_next()) / 1024.0f; + dump_debug(used_memory_kB); + + int32_t T_r = l->seq_len / B_r; // Number of row blocks + int32_t T_c = l->seq_len / B_c; // Number of column blocks for (int i = 0; i < T_r; i++) { dump_idx(i); @@ -493,18 +502,21 @@ static inline void transformer_layer(transformer_layer_t *const l) { snrt_dma_txid_t txid_Q_tile = snrt_dma_start_2d( Q_tile, /* dst */ - q_lin + i * RowBlock_Size * l->positional_embeddings,/* src */ + q_lin + i * B_r * l->positional_embeddings, /* src */ RowBlock, /* size */ - l->positional_embeddings * sizeof(float), /* dst_stride */ - RowBlock_Size * l->positional_embeddings * sizeof(float),/* src_stride */ - RowBlock_Size); /* repetitions */ + l->positional_embeddings * sizeof(float), /* dst_stride */ + B_r * l->positional_embeddings * sizeof(float), /* src_stride */ + B_r); /* repetitions */ snrt_dma_wait_all(); } // Initialize the SoftMax statistics - l_j[i] = 0.0f; // row sum - m_i[i] = -INFINITY; // row max + for (int r = 0; r < B_r; r++) { + m_i[r] = -INFINITY; + m_i_prev[r] = -INFINITY; + l_j[r] = 0.0f; + } snrt_cluster_hw_barrier(); @@ -516,27 +528,166 @@ static inline void transformer_layer(transformer_layer_t *const l) { snrt_dma_txid_t txid_K_tile = snrt_dma_start_2d( K_tile, /* dst */ - k_lin + j * ColBlock_Size * l->positional_embeddings,/* src */ + k_lin + j * B_c * l->positional_embeddings,/* src */ ColBlock, /* size */ l->positional_embeddings * sizeof(float), /* dst_stride */ - ColBlock_Size * l->positional_embeddings * sizeof(float),/* src_stride */ - ColBlock_Size); /* repetitions */ + B_c * l->positional_embeddings * sizeof(float),/* src_stride */ + B_c); /* repetitions */ snrt_dma_txid_t txid_V_tile = snrt_dma_start_2d( V_tile, /* dst */ - v_lin + j * ColBlock_Size * l->positional_embeddings,/* src */ + v_lin + j * B_c * l->positional_embeddings,/* src */ ColBlock, /* size */ l->positional_embeddings * sizeof(float), /* dst_stride */ - ColBlock_Size * l->positional_embeddings * sizeof(float),/* src_stride */ - ColBlock_Size); /* repetitions */ + B_c * l->positional_embeddings * sizeof(float),/* src_stride */ + B_c); /* repetitions */ snrt_dma_wait_all(); } snrt_cluster_hw_barrier(); + + // Compute the A_tile + if (snrt_is_compute_core()) { + // printf("Computing A_tile %d/%d\n", i, T_r); + // determine the row offset for the current compute core + // every core computes B_r / num_cores rows + // and jumps by B_r rows in the next iteration + uint32_t row_offset = (i * B_r + (B_r / num_cores) * compute_id) * l->positional_embeddings; + // determine the column offset for the current compute core + // every core computes B_c / num_cores columns + // and jumps by B_c columns in the next iteration + uint32_t col_offset = (j * B_c + (B_c / num_cores) * compute_id) * l->positional_embeddings; + + // compute a tile of the attention matrix + // A_tile = Q_tile * K_tile^T (B_r x B_c) + // we parallelize the rows of Q_tile + gemm_fp32_baseline(B_r / num_cores, B_c, l->positional_embeddings, + &Q_tile[row_offset], l->positional_embeddings, 0, K_tile, + l->positional_embeddings, 0, A_tile, B_c, 0.0f); + + // print the current attention matrix for debugging + if (DEBUG_VERBOSE == 1) { + if (BANSHEE == 0) { + for (int i = 0; i < B_r * B_c; i++) { + dump_value(A_tile[i]); + } + } else { + for (int i = 0; i < B_r * B_c; i++) { + printf("A_tile[%d] = %f\n", i, A_tile[i]); + } + } + } else if (DEBUG == 1) + { + if (BANSHEE == 0) { + float attention_start = A_tile[0]; + float attention_end = A_tile[B_r * B_c - 1]; + dump_value(attention_start); + dump_value(attention_end); + } else { + printf("A_tile[0] = %f\n", A_tile[0]); + printf("A_tile[%d] = %f\n", B_r * B_c - 1, A_tile[B_r * B_c - 1]); + } + } + + // determine the row max of the rows of A_tile (B_r x B_c) + // we parallelize over the rows of A_tile + int32_t A_offset = compute_id * (B_r / num_cores) * B_c; + for (int k = 0; k < (B_r / num_cores); k++) { + // store the previous row max + m_i_prev[k + compute_id * (B_r / num_cores)] = m_i[k + compute_id * (B_r / num_cores)]; + for (int l = 0; l < B_c; l++) { + if (A_tile[A_offset + k * B_c + l] > m_i[k + compute_id * (B_r / num_cores)]) { + m_i[k + compute_id * (B_r / num_cores)] = A_tile[A_offset + k * B_c + l]; + } + } + } + + // next we determine the shifted exponential of A_tile + // shifted_exp = exp(m_i_prev - m_i) + for (int k = 0; k < (B_r / num_cores); k++) { + shifted_exp[k + compute_id * (B_r / num_cores)] = + expf(m_i_prev[k + compute_id * (B_r / num_cores)] - m_i[k + compute_id * (B_r / num_cores)]); + } + + // we compute the P_tilde matrix (B_r x B_c) + // P_tilde = exp(A_tile - m_i) + int32_t P_tilde_offset = A_offset; + for (int k = 0; k < (B_r / num_cores); k++) { + for (int l = 0; l < B_c; l++) { + P_tilde[P_tilde_offset + k * B_c + l] = + expf(A_tile[A_offset + k * B_c + l] - m_i[k + compute_id * (B_r / num_cores)]); + } + } + + // we scale the row sum by the shifted exponential + // l_i = shifted_exp * l_i + for (int k = 0; k < (B_r / num_cores); k++) { + l_j[k + compute_id * (B_r / num_cores)] = + shifted_exp[k + compute_id * (B_r / num_cores)] * l_j[k + compute_id * (B_r / num_cores)]; + for (int l = 0; l < B_c; l++) { + l_j[k + compute_id * (B_r / num_cores)] += P_tilde[P_tilde_offset + k * B_c + l]; + } + } + + // finally we compute the fused softmax + // O_i = diag(shifted_exp)^-1 * O_i_prev + P_tilde * V_tile (B_r x E) + // we parallelize over the rows of O_tile + uint32_t O_offset = compute_id * (B_r / num_cores) * l->positional_embeddings; + for (int k = 0; k < B_r / num_cores; k++) { + for (int p = 0; p < l->positional_embeddings; p++) { + // the current row will be scaled by the shifted exponential + O_tile[O_offset + k * l->positional_embeddings + p] = + (1 / shifted_exp[k + compute_id * (B_r / num_cores)]) * O_tile[O_offset + k * l->positional_embeddings + p]; + // we compute the matrix multiplication of P_tilde and V_tile + for (int o = 0; o < B_c; o++) { + O_tile[O_offset + k * (l->positional_embeddings) + p] += + P_tilde[P_tilde_offset + k * B_c + o] * V_tile[o * (l->positional_embeddings) + p]; + } + } + } + + } + } // end of T_c loop + + snrt_cluster_hw_barrier(); + + // we scale the O_tile by the final row sum + // O_i = diag(l_i)^-1 * O_i + if (snrt_is_compute_core()) { + // printf("Scaling O_tile %d/%d\n", i, T_r); + // we parallelize over the rows of O_tile + uint32_t O_offset = compute_id * (B_r / num_cores) * l->positional_embeddings; + for (int k = 0; k < B_r / num_cores; k++) { + for (int p = 0; p < l->positional_embeddings; p++) { + O_tile[O_offset + k * l->positional_embeddings + p] = + (1 / l_j[k + compute_id * (B_r / num_cores)]) * O_tile[O_offset + k * l->positional_embeddings + p]; + } + } } - } + + snrt_cluster_hw_barrier(); + + // we write back O_i as the i-th block row of the output matrix O to DRAM + if (!snrt_is_compute_core()) { + // printf("Writing back O_tile %d/%d\n", i, T_r); + // determine the current offset in the output matrix + uint32_t output_offset = i * B_r * l->positional_embeddings; + + // write back the output matrix + snrt_dma_txid_t txid_output = + snrt_dma_start_2d( + output_matrix + output_offset, /* dst */ + O_tile, /* src */ + B_r * l->positional_embeddings * sizeof(float), /* size */ + l->positional_embeddings * sizeof(float), /* dst_stride */ + B_r * l->positional_embeddings * sizeof(float), /* src_stride */ + 1); /* repetitions */ + + snrt_dma_wait_all(); + } + } // end of T_r loop snrt_global_barrier(); } \ No newline at end of file