Skip to content

Commit

Permalink
Template method
Browse files Browse the repository at this point in the history
  • Loading branch information
rogerbarton committed Jan 30, 2024
1 parent 96b2101 commit 6cb31d3
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 64 deletions.
42 changes: 39 additions & 3 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,43 @@

#include "gemm_kernel.h"
#ifdef OCCAMY
// #include "gemm_baseline.h"
// #include "gemm_1dpipe.h"
#include "gemm_2dpipe.h"

// -- 2D Pipeline
#define METHOD 2dpipe
#define FLOAT_T fp64
#include "gemm_tpl.h"
#undef FLOAT_T

// #define FLOAT_T fp32
// #include "gemm_2dpipe_tpl.h"
// #undef FLOAT_T

// #define FLOAT_T fp16
// #include "gemm_2dpipe_tpl.h"
// #undef FLOAT_T

// #define FLOAT_T fp8
// #include "gemm_2dpipe_tpl.h"
// #undef FLOAT_T
#undef METHOD

// -- Baseline
#define METHOD baseline
#define FLOAT_T fp64
#include "gemm_tpl.h"
#undef FLOAT_T

// #define FLOAT_T fp32
// #include "gemm_baseline_tpl.h"
// #undef FLOAT_T

// #define FLOAT_T fp16
// #include "gemm_baseline_tpl.h"
// #undef FLOAT_T

// #define FLOAT_T fp8
// #include "gemm_baseline_tpl.h"
// #undef FLOAT_T
#undef METHOD

#endif
17 changes: 0 additions & 17 deletions sw/blas/gemm/src/gemm_2dpipe.h

This file was deleted.

30 changes: 0 additions & 30 deletions sw/blas/gemm/src/gemm_2dpipe_tpl.h

This file was deleted.

8 changes: 6 additions & 2 deletions sw/blas/gemm/src/gemm_decls.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@ NAMED_DUMP(double, a, 0xa)
NAMED_DUMP(double, b, 0xb)
NAMED_DUMP(double, c, 0xc)

#define CONCAT(a,b) a ## b
#define CONCAT3(a,b,c) a ## b ## c
#define STR_IMPL(A) #A
#define STR(A) STR_IMPL(A)
#define CONCAT_IMPL(a,b) a ## b
#define CONCAT(a,b) CONCAT_IMPL(a,b)
#define CONCAT3(a,b,c) CONCAT(a,CONCAT(b,c))
#define CONCAT4(a,b,c,d) CONCAT3(a,b,CONCAT(c,d))

#ifndef PRECISION_T
#define PRECISION_T
Expand Down
20 changes: 9 additions & 11 deletions sw/blas/gemm/src/gemm_tiling_2dpipe_tpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,16 @@
#error "Define FLOAT_T to use this template."
#endif

#ifndef USE_C2C_TILES
#error "Define USE_C2C_TILES to use this template."
#endif

#ifndef IS_DM_CORE
#error "Define IS_DM_CORE to use this template."
#endif

#include "gemm_kernel.h"

