diff --git a/sw/apps/transformer/src/transformer.h b/sw/apps/transformer/src/transformer.h index 7301a0d5c1..bd09a85973 100644 --- a/sw/apps/transformer/src/transformer.h +++ b/sw/apps/transformer/src/transformer.h @@ -14,46 +14,108 @@ #define BANSHEE 0 #define DEBUG_VERBOSE 0 #define DEBUG 1 +#define BRIEF 1 // enable this define to only compute two iterations of the loops +#define LAYERNORM 0 // enable this define to compute the layernorm +#define LINEAR_1 1 // enable this define to compute the linear layer 1 /** - * @struct Transformer_layer_struct + * @struct transformer_layer_fp64_struct * @brief This structure contains all parameters necessary * for computing a Transformer layer - * @var transformer_layer_struct::seq_len + * @var transformer_layer_fp64_struct::seq_len * Number of input tokens - * @var transformer_layer_struct::embeddings + * @var transformer_layer_fp64_struct::embeddings * Vector size of each input token - * @var transformer_layer_struct::positional_embeddings + * @var transformer_layer_fp64_struct::positional_embeddings * Vector size of each positional embedding - * @var transformer_layer_struct::ifmap + * @var transformer_layer_fp64_struct::ifmap * Pointer to input feature map - * @var transformer_layer_struct::bias + * @var transformer_layer_fp64_struct::bias * Pointer to bias for each head - * @var transformer_layer_struct::weights_q + * @var transformer_layer_fp64_struct::weights_q * Pointer to weights for query - * @var transformer_layer_struct::weights_k + * @var transformer_layer_fp64_struct::weights_k * Pointer to weights for key - * @var transformer_layer_struct::weights_v + * @var transformer_layer_fp64_struct::weights_v * Pointer to weights for value - * @var transformer_layer_struct::ofmap + * @var transformer_layer_fp64_struct::ofmap * Pointer to output feature map - * @var transformer_layer_struct::query + * @var transformer_layer_fp64_struct::query * Pointer to the golden model output for query - * @var transformer_layer_struct::key + * @var transformer_layer_fp64_struct::key * Pointer to the golden model output for key - * @var transformer_layer_struct::value + * @var transformer_layer_fp64_struct::value * Pointer to the golden model output for value */ -typedef struct transformer_layer_struct { +typedef struct transformer_layer_fp64_struct { uint32_t BATCH_SIZE; uint32_t seq_len; + uint32_t S_tile_ln; + uint32_t S_tile_lin1; + uint32_t P_tile_lin1; uint32_t embeddings; uint32_t positional_embeddings; + uint32_t feedforward_len; uint32_t heads; - uint32_t eps; - uint32_t attn_seq_len; - uint32_t attn_embeddings; + double eps; + + double *ifmap; + double *attn_ifmap; + double *bias; + double *weights_q; + double *weights_k; + double *weights_v; + double *weights_o; + double *Q_lin; + double *K_lin; + double *V_lin; + double *O; + double *ofmap; + double *query; + double *key; + double *value; + + precision_t dtype; +} transformer_layer_fp64_t; + +/** + * @struct transformer_layer_fp32_struct + * @brief This structure contains all parameters necessary + * for computing a Transformer layer + * @var transformer_layer_fp32_struct::seq_len + * Number of input tokens + * @var transformer_layer_fp32_struct::embeddings + * Vector size of each input token + * @var transformer_layer_fp32_struct::positional_embeddings + * Vector size of each positional embedding + * @var transformer_layer_fp32_struct::ifmap + * Pointer to input feature map + * @var transformer_layer_fp32_struct::bias + * Pointer to bias for each head + * @var transformer_layer_fp32_struct::weights_q + * Pointer to weights for query + * @var transformer_layer_fp32_struct::weights_k + * Pointer to weights for key + * @var transformer_layer_fp32_struct::weights_v + * Pointer to weights for value + * @var transformer_layer_fp32_struct::ofmap + * Pointer to output feature map + * @var transformer_layer_fp32_struct::query + * Pointer to the golden model output for query + * @var transformer_layer_fp32_struct::key + * Pointer to the golden model output for key + * @var transformer_layer_fp32_struct::value + * Pointer to the golden model output for value + */ + +typedef struct transformer_layer_fp32_struct { + uint32_t BATCH_SIZE; + uint32_t seq_len; + uint32_t embeddings; + uint32_t positional_embeddings; + uint32_t heads; + float eps; float *ifmap; float *attn_ifmap; @@ -62,8 +124,6 @@ typedef struct transformer_layer_struct { float *weights_k; float *weights_v; float *weights_o; - float *attn_weights; - float *attn_out; float *Q_lin; float *K_lin; float *V_lin; @@ -74,20 +134,20 @@ typedef struct transformer_layer_struct { float *value; precision_t dtype; -} transformer_layer_t; +} transformer_layer_fp32_t; /* Debugging Variables*/ -dump_float(debug, 8); -dump_uint(idx, 7); -dump_float(ifmap, 6); -dump_float(weights, 10); // = 0xa -dump_float(value, 12); // = 0xc +// dump_float(debug, 8); +// dump_uint(idx, 7); +// dump_float(ifmap, 6); +// dump_float(weights, 10); // = 0xa +// dump_float(value, 12); // = 0xc /** - * @brief Generic DMA transfer function + * @brief Documentation for the DMA transfer function * @param dst Pointer to destination * @param src Pointer to source - * @param size How many bytes to transfer in total + * @param size How many bytes to transfer in total in one repetition/row * @param dst_stride Stride of the destination, i.e. how many bytes to jump * between two consecutive transfers in the same row of * the destination @@ -97,24 +157,6 @@ dump_float(value, 12); // = 0xc * @param repetitions How many rows to transfer */ -static inline void dma_transfer(void *dst, void *src, uint32_t size, uint32_t dst_stride, - uint32_t src_stride, uint32_t repetitions) { - if (!snrt_is_compute_core()) { - snrt_dma_txid_t txid = - snrt_dma_start_2d( - dst, /* dst */ - src, /* src */ - size, /* size */ - dst_stride, /* dst_stride */ - src_stride, /* src_stride */ - repetitions); /* repetitions */ - - snrt_dma_wait_all(); - } - snrt_cluster_hw_barrier(); -} - - /** * Implementation of the GELU layer */ @@ -166,747 +208,923 @@ static inline void fused_mlp_baseline(float *input, float *output, int32_t ldI, output[s * ldO + f] = transformer_gelu_fp32(acc); } } - } +// Debugging functions +dump_uint(idx, 6); +dump_float(debug, 7); + /** * @brief Transformer layer * - * @param l transformer_layer struct that holds addresses and parameters + * @param l transformer_layer struct that holds addresses and parameters in FP64 * */ -static inline void transformer_layer(transformer_layer_t *const l) { +static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) { uint32_t compute_id = snrt_global_core_idx(); uint32_t num_cores = snrt_cluster_compute_core_num(); - ///////////////////////////////////////////////////////////////////// - ////////////// MULTI-HEAD SELF-ATTENTION LAYER ///////////////// - ///////////////////////////////////////////////////////////////////// - - - uint32_t S_TILE = 24; - uint32_t P_TILE = 4; - uint32_t NUM_S_TILES = l->seq_len / S_TILE; - uint32_t NUM_P_TILES = l->embeddings / P_TILE; - // Matrices for the linear mapping - uint32_t ifmap_size = S_TILE * l->embeddings * sizeof(float); - uint32_t weights_q_tiled_size = l->embeddings * P_TILE * sizeof(float); - uint32_t weights_k_tiled_size = weights_q_tiled_size; - uint32_t weights_v_tiled_size = weights_q_tiled_size; - uint32_t key_matrix = S_TILE * P_TILE * sizeof(float); - uint32_t query_matrix = S_TILE * P_TILE * sizeof(float); - uint32_t value_matrix = S_TILE * P_TILE * sizeof(float); - uint32_t attention_matrix = S_TILE * S_TILE * sizeof(float); - - // here we define the matrices that will be stored back in DRAM - uint32_t q_lin_size = l->seq_len * l->positional_embeddings * sizeof(float); - uint32_t k_lin_size = l->positional_embeddings * l->seq_len * sizeof(float); - uint32_t v_lin_size = l->positional_embeddings * l->seq_len * sizeof(float); - uint32_t output_size = l->seq_len * l->positional_embeddings * sizeof(float); - - void *tcdm_ptr = (float *)snrt_l1_next(); - float *ifmap = tcdm_ptr; - tcdm_ptr += ifmap_size; - float *weights_q = tcdm_ptr; - tcdm_ptr += weights_q_tiled_size; - float *weights_k = tcdm_ptr; - tcdm_ptr += weights_k_tiled_size; - float *weights_v = tcdm_ptr; - tcdm_ptr += weights_v_tiled_size; - float *key = tcdm_ptr; - tcdm_ptr += key_matrix; - float *query = tcdm_ptr; - tcdm_ptr += query_matrix; - float *value = tcdm_ptr; - tcdm_ptr += value_matrix; - float *attention = tcdm_ptr; - tcdm_ptr += attention_matrix; - - float used_memory_kB = (float)((uint64_t)tcdm_ptr - (uint64_t)snrt_l1_next()) / 1024.0f; - - dump_debug(used_memory_kB); - - snrt_cluster_hw_barrier(); - - ///////////////////////////////////////////////// - //// Linear mapping //// - ///////////////////////////////////////////////// - - for (int s_tile = 0; s_tile < NUM_S_TILES; s_tile++) { - dump_idx(s_tile); - // printf("Loaded %d/%d s_tiles\n", s_tile, NUM_S_TILES); - // we load the ifmap for the current s_tile - uint32_t ifmap_offset = s_tile * S_TILE * l->embeddings; - if (!snrt_is_compute_core()) { - write_csr(0x7d0,0); - // load the input feature map: 24 * 768 * 4 = 73728 bytes - snrt_dma_txid_t txid_ifmap = - snrt_dma_start_2d( - ifmap, /* dst */ - l->ifmap + ifmap_offset, /* src */ - S_TILE * sizeof(float), /* size */ - l->embeddings, /* dst_stride */ - l->embeddings, /* src_stride */ - S_TILE); /* repetitions */ - - snrt_dma_wait_all(); - - // print the current ifmap for debugging - if (DEBUG_VERBOSE == 1) { - if (BANSHEE == 0) { - for (int i = 0; i < S_TILE * l->embeddings; i++) { - dump_ifmap(ifmap[i]); - } - } else { - for (int i = 0; i < S_TILE * l->embeddings; i++) { - printf("ifmap[%d] = %f\n", i, ifmap[i]); - } - } - } else if (DEBUG == 1) - { - if (BANSHEE == 0) { - dump_ifmap(ifmap[0]); - dump_ifmap(ifmap[S_TILE * l->embeddings - 1]); - } else { - printf("ifmap[0] = %f\n", ifmap[0]); - printf("ifmap[%d] = %f\n", S_TILE * l->embeddings - 1, ifmap[S_TILE * l->embeddings - 1]); - } - - } - - } - - for (int p_tile = 0; p_tile < NUM_P_TILES; p_tile++) { - dump_idx(p_tile); - // printf("Loaded %d/%d p_tiles\n", p_tile, NUM_P_TILES); - - // we load the weights for the current p_tile - // weights matrix: E x P = 768 x 64 - uint32_t weights_offset = p_tile * P_TILE; + uint32_t seq_len = l->seq_len; + uint32_t embeddings = l->embeddings; + uint32_t positional_embeddings = l->positional_embeddings; + uint32_t heads = l->heads; + uint32_t eps = l->eps; - if(!snrt_is_compute_core()) { - // load the weights for query: 768 * 4 = 3072 bytes - snrt_dma_txid_t txid_weights_q = - snrt_dma_start_2d( - weights_q, /* dst */ - l->weights_q + weights_offset, /* src */ - P_TILE * sizeof(float), /* size */ - P_TILE * sizeof(float), /* dst_stride */ - l->positional_embeddings * sizeof(float), /* src_stride */ - l->embeddings); /* repetitions */ - - // load the weights for key - snrt_dma_txid_t txid_weights_k = - snrt_dma_start_2d( - weights_k, /* dst */ - l->weights_k + weights_offset, /* src */ - P_TILE * sizeof(float), /* size */ - P_TILE * sizeof(float), /* dst_stride */ - l->positional_embeddings * sizeof(float), /* src_stride */ - l->embeddings); /* repetitions */ - - // load the weights for value - snrt_dma_txid_t txid_weights_v = + ///////////////////////////////////////////////////////////////////// + ////////////// MULTI-HEAD SELF-ATTENTION BLOCK ///////////////// + /////////////////////////////////////////////////////////////////// + + ///////////////////// Layer 1: LayerNorm ///////////////////// + // Input size (S x E) // + // Output size (S x E) // + // S is tiled into Row Blocks of B_r rows // + ////////////////////////////////////////////////////////////////// + + if(LAYERNORM == 1) { + // compute the tiling parameters + uint32_t B_r = l->S_tile_ln; // number of rows per row block + uint32_t T_r = l->seq_len / B_r; // number of row blocks + + dump_idx(B_r); + dump_idx(T_r); + + // compute the size of the matrices + uint32_t ifmap_tcdm = B_r * embeddings * sizeof(double); + + // allocate memory in TCDM + void *tcdm_ptr = (double *)snrt_l1_next(); + double *ifmap = tcdm_ptr; + tcdm_ptr += ifmap_tcdm; + + uint32_t start_loop_outer = snrt_mcycle(); + for (int t_r = 0; t_r < T_r; t_r++) { + // dump_index(t_r); + uint32_t ifmap_offset = t_r * B_r * embeddings; + if (!snrt_is_compute_core()) { + // load the input feature map: B_r * E + // DMA transfer the ifmap into the cluster TCDM + uint32_t start_dma = snrt_mcycle(); + snrt_dma_txid_t txid_ifmap = snrt_dma_start_2d( - weights_v, /* dst */ - l->weights_v + weights_offset, /* src */ - P_TILE * sizeof(float), /* size */ - P_TILE * sizeof(float), /* dst_stride */ - l->positional_embeddings * sizeof(float), /* src_stride */ - l->embeddings); /* repetitions */ - - snrt_dma_wait_all(); - - if (DEBUG_VERBOSE == 1) { - if (BANSHEE == 0) { - // print the current query weights for debugging - for (int i = 0; i < l->embeddings * P_TILE; i++) { - dump_weights(weights_q[i]); - } - - // print the current key weights for debugging - for (int i = 0; i < l->embeddings * P_TILE; i++) { - dump_weights(weights_k[i]); - } - - // print the current value weights for debugging - for (int i = 0; i < l->embeddings * P_TILE; i++) { - dump_weights(weights_v[i]); - } - } else { - // print the current query weights for debugging - for (int i = 0; i < l->embeddings * P_TILE; i++) { - printf("weights_q[%d] = %f\n", i, weights_q[i]); - } - - // print the current key weights for debugging - for (int i = 0; i < l->embeddings * P_TILE; i++) { - printf("weights_k[%d] = %f\n", i, weights_k[i]); - } - - // print the current value weights for debugging - for (int i = 0; i < l->embeddings * P_TILE; i++) { - printf("weights_v[%d] = %f\n", i, weights_v[i]); - } - } - } else if (DEBUG == 1) - { - if (BANSHEE == 0) { - dump_weights(weights_q[0]); - dump_weights(weights_q[l->embeddings * P_TILE - 1]); - dump_weights(weights_k[0]); - dump_weights(weights_k[l->embeddings * P_TILE - 1]); - dump_weights(weights_v[0]); - dump_weights(weights_v[l->embeddings * P_TILE - 1]); - } else { - printf("weights_q[0] = %f\n", weights_q[0]); - printf("weights_q[%d] = %f\n", l->embeddings * P_TILE - 1, weights_q[l->embeddings * P_TILE - 1]); - printf("weights_k[0] = %f\n", weights_k[0]); - printf("weights_k[%d] = %f\n", l->embeddings * P_TILE - 1, weights_k[l->embeddings * P_TILE - 1]); - printf("weights_v[0] = %f\n", weights_v[0]); - printf("weights_v[%d] = %f\n", l->embeddings * P_TILE - 1, weights_v[l->embeddings * P_TILE - 1]); - } - } - - + ifmap, /* dst */ + l->ifmap + ifmap_offset, /* src */ + embeddings * sizeof(double), /* size */ + embeddings * sizeof(double), /* dst_stride */ + embeddings * sizeof(double), /* src_stride */ + B_r); /* repetitions */ + + snrt_dma_wait_all(); + uint32_t end_dma = snrt_mcycle(); + + // dump the ifmap for debugging + // for (int i = 0; i < B_r * embeddings; i++) { + // dump_debug(ifmap[i]); + // } } snrt_cluster_hw_barrier(); - // compute the query matrix - // void gemm_fp32_baseline(uint32_t M, uint32_t N, uint32_t K, double* A, - // uint32_t ldA, uint32_t ta, float* B, uint32_t ldB, - // uint32_t tb, float* C, uint32_t ldC, float ALPHA) - if (snrt_is_compute_core()) { // determine the row offset for the current compute core - // every core computes S_TILE / num_cores rows - // and jumps by S_TILE rows in the next iteration - uint32_t row_offset = (s_tile * S_TILE + (S_TILE / num_cores) * compute_id) * l->embeddings; - - // compute a tile of the query matrix - // ifmap (S_TILE x E) * weights_q (E x P) = query (S_TILE x P) - if (p_tile == 0) { - gemm_fp32_baseline(S_TILE / num_cores, P_TILE, l->embeddings, - &ifmap[row_offset], l->embeddings, 0, weights_q, - P_TILE, 0, &query[row_offset], P_TILE, 0.0f); - } else { - gemm_fp32_baseline(S_TILE / num_cores, P_TILE, l->embeddings, - &ifmap[row_offset], l->embeddings, 0, weights_q, - P_TILE, 0, &query[row_offset], P_TILE, 1.0f); - } - - // print the current query matrix for debugging - if (DEBUG_VERBOSE == 1) { - if (BANSHEE == 0) { - for (int i = 0; i < S_TILE * P_TILE; i++) { - dump_query(query[i]); - } - } else { - for (int i = 0; i < S_TILE * P_TILE; i++) { - printf("query[%d] = %f\n", i, query[i]); - } - } - } else if (DEBUG == 1) - { - if (BANSHEE == 0) { - float query_start = query[row_offset + 0]; - float query_end = query[S_TILE * P_TILE - 1]; - dump_query(query_start); - dump_query(query_end); - } else { - printf("query[0] = %f\n", query[row_offset + 0]); - printf("query[%d] = %f\n", S_TILE * P_TILE - 1, query[S_TILE * P_TILE - 1]); - } - } - - - // compute the key matrix - if (p_tile == 0) { - gemm_fp32_baseline(S_TILE / num_cores, P_TILE, l->embeddings, - &ifmap[row_offset], l->embeddings, 0, weights_k, - P_TILE, 0, &key[row_offset], P_TILE, 0.0f); - } else { - gemm_fp32_baseline(S_TILE / num_cores, P_TILE, l->embeddings, - &ifmap[row_offset], l->embeddings, 0, weights_k, - P_TILE, 0, &key[row_offset], P_TILE, 1.0f); - } - - // print the current key matrix for debugging - if (DEBUG_VERBOSE == 1) { - if (BANSHEE == 0) { - for (int i = 0; i < S_TILE * P_TILE; i++) { - dump_weights(key[i]); - } - } else { - for (int i = 0; i < S_TILE * P_TILE; i++) { - printf("key[%d] = %f\n", i, key[i]); - } - } - } else if (DEBUG == 1) - { - if (BANSHEE == 0) { - float key_start = key[row_offset + 0]; - float key_end = key[S_TILE * P_TILE - 1]; - dump_weights(key_start); - dump_weights(key_end); - } else { - printf("key[0] = %f\n", key[0]); - printf("key[%d] = %f\n", S_TILE * P_TILE - 1, key[S_TILE * P_TILE - 1]); - } - } - - - // compute the value matrix - if (p_tile == 0) { - gemm_fp32_baseline(S_TILE / num_cores, P_TILE, l->embeddings, - &ifmap[row_offset], l->embeddings, 0, weights_v, - P_TILE, 0, &value[row_offset], P_TILE, 0.0f); - } else { - gemm_fp32_baseline(S_TILE / num_cores, P_TILE, l->embeddings, - &ifmap[row_offset], l->embeddings, 0, weights_v, - P_TILE, 0, &value[row_offset], P_TILE, 1.0f); - } - - // print the current value matrix for debugging - if (DEBUG_VERBOSE == 1) { - if (BANSHEE == 0) { - for (int i = 0; i < S_TILE * P_TILE; i++) { - dump_value(value[i]); - } - } else { - for (int i = 0; i < S_TILE * P_TILE; i++) { - printf("value[%d] = %f\n", i, value[i]); - } - } - } else if (DEBUG == 1) - { - if (BANSHEE == 0) { - float value_start = value[row_offset + 0]; - float value_end = value[S_TILE * P_TILE - 1]; - dump_value(value_start); - dump_value(value_end); - } else { - printf("value[0] = %f\n", value[0]); - printf("value[%d] = %f\n", S_TILE * P_TILE - 1, value[S_TILE * P_TILE - 1]); - } - } - - - + uint32_t row_offset = (B_r / num_cores) * compute_id * embeddings; + uint32_t ldI = embeddings; + uint32_t start_layernom = snrt_mcycle(); + transformer_layernorm_fp64(&ifmap[row_offset], ldI, B_r / num_cores, embeddings, eps); + uint32_t end_layernom = snrt_mcycle(); } else { - + snrt_cluster_hw_barrier(); } - snrt_cluster_hw_barrier(); + } + uint32_t end_loop_outer = snrt_mcycle(); - // Write back the tiled matrices to DRAM - if (!snrt_is_compute_core()) { - // printf ("Writing back the tiled matrices to DRAM\n"); - // determine the current offset in the output matrices - uint32_t q_lin_offset = s_tile * S_TILE * l->positional_embeddings + p_tile * P_TILE; + snrt_cluster_hw_barrier(); - // write back the query matrix - snrt_dma_txid_t txid_q_lin = - snrt_dma_start_2d( - &l->Q_lin[q_lin_offset], /* dst */ - query, /* src */ - P_TILE * sizeof(float), /* size */ - l->positional_embeddings * sizeof(float), /* dst_stride */ - P_TILE * sizeof(float), /* src_stride */ - S_TILE ); /* repetitions */ - - // write back the key matrix - snrt_dma_txid_t txid_k_lin = - snrt_dma_start_2d( - &l->K_lin[q_lin_offset], /* dst */ - key, /* src */ - P_TILE * sizeof(float), /* size */ - l->positional_embeddings * sizeof(float), /* dst_stride */ - P_TILE * sizeof(float), /* src_stride */ - S_TILE ); /* repetitions */ - - // write back the value matrix - snrt_dma_txid_t txid_v_lin = - snrt_dma_start_2d( - &l->V_lin[q_lin_offset], /* dst */ - value, /* src */ - P_TILE * sizeof(float), /* size */ - l->positional_embeddings * sizeof(float), /* dst_stride */ - P_TILE * sizeof(float), /* src_stride */ - S_TILE ); /* repetitions */ + } - snrt_dma_wait_all(); - - } + snrt_global_barrier(); +} - } +/** + * @brief Transformer layer + * + * @param l transformer_layer struct that holds addresses and parameters in FP32 + * + */ +static inline void transformer_layer_fp32(transformer_layer_fp32_t *const l) { + uint32_t compute_id = snrt_global_core_idx(); + uint32_t num_cores = snrt_cluster_compute_core_num(); + ///////////////////////////////////////////////////////////////////// + ////////////// MULTI-HEAD SELF-ATTENTION BLOCK ///////////////// + /////////////////////////////////////////////////////////////////// + + ///////////////////// Layer 1: LayerNorm ///////////////////// + // Input size (S x E) // + // Output size (S x E) // + // S is tiled into Row Blocks of B_r rows // + ////////////////////////////////////////////////////////////////// + uint32_t seq_len = l->seq_len; + uint32_t embeddings = l->embeddings; + uint32_t positional_embeddings = l->positional_embeddings; + uint32_t heads = l->heads; + uint32_t eps = l->eps; + + // compute the tiling parameters + uint32_t B_r; // number of rows per row block + uint32_t T_r; // number of row blocks + if (BRIEF == 0) { + B_r = 40; + T_r = l->seq_len / B_r; + } else { + T_r = 2; + B_r = 16; + embeddings = 48; } - snrt_cluster_hw_barrier(); - - ///////////////////////////////////////////////// - //// FlashAttention-2 //// - ///////////////////////////////////////////////// - - int32_t B_r = 64; // Row block size - int32_t B_c = 64; // Column block size - - // Compute the storage size of the matrices - uint32_t RowBlock = B_r * l->positional_embeddings * sizeof(float); - uint32_t ColBlock = B_c * l->positional_embeddings * sizeof(float); - - // Reset the TCDM pointer - tcdm_ptr = (float *)snrt_l1_next(); - float *Q_tile = tcdm_ptr; - tcdm_ptr += RowBlock; - float *K_tile = tcdm_ptr; - tcdm_ptr += ColBlock; - float *V_tile = tcdm_ptr; - tcdm_ptr += ColBlock; - float *A_tile = tcdm_ptr; - tcdm_ptr += B_r * B_c * sizeof(float); - float *P_tilde = tcdm_ptr; - tcdm_ptr += B_r * B_c * sizeof(float); - float *O_tile = tcdm_ptr; - tcdm_ptr += B_r * l->positional_embeddings * sizeof(float); - float *m_i = tcdm_ptr; - tcdm_ptr += B_r; - float *m_i_prev = tcdm_ptr; - tcdm_ptr += B_r; - float *l_j = tcdm_ptr; - tcdm_ptr += B_r; - float *shifted_exp = tcdm_ptr; - tcdm_ptr += B_r; - - // compute memory usage - used_memory_kB = (float)((uint64_t)tcdm_ptr - (uint64_t)snrt_l1_next()) / 1024.0f; - dump_debug(used_memory_kB); - - int32_t T_r = l->seq_len / B_r; // Number of row blocks - int32_t T_c = l->seq_len / B_c; // Number of column blocks - - for (int i = 0; i < T_r; i++) { - dump_idx(i); - // load the Q_tile + // compute the size of the matrices + uint32_t ifmap_tcdm = B_r * embeddings * sizeof(float); + + // allocate memory in TCDM + void *tcdm_ptr = (float *)snrt_l1_next(); + float *ifmap = tcdm_ptr; + tcdm_ptr += ifmap_tcdm; + + uint32_t start_loop_outer = snrt_mcycle(); + for (int t_r = 0; t_r < T_r; t_r++) { + // dump_index(t_r); + uint32_t ifmap_offset = t_r * B_r * embeddings; + // DMA transfer the ifmap into the cluster TCDM + uint32_t start_dma = snrt_mcycle(); if (!snrt_is_compute_core()) { - // printf("Loading Q_tile %d/%d\n", i, T_r); - snrt_dma_txid_t txid_Q_tile = + // load the input feature map: B_r * E + snrt_dma_txid_t txid_ifmap = snrt_dma_start_2d( - Q_tile, /* dst */ - &l->Q_lin[i * B_r * l->positional_embeddings], /* src */ - l->positional_embeddings * sizeof(float), /* size */ - l->positional_embeddings * sizeof(float), /* dst_stride */ - l->positional_embeddings * sizeof(float), /* src_stride */ - B_r); /* repetitions */ - - snrt_dma_wait_all(); - } - - // Initialize the SoftMax statistics - for (int r = 0; r < B_r / num_cores; r++) { - m_i[r + compute_id * (B_r / num_cores)] = -INFINITY; - m_i_prev[r + compute_id * (B_r / num_cores)] = -INFINITY; - l_j[r + compute_id * (B_r / num_cores)] = 0.0f; - } - - snrt_cluster_hw_barrier(); - - for (int j = 0; j < T_c; j++) { - dump_idx(j); - // load the K_tile and V_tile - if (!snrt_is_compute_core()) { - // printf("Loading K_tile %d/%d\n", j, T_c); - snrt_dma_txid_t txid_K_tile = - snrt_dma_start_2d( - K_tile, /* dst */ - &l->K_lin[j * B_c * l->positional_embeddings], /* src */ - l->positional_embeddings * sizeof(float), /* size */ - l->positional_embeddings * sizeof(float), /* dst_stride */ - l->positional_embeddings * sizeof(float), /* src_stride */ - B_c); /* repetitions */ - - snrt_dma_txid_t txid_V_tile = - snrt_dma_start_2d( - V_tile, /* dst */ - &l->V_lin[j * B_c * l->positional_embeddings], /* src */ - l->positional_embeddings * sizeof(float), /* size */ - l->positional_embeddings * sizeof(float), /* dst_stride */ - l->positional_embeddings * sizeof(float), /* src_stride */ - B_c); /* repetitions */ + ifmap, /* dst */ + l->ifmap + ifmap_offset, /* src */ + embeddings * sizeof(float), /* size */ + embeddings * sizeof(float), /* dst_stride */ + embeddings * sizeof(float), /* src_stride */ + B_r); /* repetitions */ snrt_dma_wait_all(); - } - - snrt_cluster_hw_barrier(); - - // Compute the A_tile - if (snrt_is_compute_core()) { - // printf("Computing A_tile %d/%d\n", i, T_r); - // determine the row offset for the current compute core - // every core computes B_r / num_cores rows - // and jumps by B_r rows in the next iteration - uint32_t row_offset = (i * B_r + (B_r / num_cores) * compute_id) * l->positional_embeddings; - // determine the column offset for the current compute core - // every core computes B_c / num_cores columns - // and jumps by B_c columns in the next iteration - uint32_t col_offset = (j * B_c + (B_c / num_cores) * compute_id) * l->positional_embeddings; - - // compute a tile of the attention matrix - // A_tile = Q_tile * K_tile^T (B_r x B_c) - // we parallelize the rows of Q_tile - gemm_fp32_baseline(B_r / num_cores, B_c, l->positional_embeddings, - &Q_tile[row_offset], l->positional_embeddings, 0, K_tile, - l->positional_embeddings, 1, &A_tile[compute_id * (B_r / num_cores) * B_c], B_c, 0.0f); - - // print the current attention matrix for debugging - if (DEBUG_VERBOSE == 1) { - if (BANSHEE == 0) { - for (int i = 0; i < B_r * B_c; i++) { - dump_value(A_tile[i]); - } - } else { - for (int i = 0; i < B_r * B_c; i++) { - printf("A_tile[%d] = %f\n", i, A_tile[i]); - } - } - } else if (DEBUG == 1) - { - if (BANSHEE == 0) { - float attention_start = A_tile[0]; - float attention_end = A_tile[B_r * B_c - 1]; - dump_value(attention_start); - dump_value(attention_end); - } else { - printf("A_tile[0] = %f\n", A_tile[0]); - printf("A_tile[%d] = %f\n", B_r * B_c - 1, A_tile[B_r * B_c - 1]); - } - } - - // determine the row max of the rows of A_tile (B_r x B_c) - // we parallelize over the rows of A_tile - int32_t A_offset = compute_id * (B_r / num_cores) * B_c; - for (int k = 0; k < (B_r / num_cores); k++) { - // store the previous row max - m_i_prev[k + compute_id * (B_r / num_cores)] = m_i[k + compute_id * (B_r / num_cores)]; - for (int l = 0; l < B_c; l++) { - if (A_tile[A_offset + k * B_c + l] > m_i[k + compute_id * (B_r / num_cores)]) { - m_i[k + compute_id * (B_r / num_cores)] = A_tile[A_offset + k * B_c + l]; - } - } - } - - // next we determine the shifted exponential of A_tile - // shifted_exp = exp(m_i_prev - m_i) - for (int k = 0; k < (B_r / num_cores); k++) { - shifted_exp[k + compute_id * (B_r / num_cores)] = - expf(m_i_prev[k + compute_id * (B_r / num_cores)] - m_i[k + compute_id * (B_r / num_cores)]); - } - - // we compute the P_tilde matrix (B_r x B_c) - // P_tilde = exp(A_tile - m_i) - int32_t P_tilde_offset = A_offset; - for (int k = 0; k < (B_r / num_cores); k++) { - for (int l = 0; l < B_c; l++) { - P_tilde[P_tilde_offset + k * B_c + l] = - expf(A_tile[A_offset + k * B_c + l] - m_i[k + compute_id * (B_r / num_cores)]); - } - } - - // we scale the row sum by the shifted exponential - // l_i = shifted_exp * l_i - for (int k = 0; k < (B_r / num_cores); k++) { - l_j[k + compute_id * (B_r / num_cores)] = - shifted_exp[k + compute_id * (B_r / num_cores)] * l_j[k + compute_id * (B_r / num_cores)]; - for (int l = 0; l < B_c; l++) { - l_j[k + compute_id * (B_r / num_cores)] += P_tilde[P_tilde_offset + k * B_c + l]; - } - } - // finally we compute the fused softmax - // O_i = diag(shifted_exp)^-1 * O_i_prev + P_tilde * V_tile (B_r x E) - // we parallelize over the rows of O_tile - uint32_t O_offset = compute_id * (B_r / num_cores) * l->positional_embeddings; - for (int k = 0; k < B_r / num_cores; k++) { - for (int p = 0; p < l->positional_embeddings; p++) { - // the current row will be scaled by the shifted exponential - if (i == 0) { - O_tile[O_offset + k * l->positional_embeddings + p] = 0.0f; - } else { - O_tile[O_offset + k * l->positional_embeddings + p] = - (1 / shifted_exp[k + compute_id * (B_r / num_cores)]) * O_tile[O_offset + k * l->positional_embeddings + p]; - } - // we compute the matrix multiplication of P_tilde and V_tile - for (int o = 0; o < B_c; o++) { - O_tile[O_offset + k * (l->positional_embeddings) + p] += - P_tilde[P_tilde_offset + k * B_c + o] * V_tile[o * (l->positional_embeddings) + p]; - } - } - } - - } - } // end of T_c loop + // dump the ifmap for debugging + // for (int i = 0; i < B_r * embeddings; i++) { + // dump_debug(ifmap[i]); + // } + } + uint32_t end_dma = snrt_mcycle(); snrt_cluster_hw_barrier(); - // we scale the O_tile by the final row sum - // O_i = diag(l_i)^-1 * O_i if (snrt_is_compute_core()) { - // printf("Scaling O_tile %d/%d\n", i, T_r); - // we parallelize over the rows of O_tile - uint32_t O_offset = compute_id * (B_r / num_cores) * l->positional_embeddings; - for (int k = 0; k < B_r / num_cores; k++) { - for (int p = 0; p < l->positional_embeddings; p++) { - O_tile[O_offset + k * l->positional_embeddings + p] = - (1 / l_j[k + compute_id * (B_r / num_cores)]) * O_tile[O_offset + k * l->positional_embeddings + p]; - } - } + // determine the row offset for the current compute core + uint32_t row_offset = (B_r / num_cores) * compute_id * embeddings; + uint32_t ldI = embeddings; + uint32_t start_layernom = snrt_mcycle(); + transformer_layernorm_fp32(&ifmap[row_offset], ldI, B_r / num_cores, embeddings, eps); + uint32_t end_layernom = snrt_mcycle(); + } else { + snrt_cluster_hw_barrier(); } - snrt_cluster_hw_barrier(); - - // we write back O_i as the i-th block row of the output matrix O to DRAM - // O_i is og size B_r x E - if (!snrt_is_compute_core()) { - // printf("Writing back O_tile %d/%d\n", i, T_r); - // determine the current offset in the output matrix - uint32_t output_offset = i * B_r * l->positional_embeddings; - - // write back the output matrix - snrt_dma_txid_t txid_output = - snrt_dma_start_2d( - &l->O[output_offset], /* dst */ - O_tile, /* src */ - l->positional_embeddings * sizeof(float), /* size */ - l->positional_embeddings * sizeof(float), /* dst_stride */ - l->positional_embeddings * sizeof(float), /* src_stride */ - B_r); /* repetitions */ - - snrt_dma_wait_all(); - } - } // end of T_r loop - - // TODO: This step requires a GLOBAL barrier!!! - snrt_cluster_hw_barrier(); - - // Self Attention - // (S x HP) = (2048 x 768) - // Every cluster computes ceil(2048 / num_clusters) = 86 rows - // and every core computes ceil(86 / num_cores) = 11 rows - // (HP x E) = (768 x 768) - // for simplicity we we only perform two iterations - // of the inner and outer loops - - // define the tile sizes - int32_t s_tile_size = 8; - int32_t e_tile_size = 32; - - int32_t num_s_tiles = l->attn_seq_len / s_tile_size; - int32_t num_e_tiles = l->attn_embeddings / e_tile_size; - - // A_OUT (S x HP) - // ATTN_WEIGHTS (HP x E) - // MSHA_OUT (S x E) - int32_t a_out_size = s_tile_size * l->positional_embeddings * l->heads; - int32_t attn_weights_size = l->heads * l->positional_embeddings * l->attn_embeddings; - int32_t mhsa_out_size = s_tile_size * e_tile_size; - - // reset the TCDM pointer - tcdm_ptr = (float *)snrt_l1_next(); - float *a_out = tcdm_ptr; - tcdm_ptr += a_out_size; - float *attn_weights = tcdm_ptr; - tcdm_ptr += attn_weights_size; - float *mhsa_out = tcdm_ptr; - tcdm_ptr += mhsa_out_size; - - // compute memory usage - used_memory_kB = (float)((uint64_t)tcdm_ptr - (uint64_t)snrt_l1_next()) / 1024.0f; - dump_debug(used_memory_kB); - - for (int s_tile = 0; s_tile < num_s_tiles; s_tile++) { - // load the a_out tile from DRAM - if (!snrt_is_compute_core()) { - // compute the offset from which to load - uint32_t a_out_offset = s_tile * s_tile_size * l->positional_embeddings * l->heads; - - snrt_dma_txid_t txid_a_out = - snrt_dma_start_2d( - a_out, /* dst */ - &l->attn_ifmap[a_out_offset], /* src */ - l->positional_embeddings * l->heads * sizeof(float),/* size */ - l->positional_embeddings * l->heads * sizeof(float),/* dst_stride */ - l->positional_embeddings * l->heads * sizeof(float),/* src_stride */ - s_tile); /* repetitions */ - snrt_dma_wait_all(); - } + } + uint32_t end_loop_outer = snrt_mcycle(); - snrt_cluster_hw_barrier(); + ///////////////////// Layer 2: Linear ///////////////////// + // Input size (S x E) // + // Weights size (E x P) // + // Output size (S x P) // + // S is tiled into Row Blocks of B_r rows // + // P is tiled into Column Blocks of B_c columns // + /////////////////////////////////////////////////////////////// - for (int e_tile = 0; e_tile < num_e_tiles; e_tile++) { - // we load the column block of the attention weights - if (!snrt_is_compute_core()) { - // compute the offset from which to load - uint32_t attn_weights_offset = e_tile * e_tile_size * l->positional_embeddings * l->attn_embeddings; - snrt_dma_txid_t txid_attn_weights = - snrt_dma_start_2d( - attn_weights, /* dst */ - &l->attn_weights[attn_weights_offset], /* src */ - e_tile * sizeof(float), /* size */ - l->attn_embeddings * sizeof(float), /* dst_stride */ - e_tile * sizeof(float), /* src_stride */ - l->heads * l->positional_embeddings); /* repetitions */ - snrt_dma_wait_all(); - } + + // uint32_t S_TILE = 24; + // uint32_t P_TILE = 4; + // uint32_t NUM_S_TILES = l->seq_len / S_TILE; + // uint32_t NUM_P_TILES = l->embeddings / P_TILE; + // // Matrices for the linear mapping + // uint32_t ifmap_size = S_TILE * l->embeddings * sizeof(float); + // uint32_t weights_q_tiled_size = l->embeddings * P_TILE * sizeof(float); + // uint32_t weights_k_tiled_size = weights_q_tiled_size; + // uint32_t weights_v_tiled_size = weights_q_tiled_size; + // uint32_t key_matrix = S_TILE * P_TILE * sizeof(float); + // uint32_t query_matrix = S_TILE * P_TILE * sizeof(float); + // uint32_t value_matrix = S_TILE * P_TILE * sizeof(float); + // uint32_t attention_matrix = S_TILE * S_TILE * sizeof(float); + + // // here we define the matrices that will be stored back in DRAM + // uint32_t q_lin_size = l->seq_len * l->positional_embeddings * sizeof(float); + // uint32_t k_lin_size = l->positional_embeddings * l->seq_len * sizeof(float); + // uint32_t v_lin_size = l->positional_embeddings * l->seq_len * sizeof(float); + // uint32_t output_size = l->seq_len * l->positional_embeddings * sizeof(float); + + // void *tcdm_ptr = (float *)snrt_l1_next(); + // float *ifmap = tcdm_ptr; + // tcdm_ptr += ifmap_size; + // float *weights_q = tcdm_ptr; + // tcdm_ptr += weights_q_tiled_size; + // float *weights_k = tcdm_ptr; + // tcdm_ptr += weights_k_tiled_size; + // float *weights_v = tcdm_ptr; + // tcdm_ptr += weights_v_tiled_size; + // float *key = tcdm_ptr; + // tcdm_ptr += key_matrix; + // float *query = tcdm_ptr; + // tcdm_ptr += query_matrix; + // float *value = tcdm_ptr; + // tcdm_ptr += value_matrix; + // float *attention = tcdm_ptr; + // tcdm_ptr += attention_matrix; + + // float used_memory_kB = (float)((uint64_t)tcdm_ptr - (uint64_t)snrt_l1_next()) / 1024.0f; + + // dump_debug(used_memory_kB); + + // snrt_cluster_hw_barrier(); + + // ///////////////////////////////////////////////// + // //// Linear mapping //// + // ///////////////////////////////////////////////// + + // for (int s_tile = 0; s_tile < NUM_S_TILES; s_tile++) { + // dump_idx(s_tile); + // // printf("Loaded %d/%d s_tiles\n", s_tile, NUM_S_TILES); + // // we load the ifmap for the current s_tile + // uint32_t ifmap_offset = s_tile * S_TILE * l->embeddings; + // if (!snrt_is_compute_core()) { + // write_csr(0x7d0,0); + // // load the input feature map: 24 * 768 * 4 = 73728 bytes + // snrt_dma_txid_t txid_ifmap = + // snrt_dma_start_2d( + // ifmap, /* dst */ + // l->ifmap + ifmap_offset, /* src */ + // S_TILE * sizeof(float), /* size */ + // l->embeddings, /* dst_stride */ + // l->embeddings, /* src_stride */ + // S_TILE); /* repetitions */ + + // snrt_dma_wait_all(); + + // // print the current ifmap for debugging + // if (DEBUG_VERBOSE == 1) { + // if (BANSHEE == 0) { + // for (int i = 0; i < S_TILE * l->embeddings; i++) { + // dump_ifmap(ifmap[i]); + // } + // } else { + // for (int i = 0; i < S_TILE * l->embeddings; i++) { + // printf("ifmap[%d] = %f\n", i, ifmap[i]); + // } + // } + // } else if (DEBUG == 1) + // { + // if (BANSHEE == 0) { + // dump_ifmap(ifmap[0]); + // dump_ifmap(ifmap[S_TILE * l->embeddings - 1]); + // } else { + // printf("ifmap[0] = %f\n", ifmap[0]); + // printf("ifmap[%d] = %f\n", S_TILE * l->embeddings - 1, ifmap[S_TILE * l->embeddings - 1]); + // } + + // } + + // } + + // for (int p_tile = 0; p_tile < NUM_P_TILES; p_tile++) { + // dump_idx(p_tile); + // // printf("Loaded %d/%d p_tiles\n", p_tile, NUM_P_TILES); + + // // we load the weights for the current p_tile + // // weights matrix: E x P = 768 x 64 + // uint32_t weights_offset = p_tile * P_TILE; + + // if(!snrt_is_compute_core()) { + // // load the weights for query: 768 * 4 = 3072 bytes + // snrt_dma_txid_t txid_weights_q = + // snrt_dma_start_2d( + // weights_q, /* dst */ + // l->weights_q + weights_offset, /* src */ + // P_TILE * sizeof(float), /* size */ + // P_TILE * sizeof(float), /* dst_stride */ + // l->positional_embeddings * sizeof(float), /* src_stride */ + // l->embeddings); /* repetitions */ + + // // load the weights for key + // snrt_dma_txid_t txid_weights_k = + // snrt_dma_start_2d( + // weights_k, /* dst */ + // l->weights_k + weights_offset, /* src */ + // P_TILE * sizeof(float), /* size */ + // P_TILE * sizeof(float), /* dst_stride */ + // l->positional_embeddings * sizeof(float), /* src_stride */ + // l->embeddings); /* repetitions */ + + // // load the weights for value + // snrt_dma_txid_t txid_weights_v = + // snrt_dma_start_2d( + // weights_v, /* dst */ + // l->weights_v + weights_offset, /* src */ + // P_TILE * sizeof(float), /* size */ + // P_TILE * sizeof(float), /* dst_stride */ + // l->positional_embeddings * sizeof(float), /* src_stride */ + // l->embeddings); /* repetitions */ + + // snrt_dma_wait_all(); + + // if (DEBUG_VERBOSE == 1) { + // if (BANSHEE == 0) { + // // print the current query weights for debugging + // for (int i = 0; i < l->embeddings * P_TILE; i++) { + // dump_weights(weights_q[i]); + // } + + // // print the current key weights for debugging + // for (int i = 0; i < l->embeddings * P_TILE; i++) { + // dump_weights(weights_k[i]); + // } + + // // print the current value weights for debugging + // for (int i = 0; i < l->embeddings * P_TILE; i++) { + // dump_weights(weights_v[i]); + // } + // } else { + // // print the current query weights for debugging + // for (int i = 0; i < l->embeddings * P_TILE; i++) { + // printf("weights_q[%d] = %f\n", i, weights_q[i]); + // } + + // // print the current key weights for debugging + // for (int i = 0; i < l->embeddings * P_TILE; i++) { + // printf("weights_k[%d] = %f\n", i, weights_k[i]); + // } + + // // print the current value weights for debugging + // for (int i = 0; i < l->embeddings * P_TILE; i++) { + // printf("weights_v[%d] = %f\n", i, weights_v[i]); + // } + // } + // } else if (DEBUG == 1) + // { + // if (BANSHEE == 0) { + // dump_weights(weights_q[0]); + // dump_weights(weights_q[l->embeddings * P_TILE - 1]); + // dump_weights(weights_k[0]); + // dump_weights(weights_k[l->embeddings * P_TILE - 1]); + // dump_weights(weights_v[0]); + // dump_weights(weights_v[l->embeddings * P_TILE - 1]); + // } else { + // printf("weights_q[0] = %f\n", weights_q[0]); + // printf("weights_q[%d] = %f\n", l->embeddings * P_TILE - 1, weights_q[l->embeddings * P_TILE - 1]); + // printf("weights_k[0] = %f\n", weights_k[0]); + // printf("weights_k[%d] = %f\n", l->embeddings * P_TILE - 1, weights_k[l->embeddings * P_TILE - 1]); + // printf("weights_v[0] = %f\n", weights_v[0]); + // printf("weights_v[%d] = %f\n", l->embeddings * P_TILE - 1, weights_v[l->embeddings * P_TILE - 1]); + // } + // } + - snrt_cluster_hw_barrier(); + // } - // compute the matrix multiplication - if (snrt_is_compute_core()) { - // compute the row offset for the current compute core - uint32_t a_out_offset = compute_id * (s_tile_size / num_cores) * l->positional_embeddings * l->heads; - uint32_t mhsa_out_offset = (s_tile * s_tile_size + (s_tile_size / num_cores) * compute_id) * e_tile_size; - - // compute a tile of the MHSA output - // A_OUT (S x HP) * ATTN_WEIGHTS (HP x E) = MSHA_OUT (S x E) - // we parallelize the rows of A_OUT - gemm_fp32_baseline(s_tile_size / num_cores, e_tile_size, l->positional_embeddings * l->heads, - &a_out[a_out_offset], l->positional_embeddings * l->heads, 0, attn_weights, - l->positional_embeddings * l->heads, 0, &mhsa_out[mhsa_out_offset], e_tile_size, 0.0f); - } + // snrt_cluster_hw_barrier(); - snrt_cluster_hw_barrier(); + // // compute the query matrix + // // void gemm_fp32_baseline(uint32_t M, uint32_t N, uint32_t K, double* A, + // // uint32_t ldA, uint32_t ta, float* B, uint32_t ldB, + // // uint32_t tb, float* C, uint32_t ldC, float ALPHA) - // write back the MHSA output to DRAM - if (!snrt_is_compute_core()) { - // compute the offset from which to load - uint32_t mhsa_out_offset = (s_tile * s_tile_size + (s_tile_size / num_cores) * compute_id) * e_tile_size; + // if (snrt_is_compute_core()) { + // // determine the row offset for the current compute core + // // every core computes S_TILE / num_cores rows + // // and jumps by S_TILE rows in the next iteration + // uint32_t row_offset = (s_tile * S_TILE + (S_TILE / num_cores) * compute_id) * l->embeddings; + + // // compute a tile of the query matrix + // // ifmap (S_TILE x E) * weights_q (E x P) = query (S_TILE x P) + // if (p_tile == 0) { + // gemm_fp32_baseline(S_TILE / num_cores, P_TILE, l->embeddings, + // &ifmap[row_offset], l->embeddings, 0, weights_q, + // P_TILE, 0, &query[row_offset], P_TILE, 0.0f); + // } else { + // gemm_fp32_baseline(S_TILE / num_cores, P_TILE, l->embeddings, + // &ifmap[row_offset], l->embeddings, 0, weights_q, + // P_TILE, 0, &query[row_offset], P_TILE, 1.0f); + // } + + // // print the current query matrix for debugging + // if (DEBUG_VERBOSE == 1) { + // if (BANSHEE == 0) { + // for (int i = 0; i < S_TILE * P_TILE; i++) { + // dump_query(query[i]); + // } + // } else { + // for (int i = 0; i < S_TILE * P_TILE; i++) { + // printf("query[%d] = %f\n", i, query[i]); + // } + // } + // } else if (DEBUG == 1) + // { + // if (BANSHEE == 0) { + // float query_start = query[row_offset + 0]; + // float query_end = query[S_TILE * P_TILE - 1]; + // dump_query(query_start); + // dump_query(query_end); + // } else { + // printf("query[0] = %f\n", query[row_offset + 0]); + // printf("query[%d] = %f\n", S_TILE * P_TILE - 1, query[S_TILE * P_TILE - 1]); + // } + // } + - snrt_dma_txid_t txid_mhsa_out = - snrt_dma_start_2d( - &l->attn_out[mhsa_out_offset], /* dst */ - mhsa_out, /* src */ - e_tile_size * sizeof(float), /* size */ - l->attn_embeddings * sizeof(float), /* dst_stride */ - e_tile_size * sizeof(float), /* src_stride */ - s_tile_size); /* repetitions */ - snrt_dma_wait_all(); - } + // // compute the key matrix + // if (p_tile == 0) { + // gemm_fp32_baseline(S_TILE / num_cores, P_TILE, l->embeddings, + // &ifmap[row_offset], l->embeddings, 0, weights_k, + // P_TILE, 0, &key[row_offset], P_TILE, 0.0f); + // } else { + // gemm_fp32_baseline(S_TILE / num_cores, P_TILE, l->embeddings, + // &ifmap[row_offset], l->embeddings, 0, weights_k, + // P_TILE, 0, &key[row_offset], P_TILE, 1.0f); + // } + + // // print the current key matrix for debugging + // if (DEBUG_VERBOSE == 1) { + // if (BANSHEE == 0) { + // for (int i = 0; i < S_TILE * P_TILE; i++) { + // dump_weights(key[i]); + // } + // } else { + // for (int i = 0; i < S_TILE * P_TILE; i++) { + // printf("key[%d] = %f\n", i, key[i]); + // } + // } + // } else if (DEBUG == 1) + // { + // if (BANSHEE == 0) { + // float key_start = key[row_offset + 0]; + // float key_end = key[S_TILE * P_TILE - 1]; + // dump_weights(key_start); + // dump_weights(key_end); + // } else { + // printf("key[0] = %f\n", key[0]); + // printf("key[%d] = %f\n", S_TILE * P_TILE - 1, key[S_TILE * P_TILE - 1]); + // } + // } + - } - } + // // compute the value matrix + // if (p_tile == 0) { + // gemm_fp32_baseline(S_TILE / num_cores, P_TILE, l->embeddings, + // &ifmap[row_offset], l->embeddings, 0, weights_v, + // P_TILE, 0, &value[row_offset], P_TILE, 0.0f); + // } else { + // gemm_fp32_baseline(S_TILE / num_cores, P_TILE, l->embeddings, + // &ifmap[row_offset], l->embeddings, 0, weights_v, + // P_TILE, 0, &value[row_offset], P_TILE, 1.0f); + // } + + // // print the current value matrix for debugging + // if (DEBUG_VERBOSE == 1) { + // if (BANSHEE == 0) { + // for (int i = 0; i < S_TILE * P_TILE; i++) { + // dump_value(value[i]); + // } + // } else { + // for (int i = 0; i < S_TILE * P_TILE; i++) { + // printf("value[%d] = %f\n", i, value[i]); + // } + // } + // } else if (DEBUG == 1) + // { + // if (BANSHEE == 0) { + // float value_start = value[row_offset + 0]; + // float value_end = value[S_TILE * P_TILE - 1]; + // dump_value(value_start); + // dump_value(value_end); + // } else { + // printf("value[0] = %f\n", value[0]); + // printf("value[%d] = %f\n", S_TILE * P_TILE - 1, value[S_TILE * P_TILE - 1]); + // } + // } + - snrt_cluster_hw_barrier(); - ///////////////////////////////////////////////// - //// MULTI-LAYER PERCEPTRON //// - ///////////////////////////////////////////////// + // } else { + + // } + + // snrt_cluster_hw_barrier(); + + // // Write back the tiled matrices to DRAM + // if (!snrt_is_compute_core()) { + // // printf ("Writing back the tiled matrices to DRAM\n"); + // // determine the current offset in the output matrices + // uint32_t q_lin_offset = s_tile * S_TILE * l->positional_embeddings + p_tile * P_TILE; + + // // write back the query matrix + // snrt_dma_txid_t txid_q_lin = + // snrt_dma_start_2d( + // &l->Q_lin[q_lin_offset], /* dst */ + // query, /* src */ + // P_TILE * sizeof(float), /* size */ + // l->positional_embeddings * sizeof(float), /* dst_stride */ + // P_TILE * sizeof(float), /* src_stride */ + // S_TILE ); /* repetitions */ + + // // write back the key matrix + // snrt_dma_txid_t txid_k_lin = + // snrt_dma_start_2d( + // &l->K_lin[q_lin_offset], /* dst */ + // key, /* src */ + // P_TILE * sizeof(float), /* size */ + // l->positional_embeddings * sizeof(float), /* dst_stride */ + // P_TILE * sizeof(float), /* src_stride */ + // S_TILE ); /* repetitions */ + + // // write back the value matrix + // snrt_dma_txid_t txid_v_lin = + // snrt_dma_start_2d( + // &l->V_lin[q_lin_offset], /* dst */ + // value, /* src */ + // P_TILE * sizeof(float), /* size */ + // l->positional_embeddings * sizeof(float), /* dst_stride */ + // P_TILE * sizeof(float), /* src_stride */ + // S_TILE ); /* repetitions */ + + // snrt_dma_wait_all(); + + // } + + // } + + // } + + // snrt_cluster_hw_barrier(); + + // ///////////////////////////////////////////////// + // //// FlashAttention-2 //// + // ///////////////////////////////////////////////// + + // int32_t B_r = 64; // Row block size + // int32_t B_c = 64; // Column block size + + // // Compute the storage size of the matrices + // uint32_t RowBlock = B_r * l->positional_embeddings * sizeof(float); + // uint32_t ColBlock = B_c * l->positional_embeddings * sizeof(float); + + // // Reset the TCDM pointer + // tcdm_ptr = (float *)snrt_l1_next(); + // float *Q_tile = tcdm_ptr; + // tcdm_ptr += RowBlock; + // float *K_tile = tcdm_ptr; + // tcdm_ptr += ColBlock; + // float *V_tile = tcdm_ptr; + // tcdm_ptr += ColBlock; + // float *A_tile = tcdm_ptr; + // tcdm_ptr += B_r * B_c * sizeof(float); + // float *P_tilde = tcdm_ptr; + // tcdm_ptr += B_r * B_c * sizeof(float); + // float *O_tile = tcdm_ptr; + // tcdm_ptr += B_r * l->positional_embeddings * sizeof(float); + // float *m_i = tcdm_ptr; + // tcdm_ptr += B_r; + // float *m_i_prev = tcdm_ptr; + // tcdm_ptr += B_r; + // float *l_j = tcdm_ptr; + // tcdm_ptr += B_r; + // float *shifted_exp = tcdm_ptr; + // tcdm_ptr += B_r; + + // // compute memory usage + // used_memory_kB = (float)((uint64_t)tcdm_ptr - (uint64_t)snrt_l1_next()) / 1024.0f; + // dump_debug(used_memory_kB); + + // int32_t T_r = l->seq_len / B_r; // Number of row blocks + // int32_t T_c = l->seq_len / B_c; // Number of column blocks + + // for (int i = 0; i < T_r; i++) { + // dump_idx(i); + // // load the Q_tile + // if (!snrt_is_compute_core()) { + // // printf("Loading Q_tile %d/%d\n", i, T_r); + // snrt_dma_txid_t txid_Q_tile = + // snrt_dma_start_2d( + // Q_tile, /* dst */ + // &l->Q_lin[i * B_r * l->positional_embeddings], /* src */ + // l->positional_embeddings * sizeof(float), /* size */ + // l->positional_embeddings * sizeof(float), /* dst_stride */ + // l->positional_embeddings * sizeof(float), /* src_stride */ + // B_r); /* repetitions */ + + // snrt_dma_wait_all(); + // } + + // // Initialize the SoftMax statistics + // for (int r = 0; r < B_r / num_cores; r++) { + // m_i[r + compute_id * (B_r / num_cores)] = -INFINITY; + // m_i_prev[r + compute_id * (B_r / num_cores)] = -INFINITY; + // l_j[r + compute_id * (B_r / num_cores)] = 0.0f; + // } + + // snrt_cluster_hw_barrier(); + + // for (int j = 0; j < T_c; j++) { + // dump_idx(j); + // // load the K_tile and V_tile + // if (!snrt_is_compute_core()) { + // // printf("Loading K_tile %d/%d\n", j, T_c); + // snrt_dma_txid_t txid_K_tile = + // snrt_dma_start_2d( + // K_tile, /* dst */ + // &l->K_lin[j * B_c * l->positional_embeddings], /* src */ + // l->positional_embeddings * sizeof(float), /* size */ + // l->positional_embeddings * sizeof(float), /* dst_stride */ + // l->positional_embeddings * sizeof(float), /* src_stride */ + // B_c); /* repetitions */ + + // snrt_dma_txid_t txid_V_tile = + // snrt_dma_start_2d( + // V_tile, /* dst */ + // &l->V_lin[j * B_c * l->positional_embeddings], /* src */ + // l->positional_embeddings * sizeof(float), /* size */ + // l->positional_embeddings * sizeof(float), /* dst_stride */ + // l->positional_embeddings * sizeof(float), /* src_stride */ + // B_c); /* repetitions */ + + // snrt_dma_wait_all(); + // } + + // snrt_cluster_hw_barrier(); + + // // Compute the A_tile + // if (snrt_is_compute_core()) { + // // printf("Computing A_tile %d/%d\n", i, T_r); + // // determine the row offset for the current compute core + // // every core computes B_r / num_cores rows + // // and jumps by B_r rows in the next iteration + // uint32_t row_offset = (i * B_r + (B_r / num_cores) * compute_id) * l->positional_embeddings; + // // determine the column offset for the current compute core + // // every core computes B_c / num_cores columns + // // and jumps by B_c columns in the next iteration + // uint32_t col_offset = (j * B_c + (B_c / num_cores) * compute_id) * l->positional_embeddings; + + // // compute a tile of the attention matrix + // // A_tile = Q_tile * K_tile^T (B_r x B_c) + // // we parallelize the rows of Q_tile + // gemm_fp32_baseline(B_r / num_cores, B_c, l->positional_embeddings, + // &Q_tile[row_offset], l->positional_embeddings, 0, K_tile, + // l->positional_embeddings, 1, &A_tile[compute_id * (B_r / num_cores) * B_c], B_c, 0.0f); + + // // print the current attention matrix for debugging + // if (DEBUG_VERBOSE == 1) { + // if (BANSHEE == 0) { + // for (int i = 0; i < B_r * B_c; i++) { + // dump_value(A_tile[i]); + // } + // } else { + // for (int i = 0; i < B_r * B_c; i++) { + // printf("A_tile[%d] = %f\n", i, A_tile[i]); + // } + // } + // } else if (DEBUG == 1) + // { + // if (BANSHEE == 0) { + // float attention_start = A_tile[0]; + // float attention_end = A_tile[B_r * B_c - 1]; + // dump_value(attention_start); + // dump_value(attention_end); + // } else { + // printf("A_tile[0] = %f\n", A_tile[0]); + // printf("A_tile[%d] = %f\n", B_r * B_c - 1, A_tile[B_r * B_c - 1]); + // } + // } + + // // determine the row max of the rows of A_tile (B_r x B_c) + // // we parallelize over the rows of A_tile + // int32_t A_offset = compute_id * (B_r / num_cores) * B_c; + // for (int k = 0; k < (B_r / num_cores); k++) { + // // store the previous row max + // m_i_prev[k + compute_id * (B_r / num_cores)] = m_i[k + compute_id * (B_r / num_cores)]; + // for (int l = 0; l < B_c; l++) { + // if (A_tile[A_offset + k * B_c + l] > m_i[k + compute_id * (B_r / num_cores)]) { + // m_i[k + compute_id * (B_r / num_cores)] = A_tile[A_offset + k * B_c + l]; + // } + // } + // } + + // // next we determine the shifted exponential of A_tile + // // shifted_exp = exp(m_i_prev - m_i) + // for (int k = 0; k < (B_r / num_cores); k++) { + // shifted_exp[k + compute_id * (B_r / num_cores)] = + // expf(m_i_prev[k + compute_id * (B_r / num_cores)] - m_i[k + compute_id * (B_r / num_cores)]); + // } + + // // we compute the P_tilde matrix (B_r x B_c) + // // P_tilde = exp(A_tile - m_i) + // int32_t P_tilde_offset = A_offset; + // for (int k = 0; k < (B_r / num_cores); k++) { + // for (int l = 0; l < B_c; l++) { + // P_tilde[P_tilde_offset + k * B_c + l] = + // expf(A_tile[A_offset + k * B_c + l] - m_i[k + compute_id * (B_r / num_cores)]); + // } + // } + + // // we scale the row sum by the shifted exponential + // // l_i = shifted_exp * l_i + // for (int k = 0; k < (B_r / num_cores); k++) { + // l_j[k + compute_id * (B_r / num_cores)] = + // shifted_exp[k + compute_id * (B_r / num_cores)] * l_j[k + compute_id * (B_r / num_cores)]; + // for (int l = 0; l < B_c; l++) { + // l_j[k + compute_id * (B_r / num_cores)] += P_tilde[P_tilde_offset + k * B_c + l]; + // } + // } + + // // finally we compute the fused softmax + // // O_i = diag(shifted_exp)^-1 * O_i_prev + P_tilde * V_tile (B_r x E) + // // we parallelize over the rows of O_tile + // uint32_t O_offset = compute_id * (B_r / num_cores) * l->positional_embeddings; + // for (int k = 0; k < B_r / num_cores; k++) { + // for (int p = 0; p < l->positional_embeddings; p++) { + // // the current row will be scaled by the shifted exponential + // if (i == 0) { + // O_tile[O_offset + k * l->positional_embeddings + p] = 0.0f; + // } else { + // O_tile[O_offset + k * l->positional_embeddings + p] = + // (1 / shifted_exp[k + compute_id * (B_r / num_cores)]) * O_tile[O_offset + k * l->positional_embeddings + p]; + // } + // // we compute the matrix multiplication of P_tilde and V_tile + // for (int o = 0; o < B_c; o++) { + // O_tile[O_offset + k * (l->positional_embeddings) + p] += + // P_tilde[P_tilde_offset + k * B_c + o] * V_tile[o * (l->positional_embeddings) + p]; + // } + // } + // } + + // } + // } // end of T_c loop + + // snrt_cluster_hw_barrier(); + + // // we scale the O_tile by the final row sum + // // O_i = diag(l_i)^-1 * O_i + // if (snrt_is_compute_core()) { + // // printf("Scaling O_tile %d/%d\n", i, T_r); + // // we parallelize over the rows of O_tile + // uint32_t O_offset = compute_id * (B_r / num_cores) * l->positional_embeddings; + // for (int k = 0; k < B_r / num_cores; k++) { + // for (int p = 0; p < l->positional_embeddings; p++) { + // O_tile[O_offset + k * l->positional_embeddings + p] = + // (1 / l_j[k + compute_id * (B_r / num_cores)]) * O_tile[O_offset + k * l->positional_embeddings + p]; + // } + // } + // } + + // snrt_cluster_hw_barrier(); + + // // we write back O_i as the i-th block row of the output matrix O to DRAM + // // O_i is og size B_r x E + // if (!snrt_is_compute_core()) { + // // printf("Writing back O_tile %d/%d\n", i, T_r); + // // determine the current offset in the output matrix + // uint32_t output_offset = i * B_r * l->positional_embeddings; + + // // write back the output matrix + // snrt_dma_txid_t txid_output = + // snrt_dma_start_2d( + // &l->O[output_offset], /* dst */ + // O_tile, /* src */ + // l->positional_embeddings * sizeof(float), /* size */ + // l->positional_embeddings * sizeof(float), /* dst_stride */ + // l->positional_embeddings * sizeof(float), /* src_stride */ + // B_r); /* repetitions */ + + // snrt_dma_wait_all(); + // } + // } // end of T_r loop + + // // TODO: This step requires a GLOBAL barrier!!! + // snrt_cluster_hw_barrier(); + + // // Self Attention + // // (S x HP) = (2048 x 768) + // // Every cluster computes ceil(2048 / num_clusters) = 86 rows + // // and every core computes ceil(86 / num_cores) = 11 rows + // // (HP x E) = (768 x 768) + // // for simplicity we we only perform two iterations + // // of the inner and outer loops + + // // define the tile sizes + // int32_t s_tile_size = 8; + // int32_t e_tile_size = 32; + + // int32_t num_s_tiles = l->attn_seq_len / s_tile_size; + // int32_t num_e_tiles = l->attn_embeddings / e_tile_size; + + // // A_OUT (S x HP) + // // ATTN_WEIGHTS (HP x E) + // // MSHA_OUT (S x E) + // int32_t a_out_size = s_tile_size * l->positional_embeddings * l->heads; + // int32_t attn_weights_size = l->heads * l->positional_embeddings * l->attn_embeddings; + // int32_t mhsa_out_size = s_tile_size * e_tile_size; + + // // reset the TCDM pointer + // tcdm_ptr = (float *)snrt_l1_next(); + // float *a_out = tcdm_ptr; + // tcdm_ptr += a_out_size; + // float *attn_weights = tcdm_ptr; + // tcdm_ptr += attn_weights_size; + // float *mhsa_out = tcdm_ptr; + // tcdm_ptr += mhsa_out_size; + + // // compute memory usage + // used_memory_kB = (float)((uint64_t)tcdm_ptr - (uint64_t)snrt_l1_next()) / 1024.0f; + // dump_debug(used_memory_kB); + + // // for (int s_tile = 0; s_tile < num_s_tiles; s_tile++) { + // // // load the a_out tile from DRAM + // // if (!snrt_is_compute_core()) { + // // // compute the offset from which to load + // // uint32_t a_out_offset = s_tile * s_tile_size * l->positional_embeddings * l->heads; + + // // snrt_dma_txid_t txid_a_out = + // // snrt_dma_start_2d( + // // a_out, /* dst */ + // // &l->attn_ifmap[a_out_offset], /* src */ + // // l->positional_embeddings * l->heads * sizeof(float),/* size */ + // // l->positional_embeddings * l->heads * sizeof(float),/* dst_stride */ + // // l->positional_embeddings * l->heads * sizeof(float),/* src_stride */ + // // s_tile); /* repetitions */ + // // snrt_dma_wait_all(); + // // } + + // // snrt_cluster_hw_barrier(); + + // // for (int e_tile = 0; e_tile < num_e_tiles; e_tile++) { + // // // we load the column block of the attention weights + // // if (!snrt_is_compute_core()) { + // // // compute the offset from which to load + // // uint32_t attn_weights_offset = e_tile * e_tile_size * l->positional_embeddings * l->attn_embeddings; + + // // snrt_dma_txid_t txid_attn_weights = + // // snrt_dma_start_2d( + // // attn_weights, /* dst */ + // // &l->attn_weights[attn_weights_offset], /* src */ + // // e_tile * sizeof(float), /* size */ + // // l->attn_embeddings * sizeof(float), /* dst_stride */ + // // e_tile * sizeof(float), /* src_stride */ + // // l->heads * l->positional_embeddings); /* repetitions */ + // // snrt_dma_wait_all(); + // // } + + // // snrt_cluster_hw_barrier(); + + // // // compute the matrix multiplication + // // if (snrt_is_compute_core()) { + // // // compute the row offset for the current compute core + // // uint32_t a_out_offset = compute_id * (s_tile_size / num_cores) * l->positional_embeddings * l->heads; + // // uint32_t mhsa_out_offset = (s_tile * s_tile_size + (s_tile_size / num_cores) * compute_id) * e_tile_size; + + // // // compute a tile of the MHSA output + // // // A_OUT (S x HP) * ATTN_WEIGHTS (HP x E) = MSHA_OUT (S x E) + // // // we parallelize the rows of A_OUT + // // gemm_fp32_baseline(s_tile_size / num_cores, e_tile_size, l->positional_embeddings * l->heads, + // // &a_out[a_out_offset], l->positional_embeddings * l->heads, 0, attn_weights, + // // l->positional_embeddings * l->heads, 0, &mhsa_out[mhsa_out_offset], e_tile_size, 0.0f); + // // } + + // // snrt_cluster_hw_barrier(); + + // // // write back the MHSA output to DRAM + // // if (!snrt_is_compute_core()) { + // // // compute the offset from which to load + // // uint32_t mhsa_out_offset = (s_tile * s_tile_size + (s_tile_size / num_cores) * compute_id) * e_tile_size; + + // // snrt_dma_txid_t txid_mhsa_out = + // // snrt_dma_start_2d( + // // &l->attn_out[mhsa_out_offset], /* dst */ + // // mhsa_out, /* src */ + // // e_tile_size * sizeof(float), /* size */ + // // l->attn_embeddings * sizeof(float), /* dst_stride */ + // // e_tile_size * sizeof(float), /* src_stride */ + // // s_tile_size); /* repetitions */ + // // snrt_dma_wait_all(); + // // } + + // // } + // // } + + // snrt_cluster_hw_barrier(); + + // ///////////////////////////////////////////////// + // //// MULTI-LAYER PERCEPTRON //// + // ///////////////////////////////////////////////// - // reset the TCDM pointer - tcdm_ptr = (float *)snrt_l1_next(); + // // reset the TCDM pointer + // tcdm_ptr = (float *)snrt_l1_next(); - snrt_global_barrier(); + // snrt_global_barrier(); } \ No newline at end of file