Skip to content

Commit

Permalink
Add tile transpose, snrt_dma_*_2d_tile_transpose, fctptr for load/sto…
Browse files Browse the repository at this point in the history
…re tile
  • Loading branch information
rogerbarton committed Feb 6, 2024
1 parent 7c2bba9 commit c056eac
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 86 deletions.
104 changes: 66 additions & 38 deletions sw/blas/gemm/data/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,76 +49,104 @@ def golden_model(alpha, a, b, beta, c):


def emit_header(**kwargs):
gemmInfo = kwargs['gemmInfo']
gemmArgs = kwargs['gemmArgs']
gemmImpl = kwargs['gemmImpl']

# Generate random input matrices
dtype = NUMPY_TYPES[str(kwargs['prec'])]
if (kwargs['prec']) == 8:
dtype = NUMPY_TYPES[str(gemmInfo['prec'])]
if (gemmInfo['prec']) == 8:
# sign -1 or 1
sign_a = np.random.randint(0, 2, (kwargs['M'], kwargs['K'])).astype(dtype)
sign_a = np.random.randint(0, 2, (gemmInfo['M'], gemmInfo['K'])).astype(dtype)
# esponent < 0b01111
exponent_a = np.random.randint(0, 16, (kwargs['M'], kwargs['K'])).astype(dtype)
exponent_a = np.random.randint(0, 16, (gemmInfo['M'], gemmInfo['K'])).astype(dtype)
# mantissa can be arbitrary
mantissa_a = np.random.randint(0, 4, (kwargs['M'], kwargs['K'])).astype(dtype)
mantissa_a = np.random.randint(0, 4, (gemmInfo['M'], gemmInfo['K'])).astype(dtype)
# sign -1 or 1
sign_b = np.random.randint(0, 2, (kwargs['K'], kwargs['N'])).astype(dtype)
sign_b = np.random.randint(0, 2, (gemmInfo['K'], gemmInfo['N'])).astype(dtype)
# esponent < 0b01111
exponent_b = np.random.randint(0, 16, (kwargs['K'], kwargs['N'])).astype(dtype)
exponent_b = np.random.randint(0, 16, (gemmInfo['K'], gemmInfo['N'])).astype(dtype)
# mantissa can be arbitrary
mantissa_b = np.random.randint(0, 4, (kwargs['K'], kwargs['N'])).astype(dtype)
mantissa_b = np.random.randint(0, 4, (gemmInfo['K'], gemmInfo['N'])).astype(dtype)
# sign -1 or 1
sign_c = np.random.randint(0, 2, (kwargs['M'], kwargs['N'])).astype(dtype)
sign_c = np.random.randint(0, 2, (gemmInfo['M'], gemmInfo['N'])).astype(dtype)
# esponent < 0b01111
exponent_c = np.random.randint(0, 16, (kwargs['M'], kwargs['N'])).astype(dtype)
exponent_c = np.random.randint(0, 16, (gemmInfo['M'], gemmInfo['N'])).astype(dtype)
# mantissa can be arbitrary
mantissa_c = np.random.randint(0, 4, (kwargs['M'], kwargs['N'])).astype(dtype)
mantissa_c = np.random.randint(0, 4, (gemmInfo['M'], gemmInfo['N'])).astype(dtype)
_a = ((-1.0)**sign_a.astype(np.double))*(2.0**(exponent_a.astype(np.double)-15.0)) \
* (1.0 + mantissa_a.astype(np.double) / (2**2))
_b = ((-1.0)**sign_b.astype(np.double))*(2.0**(exponent_b.astype(np.double)-15.0)) \
* (1.0 + mantissa_b.astype(np.double) / (2**2))
_c = ((-1.0)**sign_c.astype(np.double))*(2.0**(exponent_c.astype(np.double)-15.0)) \
* (1.0 + mantissa_c.astype(np.double) / (2**2))
result = golden_model(1, _a, _b, kwargs['beta'], _c)
result = golden_model(1, _a, _b, gemmArgs['beta'], _c)
a = sign_a << 7 | exponent_a << FP8_FORMATS['fp8']['mant'] | mantissa_a
b = sign_b << 7 | exponent_b << FP8_FORMATS['fp8']['mant'] | mantissa_b
c = sign_c << 7 | exponent_c << FP8_FORMATS['fp8']['mant'] | mantissa_c
else:
if kwargs['linspace']:
a = np.linspace(0.1, kwargs['M'] * kwargs['K'] + 0.1 -1, num=kwargs['M'] * kwargs['K']).reshape((kwargs['M'], kwargs['K'])).astype(dtype)
b = np.linspace(0.2, kwargs['K'] * kwargs['N'] + 0.2 -1, num=kwargs['K'] * kwargs['N']).reshape((kwargs['K'], kwargs['N'])).astype(dtype)
c = np.linspace(0.3, kwargs['M'] * kwargs['N'] + 0.3 -1, num=kwargs['M'] * kwargs['N']).reshape((kwargs['M'], kwargs['N'])).astype(dtype)
if kwargs['datagen']['linspace']:
a = np.linspace(0.1, gemmInfo['M'] * gemmInfo['K'] + 0.1 -1, num=gemmInfo['M'] * gemmInfo['K']).reshape((gemmInfo['M'], gemmInfo['K'])).astype(dtype)
b = np.linspace(0.2, gemmInfo['K'] * gemmInfo['N'] + 0.2 -1, num=gemmInfo['K'] * gemmInfo['N']).reshape((gemmInfo['K'], gemmInfo['N'])).astype(dtype)
c = np.linspace(0.3, gemmInfo['M'] * gemmInfo['N'] + 0.3 -1, num=gemmInfo['M'] * gemmInfo['N']).reshape((gemmInfo['M'], gemmInfo['N'])).astype(dtype)
else:
a = np.random.rand(kwargs['M'], kwargs['K']).astype(dtype)
b = np.random.rand(kwargs['K'], kwargs['N']).astype(dtype)
c = np.random.rand(kwargs['M'], kwargs['N']).astype(dtype)
result = golden_model(1, a, b, kwargs['beta'], c)
a = np.random.rand(gemmInfo['M'], gemmInfo['K']).astype(dtype)
b = np.random.rand(gemmInfo['K'], gemmInfo['N']).astype(dtype)
c = np.random.rand(gemmInfo['M'], gemmInfo['N']).astype(dtype)
result = golden_model(gemmArgs['alpha'], a, b, gemmArgs['beta'], c)

