Skip to content

Commit

Permalink
transformer: add concat gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
Viviane Potocnik committed Oct 20, 2023
1 parent f83b1dc commit cbc0d29
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions sw/apps/transformer/src/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -897,19 +897,42 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) {

snrt_dma_wait_all();

for (int i = 0; i < B_c_lin2 * l->positional_embeddings_fa; i++) {
dump_idx(i + weights_offset);
dump_debug(weights_lin2[i]);
// printf("weights[%d] = %f\n", i + weights_offset, weights_lin2[i]);
}
// for (int i = 0; i < B_c_lin2 * l->positional_embeddings_fa; i++) {
// dump_idx(i + weights_offset);
// dump_debug(weights_lin2[i]);
// // printf("weights[%d] = %f\n", i + weights_offset, weights_lin2[i]);
// }
}
uint32_t end_dma = snrt_mcycle();

snrt_cluster_hw_barrier();

if (snrt_is_compute_core()) {
// compute the gemm for the current row block and column block
uint32_t row_offset = (B_r_lin2 / num_cores) * compute_id * l->positional_embeddings_fa;
uint32_t ofmap_offset = (B_r_lin2 / num_cores) * compute_id * B_c_lin2;
uint32_t start_gemm = snrt_mcycle();
gemm_fp64_baseline(B_r_lin2 / num_cores, B_c_lin2, l->positional_embeddings_fa,
&ifmap_lin2[row_offset], l->positional_embeddings_fa, 0,
weights_lin2, l->positional_embeddings_fa, 0,
&ofmap_lin2[ofmap_offset], B_c_lin2, 0.0f);
uint32_t end_gemm = snrt_mcycle();

snrt_cluster_hw_barrier();

for (int i = 0; i < B_r_lin2 / num_cores; i++) {
for (int j = 0; j < B_c_lin2; j++) {
dump_idx(i * B_c_lin2 + j + ofmap_offset);
dump_debug(ofmap_lin2[i * B_c_lin2 + j + ofmap_offset]);
}
}
} else {
snrt_cluster_hw_barrier();
}
}
uint32_t end_loop_inner = snrt_mcycle();
}


uint32_t end_loop_outer = snrt_mcycle();

snrt_cluster_hw_barrier();
}
Expand Down

0 comments on commit cbc0d29

Please sign in to comment.