#ifndef SNBLAS_GEMM_TILING
#define SNBLAS_GEMM_TILING(is_dm_core, float_t) CONCAT3(snblas_gemm_, is_dm_core, float_t)
#endif
void SNBLAS_GEMM_TILING(2dpipe, IS_DM_CORE, FLOAT_T) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const bool bench) {

#define USE_C2C_TILES true

void SNBLAS_GEMM_TILING(IS_DM_CORE, FLOAT_T) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const bool bench) {

/**
* Problem is double buffered in L1. The buffer that is used is toggled at
* each iteration. The DMA cores are one index step ahead so they load the
Expand Down Expand Up @@ -108,7 +102,7 @@ void SNBLAS_GEMM_TILING(IS_DM_CORE, FLOAT_T) (const SnblasGemmInfo info, const S

if (!IS_DM_CORE) {
SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(tileInfo);
snrt_global_barrier(); // DMA core is one index ahead
// DMA core is one index ahead
}

// Wait for pipeline to be filled
Expand Down Expand Up @@ -203,7 +197,11 @@ void SNBLAS_GEMM_TILING(IS_DM_CORE, FLOAT_T) (const SnblasGemmInfo info, const S
}

if (IS_DM_CORE) {
snrt_global_barrier(); // DMA core is one index ahead
// DMA core is one index ahead
if (USE_C2C_TILES)
snrt_global_barrier();
else
snrt_cluster_hw_barrier();

// store final tile
// if (ib_prev >= 0 && jb_prev >= 0) {
Expand Down
153 changes: 153 additions & 0 deletions sw/blas/gemm/src/gemm_tiling_baseline_tpl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#ifndef FLOAT_T
#error "Define FLOAT_T to use this template."
#endif

#ifndef IS_DM_CORE
#error "Define IS_DM_CORE to use this template."
#endif

#include "gemm_kernel.h"

void SNBLAS_GEMM_TILING(baseline, IS_DM_CORE, FLOAT_T) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const bool bench) {

/**
* Problem is double buffered in L1. The buffer that is used is toggled at
* each iteration. The DMA cores are one index step ahead so they load the
* data in advance into the buffer that will be used.
*/

typedef SNBLAS_GEMM_TCDM(FLOAT_T) TcdmLayout;
typedef SnblasGemmInfo GemmInfo;
typedef SNBLAS_GEMM_ARGS(FLOAT_T) GemmArgs;

if (bench) snrt_mcycle();

const uint32_t M = info.M;
const uint32_t N = info.N;
const uint32_t K = info.K;
const uint32_t lda = info.lda;
const uint32_t ldb = info.ldb;
const uint32_t ldc = info.ldc;
const uint32_t ta = info.ta;
const uint32_t tb = info.tb;

const FLOAT_T* const A = args.A;
const FLOAT_T* const B = args.B;
FLOAT_T* const C = args.C;
const FLOAT_T alpha = args.alpha;
const FLOAT_T beta = args.beta;

uint32_t p[3] = {0, 0, 0};
uint32_t P[3] = {0, 0, 0};
ocrt_thread_idx(p);
ocrt_compute_thread_num(P);

// Setup layout for TCDM L1
// For double buffering l1 is a size 2 array
TcdmLayout* l1 = snrt_l1_next();

// Which buffer is the valid data in for computation
bool l1Id_A = true;
bool l1Id_B = true;
bool l1Id_C = false;

// Initialize indices
const uint32_t PI = 2, PJ = 2;
const uint32_t pi = p[1] / PJ;
const uint32_t pj = p[1] % PJ;

int ib, jb, kb;
int ib_prev = -1, jb_prev = -1, kb_prev = -1;

bool storeC = false;

GemmInfo tileInfo = {0};
tileInfo.M = L1_M;
tileInfo.N = L1_N;
tileInfo.K = L1_K;
tileInfo.lda = L1_LDA;
tileInfo.ldb = L1_LDB;
tileInfo.ldc = L1_LDC;
tileInfo.ta = false;
tileInfo.tb = false;

// TODO: place memory barrier before sync
if (bench) snrt_mcycle();

if (!IS_DM_CORE) {
SNBLAS_GEMM_CLUSTER_KERNEL_INIT(FLOAT_T)(tileInfo);
snrt_cluster_hw_barrier(); // DMA core is one index ahead
}

for(ib = pi; ib < M / L1_M; ib += PI) {
for(jb = pj; jb < N / L1_N; jb += PJ) {
FLOAT_T* const l1_C = l1[l1Id_C].C;

if (IS_DM_CORE) {
dump_ib(ib);
dump_jb(jb);
snrt_dma_load_2d_tile(l1_C, (void*) C, ib, jb, L1_M, L1_N, ldc, FP64);
if (ib_prev >= 0 && jb_prev >= 0) storeC = true;
}

for(kb = 0; kb < K / L1_K; kb++) {
// Switch buffers when the indices have changed
l1Id_A = !l1Id_A;
l1Id_B = !l1Id_B;

FLOAT_T* const l1_A = l1[l1Id_A].A;
FLOAT_T* const l1_B = l1[l1Id_B].B;

if (IS_DM_CORE) {
dump_kb(kb);
snrt_dma_load_2d_tile(l1_A, (void*) A, ib, kb, L1_M, L1_K, lda,
FP64);
snrt_dma_load_2d_tile(l1_B, (void*) B, kb, jb, L1_K, L1_N, ldb,
FP64);
snrt_dma_wait_all();
} else {
GemmArgs tileArgs = {0};
tileArgs.A = l1_A;
tileArgs.B = l1_B;
tileArgs.C = l1_C;
tileArgs.alpha = alpha;
tileArgs.beta = beta;

SNBLAS_GEMM_CLUSTER_KERNEL_COMPUTE(FLOAT_T)(tileInfo, tileArgs, bench);
}

// if (bench) snrt_mcycle();
snrt_cluster_hw_barrier();
if (bench) snrt_mcycle();

if (IS_DM_CORE) {
if (storeC) {
storeC = false;
snrt_dma_store_2d_tile(C, l1[!l1Id_C].C, ib_prev,
jb_prev, L1_M, L1_N, ldc, FP64);
}
}
kb_prev = kb;
}

l1Id_C = !l1Id_C;
jb_prev = jb;
ib_prev = ib;
}
}

if (IS_DM_CORE) {
snrt_cluster_hw_barrier(); // DMA core is one index ahead

// store final tile
// if (ib_prev >= 0 && jb_prev >= 0) {
snrt_dma_store_2d_tile(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N,
ldc, FP64);
snrt_dma_wait_all();
// }
} else {
SNBLAS_GEMM_CLUSTER_KERNEL_DEINIT(FLOAT_T)(tileInfo);
}

if (bench) snrt_mcycle();
}
37 changes: 37 additions & 0 deletions sw/blas/gemm/src/gemm_tpl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "gemm_decls.h"