# Store matrices in transposed form if requested
a = a.T if kwargs['ta'] else a
b = b.T if kwargs['tb'] else b
a = a.T if gemmInfo['ta'] else a
b = b.T if gemmInfo['tb'] else b
c = c.T if gemmInfo['tc'] else c
result = result.T if gemmInfo['tc'] else result

data_str = [emit_license()]
data_str = ["#pragma once"]
data_str += [format_scalar_definition('uint32_t', 'bench_iters', kwargs['bench_iters'])]
data_str += [format_scalar_definition('uint32_t', 'M', kwargs['M'])]
data_str += [format_scalar_definition('uint32_t', 'N', kwargs['N'])]
data_str += [format_scalar_definition('uint32_t', 'K', kwargs['K'])]
data_str += [format_scalar_definition('uint32_t', 'TA', int(kwargs['ta']))]
data_str += [format_scalar_definition('uint32_t', 'TB', int(kwargs['tb']))]
data_str += [format_scalar_definition('double', 'BETA', kwargs['beta'])]
# data_str += [format_scalar_definition('uint32_t', 'dtype_size', kwargs['prec']//8)]
data_str += [f"#define DTYPE fp{kwargs['prec']}"]
data_str += [f"#define METHOD {kwargs['method']}"]
data_str += [format_scalar_definition('uint32_t', 'expand', kwargs['expand'])]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'a', a.flatten(),

data_str += ["// -- gemmInfo"]
data_str += [f"#define DTYPE fp{gemmInfo['prec']}"]
# data_str += [format_scalar_definition('uint32_t', 'dtype_size', gemmInfo['prec']//8)]
data_str += [format_scalar_definition('uint32_t', 'M', gemmInfo['M'])]
data_str += [format_scalar_definition('uint32_t', 'N', gemmInfo['N'])]
data_str += [format_scalar_definition('uint32_t', 'K', gemmInfo['K'])]
data_str += [format_scalar_definition('uint32_t', 'TA', int(gemmInfo['ta']))]
data_str += [format_scalar_definition('uint32_t', 'TB', int(gemmInfo['tb']))]
data_str += [format_scalar_definition('uint32_t', 'TC', int(gemmInfo['tc']))]

