diff --git a/sw/apps/transformer/src/transformer.h b/sw/apps/transformer/src/transformer.h index 44722eee9..ce0abc5ad 100644 --- a/sw/apps/transformer/src/transformer.h +++ b/sw/apps/transformer/src/transformer.h @@ -925,12 +925,12 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) { snrt_cluster_hw_barrier(); - for (int i = 0; i < B_r_lin2 / num_cores; i++) { - for (int j = 0; j < B_c_lin2; j++) { - dump_idx(i * B_c_lin2 + j + ofmap_offset); - dump_debug(ofmap_lin2[i * B_c_lin2 + j + ofmap_offset]); - } - } + // for (int i = 0; i < B_r_lin2 / num_cores; i++) { + // for (int j = 0; j < B_c_lin2; j++) { + // dump_idx(i * B_c_lin2 + j + ofmap_offset); + // dump_debug(ofmap_lin2[i * B_c_lin2 + j + ofmap_offset]); + // } + // } } else { snrt_cluster_hw_barrier(); } @@ -938,8 +938,17 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) { uint32_t end_loop_inner = snrt_mcycle(); } uint32_t end_loop_outer = snrt_mcycle(); - snrt_cluster_hw_barrier(); + + // now we will add the partial results together + // in a logarithmic reduction fashion + + float reduction_depth = log2(num_heads); + dump_debug(reduction_depth); + // round to the next integer (ceiling) + uint32_t reduction_depth_int = (uint32_t)ceil(reduction_depth); + dump_id(reduction_depth_int); + }