Skip to content

Commit

Permalink
transformer: switch to cluster core id due to multi cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
Viviane Potocnik committed Oct 20, 2023
1 parent 84a39a0 commit 17dfdc9
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions sw/apps/transformer/src/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,7 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) {
uint32_t core_id = snrt_cluster_core_idx();
uint32_t num_clusters = snrt_cluster_num();
uint32_t compute_id = snrt_global_core_idx();
uint32_t cluster_core_id = snrt_cluster_core_idx();
uint32_t num_cores = snrt_cluster_compute_core_num();
uint32_t num_heads = l->heads;

Expand Down Expand Up @@ -911,10 +912,10 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) {

if (snrt_is_compute_core()) {
// compute the gemm for the current row block and column block
uint32_t row_offset = (B_r_lin2 / num_cores) * compute_id * l->positional_embeddings_fa;
dump_id(compute_id);
uint32_t row_offset = (B_r_lin2 / num_cores) * cluster_core_id * l->positional_embeddings_fa;
dump_id(cluster_core_id);
dump_ct(row_offset);
uint32_t ofmap_offset = (B_r_lin2 / num_cores) * compute_id * B_c_lin2;
uint32_t ofmap_offset = (B_r_lin2 / num_cores) * cluster_core_id * B_c_lin2;
uint32_t start_gemm = snrt_mcycle();
if (cluster_id == 1) {
gemm_fp64_baseline(B_r_lin2 / num_cores, B_c_lin2, l->positional_embeddings_fa,
Expand Down

0 comments on commit 17dfdc9

Please sign in to comment.