#ifndef METHOD
#error "Define METHOD to use this template."
#endif

#ifndef FLOAT_T
#error "Define FLOAT_T to use this template."
#endif

// Instantiate template code
#ifndef SNBLAS_GEMM_TILING
#define SNBLAS_GEMM_TILING(method, is_dm_core, float_t) CONCAT4(snblas_gemm_, method, is_dm_core, float_t)
#endif

#define GEMM_TILING_TPL_H STR(CONCAT3(gemm_tiling_, METHOD, _tpl.h))
#define IS_DM_CORE true
#include GEMM_TILING_TPL_H
#undef IS_DM_CORE

#define IS_DM_CORE false
#include GEMM_TILING_TPL_H
#undef IS_DM_CORE


#ifndef SNBLAS_GEMM
#define SNBLAS_GEMM(method, float_t) CONCAT3(snblas_gemm_, method, float_t)
#endif

extern void SNBLAS_GEMM(METHOD, FLOAT_T) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench);
inline void SNBLAS_GEMM(METHOD, FLOAT_T) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench) {
if (snrt_is_dm_core()) {
SNBLAS_GEMM_TILING(METHOD, true, FLOAT_T)(info, args, bench);
} else {
SNBLAS_GEMM_TILING(METHOD, false, FLOAT_T)(info, args, bench);
}
}
2 changes: 1 addition & 1 deletion sw/blas/gemm/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ int main() {

for (volatile int i = iters; i > 0; --i) {
// if (i == 1) snrt_mcycle(); // start
SNBLAS_GEMM(DTYPE)(gemmInfo, gemmArgs, i == 1);
SNBLAS_GEMM(2dpipe, DTYPE)(gemmInfo, gemmArgs, i == 1);
// dma_xfer_test(c, M*N, i == 1);

if (i == 1) snrt_mcycle(); // end
Expand Down

0 comments on commit 6cb31d3

Please sign in to comment.