From 0fd7990f2ba5af6a8b7c7915bade1a9fa38df42c Mon Sep 17 00:00:00 2001 From: Viviane Potocnik Date: Mon, 23 Oct 2023 20:34:06 +0200 Subject: [PATCH] transformer: finalize concat layer with DRAM writeback of result --- sw/apps/transformer/src/transformer.h | 158 ++++++++++++-------------- 1 file changed, 72 insertions(+), 86 deletions(-) diff --git a/sw/apps/transformer/src/transformer.h b/sw/apps/transformer/src/transformer.h index eb081e492..0d5311a66 100644 --- a/sw/apps/transformer/src/transformer.h +++ b/sw/apps/transformer/src/transformer.h @@ -254,52 +254,6 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) { // dump_id(num_cores); // dump_id(num_clusters); - // TODO: below code for debugging purposes only!! - // now we will add the partial results together - // in a logarithmic reduction fashion - uint32_t cl_offset = 0x40000; - uint32_t is_active = 0; - uint32_t is_sender = 0; - // num_levels: number of levels in the reduction tree - int num_levels = ceil(log2(num_clusters)); - // stride: distance between two clusters in the reduction tree - for (int level = 0; level < num_levels; level++) { - // determine whether the current cluster is an active cluster - // an active cluster is a cluster that is part of the reduction tree - is_active = (cluster_id % (1 << level)) == 0; - if (is_active == 1) { - // check if the current cluster is a sender or a receiver - if (cluster_id == 0) { - is_sender = 0; - } else { - is_sender = (cluster_id % (1 << (level + 1))) != 0; - } - - dump_id(level); // CSR 5 - dump_idx(is_sender); // CSR 6 - dump_ct(cluster_id); // CSR b - - // if the cluster is a sender we perform a DMA transfer - // if (is_sender == 1) { - // // determine the destination address - // double *data_dst = cluster_ofmap_lin2 - (1 << level) * cl_offset; - // if (!snrt_compute_core_idx()) { - // // printf("cluster_id = %d, level = %d, data_src = %d, data_dst = %d\n", cluster_id, level, data_src, data_dst); - // snrt_dma_txid_t txid = - // snrt_dma_start_1d( - // data_dst, /* dst */ - // ofmap_lin2, /* src */ - // ofmap_tcdm_lin2 * sizeof(double)); /* size */ - - // snrt_dma_wait_all(); - // } - // } - - } - - } - // TODO: end of debugging code - ///////////////////////////////////////////////////////////////////// ////////////// MULTI-HEAD SELF-ATTENTION BLOCK ///////////////// /////////////////////////////////////////////////////////////////// @@ -981,55 +935,87 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) { } else { snrt_cluster_hw_barrier(); } - } + } // end of T_c loop uint32_t end_loop_inner = snrt_mcycle(); - } - uint32_t end_loop_outer = snrt_mcycle(); - snrt_cluster_hw_barrier(); - // now we will add the partial results together - // in a logarithmic reduction fashion - uint32_t cl_offset = 0x40000; - uint32_t is_active = 0; - uint32_t is_sender = 0; - // num_levels: number of levels in the reduction tree - int num_levels = ceil(log2(num_clusters)); - // stride: distance between two clusters in the reduction tree - for (int level = 0; level < num_levels; level++) { - // determine whether the current cluster is an active cluster - // an active cluster is a cluster that is part of the reduction tree - is_active = (cluster_id % (1 << level)) == 0; - if (is_active == 1) { - // check if the current cluster is a sender or a receiver - if (cluster_id == 0) { - is_sender = 0; - } else { - is_sender = (cluster_id % (1 << (level + 1))) != 0; - } + snrt_cluster_hw_barrier(); + // now we will add the partial results together + // in a logarithmic reduction fashion + uint32_t cl_offset = 0x40000; + uint32_t is_active = 0; + uint32_t is_sender = 0; + // num_levels: number of levels in the reduction tree + int num_levels = ceil(log2(num_clusters)); + uint32_t start_reduction = snrt_mcycle(); + for (int level = 0; level < num_levels; level++) { + // determine whether the current cluster is an active cluster + // an active cluster is a cluster that is part of the reduction tree + is_active = (cluster_id % (1 << level)) == 0; + if (is_active == 1) { + // check if the current cluster is a sender or a receiver + if (cluster_id == 0) { + is_sender = 0; + } else { + is_sender = (cluster_id % (1 << (level + 1))) != 0; + } - // if the cluster is a sender we perform a DMA transfer - if (is_sender == 1) { - // determine the destination address - double *data_dst = cluster_ofmap_lin2 - (1 << level) * cl_offset; - if (!snrt_compute_core_idx()) { - // printf("cluster_id = %d, level = %d, data_src = %d, data_dst = %d\n", cluster_id, level, data_src, data_dst); - snrt_dma_txid_t txid = - snrt_dma_start_1d( - data_dst, /* dst */ - ofmap_lin2, /* src */ - ofmap_tcdm_lin2 * sizeof(double)); /* size */ - - snrt_dma_wait_all(); + // if the cluster is a sender we perform a DMA transfer + if (is_sender == 1) { + // determine the destination address + double *data_dst = cluster_ofmap_lin2 - (1 << level) * cl_offset; + if (!snrt_compute_core_idx()) { + // printf("cluster_id = %d, level = %d, data_src = %d, data_dst = %d\n", cluster_id, level, data_src, data_dst); + snrt_dma_txid_t txid = + snrt_dma_start_1d( + data_dst, /* dst */ + ofmap_lin2, /* src */ + ofmap_tcdm_lin2 * sizeof(double)); /* size */ + + snrt_dma_wait_all(); + } } - } - } + snrt_cluster_hw_barrier(); - } + // active clusters that are not a sender perform the addition of the partial tiles + if (is_active == 1 && is_sender == 0) { + // perform the addition + uint32_t row_offset = core_id * (B_r_lin2 / num_cores) * B_c_lin2; + for (int row = 0; row < B_r_lin2 / num_cores; row++) { + for (int col = 0; col < B_c_lin2; col++) { + ofmap_lin2[row_offset + row * B_c_lin2 + col] += cluster_ofmap_lin2[row_offset + row * B_c_lin2 + col]; + } + } + } - } + snrt_cluster_hw_barrier(); + } + + } + uint32_t end_reduction = snrt_mcycle(); + // write back O_lin2 as the i-th block of the output matrix + uint32_t start_dma_write_back = snrt_mcycle(); + if (!snrt_is_compute_core()) { + uint32_t o_lin2_offset = t_r * B_r_lin2 * l->embeddings_lin2; + // printf("o_lin2_offset = %d\n", o_lin2_offset); + snrt_dma_txid_t txid_o_lin2 = + snrt_dma_start_2d( + l->O_lin2 + o_lin2_offset, /* dst */ + ofmap_lin2, /* src */ + l->embeddings_lin2 * sizeof(double), /* size */ + l->embeddings_lin2 * sizeof(double), /* dst_stride */ + B_c_lin2 * sizeof(double), /* src_stride */ + B_r_lin2); /* repetitions */ + + snrt_dma_wait_all(); + } + uint32_t end_dma_write_back = snrt_mcycle(); + } // end of T_r loop + uint32_t end_loop_outer = snrt_mcycle(); + snrt_cluster_hw_barrier(); + } // end of CONCAT } /**