From 27757d93aeef32a2265a27c85ff2da537573b109 Mon Sep 17 00:00:00 2001 From: Xiaoling Yi <143962462+xiaoling-yi@users.noreply.github.com> Date: Fri, 8 Dec 2023 19:30:44 +0100 Subject: [PATCH] Add software test and golden model for subtraction (#53) * sw: add test and golden model for subtraction * lint: formatting datagen.py * Apply suggestions from code review Co-authored-by: Josse Van Delm * sw: format snax-gemm-lib.c --------- Co-authored-by: Josse Van Delm --- Bender.yml | 2 +- .../sw/apps/snax-block-gemm/data/datagen.py | 25 +++++++++++- .../snax-block-gemm/src/snax-block-gemm.c | 17 +++++--- .../sw/snax/gemm/include/snax-gemm-lib.h | 21 +++++++--- .../sw/snax/gemm/src/snax-gemm-lib.c | 40 ++++++++++++++----- 5 files changed, 79 insertions(+), 26 deletions(-) diff --git a/Bender.yml b/Bender.yml index 9aaf9b70b..58fd4dd77 100644 --- a/Bender.yml +++ b/Bender.yml @@ -29,7 +29,7 @@ dependencies: tech_cells_generic: { git: https://github.com/pulp-platform/tech_cells_generic, version: 0.2.11 } riscv-dbg: { git: https://github.com/pulp-platform/riscv-dbg, version: 0.8.0 } hwpe-mac-engine: { git: https://github.com/KULeuven-MICAS/hwpe-mac-engine.git, rev: 5d3b4525b665169fc8321c8a811f3c83ad3c72e8 } - snax-gemm: { git: https://github.com/KULeuven-MICAS/snax-gemm.git, rev: db87157ad619bdbfe317b48c9a2790bce156db6f } + snax-gemm: { git: https://github.com/KULeuven-MICAS/snax-gemm.git, rev: 7e04484633da28513bdbe96a73a86a8476f69da4 } vendor_package: - name: musl diff --git a/target/snitch_cluster/sw/apps/snax-block-gemm/data/datagen.py b/target/snitch_cluster/sw/apps/snax-block-gemm/data/datagen.py index cb2941d0a..1edf79ca4 100755 --- a/target/snitch_cluster/sw/apps/snax-block-gemm/data/datagen.py +++ b/target/snitch_cluster/sw/apps/snax-block-gemm/data/datagen.py @@ -22,7 +22,8 @@ # Golden model in python -def block_gemm_golden_model(m, k, n, row, size, col, a, b): +def block_gemm_golden_model(m, k, n, row, size, col, a, b, + subtraction_a, subtraction_b): c = np.zeros(m * row * n * col, dtype=(np.int32)) for mm in range(m): for nn in range(n): @@ -48,7 +49,9 @@ def block_gemm_golden_model(m, k, n, row, size, col, a, b): + cc * size + ss ) - c[c_index] = c[c_index] + a[a_index] * b[b_index] + c[c_index] = c[c_index] + \ + (a[a_index] - subtraction_a) * \ + (b[b_index] - subtraction_b) return c @@ -115,6 +118,22 @@ def emit_gemm_data(**kwargs): ) ] + # Generating random 8 integer a and b for subtraction + subtraction_a = np.random.randint(MIN, MAX) + subtraction_b = np.random.randint(MIN, MAX) + + # Writing the subtraction value to data.h + data_str += [ + format_scalar_definition( + "int8_t", "subtraction_a", subtraction_a + ) + ] + data_str += [ + format_scalar_definition( + "int8_t", "subtraction_b", subtraction_b + ) + ] + # Generate random input matrices length_a = ( kwargs["M"] * kwargs["K"] * kwargs["meshRow"] * kwargs["tileSize"] @@ -135,6 +154,8 @@ def emit_gemm_data(**kwargs): kwargs["meshCol"], a, b, + subtraction_a, + subtraction_b ) c_init = np.zeros(c_golden.shape) c_cpu = np.zeros(c_golden.shape) diff --git a/target/snitch_cluster/sw/apps/snax-block-gemm/src/snax-block-gemm.c b/target/snitch_cluster/sw/apps/snax-block-gemm/src/snax-block-gemm.c index f5aece549..f463fc24d 100644 --- a/target/snitch_cluster/sw/apps/snax-block-gemm/src/snax-block-gemm.c +++ b/target/snitch_cluster/sw/apps/snax-block-gemm/src/snax-block-gemm.c @@ -50,12 +50,16 @@ int main() { // Pack matrix size setting to one CSR uint32_t size_setting = gen_size_config(Batch, M, K, N); + uint32_t subtraction_setting = + gen_subtraction_config(subtraction_a, subtraction_b); + uint32_t gemm_start = snrt_mcycle(); // Set GEMM configuration CSR - set_batch_gemm(size_setting, local_a, local_b, local_c, - strideInnermostA, strideInnermostB, strideInnermostC, - ldA, ldB, ldC, strideA, strideB, strideC); + set_batch_gemm(size_setting, local_a, local_b, subtraction_setting, + local_c, strideInnermostA, strideInnermostB, + strideInnermostC, ldA, ldB, ldC, strideA, strideB, + strideC); // Set CSR to start GEMM and poll until GEMM accelerator finishes start_batch_gemm(); @@ -77,9 +81,10 @@ int main() { // region for benchmarking) uint32_t start_cycle = snrt_mcycle(); - batch_gemm_cpu(Batch, M, K, N, local_a, local_b, C_cpu, - strideInnermostA, strideInnermostB, strideInnermostC, - ldA, ldB, ldC, strideA, strideB, strideC); + batch_gemm_cpu(Batch, M, K, N, local_a, local_b, subtraction_a, + subtraction_b, C_cpu, strideInnermostA, strideInnermostB, + strideInnermostC, ldA, ldB, ldC, strideA, strideB, + strideC); // Read the mcycle CSR uint32_t end_cycle = snrt_mcycle(); diff --git a/target/snitch_cluster/sw/snax/gemm/include/snax-gemm-lib.h b/target/snitch_cluster/sw/snax/gemm/include/snax-gemm-lib.h index 7d360734b..a96466e2d 100644 --- a/target/snitch_cluster/sw/snax/gemm/include/snax-gemm-lib.h +++ b/target/snitch_cluster/sw/snax/gemm/include/snax-gemm-lib.h @@ -13,13 +13,21 @@ // Pack matrix size setting to one CSR int32_t gen_size_config(uint8_t Batch, uint8_t M, uint8_t K, uint8_t N); +// Pack two subtraction values to one CSR +int32_t gen_subtraction_config(int8_t subtraction_a, int8_t subtraction_b); + +// Performance counter for GEMM busy cycles +uint32_t read_performance_counter(); + // Golden model for base gemm void base_gemm(uint8_t m, uint8_t k, uint8_t n, int8_t* A, int8_t* B, - int32_t* C_cpu, bool new_batch); + int8_t subtraction_a, int8_t subtraction_b, int32_t* C_cpu, + bool new_batch); // Golden model for batch gemm void batch_gemm_cpu(uint8_t Batch, uint8_t M, uint8_t K, uint8_t N, int8_t* A, - int8_t* B, int32_t* C, uint32_t strideInnermostA, + int8_t* B, int8_t subtraction_a, int8_t subtraction_b, + int32_t* C, uint32_t strideInnermostA, uint32_t strideInnermostB, uint32_t strideInnermostC, uint32_t ldA, uint32_t ldB, uint32_t ldC, uint32_t strideA, uint32_t strideB, uint32_t strideC); @@ -33,10 +41,11 @@ void load_input_data(uint8_t Batch, uint8_t M, uint8_t K, uint8_t N, // Set GEMM configuration CSR void set_batch_gemm(uint32_t size_setting, int8_t* local_a, int8_t* local_b, - int32_t* local_c, uint32_t strideInnermostA, - uint32_t strideInnermostB, uint32_t strideInnermostC, - uint32_t ldA, uint32_t ldB, uint32_t ldC, uint32_t strideA, - uint32_t strideB, uint32_t strideC); + int32_t subtractions, int32_t* local_c, + uint32_t strideInnermostA, uint32_t strideInnermostB, + uint32_t strideInnermostC, uint32_t ldA, uint32_t ldB, + uint32_t ldC, uint32_t strideA, uint32_t strideB, + uint32_t strideC); // Set CSR to start GEMM void start_batch_gemm(); diff --git a/target/snitch_cluster/sw/snax/gemm/src/snax-gemm-lib.c b/target/snitch_cluster/sw/snax/gemm/src/snax-gemm-lib.c index cdbd157f4..e0aacea5f 100644 --- a/target/snitch_cluster/sw/snax/gemm/src/snax-gemm-lib.c +++ b/target/snitch_cluster/sw/snax/gemm/src/snax-gemm-lib.c @@ -13,8 +13,19 @@ int32_t gen_size_config(uint8_t Batch, uint8_t M, uint8_t K, uint8_t N) { (int32_t)N; } +int32_t gen_subtraction_config(int8_t subtraction_a, int8_t subtraction_b) { + return ((uint8_t)subtraction_b << 8) | (uint8_t)subtraction_a; +} + +uint32_t read_performance_counter() { + uint32_t performance_counter; + performance_counter = read_csr(0x3cd); + return performance_counter; +}; + void base_gemm(uint8_t m, uint8_t k, uint8_t n, int8_t* A, int8_t* B, - int32_t* C_cpu, bool clear) { + int8_t subtraction_a, int8_t subtraction_b, int32_t* C_cpu, + bool clear) { for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { // clear memory first before start matrix multiplication @@ -23,15 +34,18 @@ void base_gemm(uint8_t m, uint8_t k, uint8_t n, int8_t* A, int8_t* B, C_cpu[i * n + j] = 0; } for (int s = 0; s < k; s++) { - C_cpu[i * n + j] = C_cpu[i * n + j] + (int32_t)A[i * k + s] * - (int32_t)B[s + j * k]; + C_cpu[i * n + j] = + C_cpu[i * n + j] + + ((int32_t)A[i * k + s] - (int32_t)subtraction_a) * + ((int32_t)B[s + j * k] - (int32_t)subtraction_b); } } } }; void batch_gemm_cpu(uint8_t Batch, uint8_t M, uint8_t K, uint8_t N, int8_t* A, - int8_t* B, int32_t* C, uint32_t strideInnermostA, + int8_t* B, int8_t subtraction_a, int8_t subtraction_b, + int32_t* C, uint32_t strideInnermostA, uint32_t strideInnermostB, uint32_t strideInnermostC, uint32_t ldA, uint32_t ldB, uint32_t ldC, uint32_t strideA, uint32_t strideB, uint32_t strideC) { @@ -62,7 +76,7 @@ void batch_gemm_cpu(uint8_t Batch, uint8_t M, uint8_t K, uint8_t N, int8_t* A, // when k == 0, clear the memory clear = k == 0; base_gemm(meshRow, tileSize, meshCol, addr_a, addr_b, - addr_c, clear); + subtraction_a, subtraction_b, addr_c, clear); } } } @@ -120,10 +134,11 @@ void load_input_data(uint8_t Batch, uint8_t M, uint8_t K, uint8_t N, } void set_batch_gemm(uint32_t size_setting, int8_t* local_a, int8_t* local_b, - int32_t* local_c, uint32_t strideInnermostA, - uint32_t strideInnermostB, uint32_t strideInnermostC, - uint32_t ldA, uint32_t ldB, uint32_t ldC, uint32_t strideA, - uint32_t strideB, uint32_t strideC) { + int32_t subtractions, int32_t* local_c, + uint32_t strideInnermostA, uint32_t strideInnermostB, + uint32_t strideInnermostC, uint32_t ldA, uint32_t ldB, + uint32_t ldC, uint32_t strideA, uint32_t strideB, + uint32_t strideC) { // Set matrix size write_csr(0x3c0, size_setting); @@ -144,13 +159,16 @@ void set_batch_gemm(uint32_t size_setting, int8_t* local_a, int8_t* local_b, write_csr(0x3ca, strideA); write_csr(0x3cb, strideB); write_csr(0x3cc, strideC); + + // Set subtraction values + write_csr(0x3ce, subtractions); } void start_batch_gemm() { // 0x3ce is the CSR address for accelerator status // set the lowest bit of state CSR (state CSR[0]) to set start signal in // GEMM - write_csr(0x3ce, 1); + write_csr(0x3cf, 1); } void wait_batch_gemm() { @@ -158,7 +176,7 @@ void wait_batch_gemm() { while (1) { // poll the state CSR[1] to see if GEMM is still busy - break_poll = read_csr(0x3ce); + break_poll = read_csr(0x3cf); if ((break_poll >> 1) == 1) { break; };