Skip to content

Commit

Permalink
transformer: finalize concat layer with DRAM writeback of result
Browse files Browse the repository at this point in the history
  • Loading branch information
Viviane Potocnik committed Oct 23, 2023
1 parent 8da7484 commit 0fd7990
Showing 1 changed file with 72 additions and 86 deletions.
158 changes: 72 additions & 86 deletions sw/apps/transformer/src/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,52 +254,6 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) {
// dump_id(num_cores);
// dump_id(num_clusters);

// TODO: below code for debugging purposes only!!
// now we will add the partial results together
// in a logarithmic reduction fashion
uint32_t cl_offset = 0x40000;
uint32_t is_active = 0;
uint32_t is_sender = 0;
// num_levels: number of levels in the reduction tree
int num_levels = ceil(log2(num_clusters));
// stride: distance between two clusters in the reduction tree
for (int level = 0; level < num_levels; level++) {
// determine whether the current cluster is an active cluster
// an active cluster is a cluster that is part of the reduction tree
is_active = (cluster_id % (1 << level)) == 0;
if (is_active == 1) {
// check if the current cluster is a sender or a receiver
if (cluster_id == 0) {
is_sender = 0;
} else {
is_sender = (cluster_id % (1 << (level + 1))) != 0;
}

dump_id(level); // CSR 5
dump_idx(is_sender); // CSR 6
dump_ct(cluster_id); // CSR b

// if the cluster is a sender we perform a DMA transfer
// if (is_sender == 1) {
// // determine the destination address
// double *data_dst = cluster_ofmap_lin2 - (1 << level) * cl_offset;
// if (!snrt_compute_core_idx()) {
// // printf("cluster_id = %d, level = %d, data_src = %d, data_dst = %d\n", cluster_id, level, data_src, data_dst);
// snrt_dma_txid_t txid =
// snrt_dma_start_1d(
// data_dst, /* dst */
// ofmap_lin2, /* src */
// ofmap_tcdm_lin2 * sizeof(double)); /* size */

// snrt_dma_wait_all();
// }
// }

}

}
// TODO: end of debugging code

/////////////////////////////////////////////////////////////////////
////////////// MULTI-HEAD SELF-ATTENTION BLOCK /////////////////
///////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -981,55 +935,87 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) {
} else {
snrt_cluster_hw_barrier();
}
}
} // end of T_c loop
uint32_t end_loop_inner = snrt_mcycle();
}
uint32_t end_loop_outer = snrt_mcycle();
snrt_cluster_hw_barrier();

// now we will add the partial results together
// in a logarithmic reduction fashion
uint32_t cl_offset = 0x40000;
uint32_t is_active = 0;
uint32_t is_sender = 0;
// num_levels: number of levels in the reduction tree
int num_levels = ceil(log2(num_clusters));
// stride: distance between two clusters in the reduction tree
for (int level = 0; level < num_levels; level++) {
// determine whether the current cluster is an active cluster
// an active cluster is a cluster that is part of the reduction tree
is_active = (cluster_id % (1 << level)) == 0;
if (is_active == 1) {
// check if the current cluster is a sender or a receiver
if (cluster_id == 0) {
is_sender = 0;
} else {
is_sender = (cluster_id % (1 << (level + 1))) != 0;
}
snrt_cluster_hw_barrier();
// now we will add the partial results together
// in a logarithmic reduction fashion
uint32_t cl_offset = 0x40000;
uint32_t is_active = 0;
uint32_t is_sender = 0;
// num_levels: number of levels in the reduction tree
int num_levels = ceil(log2(num_clusters));
uint32_t start_reduction = snrt_mcycle();
for (int level = 0; level < num_levels; level++) {
// determine whether the current cluster is an active cluster
// an active cluster is a cluster that is part of the reduction tree
is_active = (cluster_id % (1 << level)) == 0;
if (is_active == 1) {
// check if the current cluster is a sender or a receiver
if (cluster_id == 0) {
is_sender = 0;
} else {
is_sender = (cluster_id % (1 << (level + 1))) != 0;
}

// if the cluster is a sender we perform a DMA transfer
if (is_sender == 1) {
// determine the destination address
double *data_dst = cluster_ofmap_lin2 - (1 << level) * cl_offset;
if (!snrt_compute_core_idx()) {
// printf("cluster_id = %d, level = %d, data_src = %d, data_dst = %d\n", cluster_id, level, data_src, data_dst);
snrt_dma_txid_t txid =
snrt_dma_start_1d(
data_dst, /* dst */
ofmap_lin2, /* src */
ofmap_tcdm_lin2 * sizeof(double)); /* size */

snrt_dma_wait_all();
// if the cluster is a sender we perform a DMA transfer
if (is_sender == 1) {
// determine the destination address
double *data_dst = cluster_ofmap_lin2 - (1 << level) * cl_offset;
if (!snrt_compute_core_idx()) {
// printf("cluster_id = %d, level = %d, data_src = %d, data_dst = %d\n", cluster_id, level, data_src, data_dst);
snrt_dma_txid_t txid =
snrt_dma_start_1d(
data_dst, /* dst */
ofmap_lin2, /* src */
ofmap_tcdm_lin2 * sizeof(double)); /* size */

snrt_dma_wait_all();
}
}
}

}
snrt_cluster_hw_barrier();

}
// active clusters that are not a sender perform the addition of the partial tiles
if (is_active == 1 && is_sender == 0) {
// perform the addition
uint32_t row_offset = core_id * (B_r_lin2 / num_cores) * B_c_lin2;
for (int row = 0; row < B_r_lin2 / num_cores; row++) {
for (int col = 0; col < B_c_lin2; col++) {
ofmap_lin2[row_offset + row * B_c_lin2 + col] += cluster_ofmap_lin2[row_offset + row * B_c_lin2 + col];
}
}
}

}
snrt_cluster_hw_barrier();

}

}
uint32_t end_reduction = snrt_mcycle();

// write back O_lin2 as the i-th block of the output matrix
uint32_t start_dma_write_back = snrt_mcycle();
if (!snrt_is_compute_core()) {
uint32_t o_lin2_offset = t_r * B_r_lin2 * l->embeddings_lin2;
// printf("o_lin2_offset = %d\n", o_lin2_offset);
snrt_dma_txid_t txid_o_lin2 =
snrt_dma_start_2d(
l->O_lin2 + o_lin2_offset, /* dst */
ofmap_lin2, /* src */
l->embeddings_lin2 * sizeof(double), /* size */
l->embeddings_lin2 * sizeof(double), /* dst_stride */
B_c_lin2 * sizeof(double), /* src_stride */
B_r_lin2); /* repetitions */

snrt_dma_wait_all();
}
uint32_t end_dma_write_back = snrt_mcycle();
} // end of T_r loop
uint32_t end_loop_outer = snrt_mcycle();
snrt_cluster_hw_barrier();
} // end of CONCAT
}

/**
Expand Down

0 comments on commit 0fd7990

Please sign in to comment.