# gemmArgs
data_str += ["// -- gemmArgs"]
data_str += [format_scalar_definition('double', 'ALPHA', gemmArgs['alpha'])]
data_str += [format_scalar_definition('double', 'BETA', gemmArgs['beta'])]

# gemmImpl
data_str += ["// -- gemmImpl"]
data_str += [f"#define USE_METHOD {gemmImpl['method']}"]
data_str += [f"#define L1_M {gemmImpl['L1_M']}"]
data_str += [f"#define L1_N {gemmImpl['L1_N']}"]
data_str += [f"#define L1_K {gemmImpl['L1_K']}"]
data_str += [format_scalar_definition('uint32_t', 'TA_TILE', int(gemmImpl['ta_tile']))]
data_str += [format_scalar_definition('uint32_t', 'TB_TILE', int(gemmImpl['tb_tile']))]
data_str += [format_scalar_definition('uint32_t', 'TC_TILE', int(gemmImpl['tc_tile']))]
data_str += [format_scalar_definition('uint32_t', 'expand', gemmImpl['expand'])]
data_str += [f"#define FMADD_D_UNROLL {gemmImpl['fmadd_d_unroll']}"]

# bench
data_str += ["// -- bench"]
data_str += [format_scalar_definition('uint32_t', 'bench_iters', kwargs['bench']['iters'])]

# datagen
data_str += ["// -- datagen"]
data_str += [format_vector_definition(C_TYPES[str(gemmInfo['prec'])], 'a', a.flatten(),
alignment=BURST_ALIGNMENT, section=kwargs['section'])]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'b', b.flatten(),
data_str += [format_vector_definition(C_TYPES[str(gemmInfo['prec'])], 'b', b.flatten(),
alignment=BURST_ALIGNMENT, section=kwargs['section'])]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'c', c.flatten(),
data_str += [format_vector_definition(C_TYPES[str(gemmInfo['prec'])], 'c', c.flatten(),
alignment=BURST_ALIGNMENT, section=kwargs['section'])]
if kwargs['prec'] == 8:
if gemmInfo['prec'] == 8:
result_def = format_vector_definition(C_TYPES['64'], 'result', result.flatten())
else:
result_def = format_vector_definition(C_TYPES[str(kwargs['prec'])],
result_def = format_vector_definition(C_TYPES[str(gemmInfo['prec'])],
'result',
result.flatten())
data_str += [format_ifdef_wrapper('BIST', result_def)]
Expand Down
7 changes: 6 additions & 1 deletion sw/blas/gemm/data/params.hjson
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
K: 16,
ta: false,
tb: true, // must be true for SIMD
tc: false, // not implemented
}

