Skip to content

Commit

Permalink
Introduce SnblasGemmImpl for implementation parameters, refactor data…
Browse files Browse the repository at this point in the history
…gen params format
  • Loading branch information
rogerbarton committed Feb 6, 2024
1 parent 93e7561 commit 17ede54
Show file tree
Hide file tree
Showing 12 changed files with 117 additions and 81 deletions.
40 changes: 29 additions & 11 deletions sw/blas/gemm/data/params.hjson
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,33 @@
// Parameters for a GEMM

{
M: 192,
N: 16,
K: 16,
beta: 0,
ta: false,
tb: true, // must be true for SIMD
prec: 64,
expand: 0,
linspace: true,
bench_iters: 2,
method: "2dpipe",
gemmInfo: {
prec: 64,
M: 192,
N: 16,
K: 16,
ta: false,
tb: true, // must be true for SIMD
}

gemmArgs: { // C = alpha * A * B + beta * C
beta: 0,
alpha: 1,
}

gemmImpl: {
method: "baseline",
ta_tile: false,
tb_tile: false,
tc_tile: false,
expand: 0,
}

bench: {
iters: 2,
}

datagen: {
linspace: true,
}
}
11 changes: 11 additions & 0 deletions sw/blas/gemm/src/gemm_decls.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ typedef struct {
uint32_t tb;
} SnblasGemmInfo;

/**
* Constants related to which implementation should be used.
* Only for non-template parameters
*/
typedef struct {
bool bench; // Enable benchmarking code
bool ta_tile; // Transpose the A tile when loading into TCDM
bool tb_tile;
bool tc_tile;
} SnblasGemmImpl;

#define L1_M 8
#define L1_N 8
#define L1_K 8
Expand Down
8 changes: 4 additions & 4 deletions sw/blas/gemm/src/gemm_kernel_fp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#define FLOAT_T fp16
#include "gemm_kernel_init_tpl.h"

