Skip to content

Commit

Permalink
transformer: test only 2nd cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
Viviane Potocnik committed Oct 20, 2023
1 parent 959d7ce commit b58f63b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 19 deletions.
38 changes: 21 additions & 17 deletions sw/apps/transformer/src/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -868,11 +868,13 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) {

snrt_dma_wait_all();

// for (int i = 0; i < B_r_lin2 * l->positional_embeddings_fa; i++) {
// dump_idx(i + ifmap_offset);
// dump_debug(ifmap_lin2[i]);
// // printf("ifmap[%d] = %f\n", i + ifmap_offset, ifmap_lin2[i]);
// }
if (cluster_id == 1) {
for (int i = 0; i < B_r_lin2 * l->positional_embeddings_fa; i++) {
dump_idx(i + ifmap_offset);
dump_debug(ifmap_lin2[i]);
// printf("ifmap[%d] = %f\n", i + ifmap_offset, ifmap_lin2[i]);
}
}

}
uint32_t end_dma = snrt_mcycle();
Expand Down Expand Up @@ -914,22 +916,24 @@ static inline void transformer_layer_fp64(transformer_layer_fp64_t *const l) {
uint32_t row_offset = (B_r_lin2 / num_cores) * compute_id * l->positional_embeddings_fa;
uint32_t ofmap_offset = (B_r_lin2 / num_cores) * compute_id * B_c_lin2;
uint32_t start_gemm = snrt_mcycle();
gemm_fp64_baseline(B_r_lin2 / num_cores, B_c_lin2, l->positional_embeddings_fa,
&ifmap_lin2[row_offset], l->positional_embeddings_fa, 0,
weights_lin2, l->positional_embeddings_fa, 0,
&ofmap_lin2[ofmap_offset], B_c_lin2, 0.0f);
if (cluster_id == 1) {
gemm_fp64_baseline(B_r_lin2 / num_cores, B_c_lin2, l->positional_embeddings_fa,
&ifmap_lin2[row_offset], l->positional_embeddings_fa, 0,
weights_lin2, l->positional_embeddings_fa, 0,
&ofmap_lin2[ofmap_offset], B_c_lin2, 0.0f);
}
uint32_t end_gemm = snrt_mcycle();

snrt_cluster_hw_barrier();

if (cluster_id == 0) {
for (int i = 0; i < B_r_lin2 / num_cores; i++) {
for (int j = 0; j < B_c_lin2; j++) {
dump_idx(i * B_c_lin2 + j + ofmap_offset);
dump_debug(ofmap_lin2[i * B_c_lin2 + j + ofmap_offset]);
}
}
}
// if (cluster_id == 0) {
// for (int i = 0; i < B_r_lin2 / num_cores; i++) {
// for (int j = 0; j < B_c_lin2; j++) {
// dump_idx(i * B_c_lin2 + j + ofmap_offset);
// dump_debug(ofmap_lin2[i * B_c_lin2 + j + ofmap_offset]);
// }
// }
// }
} else {
snrt_cluster_hw_barrier();
}
Expand Down
4 changes: 2 additions & 2 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
double c0 = BETA * C[m * ldC + n];
for (uint32_t k = 0; k < K; k++) {
// dump_index(k + m * ldA);
// dump_gemm(A[k + m * ldA]);
c0 += A[k + m * ldA] * B[k * ldB + n];
dump_gemm(A[k + m * ldA]);
// c0 += A[k + m * ldA] * B[k * ldB + n];
}
C[m * ldC + n] = c0;
}
Expand Down

0 comments on commit b58f63b

Please sign in to comment.