Skip to content

Commit

Permalink
transformer: add flash attention layer
Browse files Browse the repository at this point in the history
  • Loading branch information
Viviane Potocnik committed Oct 16, 2023
1 parent 65c73c7 commit 164e857
Showing 1 changed file with 305 additions and 2 deletions.
307 changes: 305 additions & 2 deletions sw/apps/transformer/src/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
#define DEBUG 1
#define BRIEF 1 // enable this define to only compute two iterations of the loops
#define LAYERNORM 0 // enable this define to compute the layernorm
#define LINEAR_1 1 // enable this define to compute the linear layer 1
#define LINEAR_1 0 // enable this define to compute the linear layer 1
#define FLASH_ATTENTION 1 // enable this define to compute the flash attention

/**
* @struct transformer_layer_fp64_struct
Expand Down Expand Up @@ -54,6 +55,9 @@ typedef struct transformer_layer_fp64_struct {
uint32_t S_tile_ln;
uint32_t S_tile_lin1;
uint32_t P_tile_lin1;
uint32_t Br_tile_fa;
uint32_t Bc_tile_fa;
uint32_t positional_embeddings_fa;
uint32_t embeddings;
uint32_t positional_embeddings;
uint32_t feedforward_len;
Expand All @@ -70,6 +74,9 @@ typedef struct transformer_layer_fp64_struct {
double *Q_lin;
double *K_lin;
double *V_lin;
double *Q_fa;
double *K_fa;
double *V_fa;
double *O;
double *ofmap;
double *query;
Expand Down Expand Up @@ -214,6 +221,7 @@ static inline void fused_mlp_baseline(float *input, float *output, int32_t ldI,
// Debugging functions
dump_uint(idx, 6);
dump_float(debug, 7);
dump_uint(id, 8);

/**
* @brief Transformer layer
Expand All @@ -224,13 +232,18 @@ dump_float(debug, 7);
static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) {
uint32_t compute_id = snrt_global_core_idx();
uint32_t num_cores = snrt_cluster_compute_core_num();
uint32_t num_clusters = snrt_cluster_num();

uint32_t seq_len = l->seq_len;
uint32_t embeddings = l->embeddings;
uint32_t positional_embeddings = l->positional_embeddings;
uint32_t heads = l->heads;
uint32_t eps = l->eps;

// dump_id(compute_id);
// dump_id(num_cores);
dump_id(num_clusters);

/////////////////////////////////////////////////////////////////////
////////////// MULTI-HEAD SELF-ATTENTION BLOCK /////////////////
///////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -433,11 +446,11 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) {
snrt_cluster_hw_barrier();

if (snrt_is_compute_core()) {
uint32_t start_gemm = snrt_mcycle();
// compute the gemm for the current row block and column block
uint32_t row_offset = (B_r_lin1 / num_cores) * compute_id * embeddings;
// printf("core %d: row_offset = %d\n", compute_id, row_offset);
// ifmap: B_r x E, weights: E x B_c
uint32_t start_gemm = snrt_mcycle();
gemm_fp64_baseline(B_r_lin1 / num_cores, B_c_lin1, l->embeddings,
&ifmap_lin1[row_offset], l->embeddings, 0,
weights_q, B_c_lin1, 0,
Expand All @@ -454,6 +467,44 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) {
} else {
// snrt_cluster_hw_barrier();
}

snrt_cluster_hw_barrier();

// write the matrices back to DRAM
uint32_t write_back_offset = t_r * B_r_lin1 * l->positional_embeddings + t_c * B_c_lin1;
uint32_t start_dma_write_back = snrt_mcycle();
if (!snrt_is_compute_core()) {
snrt_dma_txid_t txid_query =
snrt_dma_start_2d(
l->Q_lin + write_back_offset, /* dst */
query, /* src */
B_c_lin1 * sizeof(double), /* size */
l->positional_embeddings * sizeof(double), /* dst_stride */
B_c_lin1 * sizeof(double), /* src_stride */
B_r_lin1 / num_cores); /* repetitions */

snrt_dma_txid_t txid_key =
snrt_dma_start_2d(
l->K_lin + write_back_offset, /* dst */
key, /* src */
B_c_lin1 * sizeof(double), /* size */
l->positional_embeddings * sizeof(double), /* dst_stride */
B_c_lin1 * sizeof(double), /* src_stride */
B_r_lin1 / num_cores); /* repetitions */

