diff --git a/target/sim/sw/device/apps/snax/snax-streamer-gemm-conv-simd/src/snax-streamer-gemm-conv-simd.c b/target/sim/sw/device/apps/snax/snax-streamer-gemm-conv-simd/src/snax-streamer-gemm-conv-simd.c index 66c96eb50..48c54d139 100644 --- a/target/sim/sw/device/apps/snax/snax-streamer-gemm-conv-simd/src/snax-streamer-gemm-conv-simd.c +++ b/target/sim/sw/device/apps/snax/snax-streamer-gemm-conv-simd/src/snax-streamer-gemm-conv-simd.c @@ -35,75 +35,79 @@ int main() { // Transfer data from L3 to L1 // Using DMA only - if (snrt_is_dm_core()) { - load_conv_input_data(Nbatch, H + 2 * pad_h, W + 2 * pad_w, Cin, local_a, - A); - load_weight_data(Cout, Kh, Kw, Cin, local_b, B); - } + if(snrt_cluster_idx() == 0){ - // Wait for DMA to finish - snrt_cluster_hw_barrier(); + if (snrt_is_dm_core()) { + load_conv_input_data(Nbatch, H + 2 * pad_h, W + 2 * pad_w, Cin, local_a, + A); + load_weight_data(Cout, Kh, Kw, Cin, local_b, B); + } - if (snrt_is_dm_core()) { - snrt_dma_start_1d(local_c, C, - M * N * meshRow * meshCol * sizeof(int32_t)); - } + // Wait for DMA to finish + snrt_cluster_hw_barrier(); - snrt_cluster_hw_barrier(); + if (snrt_is_dm_core()) { + snrt_dma_start_1d(local_c, C, + M * N * meshRow * meshCol * sizeof(int32_t)); + } - if (snrt_global_core_idx() == 0) { - // Set Streamer configuration CSR for conv2d - set_gemmx_streamer_csr( - Aslstride0, Aslstride1, Atlbound0, Atlstride0, Atlbound1, - Atlstride1, Atlbound2, Atlstride2, Atlbound3, Atlstride3, Atlbound4, - Atlstride4, Atlbound5, Atlstride5, + snrt_cluster_hw_barrier(); - Bslstride0, Bslstride1, Btlbound0, Btlstride0, Btlbound1, - Btlstride1, Btlbound2, Btlstride2, + if (snrt_global_core_idx() == 0) { + // Set Streamer configuration CSR for conv2d + set_gemmx_streamer_csr( + Aslstride0, Aslstride1, Atlbound0, Atlstride0, Atlbound1, + Atlstride1, Atlbound2, Atlstride2, Atlbound3, Atlstride3, Atlbound4, + Atlstride4, Atlbound5, Atlstride5, - D8slstride0, D8slstride1, D8tlbound0, D8tlstride0, D8tlbound1, - D8tlstride1, D8tlbound2, D8tlstride2, + Bslstride0, Bslstride1, Btlbound0, Btlstride0, Btlbound1, + Btlstride1, Btlbound2, Btlstride2, - Cslstride0, Cslstride1, Ctlbound0, Ctlstride0, Ctlbound1, - Ctlstride1, Ctlbound2, Ctlstride2, + D8slstride0, D8slstride1, D8tlbound0, D8tlstride0, D8tlbound1, + D8tlstride1, D8tlbound2, D8tlstride2, - D32slstride0, D32slstride1, D32tlbound0, D32tlstride0, D32tlbound1, - D32tlstride1, D32tlbound2, D32tlstride2, + Cslstride0, Cslstride1, Ctlbound0, Ctlstride0, Ctlbound1, + Ctlstride1, Ctlbound2, Ctlstride2, - delta_local_a, delta_local_b, delta_local_d8, delta_local_c, - delta_local_d32, bypassSIMD); + D32slstride0, D32slstride1, D32tlbound0, D32tlstride0, D32tlbound1, + D32tlstride1, D32tlbound2, D32tlstride2, - // Set CSR to start Streamer for conv2d - set_gemmx_streamer_start(); + delta_local_a, delta_local_b, delta_local_d8, delta_local_c, + delta_local_d32, bypassSIMD); - // Set GEMM configuration CSR - uint32_t subtraction_setting = - gen_subtraction_config(subtraction_a, subtraction_b); + // Set CSR to start Streamer for conv2d + set_gemmx_streamer_start(); - uint32_t csr0 = - gen_csr0_config(input_zp_i, output_zp_i, shift_i, max_int_i); - uint32_t csr1 = gen_csr1_config(min_int_i, double_round_i); - uint32_t csr2 = gen_csr2_config(multiplier_i); + // Set GEMM configuration CSR + uint32_t subtraction_setting = + gen_subtraction_config(subtraction_a, subtraction_b); - set_gemmx_csr(K, N, M, subtraction_setting, csr0, csr1, csr2, M * N, - bypassSIMD); + uint32_t csr0 = + gen_csr0_config(input_zp_i, output_zp_i, shift_i, max_int_i); + uint32_t csr1 = gen_csr1_config(min_int_i, double_round_i); + uint32_t csr2 = gen_csr2_config(multiplier_i); - // Set CSR to start GEMM - set_gemmx_start(); + set_gemmx_csr(K, N, M, subtraction_setting, csr0, csr1, csr2, M * N, + bypassSIMD); - // Poll until Streamer and GEMM accelerator finish - wait_gemmx_and_streamer(); + // Set CSR to start GEMM + set_gemmx_start(); + + // Poll until Streamer and GEMM accelerator finish + wait_gemmx_and_streamer(); + + // check the result of the implicit im2col convolution + if (!bypassSIMD) { + err += + check_gemmx_result_D8(local_d8, D8_direct_conv2d, Batch, M, N); + } else { + err += check_gemmx_result_D32(local_d32, D32_direct_conv2d, Batch, + M, N); + } + printf("SNAX GEMM Conv2d: %s, err = %d . bypassSIMD = %d .\n", + err ? "FAIL" : "PASS", err, bypassSIMD); + }; - // check the result of the implicit im2col convolution - if (!bypassSIMD) { - err += - check_gemmx_result_D8(local_d8, D8_direct_conv2d, Batch, M, N); - } else { - err += check_gemmx_result_D32(local_d32, D32_direct_conv2d, Batch, - M, N); - } - printf("SNAX GEMM Conv2d: %s, err = %d . bypassSIMD = %d .\n", - err ? "FAIL" : "PASS", err, bypassSIMD); }; return err; diff --git a/target/sim/sw/device/snax/streamer-gemm-conv-simd/src/snax-streamer-gemm-conv-simd-lib.c b/target/sim/sw/device/snax/streamer-gemm-conv-simd/src/snax-streamer-gemm-conv-simd-lib.c index 34a31aff9..5d0f7101d 100644 --- a/target/sim/sw/device/snax/streamer-gemm-conv-simd/src/snax-streamer-gemm-conv-simd-lib.c +++ b/target/sim/sw/device/snax/streamer-gemm-conv-simd/src/snax-streamer-gemm-conv-simd-lib.c @@ -148,7 +148,7 @@ void set_gemmx_streamer_csr( } // Set CSR to start STREAMER -void set_gemmx_streamer_start() { write_csr(1005 + 6, 1); } +void set_gemmx_streamer_start() { write_csr(1011, 1); } // Set GEMM configuration CSR void set_gemmx_csr(int tempLoop0, int tempLoop1, int tempLoop2, @@ -156,42 +156,42 @@ void set_gemmx_csr(int tempLoop0, int tempLoop1, int tempLoop2, uint32_t csr2, uint32_t temporal_loop_bound, uint32_t bypassSIMD) { // set loop bounds, from innermost to outermost, aka from K to N to M - write_csr(1007 + 6, tempLoop0); - write_csr(1008 + 6, tempLoop1); - write_csr(1009 + 6, tempLoop2); + write_csr(1014, tempLoop0); + write_csr(1015, tempLoop1); + write_csr(1016, tempLoop2); // set subtraction a and b - write_csr(1010 + 6, subtractions); + write_csr(1017, subtractions); // set the constants for the SIMD unit - write_csr(1011 + 6, csr0); - write_csr(1012 + 6, csr1); - write_csr(1013 + 6, csr2); + write_csr(1018, csr0); + write_csr(1019, csr1); + write_csr(1020, csr2); // set the temporal loop bound - write_csr(1014 + 6, temporal_loop_bound); - write_csr(1015 + 6, bypassSIMD); + write_csr(1021, temporal_loop_bound); + write_csr(1022, bypassSIMD); } // Set CSR to start GEMM -void set_gemmx_start() { write_csr(1016 + 6, 1); } +void set_gemmx_start() { write_csr(1023, 1); } // Stall until Streamer and GEMM accelerator finish void wait_gemmx_and_streamer() { - write_csr(1005 + 6, 0); - write_csr(1005 + 6, 0); - write_csr(1016 + 6, 0); + write_csr(1011, 0); + write_csr(1011, 0); + write_csr(1023, 0); } // Read performance counter of the Streamer, a read-only CSR uint32_t read_gemmx_streamer_perf_counter() { - uint32_t perf_counter = read_csr(1006); + uint32_t perf_counter = read_csr(1013); return perf_counter; } // Read performance counter of GEMM, a read-only CSR uint32_t read_gemmx_perf_counter() { - uint32_t perf_counter = read_csr(1017); + uint32_t perf_counter = read_csr(1025); return perf_counter; }