diff --git a/sw/apps/transformer/src/transformer.h b/sw/apps/transformer/src/transformer.h index ff84e7615..0bb06ec3e 100644 --- a/sw/apps/transformer/src/transformer.h +++ b/sw/apps/transformer/src/transformer.h @@ -868,11 +868,13 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) { snrt_dma_wait_all(); - // for (int i = 0; i < B_r_lin2 * l->positional_embeddings_fa; i++) { - // dump_idx(i + ifmap_offset); - // dump_debug(ifmap_lin2[i]); - // // printf("ifmap[%d] = %f\n", i + ifmap_offset, ifmap_lin2[i]); - // } + if (cluster_id == 1) { + for (int i = 0; i < B_r_lin2 * l->positional_embeddings_fa; i++) { + dump_idx(i + ifmap_offset); + dump_debug(ifmap_lin2[i]); + // printf("ifmap[%d] = %f\n", i + ifmap_offset, ifmap_lin2[i]); + } + } } uint32_t end_dma = snrt_mcycle(); @@ -914,22 +916,24 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) { uint32_t row_offset = (B_r_lin2 / num_cores) * compute_id * l->positional_embeddings_fa; uint32_t ofmap_offset = (B_r_lin2 / num_cores) * compute_id * B_c_lin2; uint32_t start_gemm = snrt_mcycle(); - gemm_fp64_baseline(B_r_lin2 / num_cores, B_c_lin2, l->positional_embeddings_fa, - &ifmap_lin2[row_offset], l->positional_embeddings_fa, 0, - weights_lin2, l->positional_embeddings_fa, 0, - &ofmap_lin2[ofmap_offset], B_c_lin2, 0.0f); + if (cluster_id == 1) { + gemm_fp64_baseline(B_r_lin2 / num_cores, B_c_lin2, l->positional_embeddings_fa, + &ifmap_lin2[row_offset], l->positional_embeddings_fa, 0, + weights_lin2, l->positional_embeddings_fa, 0, + &ofmap_lin2[ofmap_offset], B_c_lin2, 0.0f); + } uint32_t end_gemm = snrt_mcycle(); snrt_cluster_hw_barrier(); - if (cluster_id == 0) { - for (int i = 0; i < B_r_lin2 / num_cores; i++) { - for (int j = 0; j < B_c_lin2; j++) { - dump_idx(i * B_c_lin2 + j + ofmap_offset); - dump_debug(ofmap_lin2[i * B_c_lin2 + j + ofmap_offset]); - } - } - } + // if (cluster_id == 0) { + // for (int i = 0; i < B_r_lin2 / num_cores; i++) { + // for (int j = 0; j < B_c_lin2; j++) { + // dump_idx(i * B_c_lin2 + j + ofmap_offset); + // dump_debug(ofmap_lin2[i * B_c_lin2 + j + ofmap_offset]); + // } + // } + // } } else { snrt_cluster_hw_barrier(); } diff --git a/sw/blas/gemm/src/gemm.h b/sw/blas/gemm/src/gemm.h index baab57047..aedd981bb 100644 --- a/sw/blas/gemm/src/gemm.h +++ b/sw/blas/gemm/src/gemm.h @@ -34,8 +34,8 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A, double c0 = BETA * C[m * ldC + n]; for (uint32_t k = 0; k < K; k++) { // dump_index(k + m * ldA); - // dump_gemm(A[k + m * ldA]); - c0 += A[k + m * ldA] * B[k * ldB + n]; + dump_gemm(A[k + m * ldA]); + // c0 += A[k + m * ldA] * B[k * ldB + n]; } C[m * ldC + n] = c0; }