snrt_dma_txid_t txid_value =
snrt_dma_start_2d(
l->V_lin + write_back_offset, /* dst */
value, /* src */
B_c_lin1 * sizeof(double), /* size */
l->positional_embeddings * sizeof(double), /* dst_stride */
B_c_lin1 * sizeof(double), /* src_stride */
B_r_lin1 / num_cores); /* repetitions */

snrt_dma_wait_all();
}

uint32_t end_dma_write_back = snrt_mcycle();
}
uint32_t end_loop_inner = snrt_mcycle();

Expand All @@ -464,6 +515,258 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) {

}

///////////////////// Layer 3: Flash Attention /////////////////////
// Input size (S x P) //
// Output size (S x P) //
// S is tiled into Row Blocks of B_r rows //
// S is tiled into Column Blocks of B_c columns //
////////////////////////////////////////////////////////////////////////

// TODO: Check if it implemented correctly!!!
if (FLASH_ATTENTION == 1) {
// compute the tiling parameters
uint32_t B_r_fa = l->Br_tile_fa; // number of rows per row block
uint32_t B_c_fa = l->Bc_tile_fa; // number of columns per column block
uint32_t T_r_fa = l->seq_len / B_r_fa; // number of row blocks
uint32_t T_c_fa = l->seq_len / B_c_fa; // number of column blocks

// dump_index(B_r_fa);
// dump_index(B_c_fa);
// dump_index(T_r_fa);
// dump_index(T_c_fa);

// compute the size of the matrices
uint32_t q_fa_size = B_r_fa * l->positional_embeddings_fa * sizeof(double);
uint32_t k_fa_size = B_c_fa * l->positional_embeddings_fa * sizeof(double);
uint32_t v_fa_size = B_c_fa * l->positional_embeddings_fa * sizeof(double);
uint32_t s_fa_size = B_r_fa * B_c_fa * sizeof(double);
uint32_t p_fa_size = B_r_fa * B_c_fa * sizeof(double);
uint32_t o_fa_size = B_r_fa * l->positional_embeddings_fa * sizeof(double);
uint32_t m_i_size = B_r_fa * sizeof(double);
uint32_t m_i_prev_size = m_i_size;
uint32_t l_i_size = B_r_fa * sizeof(double);
uint32_t shifted_exp_size = B_r_fa * sizeof(double);

// allocate memory in TCDM
void *tcdm_ptr = (double *)snrt_l1_next();
double *Q_fa = tcdm_ptr;
tcdm_ptr += q_fa_size;
double *K_fa = tcdm_ptr;
tcdm_ptr += k_fa_size;
double *V_fa = tcdm_ptr;
tcdm_ptr += v_fa_size;
double *S_fa = tcdm_ptr;
tcdm_ptr += s_fa_size;
double *P_fa = tcdm_ptr;
tcdm_ptr += p_fa_size;
double *O_fa = tcdm_ptr;
tcdm_ptr += o_fa_size;
double *m_i = tcdm_ptr;
tcdm_ptr += m_i_size;
double *m_i_prev = tcdm_ptr;
tcdm_ptr += m_i_prev_size;
double *l_i = tcdm_ptr;
tcdm_ptr += l_i_size;
double shifted_exp;
double row_sum;

double used_memory_kB = (double)((uint64_t)tcdm_ptr - (uint64_t)snrt_l1_next()) / 1024.0f;

dump_debug(used_memory_kB);

uint32_t start_loop_outer = snrt_mcycle();
for (int t_r = 0; t_r < T_r_fa; t_r++) {
uint32_t start_dma = snrt_mcycle();
if (!snrt_is_compute_core()) {
uint32_t q_fa_offset = t_r * B_r_fa * l->positional_embeddings_fa;
// printf("q_fa_offset = %d\n", q_fa_offset);
// load the Q tile
snrt_dma_txid_t txid_q_fa =
snrt_dma_start_2d(
Q_fa, /* dst */
l->Q_fa + q_fa_offset, /* src */
l->positional_embeddings_fa * sizeof(double), /* size */
l->positional_embeddings_fa * sizeof(double), /* dst_stride */
l->positional_embeddings_fa * sizeof(double), /* src_stride */
B_r_fa); /* repetitions */

snrt_dma_wait_all();

// // print matrix for debugging
// for (int i = 0; i < B_r_fa * l->positional_embeddings_fa; i++) {
// dump_debug(Q_fa[i]);
// printf("Q_fa[%d] = %f\n", i, Q_fa[i]);
// }
}
uint32_t end_dma = snrt_mcycle();

snrt_cluster_hw_barrier();

// initialize m_i, m_i_prev, l_i
for (int i = 0; i < B_r_fa / num_cores; i++) {
m_i[i + (B_r_fa / num_cores) * compute_id] = -INFINITY;
m_i_prev[i + (B_r_fa / num_cores) * compute_id] = -INFINITY;
l_i[i + (B_r_fa / num_cores) * compute_id] = 0.0f;
}

row_sum = 0.0;

uint32_t start_loop_inner = snrt_mcycle();
for (int t_c = 0; t_c < T_c_fa; t_c++) {
// K: P x B_c, V: B_c x P
uint32_t k_fa_offset = t_c * B_c_fa * l->positional_embeddings_fa;
uint32_t v_fa_offset = t_c * B_c_fa * l->positional_embeddings_fa;

uint32_t start_dma = snrt_mcycle();
if (!snrt_is_compute_core()) {
// load the K tile
snrt_dma_txid_t txid_k_fa =
snrt_dma_start_2d(
K_fa, /* dst */
l->K_fa + k_fa_offset, /* src */
B_c_fa * sizeof(double), /* size */
B_c_fa * sizeof(double), /* dst_stride */
l->positional_embeddings_fa * sizeof(double), /* src_stride */
l->positional_embeddings_fa); /* repetitions */

// load the V tile
snrt_dma_txid_t txid_v_fa =
snrt_dma_start_2d(
V_fa, /* dst */
l->V_fa + v_fa_offset, /* src */
l->positional_embeddings_fa * sizeof(double), /* size */
l->positional_embeddings_fa * sizeof(double), /* dst_stride */
l->positional_embeddings_fa * sizeof(double), /* src_stride */
B_c_fa); /* repetitions */

snrt_dma_wait_all();
}
uint32_t end_dma = snrt_mcycle();

snrt_cluster_hw_barrier();

// Matrix Multiplication S = Q * K^T
if (snrt_is_compute_core()) {
// compute the gemm for the current row block and column block
uint32_t row_offset = (B_r_fa / num_cores) * compute_id * l->positional_embeddings_fa;
// every core writes to a different row of the S matrix
uint32_t tile_offset = (B_r_fa / num_cores) * compute_id * B_c_fa;
// dump_idx(row_offset);
uint32_t start_gemm = snrt_mcycle();
gemm_fp64_baseline(B_r_fa / num_cores, B_c_fa, l->positional_embeddings_fa,
&Q_fa[row_offset], l->positional_embeddings_fa, 0,
K_fa, l->positional_embeddings_fa, 1,
&S_fa[tile_offset], B_c_fa, 0.0f);
uint32_t end_gemm = snrt_mcycle();

snrt_cluster_hw_barrier();

// debug print the S matrix
// for (int i = 0; i < (B_r_fa / num_cores); i++) {
// for (int j = 0; j < B_c_fa; j++) {
// dump_debug(S_fa[i * B_c_fa + j + tile_offset]);
// }
// }

// next we determine the maximum value of each row
uint32_t offset = (B_r_fa / num_cores) * compute_id;
uint32_t o_offset = (B_r_fa / num_cores) * compute_id * l->positional_embeddings_fa;
uint32_t start_stats = snrt_mcycle();
for (int k = 0; k < B_r_fa / num_cores; k++) {
m_i_prev[k + offset] = m_i[k + offset];

// dump_idx(k * B_c_fa + tile_offset);

for (int j = 0; j < B_c_fa; j++) {

// dump_idx(k * B_c_fa + j + tile_offset);
// dump_debug(S_fa[k * B_c_fa + j + tile_offset]);
// printf("S_fa[%d] = %f\n", k * B_c_fa + j + tile_offset, S_fa[k * B_c_fa + j + tile_offset]);

if(S_fa[tile_offset + k * B_c_fa + j] > m_i[k + offset]) {
m_i[k + offset] = S_fa[tile_offset + k * B_c_fa + j];
}
// dump_debug(m_i[k + offset]);

// determine the P matrix
P_fa[tile_offset + k * B_c_fa + j] = expf(S_fa[tile_offset + k * B_c_fa + j] - m_i[k + offset]);
// printf("P_fa[%d] = expf(%f - %f) = %f\n", tile_offset + k * B_c_fa + j, S_fa[tile_offset + k * B_c_fa + j], m_i[k + offset], P_fa[tile_offset + k * B_c_fa + j]);
// printf("P_fa[%d] = %f\n", tile_offset + k * B_c_fa + j, P_fa[tile_offset + k * B_c_fa + j]);
// dump_debug(P_fa[k * B_c_fa + j]);
row_sum += P_fa[tile_offset + k * B_c_fa + j];
}

// dump_debug(row_sum);

shifted_exp = expf(m_i[k + offset] - m_i_prev[k + offset]);
// printf("shifted_exp = expf(%f - %f) = %f\n", m_i[k + offset], m_i_prev[k + offset], shifted_exp);

// printf("BEFORE: l_i[%d] = %f\n", k + offset, l_i[k + offset]);
if (t_c == 0) {
l_i[k + offset] = row_sum;
} else {
l_i[k + offset] = l_i[k + offset] * shifted_exp + row_sum;
}
// printf("AFTER: l_i[%d] = %f\n", k + offset, l_i[k + offset]);
// dump_debug(l_i[k + offset]);

row_sum = 0.0;

// O_ij = diag(shifted_exp)^(-1) O_i(j-1) + P_ij * V_j
if (t_c == 0) {
gemm_fp64_baseline(1, B_c_fa, l->positional_embeddings_fa,
&P_fa[tile_offset + k * B_c_fa], B_c_fa, 0,
V_fa, l->positional_embeddings_fa, 0,
&O_fa[o_offset + k * l->positional_embeddings_fa], l->positional_embeddings_fa, 0.0f);
} else {
gemm_fp64_baseline(1, B_c_fa, l->positional_embeddings_fa,
&P_fa[tile_offset + k * B_c_fa], B_c_fa, 0,
V_fa, l->positional_embeddings_fa, 0,
&O_fa[o_offset + k * l->positional_embeddings_fa], l->positional_embeddings_fa, shifted_exp);
}

} // end of B_r loop
uint32_t end_stats = snrt_mcycle();

} else {
snrt_cluster_hw_barrier();
}
} // end of T_c loop

snrt_cluster_hw_barrier();

// O_i = diag(l_i_Tc) ^-1 * O_i
for (int i = 0; i < B_r_fa / num_cores; i++) {
for (int j = 0; j < l->positional_embeddings_fa; j++) {
O_fa[i * l->positional_embeddings_fa + j + (B_r_fa / num_cores) * compute_id * l->positional_embeddings_fa]
/= l_i[i + (B_r_fa / num_cores) * compute_id];
}
}

// write back O_fa as the i-th block of the output matrix
uint32_t start_dma_write_back = snrt_mcycle();
if (!snrt_is_compute_core()) {
uint32_t o_fa_offset = t_r * B_r_fa * l->positional_embeddings_fa;
// printf("o_fa_offset = %d\n", o_fa_offset);
snrt_dma_txid_t txid_o_fa =
snrt_dma_start_2d(
l->O + o_fa_offset, /* dst */
O_fa, /* src */
l->positional_embeddings_fa * sizeof(double), /* size */
l->positional_embeddings_fa * sizeof(double), /* dst_stride */
l->positional_embeddings_fa * sizeof(double), /* src_stride */
B_r_fa); /* repetitions */

snrt_dma_wait_all();
}
uint32_t end_dma_write_back = snrt_mcycle();

} // end of T_r loop
uint32_t end_loop_outer = snrt_mcycle();

snrt_cluster_hw_barrier();
}

snrt_global_barrier();
}

Expand Down

0 comments on commit 164e857

Please sign in to comment.