From ef44af05e37eb2e1d95b1680be77e664bf784490 Mon Sep 17 00:00:00 2001 From: Viviane Potocnik Date: Mon, 16 Oct 2023 14:55:46 +0200 Subject: [PATCH] transformer: add flash attention layer --- sw/apps/transformer/src/transformer.h | 307 +++++++++++++++++++++++++- 1 file changed, 305 insertions(+), 2 deletions(-) diff --git a/sw/apps/transformer/src/transformer.h b/sw/apps/transformer/src/transformer.h index 5b835a08a2..384bfaa089 100644 --- a/sw/apps/transformer/src/transformer.h +++ b/sw/apps/transformer/src/transformer.h @@ -16,7 +16,8 @@ #define DEBUG 1 #define BRIEF 1 // enable this define to only compute two iterations of the loops #define LAYERNORM 0 // enable this define to compute the layernorm -#define LINEAR_1 1 // enable this define to compute the linear layer 1 +#define LINEAR_1 0 // enable this define to compute the linear layer 1 +#define FLASH_ATTENTION 1 // enable this define to compute the flash attention /** * @struct transformer_layer_fp64_struct @@ -54,6 +55,9 @@ typedef struct transformer_layer_fp64_struct { uint32_t S_tile_ln; uint32_t S_tile_lin1; uint32_t P_tile_lin1; + uint32_t Br_tile_fa; + uint32_t Bc_tile_fa; + uint32_t positional_embeddings_fa; uint32_t embeddings; uint32_t positional_embeddings; uint32_t feedforward_len; @@ -70,6 +74,9 @@ typedef struct transformer_layer_fp64_struct { double *Q_lin; double *K_lin; double *V_lin; + double *Q_fa; + double *K_fa; + double *V_fa; double *O; double *ofmap; double *query; @@ -214,6 +221,7 @@ static inline void fused_mlp_baseline(float *input, float *output, int32_t ldI, // Debugging functions dump_uint(idx, 6); dump_float(debug, 7); +dump_uint(id, 8); /** * @brief Transformer layer @@ -224,6 +232,7 @@ dump_float(debug, 7); static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) { uint32_t compute_id = snrt_global_core_idx(); uint32_t num_cores = snrt_cluster_compute_core_num(); + uint32_t num_clusters = snrt_cluster_num(); uint32_t seq_len = l->seq_len; uint32_t embeddings = l->embeddings; @@ -231,6 +240,10 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) { uint32_t heads = l->heads; uint32_t eps = l->eps; + // dump_id(compute_id); + // dump_id(num_cores); + dump_id(num_clusters); + ///////////////////////////////////////////////////////////////////// ////////////// MULTI-HEAD SELF-ATTENTION BLOCK ///////////////// /////////////////////////////////////////////////////////////////// @@ -433,11 +446,11 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) { snrt_cluster_hw_barrier(); if (snrt_is_compute_core()) { - uint32_t start_gemm = snrt_mcycle(); // compute the gemm for the current row block and column block uint32_t row_offset = (B_r_lin1 / num_cores) * compute_id * embeddings; // printf("core %d: row_offset = %d\n", compute_id, row_offset); // ifmap: B_r x E, weights: E x B_c + uint32_t start_gemm = snrt_mcycle(); gemm_fp64_baseline(B_r_lin1 / num_cores, B_c_lin1, l->embeddings, &ifmap_lin1[row_offset], l->embeddings, 0, weights_q, B_c_lin1, 0, @@ -454,6 +467,44 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) { } else { // snrt_cluster_hw_barrier(); } + + snrt_cluster_hw_barrier(); + + // write the matrices back to DRAM + uint32_t write_back_offset = t_r * B_r_lin1 * l->positional_embeddings + t_c * B_c_lin1; + uint32_t start_dma_write_back = snrt_mcycle(); + if (!snrt_is_compute_core()) { + snrt_dma_txid_t txid_query = + snrt_dma_start_2d( + l->Q_lin + write_back_offset, /* dst */ + query, /* src */ + B_c_lin1 * sizeof(double), /* size */ + l->positional_embeddings * sizeof(double), /* dst_stride */ + B_c_lin1 * sizeof(double), /* src_stride */ + B_r_lin1 / num_cores); /* repetitions */ + + snrt_dma_txid_t txid_key = + snrt_dma_start_2d( + l->K_lin + write_back_offset, /* dst */ + key, /* src */ + B_c_lin1 * sizeof(double), /* size */ + l->positional_embeddings * sizeof(double), /* dst_stride */ + B_c_lin1 * sizeof(double), /* src_stride */ + B_r_lin1 / num_cores); /* repetitions */ + + snrt_dma_txid_t txid_value = + snrt_dma_start_2d( + l->V_lin + write_back_offset, /* dst */ + value, /* src */ + B_c_lin1 * sizeof(double), /* size */ + l->positional_embeddings * sizeof(double), /* dst_stride */ + B_c_lin1 * sizeof(double), /* src_stride */ + B_r_lin1 / num_cores); /* repetitions */ + + snrt_dma_wait_all(); + } + + uint32_t end_dma_write_back = snrt_mcycle(); } uint32_t end_loop_inner = snrt_mcycle(); @@ -464,6 +515,258 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) { } + ///////////////////// Layer 3: Flash Attention ///////////////////// + // Input size (S x P) // + // Output size (S x P) // + // S is tiled into Row Blocks of B_r rows // + // S is tiled into Column Blocks of B_c columns // + //////////////////////////////////////////////////////////////////////// + + // TODO: Check if it implemented correctly!!! + if (FLASH_ATTENTION == 1) { + // compute the tiling parameters + uint32_t B_r_fa = l->Br_tile_fa; // number of rows per row block + uint32_t B_c_fa = l->Bc_tile_fa; // number of columns per column block + uint32_t T_r_fa = l->seq_len / B_r_fa; // number of row blocks + uint32_t T_c_fa = l->seq_len / B_c_fa; // number of column blocks + + // dump_index(B_r_fa); + // dump_index(B_c_fa); + // dump_index(T_r_fa); + // dump_index(T_c_fa); + + // compute the size of the matrices + uint32_t q_fa_size = B_r_fa * l->positional_embeddings_fa * sizeof(double); + uint32_t k_fa_size = B_c_fa * l->positional_embeddings_fa * sizeof(double); + uint32_t v_fa_size = B_c_fa * l->positional_embeddings_fa * sizeof(double); + uint32_t s_fa_size = B_r_fa * B_c_fa * sizeof(double); + uint32_t p_fa_size = B_r_fa * B_c_fa * sizeof(double); + uint32_t o_fa_size = B_r_fa * l->positional_embeddings_fa * sizeof(double); + uint32_t m_i_size = B_r_fa * sizeof(double); + uint32_t m_i_prev_size = m_i_size; + uint32_t l_i_size = B_r_fa * sizeof(double); + uint32_t shifted_exp_size = B_r_fa * sizeof(double); + + // allocate memory in TCDM + void *tcdm_ptr = (double *)snrt_l1_next(); + double *Q_fa = tcdm_ptr; + tcdm_ptr += q_fa_size; + double *K_fa = tcdm_ptr; + tcdm_ptr += k_fa_size; + double *V_fa = tcdm_ptr; + tcdm_ptr += v_fa_size; + double *S_fa = tcdm_ptr; + tcdm_ptr += s_fa_size; + double *P_fa = tcdm_ptr; + tcdm_ptr += p_fa_size; + double *O_fa = tcdm_ptr; + tcdm_ptr += o_fa_size; + double *m_i = tcdm_ptr; + tcdm_ptr += m_i_size; + double *m_i_prev = tcdm_ptr; + tcdm_ptr += m_i_prev_size; + double *l_i = tcdm_ptr; + tcdm_ptr += l_i_size; + double shifted_exp; + double row_sum; + + double used_memory_kB = (double)((uint64_t)tcdm_ptr - (uint64_t)snrt_l1_next()) / 1024.0f; + + dump_debug(used_memory_kB); + + uint32_t start_loop_outer = snrt_mcycle(); + for (int t_r = 0; t_r < T_r_fa; t_r++) { + uint32_t start_dma = snrt_mcycle(); + if (!snrt_is_compute_core()) { + uint32_t q_fa_offset = t_r * B_r_fa * l->positional_embeddings_fa; + // printf("q_fa_offset = %d\n", q_fa_offset); + // load the Q tile + snrt_dma_txid_t txid_q_fa = + snrt_dma_start_2d( + Q_fa, /* dst */ + l->Q_fa + q_fa_offset, /* src */ + l->positional_embeddings_fa * sizeof(double), /* size */ + l->positional_embeddings_fa * sizeof(double), /* dst_stride */ + l->positional_embeddings_fa * sizeof(double), /* src_stride */ + B_r_fa); /* repetitions */ + + snrt_dma_wait_all(); + + // // print matrix for debugging + // for (int i = 0; i < B_r_fa * l->positional_embeddings_fa; i++) { + // dump_debug(Q_fa[i]); + // printf("Q_fa[%d] = %f\n", i, Q_fa[i]); + // } + } + uint32_t end_dma = snrt_mcycle(); + + snrt_cluster_hw_barrier(); + + // initialize m_i, m_i_prev, l_i + for (int i = 0; i < B_r_fa / num_cores; i++) { + m_i[i + (B_r_fa / num_cores) * compute_id] = -INFINITY; + m_i_prev[i + (B_r_fa / num_cores) * compute_id] = -INFINITY; + l_i[i + (B_r_fa / num_cores) * compute_id] = 0.0f; + } + + row_sum = 0.0; + + uint32_t start_loop_inner = snrt_mcycle(); + for (int t_c = 0; t_c < T_c_fa; t_c++) { + // K: P x B_c, V: B_c x P + uint32_t k_fa_offset = t_c * B_c_fa * l->positional_embeddings_fa; + uint32_t v_fa_offset = t_c * B_c_fa * l->positional_embeddings_fa; + + uint32_t start_dma = snrt_mcycle(); + if (!snrt_is_compute_core()) { + // load the K tile + snrt_dma_txid_t txid_k_fa = + snrt_dma_start_2d( + K_fa, /* dst */ + l->K_fa + k_fa_offset, /* src */ + B_c_fa * sizeof(double), /* size */ + B_c_fa * sizeof(double), /* dst_stride */ + l->positional_embeddings_fa * sizeof(double), /* src_stride */ + l->positional_embeddings_fa); /* repetitions */ + + // load the V tile + snrt_dma_txid_t txid_v_fa = + snrt_dma_start_2d( + V_fa, /* dst */ + l->V_fa + v_fa_offset, /* src */ + l->positional_embeddings_fa * sizeof(double), /* size */ + l->positional_embeddings_fa * sizeof(double), /* dst_stride */ + l->positional_embeddings_fa * sizeof(double), /* src_stride */ + B_c_fa); /* repetitions */ + + snrt_dma_wait_all(); + } + uint32_t end_dma = snrt_mcycle(); + + snrt_cluster_hw_barrier(); + + // Matrix Multiplication S = Q * K^T + if (snrt_is_compute_core()) { + // compute the gemm for the current row block and column block + uint32_t row_offset = (B_r_fa / num_cores) * compute_id * l->positional_embeddings_fa; + // every core writes to a different row of the S matrix + uint32_t tile_offset = (B_r_fa / num_cores) * compute_id * B_c_fa; + // dump_idx(row_offset); + uint32_t start_gemm = snrt_mcycle(); + gemm_fp64_baseline(B_r_fa / num_cores, B_c_fa, l->positional_embeddings_fa, + &Q_fa[row_offset], l->positional_embeddings_fa, 0, + K_fa, l->positional_embeddings_fa, 1, + &S_fa[tile_offset], B_c_fa, 0.0f); + uint32_t end_gemm = snrt_mcycle(); + + snrt_cluster_hw_barrier(); + + // debug print the S matrix + // for (int i = 0; i < (B_r_fa / num_cores); i++) { + // for (int j = 0; j < B_c_fa; j++) { + // dump_debug(S_fa[i * B_c_fa + j + tile_offset]); + // } + // } + + // next we determine the maximum value of each row + uint32_t offset = (B_r_fa / num_cores) * compute_id; + uint32_t o_offset = (B_r_fa / num_cores) * compute_id * l->positional_embeddings_fa; + uint32_t start_stats = snrt_mcycle(); + for (int k = 0; k < B_r_fa / num_cores; k++) { + m_i_prev[k + offset] = m_i[k + offset]; + + // dump_idx(k * B_c_fa + tile_offset); + + for (int j = 0; j < B_c_fa; j++) { + + // dump_idx(k * B_c_fa + j + tile_offset); + // dump_debug(S_fa[k * B_c_fa + j + tile_offset]); + // printf("S_fa[%d] = %f\n", k * B_c_fa + j + tile_offset, S_fa[k * B_c_fa + j + tile_offset]); + + if(S_fa[tile_offset + k * B_c_fa + j] > m_i[k + offset]) { + m_i[k + offset] = S_fa[tile_offset + k * B_c_fa + j]; + } + // dump_debug(m_i[k + offset]); + + // determine the P matrix + P_fa[tile_offset + k * B_c_fa + j] = expf(S_fa[tile_offset + k * B_c_fa + j] - m_i[k + offset]); + // printf("P_fa[%d] = expf(%f - %f) = %f\n", tile_offset + k * B_c_fa + j, S_fa[tile_offset + k * B_c_fa + j], m_i[k + offset], P_fa[tile_offset + k * B_c_fa + j]); + // printf("P_fa[%d] = %f\n", tile_offset + k * B_c_fa + j, P_fa[tile_offset + k * B_c_fa + j]); + // dump_debug(P_fa[k * B_c_fa + j]); + row_sum += P_fa[tile_offset + k * B_c_fa + j]; + } + + // dump_debug(row_sum); + + shifted_exp = expf(m_i[k + offset] - m_i_prev[k + offset]); + // printf("shifted_exp = expf(%f - %f) = %f\n", m_i[k + offset], m_i_prev[k + offset], shifted_exp); + + // printf("BEFORE: l_i[%d] = %f\n", k + offset, l_i[k + offset]); + if (t_c == 0) { + l_i[k + offset] = row_sum; + } else { + l_i[k + offset] = l_i[k + offset] * shifted_exp + row_sum; + } + // printf("AFTER: l_i[%d] = %f\n", k + offset, l_i[k + offset]); + // dump_debug(l_i[k + offset]); + + row_sum = 0.0; + + // O_ij = diag(shifted_exp)^(-1) O_i(j-1) + P_ij * V_j + if (t_c == 0) { + gemm_fp64_baseline(1, B_c_fa, l->positional_embeddings_fa, + &P_fa[tile_offset + k * B_c_fa], B_c_fa, 0, + V_fa, l->positional_embeddings_fa, 0, + &O_fa[o_offset + k * l->positional_embeddings_fa], l->positional_embeddings_fa, 0.0f); + } else { + gemm_fp64_baseline(1, B_c_fa, l->positional_embeddings_fa, + &P_fa[tile_offset + k * B_c_fa], B_c_fa, 0, + V_fa, l->positional_embeddings_fa, 0, + &O_fa[o_offset + k * l->positional_embeddings_fa], l->positional_embeddings_fa, shifted_exp); + } + + } // end of B_r loop + uint32_t end_stats = snrt_mcycle(); + + } else { + snrt_cluster_hw_barrier(); + } + } // end of T_c loop + + snrt_cluster_hw_barrier(); + + // O_i = diag(l_i_Tc) ^-1 * O_i + for (int i = 0; i < B_r_fa / num_cores; i++) { + for (int j = 0; j < l->positional_embeddings_fa; j++) { + O_fa[i * l->positional_embeddings_fa + j + (B_r_fa / num_cores) * compute_id * l->positional_embeddings_fa] + /= l_i[i + (B_r_fa / num_cores) * compute_id]; + } + } + + // write back O_fa as the i-th block of the output matrix + uint32_t start_dma_write_back = snrt_mcycle(); + if (!snrt_is_compute_core()) { + uint32_t o_fa_offset = t_r * B_r_fa * l->positional_embeddings_fa; + // printf("o_fa_offset = %d\n", o_fa_offset); + snrt_dma_txid_t txid_o_fa = + snrt_dma_start_2d( + l->O + o_fa_offset, /* dst */ + O_fa, /* src */ + l->positional_embeddings_fa * sizeof(double), /* size */ + l->positional_embeddings_fa * sizeof(double), /* dst_stride */ + l->positional_embeddings_fa * sizeof(double), /* src_stride */ + B_r_fa); /* repetitions */ + + snrt_dma_wait_all(); + } + uint32_t end_dma_write_back = snrt_mcycle(); + + } // end of T_r loop + uint32_t end_loop_outer = snrt_mcycle(); + + snrt_cluster_hw_barrier(); + } + snrt_global_barrier(); }