gemmArgs: { // C = alpha * A * B + beta * C
Expand All @@ -21,10 +22,14 @@

gemmImpl: {
method: "baseline",
L1_M: 8,
L1_N: 8,
L1_K: 8,
ta_tile: false,
tb_tile: false,
tc_tile: false,
tc_tile: false, // not implemented
expand: 0,
fmadd_d_unroll: 8,
}

bench: {
Expand Down
13 changes: 9 additions & 4 deletions sw/blas/gemm/src/gemm_decls.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ typedef struct {
uint32_t ldc;
uint32_t ta;
uint32_t tb;
uint32_t tc;
} SnblasGemmInfo;

/**
Expand All @@ -60,9 +61,9 @@ typedef struct {
bool tc_tile;
} SnblasGemmImpl;

#define L1_M 8
#define L1_N 8
#define L1_K 8
// #define L1_M 8 // Moved to datagen
// #define L1_N 8
// #define L1_K 8
#define L1_LDA L1_K
#define L1_LDB L1_N
#define L1_LDC L1_N
Expand Down Expand Up @@ -100,4 +101,8 @@ typedef struct {
const int i##_last = dir ? i##_end_floor : begin; \
i = i##_first; \
for (; dir ? i <= i##_last : i >= i##_last; \
i = dir ? i + stride : i - stride)
i = dir ? i + stride : i - stride)


// -- Function pointer typedefs
typedef snrt_dma_txid_t (*snrt_dma_load_2d_tile_transpose_t)(void *, void *, size_t, size_t, size_t, size_t, size_t, uint32_t);
5 changes: 2 additions & 3 deletions sw/blas/gemm/src/gemm_kernel_fp64.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "gemm_decls.h"

Check failure on line 1 in sw/blas/gemm/src/gemm_kernel_fp64.h

View workflow job for this annotation

GitHub Actions / Check License headers

FAILED: File does not start with comment

#define _FMADD_UNROLL 8
extern void snblas_gemm_cluster_kernel_init_fp64(const SnblasGemmInfo info, const SnblasGemmImpl impl);
inline void snblas_gemm_cluster_kernel_init_fp64(const SnblasGemmInfo info, const SnblasGemmImpl impl) {
uint32_t p[3], P[3];
Expand All @@ -19,7 +18,7 @@ inline void snblas_gemm_cluster_kernel_init_fp64(const SnblasGemmInfo info, cons
// Unrolling factor of most inner loop.
// Should be at least as high as the FMA delay
// for maximum utilization
const uint32_t unroll = _FMADD_UNROLL;
const uint32_t unroll = FMADD_D_UNROLL;

// SSR strides and bounds only have to be configured
// once in the beginning
Expand Down Expand Up @@ -88,7 +87,7 @@ inline void snblas_gemm_cluster_kernel_compute_fp64(const SnblasGemmInfo info, c

// Unroll by at least fmadd.d latency to fill pipeline
// Additional unrolling reduces indexing overhead but needs available registers
const uint32_t unroll = _FMADD_UNROLL;
const uint32_t unroll = FMADD_D_UNROLL;

// SSR start address need to be configured each time
snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, (void*) A);
Expand Down
23 changes: 14 additions & 9 deletions sw/blas/gemm/src/gemm_tiling_1dpipe_tpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,15 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info,
tileInfo.lda = L1_LDA;
tileInfo.ldb = L1_LDB;
tileInfo.ldc = L1_LDC;
tileInfo.ta = false;
tileInfo.tb = false;
tileInfo.ta = info.ta ^ impl.ta_tile;
tileInfo.tb = info.tb ^ impl.tb_tile;
tileInfo.tc = info.tc ^ impl.tc_tile; // TODO: implement transposed blocking

// create function ptr for dma loading
const snrt_dma_load_2d_tile_transpose_t load_tile_A = impl.ta_tile ? &snrt_dma_load_2d_tile_transpose : &snrt_dma_load_2d_tile;
const snrt_dma_load_2d_tile_transpose_t load_tile_B = impl.tb_tile ? &snrt_dma_load_2d_tile_transpose : &snrt_dma_load_2d_tile;
const snrt_dma_load_2d_tile_transpose_t load_tile_C = impl.tc_tile ? &snrt_dma_load_2d_tile_transpose : &snrt_dma_load_2d_tile;
const snrt_dma_load_2d_tile_transpose_t store_tile_C = impl.tc_tile ? &snrt_dma_store_2d_tile_transpose : &snrt_dma_store_2d_tile;

if (impl.bench) snrt_mcycle();

Expand All @@ -103,7 +110,7 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info,
if (IS_DM_CORE) {
dump_ib(ib);
dump_jb(jb);
snrt_dma_load_2d_tile(l1_C, (void*) C, ib, jb, L1_M, L1_N, ldc, FP64);
(*load_tile_C)(l1_C, (void*) C, ib, jb, L1_M, L1_N, ldc, FP64);
if (ib_prev >= 0 /* && jb_prev >= 0 */) storeC = true;
}

