From 17ede54f50e3e5225a494c234e0132e7c41f25e6 Mon Sep 17 00:00:00 2001 From: Roger Barton Date: Tue, 6 Feb 2024 17:21:31 +0100 Subject: [PATCH] Introduce SnblasGemmImpl for implementation parameters, refactor datagen params format --- sw/blas/gemm/data/params.hjson | 40 +++++++++++++++------ sw/blas/gemm/src/gemm_decls.h | 11 ++++++ sw/blas/gemm/src/gemm_kernel_fp16.h | 8 ++--- sw/blas/gemm/src/gemm_kernel_fp32.h | 8 ++--- sw/blas/gemm/src/gemm_kernel_fp64.h | 26 +++++++------- sw/blas/gemm/src/gemm_kernel_fp8.h | 8 ++--- sw/blas/gemm/src/gemm_kernel_init_tpl.h | 20 +++++------ sw/blas/gemm/src/gemm_tiling_1dpipe_tpl.h | 20 +++++------ sw/blas/gemm/src/gemm_tiling_2dpipe_tpl.h | 22 ++++++------ sw/blas/gemm/src/gemm_tiling_baseline_tpl.h | 18 +++++----- sw/blas/gemm/src/gemm_tpl.h | 8 ++--- sw/blas/gemm/src/main.c | 9 ++++- 12 files changed, 117 insertions(+), 81 deletions(-) diff --git a/sw/blas/gemm/data/params.hjson b/sw/blas/gemm/data/params.hjson index 5fe43dead..b58eecb38 100644 --- a/sw/blas/gemm/data/params.hjson +++ b/sw/blas/gemm/data/params.hjson @@ -5,15 +5,33 @@ // Parameters for a GEMM { - M: 192, - N: 16, - K: 16, - beta: 0, - ta: false, - tb: true, // must be true for SIMD - prec: 64, - expand: 0, - linspace: true, - bench_iters: 2, - method: "2dpipe", + gemmInfo: { + prec: 64, + M: 192, + N: 16, + K: 16, + ta: false, + tb: true, // must be true for SIMD + } + + gemmArgs: { // C = alpha * A * B + beta * C + beta: 0, + alpha: 1, + } + + gemmImpl: { + method: "baseline", + ta_tile: false, + tb_tile: false, + tc_tile: false, + expand: 0, + } + + bench: { + iters: 2, + } + + datagen: { + linspace: true, + } } diff --git a/sw/blas/gemm/src/gemm_decls.h b/sw/blas/gemm/src/gemm_decls.h index 670db7c7f..783a34e32 100644 --- a/sw/blas/gemm/src/gemm_decls.h +++ b/sw/blas/gemm/src/gemm_decls.h @@ -49,6 +49,17 @@ typedef struct { uint32_t tb; } SnblasGemmInfo; +/** + * Constants related to which implementation should be used. + * Only for non-template parameters +*/ +typedef struct { + bool bench; // Enable benchmarking code + bool ta_tile; // Transpose the A tile when loading into TCDM + bool tb_tile; + bool tc_tile; +} SnblasGemmImpl; + #define L1_M 8 #define L1_N 8 #define L1_K 8 diff --git a/sw/blas/gemm/src/gemm_kernel_fp16.h b/sw/blas/gemm/src/gemm_kernel_fp16.h index 70ed716a6..70f22205b 100644 --- a/sw/blas/gemm/src/gemm_kernel_fp16.h +++ b/sw/blas/gemm/src/gemm_kernel_fp16.h @@ -2,8 +2,8 @@ #define FLOAT_T fp16 #include "gemm_kernel_init_tpl.h" -extern void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench); -inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench) { +extern void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl); +inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl) { uint32_t p[3], P[3]; ocrt_thread_idx(p); ocrt_compute_thread_num(P); @@ -32,7 +32,7 @@ inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo inf // Kernel progresses by 4 values each step const uint32_t n_frep = K / 4 - 1; - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); for (uint32_t m = 0; m < M; m++) { uint32_t n = 0; for (uint32_t n0 = 0; n0 < N / unroll; n0++) { @@ -181,7 +181,7 @@ inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo inf } snrt_fpu_fence(); - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); } #undef FLOAT_T \ No newline at end of file diff --git a/sw/blas/gemm/src/gemm_kernel_fp32.h b/sw/blas/gemm/src/gemm_kernel_fp32.h index bf6c50481..2ccc5df89 100644 --- a/sw/blas/gemm/src/gemm_kernel_fp32.h +++ b/sw/blas/gemm/src/gemm_kernel_fp32.h @@ -2,8 +2,8 @@ #define FLOAT_T fp32 #include "gemm_kernel_init_tpl.h" -extern void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench); -inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench) { +extern void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl); +inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl) { uint32_t p[3], P[3]; ocrt_thread_idx(p); ocrt_compute_thread_num(P); @@ -32,7 +32,7 @@ inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo inf // Kernel progresses by 2 values each step const uint32_t n_frep = K / 2 - 1; - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); for (uint32_t m = 0; m < M; m++) { uint32_t n = 0; for (uint32_t n0 = 0; n0 < N / unroll; n0++) { @@ -146,7 +146,7 @@ inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo inf snrt_ssr_enable(); } snrt_fpu_fence(); - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); } #undef FLOAT_T \ No newline at end of file diff --git a/sw/blas/gemm/src/gemm_kernel_fp64.h b/sw/blas/gemm/src/gemm_kernel_fp64.h index 24def0cda..590a38ad2 100644 --- a/sw/blas/gemm/src/gemm_kernel_fp64.h +++ b/sw/blas/gemm/src/gemm_kernel_fp64.h @@ -1,8 +1,8 @@ #include "gemm_decls.h" #define _FMADD_UNROLL 8 -extern void snblas_gemm_cluster_kernel_init_fp64(const SnblasGemmInfo info); -inline void snblas_gemm_cluster_kernel_init_fp64(const SnblasGemmInfo info) { +extern void snblas_gemm_cluster_kernel_init_fp64(const SnblasGemmInfo info, const SnblasGemmImpl impl); +inline void snblas_gemm_cluster_kernel_init_fp64(const SnblasGemmInfo info, const SnblasGemmImpl impl) { uint32_t p[3], P[3]; ocrt_thread_idx(p); ocrt_compute_thread_num(P); @@ -60,13 +60,13 @@ inline void snblas_gemm_cluster_kernel_init_fp64(const SnblasGemmInfo info) { snrt_ssr_enable(); } -extern void snblas_gemm_cluster_kernel_deinit_fp64(const SnblasGemmInfo info); -inline void snblas_gemm_cluster_kernel_deinit_fp64(const SnblasGemmInfo info) { +extern void snblas_gemm_cluster_kernel_deinit_fp64(const SnblasGemmInfo info, const SnblasGemmImpl impl); +inline void snblas_gemm_cluster_kernel_deinit_fp64(const SnblasGemmInfo info, const SnblasGemmImpl impl) { snrt_ssr_disable(); } -extern void snblas_gemm_cluster_kernel_compute_fp64(const SnblasGemmInfo info, const SnblasGemmArgs_fp64 args, bool bench); -inline void snblas_gemm_cluster_kernel_compute_fp64(const SnblasGemmInfo info, const SnblasGemmArgs_fp64 args, bool bench) { +extern void snblas_gemm_cluster_kernel_compute_fp64(const SnblasGemmInfo info, const SnblasGemmArgs_fp64 args, const SnblasGemmImpl impl); +inline void snblas_gemm_cluster_kernel_compute_fp64(const SnblasGemmInfo info, const SnblasGemmArgs_fp64 args, const SnblasGemmImpl impl) { uint32_t p[3], P[3]; ocrt_thread_idx(p); ocrt_compute_thread_num(P); @@ -94,7 +94,7 @@ inline void snblas_gemm_cluster_kernel_compute_fp64(const SnblasGemmInfo info, c snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, (void*) A); snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_4D, (void*) B); - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); for (uint32_t m = 0; m < M; m++) { uint32_t n = 0; for (; n < N; n += unroll) { @@ -156,16 +156,16 @@ inline void snblas_gemm_cluster_kernel_compute_fp64(const SnblasGemmInfo info, c } snrt_fpu_fence(); - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); } /** * \brief Perform a one-time gemm computation for data in TCDM. * Use the `init`, `compute` and `deinit` directly to get maximum performance when running multiple times. */ -extern void snblas_gemm_cluster_kernel_fp64(const SnblasGemmInfo info, const SnblasGemmArgs_fp64 args); -inline void snblas_gemm_cluster_kernel_fp64(const SnblasGemmInfo info, const SnblasGemmArgs_fp64 args) { - snblas_gemm_cluster_kernel_init_fp64(info); - snblas_gemm_cluster_kernel_compute_fp64(info, args, false); - snblas_gemm_cluster_kernel_deinit_fp64(info); +extern void snblas_gemm_cluster_kernel_fp64(const SnblasGemmInfo info, const SnblasGemmArgs_fp64 args, const SnblasGemmImpl impl); +inline void snblas_gemm_cluster_kernel_fp64(const SnblasGemmInfo info, const SnblasGemmArgs_fp64 args, const SnblasGemmImpl impl) { + snblas_gemm_cluster_kernel_init_fp64(info, impl); + snblas_gemm_cluster_kernel_compute_fp64(info, args, impl); + snblas_gemm_cluster_kernel_deinit_fp64(info, impl); } \ No newline at end of file diff --git a/sw/blas/gemm/src/gemm_kernel_fp8.h b/sw/blas/gemm/src/gemm_kernel_fp8.h index df653c63c..1f94b4bdb 100644 --- a/sw/blas/gemm/src/gemm_kernel_fp8.h +++ b/sw/blas/gemm/src/gemm_kernel_fp8.h @@ -2,8 +2,8 @@ #define FLOAT_T fp8 #include "gemm_kernel_init_tpl.h" -extern void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench); -inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench) { +extern void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl); +inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl) { uint32_t p[3], P[3]; ocrt_thread_idx(p); ocrt_compute_thread_num(P); @@ -32,7 +32,7 @@ inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo inf // Kernel progresses by 8 values each step const uint32_t n_frep = K / 8 - 1; - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); for (uint32_t m = 0; m < M; m++) { uint32_t n = 0; for (uint32_t n0 = 0; n0 < N / unroll; n0++) { @@ -184,7 +184,7 @@ inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo inf } snrt_fpu_fence(); - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); } #undef FLOAT_T diff --git a/sw/blas/gemm/src/gemm_kernel_init_tpl.h b/sw/blas/gemm/src/gemm_kernel_init_tpl.h index 9667da137..72fb2955e 100644 --- a/sw/blas/gemm/src/gemm_kernel_init_tpl.h +++ b/sw/blas/gemm/src/gemm_kernel_init_tpl.h @@ -11,8 +11,8 @@ #define SNBLAS_GEMM_CLUSTER_KERNEL(float_t) CONCAT(snblas_gemm_cluster_kernel_, float_t) #endif -extern void SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(const SnblasGemmInfo info); -inline void SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(const SnblasGemmInfo info) { +extern void SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(const SnblasGemmInfo info, const SnblasGemmImpl impl); +inline void SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(const SnblasGemmInfo info, const SnblasGemmImpl impl) { uint32_t p[3], P[3]; ocrt_thread_idx(p); ocrt_compute_thread_num(P); @@ -50,20 +50,20 @@ inline void SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(const SnblasGemmInfo info) snrt_ssr_enable(); } -extern void SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(const SnblasGemmInfo info); -inline void SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(const SnblasGemmInfo info) { +extern void SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(const SnblasGemmInfo info, const SnblasGemmImpl impl); +inline void SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(const SnblasGemmInfo info, const SnblasGemmImpl impl) { snrt_ssr_disable(); } -void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench); +void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl); /** * \brief Perform a one-time gemm computation for data in TCDM. * Use the `init`, `compute` and `deinit` directly to get maximum performance when running multiple times. */ -extern void SNBLAS_GEMM_CLUSTER_KERNEL(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench); -inline void SNBLAS_GEMM_CLUSTER_KERNEL(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench) { - SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(info); - SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(info, args, bench); - SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(info); +extern void SNBLAS_GEMM_CLUSTER_KERNEL(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl); +inline void SNBLAS_GEMM_CLUSTER_KERNEL(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl) { + SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(info, impl); + SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(info, args, impl); + SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(info, impl); } \ No newline at end of file diff --git a/sw/blas/gemm/src/gemm_tiling_1dpipe_tpl.h b/sw/blas/gemm/src/gemm_tiling_1dpipe_tpl.h index 707bd49fa..b1410d311 100644 --- a/sw/blas/gemm/src/gemm_tiling_1dpipe_tpl.h +++ b/sw/blas/gemm/src/gemm_tiling_1dpipe_tpl.h @@ -11,13 +11,13 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) -void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const bool bench) { +void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl) { typedef SNBLAS_GEMM_TCDM(FLOAT_T) TcdmLayout; typedef SnblasGemmInfo GemmInfo; typedef SNBLAS_GEMM_ARGS(FLOAT_T) GemmArgs; - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); const uint32_t M = info.M; const uint32_t N = info.N; @@ -85,15 +85,15 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, tileInfo.ta = false; tileInfo.tb = false; - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); if (!IS_DM_CORE) { - SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(tileInfo); + SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(tileInfo, impl); // DMA core is one index ahead asm volatile ("" ::: "memory"); snrt_global_barrier(); - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); } FOR_EACH(ib, pi, M / L1_M, PI, ib_dir, ib_prev) { @@ -139,7 +139,7 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, } snrt_dma_wait_all(); - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); } else { // solve block already in l1, parallelize inside each cluster @@ -150,11 +150,11 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, tileArgs.alpha = alpha; tileArgs.beta = beta; - SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(tileInfo, tileArgs, bench); + SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(tileInfo, tileArgs, impl); } snrt_global_barrier(); - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); if (IS_DM_CORE) { if (storeC) { @@ -183,8 +183,8 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, snrt_dma_wait_all(); // } } else { - SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(tileInfo); + SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(tileInfo, impl); } - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); } diff --git a/sw/blas/gemm/src/gemm_tiling_2dpipe_tpl.h b/sw/blas/gemm/src/gemm_tiling_2dpipe_tpl.h index b20551a4c..ad75affaf 100644 --- a/sw/blas/gemm/src/gemm_tiling_2dpipe_tpl.h +++ b/sw/blas/gemm/src/gemm_tiling_2dpipe_tpl.h @@ -8,7 +8,7 @@ #include "gemm_kernel.h" -void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const bool bench) { +void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl) { #define USE_C2C_TILES true @@ -22,7 +22,7 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, typedef SnblasGemmInfo GemmInfo; typedef SNBLAS_GEMM_ARGS(FLOAT_T) GemmArgs; - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); const uint32_t M = info.M; const uint32_t N = info.N; @@ -100,10 +100,10 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, tileInfo.ta = false; tileInfo.tb = false; - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); if (!IS_DM_CORE) { - SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(tileInfo); + SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(tileInfo, impl); // DMA core is one index ahead asm volatile ("" ::: "memory"); @@ -111,7 +111,7 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, snrt_global_barrier(); else snrt_cluster_hw_barrier(); - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); } // Wait for pipeline to be filled @@ -120,7 +120,7 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, snrt_global_barrier(); else snrt_cluster_hw_barrier(); - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); } FOR_EACH(ib, pi, M / L1_M, PI, ib_dir, ib_prev) { @@ -169,7 +169,7 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, } snrt_dma_wait_all(); - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); } else { // solve block already in l1, parallelize inside each // cluster @@ -186,14 +186,14 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, tileArgs.alpha = alpha; tileArgs.beta = beta; - SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(tileInfo, tileArgs, bench); + SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(tileInfo, tileArgs, impl); } if (USE_C2C_TILES) snrt_global_barrier(); else snrt_cluster_hw_barrier(); - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); if (IS_DM_CORE) { if (storeC) { @@ -225,7 +225,7 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, snrt_dma_wait_all(); // } } else { - SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(tileInfo); + SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(tileInfo, impl); } // Wait for pipeline to be emptied @@ -233,5 +233,5 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, snrt_global_barrier(); } - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); } diff --git a/sw/blas/gemm/src/gemm_tiling_baseline_tpl.h b/sw/blas/gemm/src/gemm_tiling_baseline_tpl.h index dfe39ea6c..28916337c 100644 --- a/sw/blas/gemm/src/gemm_tiling_baseline_tpl.h +++ b/sw/blas/gemm/src/gemm_tiling_baseline_tpl.h @@ -8,7 +8,7 @@ #include "gemm_kernel.h" -void SNBLAS_GEMM_TILING(baseline, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const bool bench) { +void SNBLAS_GEMM_TILING(baseline, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl) { /** * Problem is double buffered in L1. The buffer that is used is toggled at @@ -20,7 +20,7 @@ void SNBLAS_GEMM_TILING(baseline, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo inf typedef SnblasGemmInfo GemmInfo; typedef SNBLAS_GEMM_ARGS(FLOAT_T) GemmArgs; - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); const uint32_t M = info.M; const uint32_t N = info.N; @@ -72,10 +72,10 @@ void SNBLAS_GEMM_TILING(baseline, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo inf tileInfo.tb = false; // TODO: place memory barrier before sync - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); if (!IS_DM_CORE) { - SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(tileInfo); + SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(tileInfo, impl); snrt_cluster_hw_barrier(); // DMA core is one index ahead } @@ -113,12 +113,12 @@ void SNBLAS_GEMM_TILING(baseline, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo inf tileArgs.alpha = alpha; tileArgs.beta = beta; - SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(tileInfo, tileArgs, bench); + SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(tileInfo, tileArgs, impl); } - // if (bench) snrt_mcycle(); + // if (impl.bench) snrt_mcycle(); snrt_cluster_hw_barrier(); - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); if (IS_DM_CORE) { if (storeC) { @@ -146,8 +146,8 @@ void SNBLAS_GEMM_TILING(baseline, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo inf snrt_dma_wait_all(); // } } else { - SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(tileInfo); + SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(tileInfo, impl); } - if (bench) snrt_mcycle(); + if (impl.bench) snrt_mcycle(); } diff --git a/sw/blas/gemm/src/gemm_tpl.h b/sw/blas/gemm/src/gemm_tpl.h index 65d47247c..f9f89b0f9 100644 --- a/sw/blas/gemm/src/gemm_tpl.h +++ b/sw/blas/gemm/src/gemm_tpl.h @@ -27,11 +27,11 @@ #define SNBLAS_GEMM(method, float_t) CONCAT3(snblas_gemm_, method, float_t) #endif -extern void SNBLAS_GEMM(METHOD, FLOAT_T) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench); -inline void SNBLAS_GEMM(METHOD, FLOAT_T) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench) { +extern void SNBLAS_GEMM(METHOD, FLOAT_T) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl); +inline void SNBLAS_GEMM(METHOD, FLOAT_T) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl) { if (snrt_is_dm_core()) { - SNBLAS_GEMM_TILING(METHOD, FLOAT_T, true)(info, args, bench); + SNBLAS_GEMM_TILING(METHOD, FLOAT_T, true)(info, args, impl); } else { - SNBLAS_GEMM_TILING(METHOD, FLOAT_T, false)(info, args, bench); + SNBLAS_GEMM_TILING(METHOD, FLOAT_T, false)(info, args, impl); } } diff --git a/sw/blas/gemm/src/main.c b/sw/blas/gemm/src/main.c index 93a50d558..13ece7fbb 100644 --- a/sw/blas/gemm/src/main.c +++ b/sw/blas/gemm/src/main.c @@ -48,9 +48,16 @@ int main() { gemmArgs.alpha = 1; gemmArgs.beta = BETA; + SnblasGemmImpl gemmImpl = {0}; + gemmImpl.ta_tile = TA_TILE; + gemmImpl.tb_tile = TB_TILE; + gemmImpl.tc_tile = TC_TILE; + for (volatile int i = iters; i > 0; --i) { + dump_bench_iter(-i); // if (i == 1) snrt_mcycle(); // start - SNBLAS_GEMM(METHOD, DTYPE)(gemmInfo, gemmArgs, i == 1); + gemmImpl.bench = i == 1; + SNBLAS_GEMM(METHOD, DTYPE)(gemmInfo, gemmArgs, gemmImpl); // dma_xfer_test(c, M*N, i == 1); if (i == 1) snrt_mcycle(); // end