From f07926e432149be2706f4ea1c6b5633d4b71bd28 Mon Sep 17 00:00:00 2001 From: Viviane Potocnik Date: Sat, 16 Sep 2023 19:02:03 +0200 Subject: [PATCH] trafo: add tiled self attention to MHSA block --- sw/apps/transformer/src/transformer.h | 117 ++++++++++++++++++++++++-- 1 file changed, 112 insertions(+), 5 deletions(-) diff --git a/sw/apps/transformer/src/transformer.h b/sw/apps/transformer/src/transformer.h index 88e8766bdd..f3af1c86fd 100644 --- a/sw/apps/transformer/src/transformer.h +++ b/sw/apps/transformer/src/transformer.h @@ -52,13 +52,18 @@ typedef struct transformer_layer_struct { uint32_t positional_embeddings; uint32_t heads; uint32_t eps; + uint32_t attn_seq_len; + uint32_t attn_embeddings; float *ifmap; + float *attn_ifmap; float *bias; float *weights_q; float *weights_k; float *weights_v; float *weights_o; + float *attn_weights; + float *attn_out; float *Q_lin; float *K_lin; float *V_lin; @@ -696,14 +701,116 @@ static inline void transformer_layer(transformer_layer_t *const l) { } } // end of T_r loop + // TODO: This step requires a GLOBAL barrier!!! snrt_cluster_hw_barrier(); // Self Attention - // determine the tile sizes - uint32_t s_Size = l->seq_len * sizeof(float); - int32_t num_s_tiles = l->seq_len / s_Size; - uint32_t e_Size = l->embeddings * sizeof(float) / 2; - int32_t num_e_tiles = l->embeddings / e_Size; + // (S x HP) = (2048 x 768) + // Every cluster computes ceil(2048 / num_clusters) = 86 rows + // and every core computes ceil(86 / num_cores) = 11 rows + // (HP x E) = (768 x 768) + // for simplicity we we only perform two iterations + // of the inner and outer loops + + // define the tile sizes + int32_t s_tile_size = 8; + int32_t e_tile_size = 32; + + int32_t num_s_tiles = l->attn_seq_len / s_tile_size; + int32_t num_e_tiles = l->attn_embeddings / e_tile_size; + + // A_OUT (S x HP) + // ATTN_WEIGHTS (HP x E) + // MSHA_OUT (S x E) + int32_t a_out_size = s_tile_size * l->positional_embeddings * l->heads; + int32_t attn_weights_size = l->heads * l->positional_embeddings * l->attn_embeddings; + int32_t mhsa_out_size = s_tile_size * e_tile_size; + + // reset the TCDM pointer + tcdm_ptr = (float *)snrt_l1_next(); + float *a_out = tcdm_ptr; + tcdm_ptr += a_out_size; + float *attn_weights = tcdm_ptr; + tcdm_ptr += attn_weights_size; + float *mhsa_out = tcdm_ptr; + tcdm_ptr += mhsa_out_size; + + // compute memory usage + used_memory_kB = (float)((uint64_t)tcdm_ptr - (uint64_t)snrt_l1_next()) / 1024.0f; + dump_debug(used_memory_kB); + + for (int s_tile = 0; s_tile < num_s_tiles; s_tile++) { + // load the a_out tile from DRAM + if (!snrt_is_compute_core()) { + // compute the offset from which to load + uint32_t a_out_offset = s_tile * s_tile_size * l->positional_embeddings * l->heads; + + snrt_dma_txid_t txid_a_out = + snrt_dma_start_2d( + a_out, /* dst */ + &l->attn_ifmap[a_out_offset], /* src */ + l->positional_embeddings * l->heads * sizeof(float),/* size */ + l->positional_embeddings * l->heads * sizeof(float),/* dst_stride */ + l->positional_embeddings * l->heads * sizeof(float),/* src_stride */ + s_tile); /* repetitions */ + snrt_dma_wait_all(); + } + + snrt_cluster_hw_barrier(); + + for (int e_tile = 0; e_tile < num_e_tiles; e_tile++) { + // we load the column block of the attention weights + if (!snrt_is_compute_core()) { + // compute the offset from which to load + uint32_t attn_weights_offset = e_tile * e_tile_size * l->positional_embeddings * l->attn_embeddings; + + snrt_dma_txid_t txid_attn_weights = + snrt_dma_start_2d( + attn_weights, /* dst */ + &l->attn_weights[attn_weights_offset], /* src */ + e_tile * sizeof(float), /* size */ + l->attn_embeddings * sizeof(float), /* dst_stride */ + e_tile * sizeof(float), /* src_stride */ + l->heads * l->positional_embeddings); /* repetitions */ + snrt_dma_wait_all(); + } + + snrt_cluster_hw_barrier(); + + // compute the matrix multiplication + if (snrt_is_compute_core()) { + // compute the row offset for the current compute core + uint32_t a_out_offset = compute_id * (s_tile_size / num_cores) * l->positional_embeddings * l->heads; + uint32_t mhsa_out_offset = (s_tile * s_tile_size + (s_tile_size / num_cores) * compute_id) * e_tile_size; + + // compute a tile of the MHSA output + // A_OUT (S x HP) * ATTN_WEIGHTS (HP x E) = MSHA_OUT (S x E) + // we parallelize the rows of A_OUT + gemm_fp32_baseline(s_tile_size / num_cores, e_tile_size, l->positional_embeddings * l->heads, + &a_out[a_out_offset], l->positional_embeddings * l->heads, 0, attn_weights, + l->positional_embeddings * l->heads, 0, &mhsa_out[mhsa_out_offset], e_tile_size, 0.0f); + } + + snrt_cluster_hw_barrier(); + + // write back the MHSA output to DRAM + if (!snrt_is_compute_core()) { + // compute the offset from which to load + uint32_t mhsa_out_offset = (s_tile * s_tile_size + (s_tile_size / num_cores) * compute_id) * e_tile_size; + + snrt_dma_txid_t txid_mhsa_out = + snrt_dma_start_2d( + &l->attn_out[mhsa_out_offset], /* dst */ + mhsa_out, /* src */ + e_tile_size * sizeof(float), /* size */ + l->attn_embeddings * sizeof(float), /* dst_stride */ + e_tile_size * sizeof(float), /* src_stride */ + s_tile_size); /* repetitions */ + snrt_dma_wait_all(); + } + + } + } snrt_cluster_hw_barrier();