Skip to content

Commit

Permalink
dnn: Add optimized FusedConcatLinear layer
Browse files Browse the repository at this point in the history
  • Loading branch information
colluca committed Nov 13, 2023
1 parent a17d2cb commit 3b0a39c
Showing 1 changed file with 223 additions and 4 deletions.
227 changes: 223 additions & 4 deletions sw/dnn/fused_concat_linear/src/fused_concat_linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,243 @@ typedef struct {
precision_t dtype;
} fused_concat_linear_layer_t;

static inline int fused_concat_linear_layer(fused_concat_linear_layer_t l) {
static inline int fused_concat_linear_baseline(fused_concat_linear_layer_t l) {
// Concat layer
concat_layer_t concat_layer_cfg = {
.num_inputs = l.num_inputs,
.input_shape = {l.input_shape[0], l.input_shape[1]},
.inputs = l.inputs,
.output = l.concat_output,
.dtype = l.dtype
};
.dtype = l.dtype};
int nerr = concat_layer(concat_layer_cfg);

// Linear layer
uint32_t m = l.input_shape[0];
uint32_t k = l.input_shape[1] * l.num_inputs;
uint32_t n = l.output_shape[1];
gemm(l.dtype, 0, 0, 0, 0, m, n, k, 1.0, l.concat_output, k, l.weights, n, 0.0, l.linear_output, n);
gemm(l.dtype, 0, 0, 1, 1, 1, snrt_cluster_num(), 0, 0, m, n, k, 1.0,
l.concat_output, l.weights, 0.0, l.linear_output);

snrt_global_barrier();

return nerr;
}

// static inline int fused_concat_linear_optimized(fused_concat_linear_layer_t
// l) {
// uint32_t cluster_id = snrt_cluster_idx();
// uint32_t core_id = snrt_cluster_core_idx();
// uint32_t num_clusters = snrt_cluster_num();
// uint32_t compute_id = snrt_global_core_idx();
// uint32_t cluster_core_id = snrt_cluster_core_idx();
// uint32_t num_cores = snrt_cluster_compute_core_num();
// uint32_t num_heads = l->heads;

// // compute the tiling parameters
// uint32_t B_r_lin2 = l->Br_tile_lin2; // number of rows per row block
// uint32_t B_c_lin2 = l->Bc_tile_lin2; // number of columns per column
// block uint32_t T_r_lin2 = l->seq_len / B_r_lin2; // number of row blocks
// uint32_t T_c_lin2 = l->embeddings_lin2 / B_c_lin2; // number of column
// blocks

// // every cluster computes the linear layer on a different input tensor
// // and afterward we add the partial results together

// // compute the size of the matrices
// uint32_t ifmap_tcdm_lin2 = B_r_lin2 * l->positional_embeddings *
// sizeof(double); uint32_t weights_tcdm_lin2 = l->positional_embeddings *
// B_c_lin2 * sizeof(double); uint32_t ofmap_tcdm_lin2 = B_r_lin2 * B_c_lin2
// * sizeof(double);
// // we use below variable for summing up the partial results
// uint32_t cluster_ofmap_tcdm_lin2 = ofmap_tcdm_lin2;

// // here we define the matrices that will be stored back in DRAM
// uint32_t O_lin2_size = l->seq_len * l->embeddings_lin2 * sizeof(double);

// // allocate memory in TCDM
// void *tcdm_ptr = (double *)snrt_l1_next();
// double *ifmap_lin2 = tcdm_ptr;
// tcdm_ptr += ifmap_tcdm_lin2;
// double *weights_lin2 = tcdm_ptr;
// tcdm_ptr += weights_tcdm_lin2;
// double *ofmap_lin2 = tcdm_ptr;
// tcdm_ptr += ofmap_tcdm_lin2;
// double *cluster_ofmap_lin2 = tcdm_ptr;
// tcdm_ptr += cluster_ofmap_tcdm_lin2;

// // double used_memory_kB = (double)((uint64_t)tcdm_ptr -
// (uint64_t)snrt_l1_next()) / 1024.0f;
// // dump_debug(used_memory_kB);

// // determine the column offset of the ifmap for the current cluster
// uint32_t cluster_ifmap_offset = cluster_id * l->seq_len *
// l->positional_embeddings_fa;
// // if (core_id == 0) {
// // dump_idx(cluster_id);
// // dump_idx(cluster_ifmap_offset);
// // }
// // determine the column offset of the weights for the current cluster
// uint32_t cluster_weights_offset = cluster_id *
// l->positional_embeddings_fa * l->embeddings_lin2;

// uint32_t start_loop_outer = snrt_mcycle();
// for (int t_r = 0; t_r < T_r_lin2; t_r++) {
// uint32_t start_dma = snrt_mcycle();
// // Load the ifmap tile
// if (!snrt_is_compute_core()) {
// uint32_t ifmap_offset = t_r * B_r_lin2 *
// l->positional_embeddings_fa + cluster_ifmap_offset;
// snrt_dma_txid_t txid_ifmap =
// snrt_dma_start_2d(
// ifmap_lin2, /* dst
// */ l->ifmap_lin2 + ifmap_offset, /*
// src */ l->positional_embeddings_fa * sizeof(double), /*
// size */ l->positional_embeddings_fa * sizeof(double), /*
// dst_stride */ l->positional_embeddings_fa *
// sizeof(double), /* src_stride */ B_r_lin2); /*
// repetitions */

// snrt_dma_wait_all();
// }
// uint32_t end_dma = snrt_mcycle();

// snrt_cluster_hw_barrier();