Expand All @@ -123,12 +130,12 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info,
if (IS_DM_CORE) {
dump_kb(kb);
if (loadA) {
snrt_dma_load_2d_tile(l1_A, (void*) A, ib, kb, L1_M, L1_K, lda, FP64);
(*load_tile_A)(l1_A, (void*) A, ib, kb, L1_M, L1_K, lda, FP64);
// FLOAT_T* const c2c_A = c2cL1_A[l1Id_A].A;
// snrt_dma_start_1d(l1_A, c2c_A, L1_M * L1_K * FP64);
}
if (loadB) {
snrt_dma_load_2d_tile(l1_B, (void*) B, kb, jb, L1_K, L1_N, ldb, FP64);
(*load_tile_B)(l1_B, (void*) B, kb, jb, L1_K, L1_N, ldb, FP64);
if (p[1] == 0) {
// immediately broadcast to other clusters
for (int pt = 1; pt < P[1]; ++pt) {
Expand Down Expand Up @@ -159,8 +166,7 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info,
if (IS_DM_CORE) {
if (storeC) {
storeC = false;
snrt_dma_store_2d_tile(C, l1[!l1Id_C].C, ib_prev,
jb_prev, L1_M, L1_N, ldc, FP64);
(*store_tile_C)(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N, ldc, FP64);
}
}
kb_prev = kb;
Expand All @@ -178,8 +184,7 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info,

// store final tile
// if (ib_prev >= 0 && jb_prev >= 0) {
snrt_dma_store_2d_tile(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N,
ldc, FP64);
(*store_tile_C)(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N, ldc, FP64);
snrt_dma_wait_all();
// }
} else {
Expand Down
25 changes: 14 additions & 11 deletions sw/blas/gemm/src/gemm_tiling_2dpipe_tpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,15 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info,
tileInfo.lda = L1_LDA;
tileInfo.ldb = L1_LDB;
tileInfo.ldc = L1_LDC;
tileInfo.ta = false;
tileInfo.tb = false;
tileInfo.ta = info.ta ^ impl.ta_tile;
tileInfo.tb = info.tb ^ impl.tb_tile;
tileInfo.tc = info.tc ^ impl.tc_tile; // TODO: implement transposed blocking

// create function ptr for dma loading
const snrt_dma_load_2d_tile_transpose_t load_tile_A = impl.ta_tile ? &snrt_dma_load_2d_tile_transpose : &snrt_dma_load_2d_tile;
const snrt_dma_load_2d_tile_transpose_t load_tile_B = impl.tb_tile ? &snrt_dma_load_2d_tile_transpose : &snrt_dma_load_2d_tile;
const snrt_dma_load_2d_tile_transpose_t load_tile_C = impl.tc_tile ? &snrt_dma_load_2d_tile_transpose : &snrt_dma_load_2d_tile;
const snrt_dma_load_2d_tile_transpose_t store_tile_C = impl.tc_tile ? &snrt_dma_store_2d_tile_transpose : &snrt_dma_store_2d_tile;

if (impl.bench) snrt_mcycle();

Expand Down Expand Up @@ -130,7 +137,7 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info,
if (IS_DM_CORE) {
dump_ib(ib);
dump_jb(jb);
snrt_dma_load_2d_tile(l1_C, (void*) C, ib, jb, L1_M, L1_N, ldc, FP64);
(*load_tile_C)(l1_C, (void*) C, ib, jb, L1_M, L1_N, ldc, FP64);
if (ib_prev >= 0 /* && jb_prev >= 0 */) storeC = true;
}

Expand All @@ -151,17 +158,15 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info,
dump_kb(kb);
if (loadA) {
if (c2cL1_A == NULL)
snrt_dma_load_2d_tile(l1_A, (void*) A, ib, kb, L1_M, L1_K, lda,
FP64);
(*load_tile_A)(l1_A, (void*) A, ib, kb, L1_M, L1_K, lda, FP64);
else {
FLOAT_T* const c2c_A = c2cL1_A[l1Id_A].A;
snrt_dma_start_1d(l1_A, c2c_A, L1_M * L1_K * FP64);
}
}
if (loadB) {
if (c2cL1_B == NULL)
snrt_dma_load_2d_tile(l1_B, (void*) B, kb, jb, L1_K, L1_N, ldb,
FP64);
(*load_tile_B)(l1_B, (void*) B, kb, jb, L1_K, L1_N, ldb, FP64);
else {
FLOAT_T* const c2c_B = c2cL1_B[l1Id_B].B;
snrt_dma_start_1d(l1_B, c2c_B, L1_K * L1_N * FP64);
Expand Down Expand Up @@ -198,8 +203,7 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info,
if (IS_DM_CORE) {
if (storeC) {
storeC = false;
snrt_dma_store_2d_tile(C, l1[!l1Id_C].C, ib_prev,
jb_prev, L1_M, L1_N, ldc, FP64);
(*store_tile_C)(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N, ldc, FP64);
}
}
kb_prev = kb;
Expand All @@ -220,8 +224,7 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info,

// store final tile
// if (ib_prev >= 0 && jb_prev >= 0) {
snrt_dma_store_2d_tile(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N,
ldc, FP64);
(*store_tile_C)(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N, ldc, FP64);
snrt_dma_wait_all();
// }
} else {
Expand Down
Loading

0 comments on commit c056eac

Please sign in to comment.