extern void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench);
inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench) {
extern void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl);
inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl) {
uint32_t p[3], P[3];
ocrt_thread_idx(p);
ocrt_compute_thread_num(P);
Expand Down Expand Up @@ -32,7 +32,7 @@ inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo inf
// Kernel progresses by 4 values each step
const uint32_t n_frep = K / 4 - 1;

if (bench) snrt_mcycle();
if (impl.bench) snrt_mcycle();
for (uint32_t m = 0; m < M; m++) {
uint32_t n = 0;
for (uint32_t n0 = 0; n0 < N / unroll; n0++) {
Expand Down Expand Up @@ -181,7 +181,7 @@ inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo inf
}

snrt_fpu_fence();
if (bench) snrt_mcycle();
if (impl.bench) snrt_mcycle();
}

#undef FLOAT_T
8 changes: 4 additions & 4 deletions sw/blas/gemm/src/gemm_kernel_fp32.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#define FLOAT_T fp32
#include "gemm_kernel_init_tpl.h"

extern void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench);
inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench) {
extern void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl);
inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl) {
uint32_t p[3], P[3];
ocrt_thread_idx(p);
ocrt_compute_thread_num(P);
Expand Down Expand Up @@ -32,7 +32,7 @@ inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo inf
// Kernel progresses by 2 values each step
const uint32_t n_frep = K / 2 - 1;

if (bench) snrt_mcycle();
if (impl.bench) snrt_mcycle();
for (uint32_t m = 0; m < M; m++) {
uint32_t n = 0;
for (uint32_t n0 = 0; n0 < N / unroll; n0++) {
Expand Down Expand Up @@ -146,7 +146,7 @@ inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo inf
snrt_ssr_enable();
}
snrt_fpu_fence();
if (bench) snrt_mcycle();
if (impl.bench) snrt_mcycle();
}

#undef FLOAT_T
26 changes: 13 additions & 13 deletions sw/blas/gemm/src/gemm_kernel_fp64.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include "gemm_decls.h"

#define _FMADD_UNROLL 8
extern void snblas_gemm_cluster_kernel_init_fp64(const SnblasGemmInfo info);
inline void snblas_gemm_cluster_kernel_init_fp64(const SnblasGemmInfo info) {
extern void snblas_gemm_cluster_kernel_init_fp64(const SnblasGemmInfo info, const SnblasGemmImpl impl);
inline void snblas_gemm_cluster_kernel_init_fp64(const SnblasGemmInfo info, const SnblasGemmImpl impl) {
uint32_t p[3], P[3];
ocrt_thread_idx(p);
ocrt_compute_thread_num(P);
Expand Down Expand Up @@ -60,13 +60,13 @@ inline void snblas_gemm_cluster_kernel_init_fp64(const SnblasGemmInfo info) {
snrt_ssr_enable();
}

extern void snblas_gemm_cluster_kernel_deinit_fp64(const SnblasGemmInfo info);
inline void snblas_gemm_cluster_kernel_deinit_fp64(const SnblasGemmInfo info) {
extern void snblas_gemm_cluster_kernel_deinit_fp64(const SnblasGemmInfo info, const SnblasGemmImpl impl);
inline void snblas_gemm_cluster_kernel_deinit_fp64(const SnblasGemmInfo info, const SnblasGemmImpl impl) {
snrt_ssr_disable();
}

extern void snblas_gemm_cluster_kernel_compute_fp64(const SnblasGemmInfo info, const SnblasGemmArgs_fp64 args, bool bench);
inline void snblas_gemm_cluster_kernel_compute_fp64(const SnblasGemmInfo info, const SnblasGemmArgs_fp64 args, bool bench) {
extern void snblas_gemm_cluster_kernel_compute_fp64(const SnblasGemmInfo info, const SnblasGemmArgs_fp64 args, const SnblasGemmImpl impl);
inline void snblas_gemm_cluster_kernel_compute_fp64(const SnblasGemmInfo info, const SnblasGemmArgs_fp64 args, const SnblasGemmImpl impl) {
uint32_t p[3], P[3];
ocrt_thread_idx(p);
ocrt_compute_thread_num(P);
Expand Down Expand Up @@ -94,7 +94,7 @@ inline void snblas_gemm_cluster_kernel_compute_fp64(const SnblasGemmInfo info, c
snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, (void*) A);
snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_4D, (void*) B);

if (bench) snrt_mcycle();
if (impl.bench) snrt_mcycle();
for (uint32_t m = 0; m < M; m++) {
uint32_t n = 0;
for (; n < N; n += unroll) {
Expand Down Expand Up @@ -156,16 +156,16 @@ inline void snblas_gemm_cluster_kernel_compute_fp64(const SnblasGemmInfo info, c
}

snrt_fpu_fence();
if (bench) snrt_mcycle();
if (impl.bench) snrt_mcycle();
}

/**
* \brief Perform a one-time gemm computation for data in TCDM.
* Use the `init`, `compute` and `deinit` directly to get maximum performance when running multiple times.
*/
extern void snblas_gemm_cluster_kernel_fp64(const SnblasGemmInfo info, const SnblasGemmArgs_fp64 args);
inline void snblas_gemm_cluster_kernel_fp64(const SnblasGemmInfo info, const SnblasGemmArgs_fp64 args) {
snblas_gemm_cluster_kernel_init_fp64(info);
snblas_gemm_cluster_kernel_compute_fp64(info, args, false);
snblas_gemm_cluster_kernel_deinit_fp64(info);
extern void snblas_gemm_cluster_kernel_fp64(const SnblasGemmInfo info, const SnblasGemmArgs_fp64 args, const SnblasGemmImpl impl);
inline void snblas_gemm_cluster_kernel_fp64(const SnblasGemmInfo info, const SnblasGemmArgs_fp64 args, const SnblasGemmImpl impl) {
snblas_gemm_cluster_kernel_init_fp64(info, impl);
snblas_gemm_cluster_kernel_compute_fp64(info, args, impl);
snblas_gemm_cluster_kernel_deinit_fp64(info, impl);
}
8 changes: 4 additions & 4 deletions sw/blas/gemm/src/gemm_kernel_fp8.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#define FLOAT_T fp8
#include "gemm_kernel_init_tpl.h"

extern void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench);
inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench) {
extern void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl);
inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl) {
uint32_t p[3], P[3];
ocrt_thread_idx(p);
ocrt_compute_thread_num(P);
Expand Down Expand Up @@ -32,7 +32,7 @@ inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo inf
// Kernel progresses by 8 values each step
const uint32_t n_frep = K / 8 - 1;

if (bench) snrt_mcycle();
if (impl.bench) snrt_mcycle();
for (uint32_t m = 0; m < M; m++) {
uint32_t n = 0;
for (uint32_t n0 = 0; n0 < N / unroll; n0++) {
Expand Down Expand Up @@ -184,7 +184,7 @@ inline void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo inf
}

snrt_fpu_fence();
if (bench) snrt_mcycle();
if (impl.bench) snrt_mcycle();
}

#undef FLOAT_T
20 changes: 10 additions & 10 deletions sw/blas/gemm/src/gemm_kernel_init_tpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
#define SNBLAS_GEMM_CLUSTER_KERNEL(float_t) CONCAT(snblas_gemm_cluster_kernel_, float_t)
#endif

extern void SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(const SnblasGemmInfo info);
inline void SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(const SnblasGemmInfo info) {
extern void SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(const SnblasGemmInfo info, const SnblasGemmImpl impl);
inline void SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(const SnblasGemmInfo info, const SnblasGemmImpl impl) {
uint32_t p[3], P[3];
ocrt_thread_idx(p);
ocrt_compute_thread_num(P);
Expand Down Expand Up @@ -50,20 +50,20 @@ inline void SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(const SnblasGemmInfo info)
snrt_ssr_enable();
}

extern void SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(const SnblasGemmInfo info);
inline void SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(const SnblasGemmInfo info) {
extern void SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(const SnblasGemmInfo info, const SnblasGemmImpl impl);
inline void SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(const SnblasGemmInfo info, const SnblasGemmImpl impl) {
snrt_ssr_disable();
}

void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench);
void SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl);

/**
* \brief Perform a one-time gemm computation for data in TCDM.
* Use the `init`, `compute` and `deinit` directly to get maximum performance when running multiple times.
*/
extern void SNBLAS_GEMM_CLUSTER_KERNEL(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench);
inline void SNBLAS_GEMM_CLUSTER_KERNEL(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench) {
SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(info);
SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(info, args, bench);
SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(info);
extern void SNBLAS_GEMM_CLUSTER_KERNEL(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl);
inline void SNBLAS_GEMM_CLUSTER_KERNEL(FLOAT_T)(const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl) {
SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(info, impl);
SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(info, args, impl);
SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(info, impl);
}
20 changes: 10 additions & 10 deletions sw/blas/gemm/src/gemm_tiling_1dpipe_tpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))

void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const bool bench) {
void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const SnblasGemmImpl impl) {

typedef SNBLAS_GEMM_TCDM(FLOAT_T) TcdmLayout;
typedef SnblasGemmInfo GemmInfo;
typedef SNBLAS_GEMM_ARGS(FLOAT_T) GemmArgs;

if (bench) snrt_mcycle();
if (impl.bench) snrt_mcycle();

const uint32_t M = info.M;
const uint32_t N = info.N;
Expand Down Expand Up @@ -85,15 +85,15 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info,
tileInfo.ta = false;
tileInfo.tb = false;

if (bench) snrt_mcycle();
if (impl.bench) snrt_mcycle();

if (!IS_DM_CORE) {
SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(tileInfo);
SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(tileInfo, impl);
// DMA core is one index ahead

asm volatile ("" ::: "memory");
snrt_global_barrier();
if (bench) snrt_mcycle();
if (impl.bench) snrt_mcycle();
}

FOR_EACH(ib, pi, M / L1_M, PI, ib_dir, ib_prev) {
Expand Down Expand Up @@ -139,7 +139,7 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info,
}

snrt_dma_wait_all();
if (bench) snrt_mcycle();
if (impl.bench) snrt_mcycle();
} else {
// solve block already in l1, parallelize inside each cluster

Expand All @@ -150,11 +150,11 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info,
tileArgs.alpha = alpha;
tileArgs.beta = beta;

SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(tileInfo, tileArgs, bench);
SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(tileInfo, tileArgs, impl);
}

snrt_global_barrier();
if (bench) snrt_mcycle();
if (impl.bench) snrt_mcycle();

if (IS_DM_CORE) {
if (storeC) {
Expand Down Expand Up @@ -183,8 +183,8 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info,
snrt_dma_wait_all();
// }
} else {
SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(tileInfo);
SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(tileInfo, impl);
}

if (bench) snrt_mcycle();
if (impl.bench) snrt_mcycle();
}
Loading

0 comments on commit 17ede54

Please sign in to comment.