// uint32_t start_loop_inner = snrt_mcycle();
// for (int t_c = 0; t_c < T_c_lin2; t_c++) {
// // weights: P x B_c
// uint32_t weights_offset = t_c * B_c_lin2 *
// l->positional_embeddings_fa + cluster_weights_offset; uint32_t
// start_dma = snrt_mcycle(); if (!snrt_is_compute_core()) {
// // load the weights tile
// snrt_dma_txid_t txid_weights =
// snrt_dma_start_2d(
// weights_lin2, /*
// dst */ l->weights_lin2 + weights_offset, /* src */
// B_c_lin2 * sizeof(double), /*
// size */ B_c_lin2 * sizeof(double), /* dst_stride */
// l->positional_embeddings_fa * sizeof(double), /*
// src_stride */ l->positional_embeddings_fa); /*
// repetitions */

// snrt_dma_wait_all();
// }
// uint32_t end_dma = snrt_mcycle();

// snrt_cluster_hw_barrier();

// if (snrt_is_compute_core()) {
// // compute the gemm for the current row block and column
// block uint32_t row_offset = (B_r_lin2 / num_cores) *
// cluster_core_id * l->positional_embeddings_fa;
// // dump_id(cluster_core_id);
// // dump_ct(row_offset);
// uint32_t ofmap_offset = (B_r_lin2 / num_cores) *
// cluster_core_id * B_c_lin2; uint32_t start_gemm =
// snrt_mcycle(); gemm_fp64_baseline(B_r_lin2 / num_cores,
// B_c_lin2, l->positional_embeddings_fa,
// &ifmap_lin2[row_offset],
// l->positional_embeddings_fa, 0,
// weights_lin2,
// l->positional_embeddings_fa, 0,
// &ofmap_lin2[ofmap_offset], B_c_lin2,
// 0.0f);
// uint32_t end_gemm = snrt_mcycle();

// snrt_cluster_hw_barrier();
// } else {
// snrt_cluster_hw_barrier();
// }
// } // end of T_c loop
// uint32_t end_loop_inner = snrt_mcycle();

// snrt_cluster_hw_barrier();
// // now we will add the partial results together
// // in a logarithmic reduction fashion
// uint32_t cl_offset = 0x40000;
// uint32_t is_active = 0;
// uint32_t is_sender = 0;
// // num_levels: number of levels in the reduction tree
// int num_levels = ceil(log2(num_clusters));
// uint32_t start_reduction = snrt_mcycle();
// for (int level = 0; level < num_levels; level++) {
// // determine whether the current cluster is an active cluster
// // an active cluster is a cluster that is part of the reduction
// tree is_active = (cluster_id % (1 << level)) == 0; if (is_active
// == 1) {
// // check if the current cluster is a sender or a receiver
// if (cluster_id == 0) {
// is_sender = 0;
// } else {
// is_sender = (cluster_id % (1 << (level + 1))) != 0;
// }

// // if the cluster is a sender we perform a DMA transfer
// if (is_sender == 1) {
// // determine the destination address
// double *data_dst = cluster_ofmap_lin2 - (1 << level) *
// cl_offset; if (!snrt_is_compute_core()) {
// for (int i = 0; i < B_r_lin2 * B_c_lin2; i++) {
// dump_debug(ofmap_lin2[i]);
// }
// // printf("cluster_id = %d, level = %d, data_src =
// %d, data_dst = %d\n", cluster_id, level, data_src,
// data_dst); snrt_dma_txid_t txid =
// snrt_dma_start_1d(
// data_dst, /* dst */
// ofmap_lin2, /* src */
// ofmap_tcdm_lin2 * sizeof(double)); /* size */

// snrt_dma_wait_all();
// }
// }

// snrt_cluster_hw_barrier();

// // active clusters that are not a sender perform the addition
// of the partial tiles
// // if (is_active == 1 && is_sender == 0) {
// // // perform the addition
// // uint32_t row_offset = core_id * (B_r_lin2 / num_cores)
// * B_c_lin2;
// // for (int row = 0; row < B_r_lin2 / num_cores; row++) {
// // for (int col = 0; col < B_c_lin2; col++) {
// // ofmap_lin2[row_offset + row * B_c_lin2 + col]
// += cluster_ofmap_lin2[row_offset + row * B_c_lin2 + col];
// // }
// // }
// // }

// // snrt_cluster_hw_barrier();

// }

// }
// uint32_t end_reduction = snrt_mcycle();

// // write back O_lin2 as the i-th block of the output matrix
// uint32_t start_dma_write_back = snrt_mcycle();
// if (!snrt_is_compute_core()) {
// uint32_t o_lin2_offset = t_r * B_r_lin2 * l->embeddings_lin2;
// // printf("o_lin2_offset = %d\n", o_lin2_offset);
// snrt_dma_txid_t txid_o_lin2 =
// snrt_dma_start_2d(
// l->O_lin2 + o_lin2_offset, /* dst */
// ofmap_lin2, /* src */
// l->embeddings_lin2 * sizeof(double), /* size
// */ l->embeddings_lin2 * sizeof(double), /*
// dst_stride */ B_c_lin2 * sizeof(double), /* src_stride */
// B_r_lin2); /*
// repetitions */

// snrt_dma_wait_all();
// }
// uint32_t end_dma_write_back = snrt_mcycle();
// } // end of T_r loop
// uint32_t end_loop_outer = snrt_mcycle();
// snrt_cluster_hw_barrier();
// }

static inline int fused_concat_linear_layer(fused_concat_linear_layer_t l) {
fused_concat_linear_baseline(l);
}

0 comments on commit 3b0a39c

Please sign in to comment.