diff --git a/sw/apps/transformer/src/transformer.h b/sw/apps/transformer/src/transformer.h index ddbf70038..7ae0c8c87 100644 --- a/sw/apps/transformer/src/transformer.h +++ b/sw/apps/transformer/src/transformer.h @@ -897,19 +897,42 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) { snrt_dma_wait_all(); - for (int i = 0; i < B_c_lin2 * l->positional_embeddings_fa; i++) { - dump_idx(i + weights_offset); - dump_debug(weights_lin2[i]); - // printf("weights[%d] = %f\n", i + weights_offset, weights_lin2[i]); - } + // for (int i = 0; i < B_c_lin2 * l->positional_embeddings_fa; i++) { + // dump_idx(i + weights_offset); + // dump_debug(weights_lin2[i]); + // // printf("weights[%d] = %f\n", i + weights_offset, weights_lin2[i]); + // } } uint32_t end_dma = snrt_mcycle(); snrt_cluster_hw_barrier(); + + if (snrt_is_compute_core()) { + // compute the gemm for the current row block and column block + uint32_t row_offset = (B_r_lin2 / num_cores) * compute_id * l->positional_embeddings_fa; + uint32_t ofmap_offset = (B_r_lin2 / num_cores) * compute_id * B_c_lin2; + uint32_t start_gemm = snrt_mcycle(); + gemm_fp64_baseline(B_r_lin2 / num_cores, B_c_lin2, l->positional_embeddings_fa, + &ifmap_lin2[row_offset], l->positional_embeddings_fa, 0, + weights_lin2, l->positional_embeddings_fa, 0, + &ofmap_lin2[ofmap_offset], B_c_lin2, 0.0f); + uint32_t end_gemm = snrt_mcycle(); + + snrt_cluster_hw_barrier(); + + for (int i = 0; i < B_r_lin2 / num_cores; i++) { + for (int j = 0; j < B_c_lin2; j++) { + dump_idx(i * B_c_lin2 + j + ofmap_offset); + dump_debug(ofmap_lin2[i * B_c_lin2 + j + ofmap_offset]); + } + } + } else { + snrt_cluster_hw_barrier(); + } } + uint32_t end_loop_inner = snrt_mcycle(); } - - + uint32_t end_loop_outer = snrt_mcycle(); snrt_cluster_hw_barrier(); }