Skip to content

Commit

Permalink
template gemm 2dpipe tiling
Browse files Browse the repository at this point in the history
  • Loading branch information
rogerbarton committed Jan 26, 2024
1 parent 63207c4 commit d1b52aa
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 273 deletions.
8 changes: 5 additions & 3 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once

Check failure on line 1 in sw/blas/gemm/src/gemm.h

View workflow job for this annotation

GitHub Actions / Check License headers

FAILED: File does not start with comment

#include "gemm_kernel.h"
// #include "gemm_occamy_baseline.h"
// #include "gemm_occamy_1dpipe.h"
#include "gemm_occamy_2dpipe.h"
#ifdef OCCAMY
// #include "gemm_baseline.h"
// #include "gemm_1dpipe.h"
#include "gemm_2dpipe.h"
#endif
File renamed without changes.
17 changes: 17 additions & 0 deletions sw/blas/gemm/src/gemm_2dpipe.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

Check failure on line 1 in sw/blas/gemm/src/gemm_2dpipe.h

View workflow job for this annotation

GitHub Actions / Check License headers

FAILED: File does not start with comment

#define FLOAT_T fp64
#include "gemm_2dpipe_tpl.h"
#undef FLOAT_T

#define FLOAT_T fp32
#include "gemm_2dpipe_tpl.h"
#undef FLOAT_T

#define FLOAT_T fp16
#include "gemm_2dpipe_tpl.h"
#undef FLOAT_T

#define FLOAT_T fp8
#include "gemm_2dpipe_tpl.h"
#undef FLOAT_T
30 changes: 30 additions & 0 deletions sw/blas/gemm/src/gemm_2dpipe_tpl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "gemm_decls.h"

Check failure on line 1 in sw/blas/gemm/src/gemm_2dpipe_tpl.h

View workflow job for this annotation

GitHub Actions / Check License headers

FAILED: File does not start with comment

#ifndef FLOAT_T
#error "Define FLOAT_T to use this template."
#endif

// Instantiate template code
#define USE_C2C_TILES true

#define IS_DM_CORE true
#include "gemm_tiling_2dpipe_tpl.h"
#undef IS_DM_CORE

#define IS_DM_CORE false
#include "gemm_tiling_2dpipe_tpl.h"
#undef IS_DM_CORE

#undef USE_C2C_TILES

#ifndef SNBLAS_GEMM
#define SNBLAS_GEMM(float_t) CONCAT(snblas_gemm_, float_t)
#endif

