diff --git a/sw/blas/gemm/data/datagen.py b/sw/blas/gemm/data/datagen.py index 449fba8d3..d19e372a1 100755 --- a/sw/blas/gemm/data/datagen.py +++ b/sw/blas/gemm/data/datagen.py @@ -49,76 +49,104 @@ def golden_model(alpha, a, b, beta, c): def emit_header(**kwargs): + gemmInfo = kwargs['gemmInfo'] + gemmArgs = kwargs['gemmArgs'] + gemmImpl = kwargs['gemmImpl'] # Generate random input matrices - dtype = NUMPY_TYPES[str(kwargs['prec'])] - if (kwargs['prec']) == 8: + dtype = NUMPY_TYPES[str(gemmInfo['prec'])] + if (gemmInfo['prec']) == 8: # sign -1 or 1 - sign_a = np.random.randint(0, 2, (kwargs['M'], kwargs['K'])).astype(dtype) + sign_a = np.random.randint(0, 2, (gemmInfo['M'], gemmInfo['K'])).astype(dtype) # esponent < 0b01111 - exponent_a = np.random.randint(0, 16, (kwargs['M'], kwargs['K'])).astype(dtype) + exponent_a = np.random.randint(0, 16, (gemmInfo['M'], gemmInfo['K'])).astype(dtype) # mantissa can be arbitrary - mantissa_a = np.random.randint(0, 4, (kwargs['M'], kwargs['K'])).astype(dtype) + mantissa_a = np.random.randint(0, 4, (gemmInfo['M'], gemmInfo['K'])).astype(dtype) # sign -1 or 1 - sign_b = np.random.randint(0, 2, (kwargs['K'], kwargs['N'])).astype(dtype) + sign_b = np.random.randint(0, 2, (gemmInfo['K'], gemmInfo['N'])).astype(dtype) # esponent < 0b01111 - exponent_b = np.random.randint(0, 16, (kwargs['K'], kwargs['N'])).astype(dtype) + exponent_b = np.random.randint(0, 16, (gemmInfo['K'], gemmInfo['N'])).astype(dtype) # mantissa can be arbitrary - mantissa_b = np.random.randint(0, 4, (kwargs['K'], kwargs['N'])).astype(dtype) + mantissa_b = np.random.randint(0, 4, (gemmInfo['K'], gemmInfo['N'])).astype(dtype) # sign -1 or 1 - sign_c = np.random.randint(0, 2, (kwargs['M'], kwargs['N'])).astype(dtype) + sign_c = np.random.randint(0, 2, (gemmInfo['M'], gemmInfo['N'])).astype(dtype) # esponent < 0b01111 - exponent_c = np.random.randint(0, 16, (kwargs['M'], kwargs['N'])).astype(dtype) + exponent_c = np.random.randint(0, 16, (gemmInfo['M'], gemmInfo['N'])).astype(dtype) # mantissa can be arbitrary - mantissa_c = np.random.randint(0, 4, (kwargs['M'], kwargs['N'])).astype(dtype) + mantissa_c = np.random.randint(0, 4, (gemmInfo['M'], gemmInfo['N'])).astype(dtype) _a = ((-1.0)**sign_a.astype(np.double))*(2.0**(exponent_a.astype(np.double)-15.0)) \ * (1.0 + mantissa_a.astype(np.double) / (2**2)) _b = ((-1.0)**sign_b.astype(np.double))*(2.0**(exponent_b.astype(np.double)-15.0)) \ * (1.0 + mantissa_b.astype(np.double) / (2**2)) _c = ((-1.0)**sign_c.astype(np.double))*(2.0**(exponent_c.astype(np.double)-15.0)) \ * (1.0 + mantissa_c.astype(np.double) / (2**2)) - result = golden_model(1, _a, _b, kwargs['beta'], _c) + result = golden_model(1, _a, _b, gemmArgs['beta'], _c) a = sign_a << 7 | exponent_a << FP8_FORMATS['fp8']['mant'] | mantissa_a b = sign_b << 7 | exponent_b << FP8_FORMATS['fp8']['mant'] | mantissa_b c = sign_c << 7 | exponent_c << FP8_FORMATS['fp8']['mant'] | mantissa_c else: - if kwargs['linspace']: - a = np.linspace(0.1, kwargs['M'] * kwargs['K'] + 0.1 -1, num=kwargs['M'] * kwargs['K']).reshape((kwargs['M'], kwargs['K'])).astype(dtype) - b = np.linspace(0.2, kwargs['K'] * kwargs['N'] + 0.2 -1, num=kwargs['K'] * kwargs['N']).reshape((kwargs['K'], kwargs['N'])).astype(dtype) - c = np.linspace(0.3, kwargs['M'] * kwargs['N'] + 0.3 -1, num=kwargs['M'] * kwargs['N']).reshape((kwargs['M'], kwargs['N'])).astype(dtype) + if kwargs['datagen']['linspace']: + a = np.linspace(0.1, gemmInfo['M'] * gemmInfo['K'] + 0.1 -1, num=gemmInfo['M'] * gemmInfo['K']).reshape((gemmInfo['M'], gemmInfo['K'])).astype(dtype) + b = np.linspace(0.2, gemmInfo['K'] * gemmInfo['N'] + 0.2 -1, num=gemmInfo['K'] * gemmInfo['N']).reshape((gemmInfo['K'], gemmInfo['N'])).astype(dtype) + c = np.linspace(0.3, gemmInfo['M'] * gemmInfo['N'] + 0.3 -1, num=gemmInfo['M'] * gemmInfo['N']).reshape((gemmInfo['M'], gemmInfo['N'])).astype(dtype) else: - a = np.random.rand(kwargs['M'], kwargs['K']).astype(dtype) - b = np.random.rand(kwargs['K'], kwargs['N']).astype(dtype) - c = np.random.rand(kwargs['M'], kwargs['N']).astype(dtype) - result = golden_model(1, a, b, kwargs['beta'], c) + a = np.random.rand(gemmInfo['M'], gemmInfo['K']).astype(dtype) + b = np.random.rand(gemmInfo['K'], gemmInfo['N']).astype(dtype) + c = np.random.rand(gemmInfo['M'], gemmInfo['N']).astype(dtype) + result = golden_model(gemmArgs['alpha'], a, b, gemmArgs['beta'], c) # Store matrices in transposed form if requested - a = a.T if kwargs['ta'] else a - b = b.T if kwargs['tb'] else b + a = a.T if gemmInfo['ta'] else a + b = b.T if gemmInfo['tb'] else b + c = c.T if gemmInfo['tc'] else c + result = result.T if gemmInfo['tc'] else result data_str = [emit_license()] data_str = ["#pragma once"] - data_str += [format_scalar_definition('uint32_t', 'bench_iters', kwargs['bench_iters'])] - data_str += [format_scalar_definition('uint32_t', 'M', kwargs['M'])] - data_str += [format_scalar_definition('uint32_t', 'N', kwargs['N'])] - data_str += [format_scalar_definition('uint32_t', 'K', kwargs['K'])] - data_str += [format_scalar_definition('uint32_t', 'TA', int(kwargs['ta']))] - data_str += [format_scalar_definition('uint32_t', 'TB', int(kwargs['tb']))] - data_str += [format_scalar_definition('double', 'BETA', kwargs['beta'])] - # data_str += [format_scalar_definition('uint32_t', 'dtype_size', kwargs['prec']//8)] - data_str += [f"#define DTYPE fp{kwargs['prec']}"] - data_str += [f"#define METHOD {kwargs['method']}"] - data_str += [format_scalar_definition('uint32_t', 'expand', kwargs['expand'])] - data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'a', a.flatten(), + + data_str += ["// -- gemmInfo"] + data_str += [f"#define DTYPE fp{gemmInfo['prec']}"] + # data_str += [format_scalar_definition('uint32_t', 'dtype_size', gemmInfo['prec']//8)] + data_str += [format_scalar_definition('uint32_t', 'M', gemmInfo['M'])] + data_str += [format_scalar_definition('uint32_t', 'N', gemmInfo['N'])] + data_str += [format_scalar_definition('uint32_t', 'K', gemmInfo['K'])] + data_str += [format_scalar_definition('uint32_t', 'TA', int(gemmInfo['ta']))] + data_str += [format_scalar_definition('uint32_t', 'TB', int(gemmInfo['tb']))] + data_str += [format_scalar_definition('uint32_t', 'TC', int(gemmInfo['tc']))] + + # gemmArgs + data_str += ["// -- gemmArgs"] + data_str += [format_scalar_definition('double', 'ALPHA', gemmArgs['alpha'])] + data_str += [format_scalar_definition('double', 'BETA', gemmArgs['beta'])] + + # gemmImpl + data_str += ["// -- gemmImpl"] + data_str += [f"#define USE_METHOD {gemmImpl['method']}"] + data_str += [f"#define L1_M {gemmImpl['L1_M']}"] + data_str += [f"#define L1_N {gemmImpl['L1_N']}"] + data_str += [f"#define L1_K {gemmImpl['L1_K']}"] + data_str += [format_scalar_definition('uint32_t', 'TA_TILE', int(gemmImpl['ta_tile']))] + data_str += [format_scalar_definition('uint32_t', 'TB_TILE', int(gemmImpl['tb_tile']))] + data_str += [format_scalar_definition('uint32_t', 'TC_TILE', int(gemmImpl['tc_tile']))] + data_str += [format_scalar_definition('uint32_t', 'expand', gemmImpl['expand'])] + data_str += [f"#define FMADD_D_UNROLL {gemmImpl['fmadd_d_unroll']}"] + + # bench + data_str += ["// -- bench"] + data_str += [format_scalar_definition('uint32_t', 'bench_iters', kwargs['bench']['iters'])] + + # datagen + data_str += ["// -- datagen"] + data_str += [format_vector_definition(C_TYPES[str(gemmInfo['prec'])], 'a', a.flatten(), alignment=BURST_ALIGNMENT, section=kwargs['section'])] - data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'b', b.flatten(), + data_str += [format_vector_definition(C_TYPES[str(gemmInfo['prec'])], 'b', b.flatten(), alignment=BURST_ALIGNMENT, section=kwargs['section'])] - data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'c', c.flatten(), + data_str += [format_vector_definition(C_TYPES[str(gemmInfo['prec'])], 'c', c.flatten(), alignment=BURST_ALIGNMENT, section=kwargs['section'])] - if kwargs['prec'] == 8: + if gemmInfo['prec'] == 8: result_def = format_vector_definition(C_TYPES['64'], 'result', result.flatten()) else: - result_def = format_vector_definition(C_TYPES[str(kwargs['prec'])], + result_def = format_vector_definition(C_TYPES[str(gemmInfo['prec'])], 'result', result.flatten()) data_str += [format_ifdef_wrapper('BIST', result_def)] diff --git a/sw/blas/gemm/data/params.hjson b/sw/blas/gemm/data/params.hjson index b58eecb38..46b725ded 100644 --- a/sw/blas/gemm/data/params.hjson +++ b/sw/blas/gemm/data/params.hjson @@ -12,6 +12,7 @@ K: 16, ta: false, tb: true, // must be true for SIMD + tc: false, // not implemented } gemmArgs: { // C = alpha * A * B + beta * C @@ -21,10 +22,14 @@ gemmImpl: { method: "baseline", + L1_M: 8, + L1_N: 8, + L1_K: 8, ta_tile: false, tb_tile: false, - tc_tile: false, + tc_tile: false, // not implemented expand: 0, + fmadd_d_unroll: 8, } bench: { diff --git a/sw/blas/gemm/src/gemm_decls.h b/sw/blas/gemm/src/gemm_decls.h index 783a34e32..982f16f4c 100644 --- a/sw/blas/gemm/src/gemm_decls.h +++ b/sw/blas/gemm/src/gemm_decls.h @@ -47,6 +47,7 @@ typedef struct { uint32_t ldc; uint32_t ta; uint32_t tb; + uint32_t tc; } SnblasGemmInfo; /** @@ -60,9 +61,9 @@ typedef struct { bool tc_tile; } SnblasGemmImpl; -#define L1_M 8 -#define L1_N 8 -#define L1_K 8 +// #define L1_M 8 // Moved to datagen +// #define L1_N 8 +// #define L1_K 8 #define L1_LDA L1_K #define L1_LDB L1_N #define L1_LDC L1_N @@ -100,4 +101,8 @@ typedef struct { const int i##_last = dir ? i##_end_floor : begin; \ i = i##_first; \ for (; dir ? i <= i##_last : i >= i##_last; \ - i = dir ? i + stride : i - stride) \ No newline at end of file + i = dir ? i + stride : i - stride) + + +// -- Function pointer typedefs +typedef snrt_dma_txid_t (*snrt_dma_load_2d_tile_transpose_t)(void *, void *, size_t, size_t, size_t, size_t, size_t, uint32_t); diff --git a/sw/blas/gemm/src/gemm_kernel_fp64.h b/sw/blas/gemm/src/gemm_kernel_fp64.h index 590a38ad2..9e1efdc8f 100644 --- a/sw/blas/gemm/src/gemm_kernel_fp64.h +++ b/sw/blas/gemm/src/gemm_kernel_fp64.h @@ -1,6 +1,5 @@ #include "gemm_decls.h" -#define _FMADD_UNROLL 8 extern void snblas_gemm_cluster_kernel_init_fp64(const SnblasGemmInfo info, const SnblasGemmImpl impl); inline void snblas_gemm_cluster_kernel_init_fp64(const SnblasGemmInfo info, const SnblasGemmImpl impl) { uint32_t p[3], P[3]; @@ -19,7 +18,7 @@ inline void snblas_gemm_cluster_kernel_init_fp64(const SnblasGemmInfo info, cons // Unrolling factor of most inner loop. // Should be at least as high as the FMA delay // for maximum utilization - const uint32_t unroll = _FMADD_UNROLL; + const uint32_t unroll = FMADD_D_UNROLL; // SSR strides and bounds only have to be configured // once in the beginning @@ -88,7 +87,7 @@ inline void snblas_gemm_cluster_kernel_compute_fp64(const SnblasGemmInfo info, c // Unroll by at least fmadd.d latency to fill pipeline // Additional unrolling reduces indexing overhead but needs available registers - const uint32_t unroll = _FMADD_UNROLL; + const uint32_t unroll = FMADD_D_UNROLL; // SSR start address need to be configured each time snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, (void*) A); diff --git a/sw/blas/gemm/src/gemm_tiling_1dpipe_tpl.h b/sw/blas/gemm/src/gemm_tiling_1dpipe_tpl.h index b1410d311..5c02fe951 100644 --- a/sw/blas/gemm/src/gemm_tiling_1dpipe_tpl.h +++ b/sw/blas/gemm/src/gemm_tiling_1dpipe_tpl.h @@ -82,8 +82,15 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, tileInfo.lda = L1_LDA; tileInfo.ldb = L1_LDB; tileInfo.ldc = L1_LDC; - tileInfo.ta = false; - tileInfo.tb = false; + tileInfo.ta = info.ta ^ impl.ta_tile; + tileInfo.tb = info.tb ^ impl.tb_tile; + tileInfo.tc = info.tc ^ impl.tc_tile; // TODO: implement transposed blocking + + // create function ptr for dma loading + const snrt_dma_load_2d_tile_transpose_t load_tile_A = impl.ta_tile ? &snrt_dma_load_2d_tile_transpose : &snrt_dma_load_2d_tile; + const snrt_dma_load_2d_tile_transpose_t load_tile_B = impl.tb_tile ? &snrt_dma_load_2d_tile_transpose : &snrt_dma_load_2d_tile; + const snrt_dma_load_2d_tile_transpose_t load_tile_C = impl.tc_tile ? &snrt_dma_load_2d_tile_transpose : &snrt_dma_load_2d_tile; + const snrt_dma_load_2d_tile_transpose_t store_tile_C = impl.tc_tile ? &snrt_dma_store_2d_tile_transpose : &snrt_dma_store_2d_tile; if (impl.bench) snrt_mcycle(); @@ -103,7 +110,7 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, if (IS_DM_CORE) { dump_ib(ib); dump_jb(jb); - snrt_dma_load_2d_tile(l1_C, (void*) C, ib, jb, L1_M, L1_N, ldc, FP64); + (*load_tile_C)(l1_C, (void*) C, ib, jb, L1_M, L1_N, ldc, FP64); if (ib_prev >= 0 /* && jb_prev >= 0 */) storeC = true; } @@ -123,12 +130,12 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, if (IS_DM_CORE) { dump_kb(kb); if (loadA) { - snrt_dma_load_2d_tile(l1_A, (void*) A, ib, kb, L1_M, L1_K, lda, FP64); + (*load_tile_A)(l1_A, (void*) A, ib, kb, L1_M, L1_K, lda, FP64); // FLOAT_T* const c2c_A = c2cL1_A[l1Id_A].A; // snrt_dma_start_1d(l1_A, c2c_A, L1_M * L1_K * FP64); } if (loadB) { - snrt_dma_load_2d_tile(l1_B, (void*) B, kb, jb, L1_K, L1_N, ldb, FP64); + (*load_tile_B)(l1_B, (void*) B, kb, jb, L1_K, L1_N, ldb, FP64); if (p[1] == 0) { // immediately broadcast to other clusters for (int pt = 1; pt < P[1]; ++pt) { @@ -159,8 +166,7 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, if (IS_DM_CORE) { if (storeC) { storeC = false; - snrt_dma_store_2d_tile(C, l1[!l1Id_C].C, ib_prev, - jb_prev, L1_M, L1_N, ldc, FP64); + (*store_tile_C)(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N, ldc, FP64); } } kb_prev = kb; @@ -178,8 +184,7 @@ void SNBLAS_GEMM_TILING(1dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, // store final tile // if (ib_prev >= 0 && jb_prev >= 0) { - snrt_dma_store_2d_tile(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N, - ldc, FP64); + (*store_tile_C)(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N, ldc, FP64); snrt_dma_wait_all(); // } } else { diff --git a/sw/blas/gemm/src/gemm_tiling_2dpipe_tpl.h b/sw/blas/gemm/src/gemm_tiling_2dpipe_tpl.h index ad75affaf..6b4540f20 100644 --- a/sw/blas/gemm/src/gemm_tiling_2dpipe_tpl.h +++ b/sw/blas/gemm/src/gemm_tiling_2dpipe_tpl.h @@ -97,8 +97,15 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, tileInfo.lda = L1_LDA; tileInfo.ldb = L1_LDB; tileInfo.ldc = L1_LDC; - tileInfo.ta = false; - tileInfo.tb = false; + tileInfo.ta = info.ta ^ impl.ta_tile; + tileInfo.tb = info.tb ^ impl.tb_tile; + tileInfo.tc = info.tc ^ impl.tc_tile; // TODO: implement transposed blocking + + // create function ptr for dma loading + const snrt_dma_load_2d_tile_transpose_t load_tile_A = impl.ta_tile ? &snrt_dma_load_2d_tile_transpose : &snrt_dma_load_2d_tile; + const snrt_dma_load_2d_tile_transpose_t load_tile_B = impl.tb_tile ? &snrt_dma_load_2d_tile_transpose : &snrt_dma_load_2d_tile; + const snrt_dma_load_2d_tile_transpose_t load_tile_C = impl.tc_tile ? &snrt_dma_load_2d_tile_transpose : &snrt_dma_load_2d_tile; + const snrt_dma_load_2d_tile_transpose_t store_tile_C = impl.tc_tile ? &snrt_dma_store_2d_tile_transpose : &snrt_dma_store_2d_tile; if (impl.bench) snrt_mcycle(); @@ -130,7 +137,7 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, if (IS_DM_CORE) { dump_ib(ib); dump_jb(jb); - snrt_dma_load_2d_tile(l1_C, (void*) C, ib, jb, L1_M, L1_N, ldc, FP64); + (*load_tile_C)(l1_C, (void*) C, ib, jb, L1_M, L1_N, ldc, FP64); if (ib_prev >= 0 /* && jb_prev >= 0 */) storeC = true; } @@ -151,8 +158,7 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, dump_kb(kb); if (loadA) { if (c2cL1_A == NULL) - snrt_dma_load_2d_tile(l1_A, (void*) A, ib, kb, L1_M, L1_K, lda, - FP64); + (*load_tile_A)(l1_A, (void*) A, ib, kb, L1_M, L1_K, lda, FP64); else { FLOAT_T* const c2c_A = c2cL1_A[l1Id_A].A; snrt_dma_start_1d(l1_A, c2c_A, L1_M * L1_K * FP64); @@ -160,8 +166,7 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, } if (loadB) { if (c2cL1_B == NULL) - snrt_dma_load_2d_tile(l1_B, (void*) B, kb, jb, L1_K, L1_N, ldb, - FP64); + (*load_tile_B)(l1_B, (void*) B, kb, jb, L1_K, L1_N, ldb, FP64); else { FLOAT_T* const c2c_B = c2cL1_B[l1Id_B].B; snrt_dma_start_1d(l1_B, c2c_B, L1_K * L1_N * FP64); @@ -198,8 +203,7 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, if (IS_DM_CORE) { if (storeC) { storeC = false; - snrt_dma_store_2d_tile(C, l1[!l1Id_C].C, ib_prev, - jb_prev, L1_M, L1_N, ldc, FP64); + (*store_tile_C)(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N, ldc, FP64); } } kb_prev = kb; @@ -220,8 +224,7 @@ void SNBLAS_GEMM_TILING(2dpipe, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo info, // store final tile // if (ib_prev >= 0 && jb_prev >= 0) { - snrt_dma_store_2d_tile(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N, - ldc, FP64); + (*store_tile_C)(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N, ldc, FP64); snrt_dma_wait_all(); // } } else { diff --git a/sw/blas/gemm/src/gemm_tiling_baseline_tpl.h b/sw/blas/gemm/src/gemm_tiling_baseline_tpl.h index 28916337c..bf0a5f429 100644 --- a/sw/blas/gemm/src/gemm_tiling_baseline_tpl.h +++ b/sw/blas/gemm/src/gemm_tiling_baseline_tpl.h @@ -68,9 +68,16 @@ void SNBLAS_GEMM_TILING(baseline, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo inf tileInfo.lda = L1_LDA; tileInfo.ldb = L1_LDB; tileInfo.ldc = L1_LDC; - tileInfo.ta = false; - tileInfo.tb = false; - + tileInfo.ta = info.ta ^ impl.ta_tile; + tileInfo.tb = info.tb ^ impl.tb_tile; + tileInfo.tc = info.tc ^ impl.tc_tile; // TODO: implement transposed blocking + + // create function ptr for dma loading + const snrt_dma_load_2d_tile_transpose_t load_tile_A = impl.ta_tile ? &snrt_dma_load_2d_tile_transpose : &snrt_dma_load_2d_tile; + const snrt_dma_load_2d_tile_transpose_t load_tile_B = impl.tb_tile ? &snrt_dma_load_2d_tile_transpose : &snrt_dma_load_2d_tile; + const snrt_dma_load_2d_tile_transpose_t load_tile_C = impl.tc_tile ? &snrt_dma_load_2d_tile_transpose : &snrt_dma_load_2d_tile; + const snrt_dma_load_2d_tile_transpose_t store_tile_C = impl.tc_tile ? &snrt_dma_store_2d_tile_transpose : &snrt_dma_store_2d_tile; + // TODO: place memory barrier before sync if (impl.bench) snrt_mcycle(); @@ -86,7 +93,7 @@ void SNBLAS_GEMM_TILING(baseline, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo inf if (IS_DM_CORE) { dump_ib(ib); dump_jb(jb); - snrt_dma_load_2d_tile(l1_C, (void*) C, ib, jb, L1_M, L1_N, ldc, FP64); + (*load_tile_C)(l1_C, (void*) C, ib, jb, L1_M, L1_N, ldc, FP64); if (ib_prev >= 0 && jb_prev >= 0) storeC = true; } @@ -100,10 +107,8 @@ void SNBLAS_GEMM_TILING(baseline, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo inf if (IS_DM_CORE) { dump_kb(kb); - snrt_dma_load_2d_tile(l1_A, (void*) A, ib, kb, L1_M, L1_K, lda, - FP64); - snrt_dma_load_2d_tile(l1_B, (void*) B, kb, jb, L1_K, L1_N, ldb, - FP64); + (*load_tile_A)(l1_A, (void*) A, ib, kb, L1_M, L1_K, lda, FP64); + (*load_tile_B)(l1_B, (void*) B, kb, jb, L1_K, L1_N, ldb, FP64); snrt_dma_wait_all(); } else { GemmArgs tileArgs = {0}; @@ -123,7 +128,7 @@ void SNBLAS_GEMM_TILING(baseline, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo inf if (IS_DM_CORE) { if (storeC) { storeC = false; - snrt_dma_store_2d_tile(C, l1[!l1Id_C].C, ib_prev, + (*store_tile_C)(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N, ldc, FP64); } } @@ -141,7 +146,7 @@ void SNBLAS_GEMM_TILING(baseline, FLOAT_T, IS_DM_CORE) (const SnblasGemmInfo inf // store final tile // if (ib_prev >= 0 && jb_prev >= 0) { - snrt_dma_store_2d_tile(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N, + (*store_tile_C)(C, l1[!l1Id_C].C, ib_prev, jb_prev, L1_M, L1_N, ldc, FP64); snrt_dma_wait_all(); // } diff --git a/sw/blas/gemm/src/main.c b/sw/blas/gemm/src/main.c index 13ece7fbb..018726a86 100644 --- a/sw/blas/gemm/src/main.c +++ b/sw/blas/gemm/src/main.c @@ -11,6 +11,9 @@ #include "snrt.h" +#define BIST +#include "data.h" + #include "gemm.h" #include "dma_xfer_test.h" @@ -40,6 +43,7 @@ int main() { gemmInfo.ldc = ldc; gemmInfo.ta = TA; gemmInfo.tb = TB; + gemmInfo.tc = TC; SNBLAS_GEMM_ARGS(DTYPE) gemmArgs = {0}; gemmArgs.A = a; @@ -54,13 +58,14 @@ int main() { gemmImpl.tc_tile = TC_TILE; for (volatile int i = iters; i > 0; --i) { - dump_bench_iter(-i); // if (i == 1) snrt_mcycle(); // start gemmImpl.bench = i == 1; - SNBLAS_GEMM(METHOD, DTYPE)(gemmInfo, gemmArgs, gemmImpl); + SNBLAS_GEMM(USE_METHOD, DTYPE)(gemmInfo, gemmArgs, gemmImpl); // dma_xfer_test(c, M*N, i == 1); if (i == 1) snrt_mcycle(); // end + if (snrt_global_core_idx() == 0) + dump_bench_iter(-i); snrt_fpu_fence(); snrt_global_barrier(); } diff --git a/sw/snRuntime/src/dma.c b/sw/snRuntime/src/dma.c index f646fe10a..cbd31a4a4 100644 --- a/sw/snRuntime/src/dma.c +++ b/sw/snRuntime/src/dma.c @@ -20,3 +20,23 @@ extern snrt_dma_txid_t snrt_dma_start_2d(void *dst, const void *src, extern void snrt_dma_wait(snrt_dma_txid_t tid); extern void snrt_dma_wait_all(); + +extern snrt_dma_txid_t snrt_dma_load_2d_tile( + void *dst, void *src, size_t tile_x1_idx, size_t tile_x0_idx, + size_t tile_x1_size, size_t tile_x0_size, size_t full_x0_size, + uint32_t prec); + +extern snrt_dma_txid_t snrt_dma_load_2d_tile_transpose( + void *dst, void *src, size_t tile_x1_idx, size_t tile_x0_idx, + size_t tile_x1_size, size_t tile_x0_size, size_t full_x0_size, + uint32_t prec); + +extern snrt_dma_txid_t snrt_dma_store_2d_tile( + void *dst, void *src, size_t tile_x1_idx, size_t tile_x0_idx, + size_t tile_x1_size, size_t tile_x0_size, size_t full_x0_size, + uint32_t prec); + +extern snrt_dma_txid_t snrt_dma_store_2d_tile_transpose( + void *dst, void *src, size_t tile_x1_idx, size_t tile_x0_idx, + size_t tile_x1_size, size_t tile_x0_size, size_t full_x0_size, + uint32_t prec); \ No newline at end of file diff --git a/sw/snRuntime/src/dma.h b/sw/snRuntime/src/dma.h index 578cba15d..182f6b345 100644 --- a/sw/snRuntime/src/dma.h +++ b/sw/snRuntime/src/dma.h @@ -233,8 +233,8 @@ inline snrt_dma_txid_t snrt_dma_load_2d_tile( ); } -/// Transfer a tile and transpose it -inline void snrt_dma_load_2d_tile_transpose( +/// Load a tile and transpose it +inline snrt_dma_txid_t snrt_dma_load_2d_tile_transpose( void *dst, void *src, size_t tile_x1_idx, size_t tile_x0_idx, size_t tile_x1_size, size_t tile_x0_size, size_t full_x0_size, uint32_t prec) { @@ -243,18 +243,22 @@ inline void snrt_dma_load_2d_tile_transpose( src_offset += tile_x0_idx * tile_x0_size; src_offset += tile_x1_idx * tile_x1_size * full_x0_size; src_offset *= prec; + + snrt_dma_txid_t prev_txid = -1; // Initiate transfer for (uint32_t i = 0; i < tile_x0_size; i++) { - snrt_dma_start_2d(dst + i * tile_x1_size * prec, // dst - src + src_offset + i * prec, // src - prec, // size - prec, // dst_stride - full_x0_size * prec, // src_stride - tile_x1_size // repeat + prev_txid = snrt_dma_start_2d(dst + i * tile_x1_size * prec, // dst + src + src_offset + i * prec, // src + prec, // size + prec, // dst_stride + full_x0_size * prec, // src_stride + tile_x1_size // repeat ); } + + return prev_txid; } /// Store a 2D-tile of shape (tile_x1_size, tile_x0_size) to the 2D array @@ -280,6 +284,34 @@ inline snrt_dma_txid_t snrt_dma_store_2d_tile( ); } +/// Store a tile and transpose it +inline snrt_dma_txid_t snrt_dma_store_2d_tile_transpose( + void *dst, void *src, size_t tile_x1_idx, size_t tile_x0_idx, + size_t tile_x1_size, size_t tile_x0_size, size_t full_x0_size, + uint32_t prec) { + size_t dst_offset = 0; + // Advance dst array in x0 and x1 dimensions, and convert to byte offset + dst_offset += tile_x0_idx * tile_x0_size; + dst_offset += tile_x1_idx * tile_x1_size * full_x0_size; + dst_offset *= prec; + + snrt_dma_txid_t prev_txid = -1; + // Initiate transfer + for (uint32_t i = 0; i < tile_x0_size; i++) + { + prev_txid = snrt_dma_start_2d(dst + dst_offset + i * prec, // dst + src + i * tile_x1_size * prec, // src + prec, // size + prec, // dst_stride + tile_x0_size * prec, // src_stride + tile_x1_size // repeat + ); + + } + + return prev_txid; +} + //================================================================================ // Reduction functions //================================================================================