Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix-kul-cluster-sw-lib #14

Merged
merged 1 commit into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,75 +35,79 @@ int main() {

// Transfer data from L3 to L1
// Using DMA only
if (snrt_is_dm_core()) {
load_conv_input_data(Nbatch, H + 2 * pad_h, W + 2 * pad_w, Cin, local_a,
A);
load_weight_data(Cout, Kh, Kw, Cin, local_b, B);
}
if(snrt_cluster_idx() == 0){

// Wait for DMA to finish
snrt_cluster_hw_barrier();
if (snrt_is_dm_core()) {
load_conv_input_data(Nbatch, H + 2 * pad_h, W + 2 * pad_w, Cin, local_a,
A);
load_weight_data(Cout, Kh, Kw, Cin, local_b, B);
}

if (snrt_is_dm_core()) {
snrt_dma_start_1d(local_c, C,
M * N * meshRow * meshCol * sizeof(int32_t));
}
// Wait for DMA to finish
snrt_cluster_hw_barrier();

snrt_cluster_hw_barrier();
if (snrt_is_dm_core()) {
snrt_dma_start_1d(local_c, C,
M * N * meshRow * meshCol * sizeof(int32_t));
}

if (snrt_global_core_idx() == 0) {
// Set Streamer configuration CSR for conv2d
set_gemmx_streamer_csr(
Aslstride0, Aslstride1, Atlbound0, Atlstride0, Atlbound1,
Atlstride1, Atlbound2, Atlstride2, Atlbound3, Atlstride3, Atlbound4,
Atlstride4, Atlbound5, Atlstride5,
snrt_cluster_hw_barrier();

Bslstride0, Bslstride1, Btlbound0, Btlstride0, Btlbound1,
Btlstride1, Btlbound2, Btlstride2,
if (snrt_global_core_idx() == 0) {
// Set Streamer configuration CSR for conv2d
set_gemmx_streamer_csr(
Aslstride0, Aslstride1, Atlbound0, Atlstride0, Atlbound1,
Atlstride1, Atlbound2, Atlstride2, Atlbound3, Atlstride3, Atlbound4,
Atlstride4, Atlbound5, Atlstride5,

D8slstride0, D8slstride1, D8tlbound0, D8tlstride0, D8tlbound1,
D8tlstride1, D8tlbound2, D8tlstride2,
Bslstride0, Bslstride1, Btlbound0, Btlstride0, Btlbound1,
Btlstride1, Btlbound2, Btlstride2,

Cslstride0, Cslstride1, Ctlbound0, Ctlstride0, Ctlbound1,
Ctlstride1, Ctlbound2, Ctlstride2,
D8slstride0, D8slstride1, D8tlbound0, D8tlstride0, D8tlbound1,
D8tlstride1, D8tlbound2, D8tlstride2,

D32slstride0, D32slstride1, D32tlbound0, D32tlstride0, D32tlbound1,
D32tlstride1, D32tlbound2, D32tlstride2,
Cslstride0, Cslstride1, Ctlbound0, Ctlstride0, Ctlbound1,
Ctlstride1, Ctlbound2, Ctlstride2,

delta_local_a, delta_local_b, delta_local_d8, delta_local_c,
delta_local_d32, bypassSIMD);
D32slstride0, D32slstride1, D32tlbound0, D32tlstride0, D32tlbound1,
D32tlstride1, D32tlbound2, D32tlstride2,

// Set CSR to start Streamer for conv2d
set_gemmx_streamer_start();
delta_local_a, delta_local_b, delta_local_d8, delta_local_c,
delta_local_d32, bypassSIMD);

// Set GEMM configuration CSR
uint32_t subtraction_setting =
gen_subtraction_config(subtraction_a, subtraction_b);
// Set CSR to start Streamer for conv2d
set_gemmx_streamer_start();

uint32_t csr0 =
gen_csr0_config(input_zp_i, output_zp_i, shift_i, max_int_i);
uint32_t csr1 = gen_csr1_config(min_int_i, double_round_i);
uint32_t csr2 = gen_csr2_config(multiplier_i);
// Set GEMM configuration CSR
uint32_t subtraction_setting =
gen_subtraction_config(subtraction_a, subtraction_b);

set_gemmx_csr(K, N, M, subtraction_setting, csr0, csr1, csr2, M * N,
bypassSIMD);
uint32_t csr0 =
gen_csr0_config(input_zp_i, output_zp_i, shift_i, max_int_i);
uint32_t csr1 = gen_csr1_config(min_int_i, double_round_i);
uint32_t csr2 = gen_csr2_config(multiplier_i);

// Set CSR to start GEMM
set_gemmx_start();
set_gemmx_csr(K, N, M, subtraction_setting, csr0, csr1, csr2, M * N,
bypassSIMD);

// Poll until Streamer and GEMM accelerator finish
wait_gemmx_and_streamer();
// Set CSR to start GEMM
set_gemmx_start();

// Poll until Streamer and GEMM accelerator finish
wait_gemmx_and_streamer();

// check the result of the implicit im2col convolution
if (!bypassSIMD) {
err +=
check_gemmx_result_D8(local_d8, D8_direct_conv2d, Batch, M, N);
} else {
err += check_gemmx_result_D32(local_d32, D32_direct_conv2d, Batch,
M, N);
}
printf("SNAX GEMM Conv2d: %s, err = %d . bypassSIMD = %d .\n",
err ? "FAIL" : "PASS", err, bypassSIMD);
};

// check the result of the implicit im2col convolution
if (!bypassSIMD) {
err +=
check_gemmx_result_D8(local_d8, D8_direct_conv2d, Batch, M, N);
} else {
err += check_gemmx_result_D32(local_d32, D32_direct_conv2d, Batch,
M, N);
}
printf("SNAX GEMM Conv2d: %s, err = %d . bypassSIMD = %d .\n",
err ? "FAIL" : "PASS", err, bypassSIMD);
};

return err;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,50 +148,50 @@ void set_gemmx_streamer_csr(
}

// Set CSR to start STREAMER
void set_gemmx_streamer_start() { write_csr(1005 + 6, 1); }
void set_gemmx_streamer_start() { write_csr(1011, 1); }

// Set GEMM configuration CSR
void set_gemmx_csr(int tempLoop0, int tempLoop1, int tempLoop2,
int subtractions, uint32_t csr0, uint32_t csr1,
uint32_t csr2, uint32_t temporal_loop_bound,
uint32_t bypassSIMD) {
// set loop bounds, from innermost to outermost, aka from K to N to M
write_csr(1007 + 6, tempLoop0);
write_csr(1008 + 6, tempLoop1);
write_csr(1009 + 6, tempLoop2);
write_csr(1014, tempLoop0);
write_csr(1015, tempLoop1);
write_csr(1016, tempLoop2);

// set subtraction a and b
write_csr(1010 + 6, subtractions);
write_csr(1017, subtractions);

// set the constants for the SIMD unit
write_csr(1011 + 6, csr0);
write_csr(1012 + 6, csr1);
write_csr(1013 + 6, csr2);
write_csr(1018, csr0);
write_csr(1019, csr1);
write_csr(1020, csr2);

// set the temporal loop bound
write_csr(1014 + 6, temporal_loop_bound);
write_csr(1015 + 6, bypassSIMD);
write_csr(1021, temporal_loop_bound);
write_csr(1022, bypassSIMD);
}

// Set CSR to start GEMM
void set_gemmx_start() { write_csr(1016 + 6, 1); }
void set_gemmx_start() { write_csr(1023, 1); }

// Stall until Streamer and GEMM accelerator finish
void wait_gemmx_and_streamer() {
write_csr(1005 + 6, 0);
write_csr(1005 + 6, 0);
write_csr(1016 + 6, 0);
write_csr(1011, 0);
write_csr(1011, 0);
write_csr(1023, 0);
}

// Read performance counter of the Streamer, a read-only CSR
uint32_t read_gemmx_streamer_perf_counter() {
uint32_t perf_counter = read_csr(1006);
uint32_t perf_counter = read_csr(1013);
return perf_counter;
}

// Read performance counter of GEMM, a read-only CSR
uint32_t read_gemmx_perf_counter() {
uint32_t perf_counter = read_csr(1017);
uint32_t perf_counter = read_csr(1025);
return perf_counter;
}

Expand Down
Loading