From 2f532e97b14a3228cf455ff3c7075b3b3cb93cfd Mon Sep 17 00:00:00 2001 From: Tim Fischer Date: Thu, 15 Feb 2024 15:31:22 +0100 Subject: [PATCH] sw: Add GEMM baselines without SSR & frep extensions (#83) * gemm: Add fp32 SIMD baseline kernel without SSR and frep extensions * gemm: Remove clean up of columns This is not needed for not-unrolled kernels * gemm: Implement fp16 baseline * gemm: Implement fp8 baseline * gemm: Try to verify fp8 * lint: Python sources * gemm: Parametrize type of kernels as baseline/optimized + align gemm signatures * gemm: Rename naive baseline kernels to `naive` * gemm: Clean up comments * gemm: Format code * gemm: Fix function arguments * gemm: Formatting * gemm: Format back with `clang-fomat-10` * gemm: Remove deprecated verify function * gemm: Fix error threshold * verify: Remove unused constant numpy types --- sw/blas/gemm/src/gemm.h | 253 ++++++++++++++++++++++++++++++---------- sw/blas/gemm/verify.py | 18 +-- 2 files changed, 204 insertions(+), 67 deletions(-) diff --git a/sw/blas/gemm/src/gemm.h b/sw/blas/gemm/src/gemm.h index 4542fe59f..bab6c339c 100644 --- a/sw/blas/gemm/src/gemm.h +++ b/sw/blas/gemm/src/gemm.h @@ -35,9 +35,9 @@ static inline double multiply_opt(double multiplicand, double multiplier) { return 0; } -void gemm_fp32_baseline(uint32_t M, uint32_t N, uint32_t K, float* A, - uint32_t ldA, uint32_t ta, float* B, uint32_t ldB, - uint32_t tb, float* C, uint32_t ldC, float BETA) { +void gemm_fp32_naive(uint32_t M, uint32_t N, uint32_t K, float* A, uint32_t ldA, + uint32_t ta, float* B, uint32_t ldB, uint32_t tb, float* C, + uint32_t ldC, float BETA) { if (!ta && !tb) { for (uint32_t m = 0; m < M; m++) { for (uint32_t n = 0; n < N; n++) { @@ -81,9 +81,9 @@ void gemm_fp32_baseline(uint32_t M, uint32_t N, uint32_t K, float* A, } } -void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A, - uint32_t ldA, uint32_t ta, double* B, uint32_t ldB, - uint32_t tb, double* C, uint32_t ldC, double BETA) { +void gemm_fp64_naive(uint32_t M, uint32_t N, uint32_t K, double* A, + uint32_t ldA, uint32_t ta, double* B, uint32_t ldB, + uint32_t tb, double* C, uint32_t ldC, double BETA) { if (!ta && !tb) { for (uint32_t m = 0; m < M; m++) { for (uint32_t n = 0; n < N; n++) { @@ -127,26 +127,9 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A, } } -/* params: - * M: number of rows of A and C - * N: number of columns of B and C - * K: number of columns of A and rows of B - * A: pointer to matrix A - * ldA: row stride of A - * ta: transpose A - * B: pointer to matrix B - * ldB: row stride of B - * tb: transpose B - * C: pointer to matrix C - * ldC: row stride of C - * BETA: scalar beta - * A is MxK, B is KxN, C is MxN - */ -void gemm_fp32_baseline_unrolled(uint32_t M, uint32_t N, uint32_t K, float* A, - uint32_t ldA, uint32_t ta, float* B, - uint32_t ldB, uint32_t tb, float* C, - uint32_t ldC, float BETA) { - // float c0, c1, c2, c3 = 0; +void gemm_fp32_naive_unrolled(uint32_t M, uint32_t N, uint32_t K, float* A, + uint32_t ldA, uint32_t ta, float* B, uint32_t ldB, + uint32_t tb, float* C, uint32_t ldC, float BETA) { float c0 = 0.0f; float c1 = 0.0f; float c2 = 0.0f; @@ -154,8 +137,6 @@ void gemm_fp32_baseline_unrolled(uint32_t M, uint32_t N, uint32_t K, float* A, if (!ta && !tb) { for (uint32_t m = 0; m < M; m++) { for (uint32_t n = 0; n < N; n++) { - // register float c0 = BETA * C[m * ldC + n]; - // c0, c1, c2, c3 = 0; if (BETA == 0.0f) { c0 = 0.0f; } else { @@ -176,7 +157,6 @@ void gemm_fp32_baseline_unrolled(uint32_t M, uint32_t N, uint32_t K, float* A, } else if (ta && !tb) { for (uint32_t m = 0; m < M; m++) { for (uint32_t n = 0; n < N; n++) { - // register float c0 = BETA * C[m * ldC + n]; if (BETA == 0.0f) { c0 = 0.0f; } else { @@ -197,7 +177,6 @@ void gemm_fp32_baseline_unrolled(uint32_t M, uint32_t N, uint32_t K, float* A, } else if (!ta && tb) { for (uint32_t m = 0; m < M; m++) { for (uint32_t n = 0; n < N; n++) { - // register float c0 = BETA * C[m * ldC + n]; if (BETA == 0.0f) { c0 = 0.0f; } else { @@ -207,13 +186,11 @@ void gemm_fp32_baseline_unrolled(uint32_t M, uint32_t N, uint32_t K, float* A, c2 = 0.0f; c3 = 0.0f; for (uint32_t k = 0; k < K; k += 4) { - // c0 += A[k + m * ldA] * B[k + n * ldB]; c0 += A[(k + 0) + m * ldA] * B[(k + 0) + n * ldB]; c1 += A[(k + 1) + m * ldA] * B[(k + 1) + n * ldB]; c2 += A[(k + 2) + m * ldA] * B[(k + 2) + n * ldB]; c3 += A[(k + 3) + m * ldA] * B[(k + 3) + n * ldB]; } - // C[m * ldC + n] = c0; C[m * ldC + n] = c0 + c1 + c2 + c3; } } @@ -238,18 +215,6 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA, // for maximum utilization const uint32_t unroll = 8; - // A is of size MxK, B is of size KxN, C is of size MxN - // for (uint32_t m = 0; m < M; m++) { - // for (uint32_t n = 0; n < N / unroll; n++) { - // double c0 = BETA * C[m * ldC + n]; - // for (uint32_t k = 0; k < K; k++) { - // for (uint32_t j = 0; j < unroll; j++) { - // c0 += A[k + m * ldA] * B[k + (n + j) * ldB]; - // } - // } - // C[m * ldC + n] = c0; - // } - // } // SSR strides and bounds only have to be configured // once in the beginning if (setup_SSR) { @@ -265,7 +230,6 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA, const uint32_t ssr0_b[4] = {unroll, K, N / unroll, M}; const uint32_t ssr0_i[4] = {0, 8, 0, 8 * ldA}; - // A[k + unroll * m * ldA] snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3], ssr0_i[1], ssr0_i[2], ssr0_i[3]); snrt_ssr_repeat(SNRT_SSR_DM0, unroll); @@ -283,7 +247,6 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA, const uint32_t ssr1_b[4] = {unroll, K, N / unroll, M}; const uint32_t ssr1_i[4] = {8, 8 * ldB, 8 * unroll, 0}; - // B[k + unroll * n * ldB] snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2], ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2], ssr1_i[3]); @@ -370,6 +333,63 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA, snrt_ssr_disable(); } +void gemm_fp32_baseline(const uint32_t M, const uint32_t N, const uint32_t K, + float* A, const uint32_t ldA, float* B, + const uint32_t ldB, float* C, const uint32_t ldC, + const uint32_t BETA) { + for (uint32_t m = 0; m < M; m++) { + uint32_t n = 0; + for (; n < N; n++) { + volatile register v2f32 *a_ptr, *b_ptr; + register v2f32 a, b; + volatile float* c_ptr; + const register float zero = 0.0; + double c = 0.0; + v2f32 reduce_reg; + + a_ptr = (v2f32*)(&A[m * ldA]); + b_ptr = (v2f32*)(&B[n * ldB]); + c_ptr = &C[m * ldC + n]; + // Don't accumulate in first iteration + asm volatile( + "beqz %[BETA], 1f \n" + // Load intermediate results + "flw ft2, 0(%[C]) \n" + "vfcpka.s.s ft2, ft2, %[zero]\n" + // or initialize with zero + "j 2f \n" + "1: \n" + "vfcpka.s.s ft2, %[zero], %[zero]\n" + // Don't accumulate in first iteration + "2: \n" + "fld ft0, 0(%[a_ptr]) \n" + "fld ft1, 0(%[b_ptr]) \n" + "add %[a_ptr], %[a_ptr], 8 \n" + "add %[b_ptr], %[b_ptr], 8 \n" + "vfmul.s ft3, ft0, ft1 \n" + // loop over the MACs + "li t0, 2 \n" + "3: \n" + "fld ft0, 0(%[a_ptr]) \n" + "fld ft1, 0(%[b_ptr]) \n" + "vfmac.s ft3, ft0, ft1 \n" + "add %[a_ptr], %[a_ptr], 8 \n" + "add %[b_ptr], %[b_ptr], 8 \n" + "addi t0, t0, 2 \n" + "blt t0, %[K], 3b \n" + // Sum reduce vector + "vfsum.s ft2, ft3 \n" + // Store results + "fsw ft2, 0(%[C]) \n" + : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr) + : [ c ] "f"(c), [ reduce_reg ] "f"(reduce_reg), + [ C ] "r"(c_ptr), [ BETA ] "r"(BETA), [ K ] "r"(K), + [ zero ] "f"(zero) + : "ft0", "ft1", "ft2", "ft3", "ft4", "t0"); + } + } +} + void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K, float* A, const uint32_t ldA, float* B, const uint32_t ldB, float* C, const uint32_t ldC, const uint32_t* BETA, @@ -521,6 +541,61 @@ void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K, snrt_ssr_disable(); } +void gemm_fp16_baseline(uint32_t M, uint32_t N, uint32_t K, __fp16* A, + uint32_t ldA, __fp16* B, uint32_t ldB, __fp16* C, + uint32_t ldC, const uint32_t* BETA, + uint32_t setup_SSR) { + for (uint32_t m = 0; m < M; m++) { + uint32_t n = 0; + for (; n < N; n++) { + volatile register v4f16 *a_ptr, *b_ptr; + register v4f16 a, b; + volatile __fp16* c_ptr; + const register float zero = 0.0; + double c = 0.0; + v4f16 reduce_reg; + + a_ptr = (v4f16*)(&A[m * ldA]); + b_ptr = (v4f16*)(&B[n * ldB]); + c_ptr = &C[m * ldC + n]; + // Don't accumulate in first iteration + asm volatile( + "lw t0, 0(%[BETA]) \n" + "beqz t0, 1f \n" + // Load intermediate results + "flh ft2, 0(%[C]) \n" + "vfcvt.s.h ft2, ft2\n" + "vfcpka.s.s ft2, ft2, %[zero]\n" + // or initialize with zero + "j 2f \n" + "1: \n" + "vfcpka.s.s ft2, %[zero], %[zero]\n" + "2: \n" + // loop over the MACs + "li t0, 0 \n" + "3: \n" + "fld ft0, 0(%[a_ptr]) \n" + "fld ft1, 0(%[b_ptr]) \n" + "add %[a_ptr], %[a_ptr], 8 \n" + "add %[b_ptr], %[b_ptr], 8 \n" + "vfdotpex.s.h ft2, ft0, ft1 \n" + "addi t0, t0, 4 \n" + "blt t0, %[K], 3b \n" + // Sum reduce vector + "vfcpka.s.s ft3, %[zero], %[zero]\n" + "vfsum.s ft3, ft2 \n" + "vfcvt.h.s ft3, ft3\n" + // Store results + "fsh ft3, 0(%[C]) \n" + : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr) + : [ c ] "f"(c), [ reduce_reg ] "f"(reduce_reg), + [ C ] "r"(c_ptr), [ BETA ] "r"(BETA), [ K ] "r"(K), + [ zero ] "f"(zero) + : "ft0", "ft1", "ft2", "ft3", "ft4", "t0"); + } + } +} + void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, uint32_t ldA, __fp16* B, uint32_t ldB, __fp16* C, uint32_t ldC, const uint32_t* BETA, uint32_t setup_SSR) { @@ -869,6 +944,61 @@ void gemm_fp16_ex_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, snrt_ssr_disable(); } +void gemm_fp8_baseline(uint32_t M, uint32_t N, uint32_t K, char* A, + uint32_t ldA, char* B, uint32_t ldB, char* C, + uint32_t ldC, const uint32_t* BETA, uint32_t setup_SSR) { + for (uint32_t m = 0; m < M; m++) { + uint32_t n = 0; + for (; n < N; n++) { + volatile register v8f8 *a_ptr, *b_ptr; + register v8f8 a, b; + volatile char* c_ptr; + const register float zero = 0.0; + double c = 0.0; + v8f8 reduce_reg; + + a_ptr = (v8f8*)(&A[m * ldA]); + b_ptr = (v8f8*)(&B[n * ldB]); + c_ptr = &C[m * ldC + n]; + asm volatile( + "lw t0, 0(%[BETA]) \n" + "beqz t0, 1f \n" + // Load intermediate results + "flb ft2, 0(%[C]) \n" + "vfcvt.s.b ft2, ft2\n" + "vfcpka.h.s ft2, ft2, %[zero]\n" + // or initialize with zero + "j 2f \n" + "1: \n" + "vfcpka.s.s ft2, %[zero], %[zero]\n" + "2: \n" + // loop over the MACs + "li t0, 0 \n" + "3: \n" + "fld ft0, 0(%[a_ptr]) \n" + "fld ft1, 0(%[b_ptr]) \n" + "add %[a_ptr], %[a_ptr], 8 \n" + "add %[b_ptr], %[b_ptr], 8 \n" + "vfdotpex.h.b ft2, ft0, ft1 \n" + "addi t0, t0, 8 \n" + "blt t0, %[K], 3b \n" + // Sum reduce vector + "vfcpka.s.s ft3, %[zero], %[zero]\n" + "vfsumex.s.h ft3, ft2 \n" + "vfcpka.s.s ft2, %[zero], %[zero]\n" + "vfsum.s ft2, ft3 \n" + "vfcvt.b.s ft2, ft2\n" + // Store results + "fsb ft2, 0(%[C]) \n" + : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr) + : [ c ] "f"(c), [ reduce_reg ] "f"(reduce_reg), + [ C ] "r"(c_ptr), [ BETA ] "r"(BETA), [ K ] "r"(K), + [ zero ] "f"(zero) + : "ft0", "ft1", "ft2", "ft3", "ft4", "t0"); + } + } +} + void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA, char* B, uint32_t ldB, char* C, uint32_t ldC, const uint32_t* BETA, uint32_t setup_SSR) { @@ -1088,11 +1218,11 @@ void sc_st_gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr, switch (prec) { case FP64: - if (baseline == 1) { - gemm_fp64_baseline(frac_m, n, k, (double*)a + offsetA, - lda_strided, transa, (double*)b, ldb, - transb, (double*)c + offsetC, - ldc_strided, (double)beta); + if (baseline) { + gemm_fp64_naive(frac_m, n, k, (double*)a + offsetA, + lda_strided, transa, (double*)b, ldb, + transb, (double*)c + offsetC, ldc_strided, + (double)beta); } else { gemm_fp64_opt(frac_m, n, k, (double*)a + offsetA, lda_strided, transa, (double*)b, ldb, transb, @@ -1101,11 +1231,10 @@ void sc_st_gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr, } break; case FP32: - if (baseline == 1) { + if (baseline) { gemm_fp32_baseline(frac_m, n, k, (float*)a + offsetA, - lda_strided, transa, (float*)b, ldb, - transb, (float*)c + offsetC, ldc_strided, - (float)beta); + lda_strided, (float*)b, ldb, + (float*)c + offsetC, ldc_strided, beta); } else { gemm_fp32_opt(frac_m, n, k, (float*)a + offsetA, lda_strided, (float*)b, ldb, @@ -1128,9 +1257,15 @@ void sc_st_gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr, } break; case FP8: - gemm_fp8_ex_opt(frac_m, n, k, (char*)a + offsetA, lda, (char*)b, - ldb, (char*)c + offsetC, ldc_strided, &beta, - setup_ssr); + if (baseline) { + gemm_fp8_baseline(frac_m, n, k, (char*)a + offsetA, lda, + (char*)b, ldb, (char*)c + offsetC, + ldc_strided, &beta, setup_ssr); + } else { + gemm_fp8_ex_opt(frac_m, n, k, (char*)a + offsetA, lda, + (char*)b, ldb, (char*)c + offsetC, + ldc_strided, &beta, setup_ssr); + } break; } } @@ -1292,4 +1427,4 @@ int gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr, } return 0; -} \ No newline at end of file +} diff --git a/sw/blas/gemm/verify.py b/sw/blas/gemm/verify.py index 14899d1a6..365b141c1 100755 --- a/sw/blas/gemm/verify.py +++ b/sw/blas/gemm/verify.py @@ -10,23 +10,25 @@ import numpy as np from data.datagen import golden_model -sys.path.append(str(Path(__file__).parent / '../../../util/sim/')) +sys.path.append(str(Path(__file__).parent / "../../../util/sim/")) import verification # noqa: E402 from elf import Elf # noqa: E402 from data_utils import from_buffer, ctype_from_precision_t # noqa: E402 -ERR_THRESHOLD = 0.001 +ERR_THRESHOLD = {8: 1e-6, 4: 1e-6, 2: 1e-2, 1: 1e-1} def main(): # Run simulation and get outputs args = verification.parse_args() - raw_results = verification.simulate(sim_bin=args.sim_bin, - snitch_bin=args.snitch_bin, - symbols_bin=args.symbols_bin, - log=args.log, - output_uids=['c']) + raw_results = verification.simulate( + sim_bin=args.sim_bin, + snitch_bin=args.snitch_bin, + symbols_bin=args.symbols_bin, + log=args.log, + output_uids=["c"], + ) # Extract input operands from ELF file if args.symbols_bin: @@ -55,7 +57,7 @@ def main(): c_golden = golden_model(1, a, b, beta, c).flatten() absolute_err = np.absolute(c_golden - c_actual) - fail = np.any(absolute_err > ERR_THRESHOLD) + fail = np.any(absolute_err > ERR_THRESHOLD[prec]) if (fail): print('Simulation results are incorrect.') verification.dump_results_to_csv([c_golden, c_actual, absolute_err],