void SNBLAS_GEMM(FLOAT_T) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, bool bench) {
if (snrt_is_dm_core()) {
SNBLAS_GEMM_TILING(true, FLOAT_T)(info, args, bench);
} else {
SNBLAS_GEMM_TILING(false, FLOAT_T)(info, args, bench);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,54 +9,7 @@
#include "gemm.h"
#include "snrt.h"

#include "dump.h"
NAMED_DUMP(uint32_t, aIdx, 0x1a)
NAMED_DUMP(uint32_t, bIdx, 0x1b)
NAMED_DUMP(uint32_t, cIdx, 0x1c)
NAMED_DUMP(uint32_t, ib, 0x10)
NAMED_DUMP(uint32_t, jb, 0x11)
NAMED_DUMP(uint32_t, kb, 0x12)
NAMED_DUMP(double, a, 0xa)
NAMED_DUMP(double, b, 0xb)
NAMED_DUMP(double, c, 0xc)

/**
* \brief Implements a reversing loop for an index range
* \param begin Beginning of the range
* \param end End of the range
* \param dir Sets the direction of traversal. True: loop starts at begin.
* \param i_prev Set the previous index to the first index, must update this
* manually at the end of the loop. \details i_end_floor will contain the exact
* end with the stride, s.t. the reversed loop starts at the correct index.
*/
#define FOR_EACH(i, begin, end, stride, dir, i_prev) \
dir = !dir; \
const int i##_end_floor = \
((end - begin + stride - 1) / stride) * stride - stride + begin; \
const int i##_first = dir ? begin : i##_end_floor; \
const int i##_last = dir ? i##_end_floor : begin; \
i = i##_first; \
i_prev = i; \
for (; dir ? i <= i##_last : i >= i##_last; \
i = dir ? i + stride : i - stride)

#define L1_M 8 // 128;
#define L1_N 8 // 128;
#define L1_K 8 // 128;
#define L1_LDA L1_K
#define L1_LDB L1_N
#define L1_LDC L1_N

/**
* \brief Maps the layout of the TCDM. May be double buffered.
*/
typedef struct {
double A[L1_M * L1_K];
double B[L1_K * L1_N];
double C[L1_M * L1_N];
} TcdmLayout;

NAMED_DUMP(TcdmLayout*, l1, 0x8)
#include "gemm_decls.h"

/**
* \brief Each cluster performs a GEMM for A, B, C inside each TCDM
Expand All @@ -71,16 +24,11 @@ void gemm_cluster_kernel(double alpha, double beta, uint32_t M, uint32_t N,
for (uint32_t i = p[0]; i < M; i += P[0]) {
for (uint32_t j = 0; j < N; j++) {
uint32_t cIdx = i * ldc + j; // C[i][j]
// dump_cIdx(cIdx);
// dump_c(C[cIdx]);
register double c0 = beta * C[cIdx];

for (uint32_t k = 0; k < K; k++) {
uint32_t aIdx = i * lda + k; // A[i][k]
uint32_t bIdx = k * ldb + j; // B[k][j]
// dump_aIdx(aIdx);
// dump_bIdx(bIdx);
// dump_a(A[aIdx]);
// dump_b(B[bIdx]);

c0 += A[aIdx] * B[bIdx];
}
Expand Down
212 changes: 0 additions & 212 deletions sw/blas/gemm/src/gemm_occamy_2dpipe.h

This file was deleted.

2 changes: 1 addition & 1 deletion sw/blas/gemm/src/gemm_tiling_2dpipe_tpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include "gemm_kernel.h"

#ifndef SNBLAS_GEMM_TILING
#define SNBLAS_GEMM_TILING(is_dm_core, float_t) CONCAT3(gemm_, is_dm_core, float_t)
#define SNBLAS_GEMM_TILING(is_dm_core, float_t) CONCAT3(snblas_gemm_, is_dm_core, float_t)
#endif

void SNBLAS_GEMM_TILING(IS_DM_CORE, FLOAT_T) (const SnblasGemmInfo info, const SNBLAS_GEMM_ARGS(FLOAT_T) args, const bool bench) {
Expand Down
6 changes: 3 additions & 3 deletions sw/blas/gemm/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ int main() {
uint32_t ldb = N;
uint32_t ldc = N;

GemmInfo gemmInfo = {0};
SnblasGemmInfo gemmInfo = {0};
gemmInfo.M = M;
gemmInfo.N = N;
gemmInfo.K = K;
Expand All @@ -39,7 +39,7 @@ int main() {
gemmInfo.ta = TA;
gemmInfo.tb = TB;

GemmArgs gemmArgs = {0};
SNBLAS_GEMM_ARGS(DTYPE) gemmArgs = {0};
gemmArgs.A = a;
gemmArgs.B = b;
gemmArgs.C = c;
Expand All @@ -48,7 +48,7 @@ int main() {

for (volatile int i = iters; i > 0; --i) {
// if (i == 1) snrt_mcycle(); // start
gemm_oc(gemmInfo, gemmArgs, i == 1);
SNBLAS_GEMM(DTYPE)(gemmInfo, gemmArgs, i == 1);
// gemm_oc(data_dtype_size, data_expand, setup_ssr, data_TA, data_TB, data_M, data_N, data_K, 1,
// data_a, lda, data_b, ldb, data_BETA, data_c, ldc);
if (i == 1) snrt_mcycle(); // end
Expand Down

0 comments on commit d1b52aa

Please sign in to comment.