From 3b0a39cf09b15860f81065d299a3ab3443d919f7 Mon Sep 17 00:00:00 2001 From: Luca Colagrande Date: Mon, 13 Nov 2023 12:12:08 +0100 Subject: [PATCH] dnn: Add optimized FusedConcatLinear layer --- .../src/fused_concat_linear.h | 227 +++++++++++++++++- 1 file changed, 223 insertions(+), 4 deletions(-) diff --git a/sw/dnn/fused_concat_linear/src/fused_concat_linear.h b/sw/dnn/fused_concat_linear/src/fused_concat_linear.h index 09a7a9b7bb..57f485a960 100644 --- a/sw/dnn/fused_concat_linear/src/fused_concat_linear.h +++ b/sw/dnn/fused_concat_linear/src/fused_concat_linear.h @@ -30,24 +30,243 @@ typedef struct { precision_t dtype; } fused_concat_linear_layer_t; -static inline int fused_concat_linear_layer(fused_concat_linear_layer_t l) { +static inline int fused_concat_linear_baseline(fused_concat_linear_layer_t l) { // Concat layer concat_layer_t concat_layer_cfg = { .num_inputs = l.num_inputs, .input_shape = {l.input_shape[0], l.input_shape[1]}, .inputs = l.inputs, .output = l.concat_output, - .dtype = l.dtype - }; + .dtype = l.dtype}; int nerr = concat_layer(concat_layer_cfg); // Linear layer uint32_t m = l.input_shape[0]; uint32_t k = l.input_shape[1] * l.num_inputs; uint32_t n = l.output_shape[1]; - gemm(l.dtype, 0, 0, 0, 0, m, n, k, 1.0, l.concat_output, k, l.weights, n, 0.0, l.linear_output, n); + gemm(l.dtype, 0, 0, 1, 1, 1, snrt_cluster_num(), 0, 0, m, n, k, 1.0, + l.concat_output, l.weights, 0.0, l.linear_output); snrt_global_barrier(); return nerr; } + +// static inline int fused_concat_linear_optimized(fused_concat_linear_layer_t +// l) { +// uint32_t cluster_id = snrt_cluster_idx(); +// uint32_t core_id = snrt_cluster_core_idx(); +// uint32_t num_clusters = snrt_cluster_num(); +// uint32_t compute_id = snrt_global_core_idx(); +// uint32_t cluster_core_id = snrt_cluster_core_idx(); +// uint32_t num_cores = snrt_cluster_compute_core_num(); +// uint32_t num_heads = l->heads; + +// // compute the tiling parameters +// uint32_t B_r_lin2 = l->Br_tile_lin2; // number of rows per row block +// uint32_t B_c_lin2 = l->Bc_tile_lin2; // number of columns per column +// block uint32_t T_r_lin2 = l->seq_len / B_r_lin2; // number of row blocks +// uint32_t T_c_lin2 = l->embeddings_lin2 / B_c_lin2; // number of column +// blocks + +// // every cluster computes the linear layer on a different input tensor +// // and afterward we add the partial results together + +// // compute the size of the matrices +// uint32_t ifmap_tcdm_lin2 = B_r_lin2 * l->positional_embeddings * +// sizeof(double); uint32_t weights_tcdm_lin2 = l->positional_embeddings * +// B_c_lin2 * sizeof(double); uint32_t ofmap_tcdm_lin2 = B_r_lin2 * B_c_lin2 +// * sizeof(double); +// // we use below variable for summing up the partial results +// uint32_t cluster_ofmap_tcdm_lin2 = ofmap_tcdm_lin2; + +// // here we define the matrices that will be stored back in DRAM +// uint32_t O_lin2_size = l->seq_len * l->embeddings_lin2 * sizeof(double); + +// // allocate memory in TCDM +// void *tcdm_ptr = (double *)snrt_l1_next(); +// double *ifmap_lin2 = tcdm_ptr; +// tcdm_ptr += ifmap_tcdm_lin2; +// double *weights_lin2 = tcdm_ptr; +// tcdm_ptr += weights_tcdm_lin2; +// double *ofmap_lin2 = tcdm_ptr; +// tcdm_ptr += ofmap_tcdm_lin2; +// double *cluster_ofmap_lin2 = tcdm_ptr; +// tcdm_ptr += cluster_ofmap_tcdm_lin2; + +// // double used_memory_kB = (double)((uint64_t)tcdm_ptr - +// (uint64_t)snrt_l1_next()) / 1024.0f; +// // dump_debug(used_memory_kB); + +// // determine the column offset of the ifmap for the current cluster +// uint32_t cluster_ifmap_offset = cluster_id * l->seq_len * +// l->positional_embeddings_fa; +// // if (core_id == 0) { +// // dump_idx(cluster_id); +// // dump_idx(cluster_ifmap_offset); +// // } +// // determine the column offset of the weights for the current cluster +// uint32_t cluster_weights_offset = cluster_id * +// l->positional_embeddings_fa * l->embeddings_lin2; + +// uint32_t start_loop_outer = snrt_mcycle(); +// for (int t_r = 0; t_r < T_r_lin2; t_r++) { +// uint32_t start_dma = snrt_mcycle(); +// // Load the ifmap tile +// if (!snrt_is_compute_core()) { +// uint32_t ifmap_offset = t_r * B_r_lin2 * +// l->positional_embeddings_fa + cluster_ifmap_offset; +// snrt_dma_txid_t txid_ifmap = +// snrt_dma_start_2d( +// ifmap_lin2, /* dst +// */ l->ifmap_lin2 + ifmap_offset, /* +// src */ l->positional_embeddings_fa * sizeof(double), /* +// size */ l->positional_embeddings_fa * sizeof(double), /* +// dst_stride */ l->positional_embeddings_fa * +// sizeof(double), /* src_stride */ B_r_lin2); /* +// repetitions */ + +// snrt_dma_wait_all(); +// } +// uint32_t end_dma = snrt_mcycle(); + +// snrt_cluster_hw_barrier(); + +// uint32_t start_loop_inner = snrt_mcycle(); +// for (int t_c = 0; t_c < T_c_lin2; t_c++) { +// // weights: P x B_c +// uint32_t weights_offset = t_c * B_c_lin2 * +// l->positional_embeddings_fa + cluster_weights_offset; uint32_t +// start_dma = snrt_mcycle(); if (!snrt_is_compute_core()) { +// // load the weights tile +// snrt_dma_txid_t txid_weights = +// snrt_dma_start_2d( +// weights_lin2, /* +// dst */ l->weights_lin2 + weights_offset, /* src */ +// B_c_lin2 * sizeof(double), /* +// size */ B_c_lin2 * sizeof(double), /* dst_stride */ +// l->positional_embeddings_fa * sizeof(double), /* +// src_stride */ l->positional_embeddings_fa); /* +// repetitions */ + +// snrt_dma_wait_all(); +// } +// uint32_t end_dma = snrt_mcycle(); + +// snrt_cluster_hw_barrier(); + +// if (snrt_is_compute_core()) { +// // compute the gemm for the current row block and column +// block uint32_t row_offset = (B_r_lin2 / num_cores) * +// cluster_core_id * l->positional_embeddings_fa; +// // dump_id(cluster_core_id); +// // dump_ct(row_offset); +// uint32_t ofmap_offset = (B_r_lin2 / num_cores) * +// cluster_core_id * B_c_lin2; uint32_t start_gemm = +// snrt_mcycle(); gemm_fp64_baseline(B_r_lin2 / num_cores, +// B_c_lin2, l->positional_embeddings_fa, +// &ifmap_lin2[row_offset], +// l->positional_embeddings_fa, 0, +// weights_lin2, +// l->positional_embeddings_fa, 0, +// &ofmap_lin2[ofmap_offset], B_c_lin2, +// 0.0f); +// uint32_t end_gemm = snrt_mcycle(); + +// snrt_cluster_hw_barrier(); +// } else { +// snrt_cluster_hw_barrier(); +// } +// } // end of T_c loop +// uint32_t end_loop_inner = snrt_mcycle(); + +// snrt_cluster_hw_barrier(); +// // now we will add the partial results together +// // in a logarithmic reduction fashion +// uint32_t cl_offset = 0x40000; +// uint32_t is_active = 0; +// uint32_t is_sender = 0; +// // num_levels: number of levels in the reduction tree +// int num_levels = ceil(log2(num_clusters)); +// uint32_t start_reduction = snrt_mcycle(); +// for (int level = 0; level < num_levels; level++) { +// // determine whether the current cluster is an active cluster +// // an active cluster is a cluster that is part of the reduction +// tree is_active = (cluster_id % (1 << level)) == 0; if (is_active +// == 1) { +// // check if the current cluster is a sender or a receiver +// if (cluster_id == 0) { +// is_sender = 0; +// } else { +// is_sender = (cluster_id % (1 << (level + 1))) != 0; +// } + +// // if the cluster is a sender we perform a DMA transfer +// if (is_sender == 1) { +// // determine the destination address +// double *data_dst = cluster_ofmap_lin2 - (1 << level) * +// cl_offset; if (!snrt_is_compute_core()) { +// for (int i = 0; i < B_r_lin2 * B_c_lin2; i++) { +// dump_debug(ofmap_lin2[i]); +// } +// // printf("cluster_id = %d, level = %d, data_src = +// %d, data_dst = %d\n", cluster_id, level, data_src, +// data_dst); snrt_dma_txid_t txid = +// snrt_dma_start_1d( +// data_dst, /* dst */ +// ofmap_lin2, /* src */ +// ofmap_tcdm_lin2 * sizeof(double)); /* size */ + +// snrt_dma_wait_all(); +// } +// } + +// snrt_cluster_hw_barrier(); + +// // active clusters that are not a sender perform the addition +// of the partial tiles +// // if (is_active == 1 && is_sender == 0) { +// // // perform the addition +// // uint32_t row_offset = core_id * (B_r_lin2 / num_cores) +// * B_c_lin2; +// // for (int row = 0; row < B_r_lin2 / num_cores; row++) { +// // for (int col = 0; col < B_c_lin2; col++) { +// // ofmap_lin2[row_offset + row * B_c_lin2 + col] +// += cluster_ofmap_lin2[row_offset + row * B_c_lin2 + col]; +// // } +// // } +// // } + +// // snrt_cluster_hw_barrier(); + +// } + +// } +// uint32_t end_reduction = snrt_mcycle(); + +// // write back O_lin2 as the i-th block of the output matrix +// uint32_t start_dma_write_back = snrt_mcycle(); +// if (!snrt_is_compute_core()) { +// uint32_t o_lin2_offset = t_r * B_r_lin2 * l->embeddings_lin2; +// // printf("o_lin2_offset = %d\n", o_lin2_offset); +// snrt_dma_txid_t txid_o_lin2 = +// snrt_dma_start_2d( +// l->O_lin2 + o_lin2_offset, /* dst */ +// ofmap_lin2, /* src */ +// l->embeddings_lin2 * sizeof(double), /* size +// */ l->embeddings_lin2 * sizeof(double), /* +// dst_stride */ B_c_lin2 * sizeof(double), /* src_stride */ +// B_r_lin2); /* +// repetitions */ + +// snrt_dma_wait_all(); +// } +// uint32_t end_dma_write_back = snrt_mcycle(); +// } // end of T_r loop +// uint32_t end_loop_outer = snrt_mcycle(); +// snrt_cluster_hw_barrier(); +// } + +static inline int fused_concat_linear_layer(fused_concat_linear_layer_t l) { + fused_concat_linear_baseline(l); +}