diff --git a/sw/blas/gemm/data/datagen.py b/sw/blas/gemm/data/datagen.py index b33eb4afcd..3848653f1b 100755 --- a/sw/blas/gemm/data/datagen.py +++ b/sw/blas/gemm/data/datagen.py @@ -40,8 +40,8 @@ } -def golden_model(a, b, alpha, c): - return np.matmul(a, b) + alpha * c +def golden_model(alpha, a, b, beta, c): + return alpha * np.matmul(a, b) + beta * c def emit_header(**kwargs): @@ -73,7 +73,7 @@ def emit_header(**kwargs): * (1.0 + mantissa_b.astype(np.double) / (2**2)) _c = ((-1.0)**sign_c.astype(np.double))*(2.0**(exponent_c.astype(np.double)-15.0)) \ * (1.0 + mantissa_c.astype(np.double) / (2**2)) - result = np.matmul(_a, _b) + kwargs['alpha'] * _c + result = golden_model(1, _a, _b, kwargs['beta'], _c) a = sign_a << 7 | exponent_a << FP8_FORMATS['fp8']['mant'] | mantissa_a b = sign_b << 7 | exponent_b << FP8_FORMATS['fp8']['mant'] | mantissa_b c = sign_c << 7 | exponent_c << FP8_FORMATS['fp8']['mant'] | mantissa_c @@ -81,7 +81,7 @@ def emit_header(**kwargs): a = np.random.rand(kwargs['M'], kwargs['K']).astype(dtype) b = np.random.rand(kwargs['K'], kwargs['N']).astype(dtype) c = np.random.rand(kwargs['M'], kwargs['N']).astype(dtype) - result = golden_model(a, b, kwargs['alpha'], c) + result = golden_model(1, a, b, kwargs['beta'], c) # Store matrices in transposed form if requested a = a.T if kwargs['ta'] else a @@ -93,7 +93,7 @@ def emit_header(**kwargs): data_str += [format_scalar_definition('uint32_t', 'K', kwargs['K'])] data_str += [format_scalar_definition('uint32_t', 'TA', int(kwargs['ta']))] data_str += [format_scalar_definition('uint32_t', 'TB', int(kwargs['tb']))] - data_str += [format_scalar_definition('uint32_t', 'ALPHA', kwargs['alpha'])] + data_str += [format_scalar_definition('uint32_t', 'BETA', kwargs['beta'])] data_str += [format_scalar_definition('uint32_t', 'dtype_size', kwargs['prec']//8)] data_str += [format_scalar_definition('uint32_t', 'expand', kwargs['expand'])] data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'a', a.flatten())] diff --git a/sw/blas/gemm/data/params.hjson b/sw/blas/gemm/data/params.hjson index 23a4100cf8..7b3ab59b90 100644 --- a/sw/blas/gemm/data/params.hjson +++ b/sw/blas/gemm/data/params.hjson @@ -8,7 +8,7 @@ M: 192, N: 16, K: 16, - alpha: 0, + beta: 0, ta: false, tb: true, // must be true for SIMD prec: 64, diff --git a/sw/blas/gemm/src/gemm.h b/sw/blas/gemm/src/gemm.h index ab0f17285e..e46a328a88 100644 --- a/sw/blas/gemm/src/gemm.h +++ b/sw/blas/gemm/src/gemm.h @@ -23,11 +23,11 @@ typedef char v8f8 __attribute__((vector_size(8))); void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA, uint32_t ta, double* B, uint32_t ldB, - uint32_t tb, double* C, uint32_t ldC, double ALPHA) { + uint32_t tb, double* C, uint32_t ldC, double BETA) { if (!ta && !tb) { for (uint32_t m = 0; m < M; m++) { for (uint32_t n = 0; n < N; n++) { - register double c0 = ALPHA * C[m * ldC + n]; + register double c0 = BETA * C[m * ldC + n]; for (uint32_t k = 0; k < K; k++) { c0 += A[k + m * ldA] * B[k * ldB + n]; } @@ -37,7 +37,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A, } else if (ta && !tb) { for (uint32_t m = 0; m < M; m++) { for (uint32_t n = 0; n < N; n++) { - register double c0 = ALPHA * C[m * ldC + n]; + register double c0 = BETA * C[m * ldC + n]; for (uint32_t k = 0; k < K; k++) { c0 += A[k * M * ldA + m * ldA] * B[k * ldB + n]; } @@ -47,7 +47,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A, } else if (!ta && tb) { for (uint32_t m = 0; m < M; m++) { for (uint32_t n = 0; n < N; n++) { - register double c0 = ALPHA * C[m * ldC + n]; + register double c0 = BETA * C[m * ldC + n]; for (uint32_t k = 0; k < K; k++) { c0 += A[k + m * ldA] * B[k + n * ldB]; } @@ -57,7 +57,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A, } else { for (uint32_t m = 0; m < M; m++) { for (uint32_t n = 0; n < N; n++) { - register double c0 = ALPHA * C[m * ldC + n]; + register double c0 = BETA * C[m * ldC + n]; for (uint32_t k = 0; k < K; k++) { c0 += A[k * M * ldA + m * ldA] * B[k + n * ldB]; } @@ -69,7 +69,7 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A, void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA, uint32_t ta, double* B, uint32_t ldB, uint32_t tb, double* C, - uint32_t ldC, const uint32_t* ALPHA, uint32_t setup_SSR) { + uint32_t ldC, const uint32_t* BETA, uint32_t setup_SSR) { // Unrolling factor of most inner loop. // Should be at least as high as the FMA delay // for maximum utilization @@ -124,7 +124,7 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA, double c[unroll]; // Load intermediate result - if (*ALPHA) { + if (*BETA) { c[0] = C[m * ldC + n + 0]; c[1] = C[m * ldC + n + 1]; c[2] = C[m * ldC + n + 2]; @@ -177,7 +177,7 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA, for (; n < N; n++) { double c; - if (*ALPHA) { + if (*BETA) { c = C[m * ldC + n]; } else { c = 0.0; @@ -196,7 +196,7 @@ void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, double* A, uint32_t ldA, void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K, float* A, const uint32_t ldA, float* B, const uint32_t ldB, - float* C, const uint32_t ldC, const uint32_t* ALPHA, + float* C, const uint32_t ldC, const uint32_t* BETA, const uint32_t setup_SSR) { // Unrolling factor of most inner loop. // Should be at least as high as the FMA delay @@ -237,7 +237,7 @@ void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K, v2f32 c[unroll], reduce_reg[unroll]; asm volatile( - "lw t0, 0(%[ALPHA]) \n" + "lw t0, 0(%[BETA]) \n" "beqz t0, 1f \n" // Load intermediate results "flw %[reduce_reg0], 0(%[C]) \n" @@ -315,7 +315,7 @@ void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K, [ reduce_reg6 ] "+f"(reduce_reg[6]), [ reduce_reg7 ] "+f"(reduce_reg[7]) : [ C ] "r"(_C), [ zero ] "f"(zero), [ n_frep ] "r"(n_frep - 1), - [ ALPHA ] "r"(ALPHA) + [ BETA ] "r"(BETA) : "ft0", "ft1", "ft2"); // Store results @@ -332,7 +332,7 @@ void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K, snrt_ssr_disable(); for (; n < N; n++) { - float c = (*ALPHA) ? C[m * ldC + n] : 0.0; + float c = (*BETA) ? C[m * ldC + n] : 0.0; for (uint32_t k = 0; k < K; k++) { c += A[k + m * ldA] * B[k + n * ldB]; } @@ -347,7 +347,7 @@ void gemm_fp32_opt(const uint32_t M, const uint32_t N, const uint32_t K, void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, uint32_t ldA, __fp16* B, uint32_t ldB, __fp16* C, uint32_t ldC, - const uint32_t* ALPHA, uint32_t setup_SSR) { + const uint32_t* BETA, uint32_t setup_SSR) { // Unrolling factor of most inner loop. // Should be at least as high as the FMA delay // for maximum utilization @@ -386,11 +386,11 @@ void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, uint32_t ldA, const register float zero = 0.0; v4f16 c[unroll]; v2f32 reduce_reg[unroll]; - uint32_t alpha; + uint32_t beta; asm volatile( - "lw %[alpha], 0(%[ALPHA]) \n" - "beqz %[alpha], 1f \n" + "lw %[beta], 0(%[BETA]) \n" + "beqz %[beta], 1f \n" // Load intermediate results "flh %[reduce_reg0], 0(%[C]) \n" "flh %[reduce_reg1], 2(%[C]) \n" @@ -493,7 +493,7 @@ void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, uint32_t ldA, "vfcpkb.h.s %[c1], %[c6], %[c7] \n" : [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]), [ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]), - [ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), [ alpha ] "=r"(alpha), + [ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), [ beta ] "=r"(beta), [ reduce_reg0 ] "+f"(reduce_reg[0]), [ reduce_reg1 ] "+f"(reduce_reg[1]), [ reduce_reg2 ] "+f"(reduce_reg[2]), @@ -503,7 +503,7 @@ void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, uint32_t ldA, [ reduce_reg6 ] "+f"(reduce_reg[6]), [ reduce_reg7 ] "+f"(reduce_reg[7]) : [ C ] "r"(_C), [ zero ] "f"(zero), [ n_frep ] "r"(n_frep), - [ ALPHA ] "r"(ALPHA) + [ BETA ] "r"(BETA) : "ft0", "ft1", "ft2"); // Store results back @@ -516,7 +516,7 @@ void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, uint32_t ldA, // snrt_ssr_disable(); // for (; n < N; n++) { - // __fp16 c = (*ALPHA) ? C[m * ldC + n] : 0.0; + // __fp16 c = (*BETA) ? C[m * ldC + n] : 0.0; // for (uint32_t k = 0; k < K; k++) { // c += A[k + m * ldA] * B[k + n * ldB]; // } @@ -531,7 +531,7 @@ void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, uint32_t ldA, void gemm_fp16_ex_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, uint32_t ldA, __fp16* B, uint32_t ldB, __fp16* C, - uint32_t ldC, const uint32_t* ALPHA, uint32_t setup_SSR) { + uint32_t ldC, const uint32_t* BETA, uint32_t setup_SSR) { // Unrolling factor of most inner loop. // Should be at least as high as the FMA delay // for maximum utilization @@ -570,11 +570,11 @@ void gemm_fp16_ex_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, const register float zero = 0.0; v4f16 c[unroll]; v2f32 reduce_reg[unroll]; - uint32_t alpha; + uint32_t beta; asm volatile( - "lw %[alpha], 0(%[ALPHA]) \n" - "beqz %[alpha], 1f \n" + "lw %[beta], 0(%[BETA]) \n" + "beqz %[beta], 1f \n" "flh %[reduce_reg0], 0(%[C]) \n" "flh %[reduce_reg1], 2(%[C]) \n" "flh %[reduce_reg2], 4(%[C]) \n" @@ -657,7 +657,7 @@ void gemm_fp16_ex_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, "vfcpkb.h.s %[c1], %[reduce_reg6], %[reduce_reg7] \n" : [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]), [ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]), - [ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), [ alpha ] "=r"(alpha), + [ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), [ beta ] "=r"(beta), [ reduce_reg0 ] "+f"(reduce_reg[0]), [ reduce_reg1 ] "+f"(reduce_reg[1]), [ reduce_reg2 ] "+f"(reduce_reg[2]), @@ -667,7 +667,7 @@ void gemm_fp16_ex_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, [ reduce_reg6 ] "+f"(reduce_reg[6]), [ reduce_reg7 ] "+f"(reduce_reg[7]) : [ C ] "r"(_C), [ zero ] "f"(zero), [ n_frep ] "r"(n_frep), - [ ALPHA ] "r"(ALPHA) + [ BETA ] "r"(BETA) : "ft0", "ft1", "ft2"); // Store results back @@ -680,7 +680,7 @@ void gemm_fp16_ex_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, // snrt_ssr_disable(); // for (; n < N; n++) { - // __fp16 c = (*ALPHA) ? C[m * ldC + n] : 0.0; + // __fp16 c = (*BETA) ? C[m * ldC + n] : 0.0; // for (uint32_t k = 0; k < K; k++) { // c += A[k + m * ldA] * B[k + n * ldB]; // } @@ -695,9 +695,9 @@ void gemm_fp16_ex_opt(uint32_t M, uint32_t N, uint32_t K, __fp16* A, void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA, char* B, uint32_t ldB, char* C, uint32_t ldC, - const uint32_t* ALPHA, uint32_t setup_SSR) { + const uint32_t* BETA, uint32_t setup_SSR) { // Accumulating currently not implemented - if (*ALPHA != 0) return; + if (*BETA != 0) return; // Unrolling factor of most inner loop. // Should be at least as high as the FMA delay @@ -737,11 +737,11 @@ void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA, const register float zero = 0.0; v8f8 c[unroll]; v4f16 reduce_reg[unroll]; - uint32_t alpha; + uint32_t beta; asm volatile( - "lw %[alpha], 0(%[ALPHA]) \n" - "beqz %[alpha], 1f \n" + "lw %[beta], 0(%[BETA]) \n" + "beqz %[beta], 1f \n" "flb %[reduce_reg0], 0(%[C]) \n" "flb %[reduce_reg1], 1(%[C]) \n" "flb %[reduce_reg2], 2(%[C]) \n" @@ -848,7 +848,7 @@ void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA, // "vfcpkd.b.s %[c0], %[reduce_reg6], %[reduce_reg7] \n" : [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]), [ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]), - [ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), [ alpha ] "=r"(alpha), + [ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), [ beta ] "=r"(beta), [ reduce_reg0 ] "+f"(reduce_reg[0]), [ reduce_reg1 ] "+f"(reduce_reg[1]), [ reduce_reg2 ] "+f"(reduce_reg[2]), @@ -857,7 +857,7 @@ void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA, [ reduce_reg5 ] "+f"(reduce_reg[5]), [ reduce_reg6 ] "+f"(reduce_reg[6]), [ reduce_reg7 ] "+f"(reduce_reg[7]) - : [ C ] "r"(_C), [ n_frep ] "r"(n_frep), [ ALPHA ] "r"(ALPHA), + : [ C ] "r"(_C), [ n_frep ] "r"(n_frep), [ BETA ] "r"(BETA), [ zero ] "f"(zero) : "ft0", "ft1", "ft2"); @@ -870,7 +870,7 @@ void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA, // snrt_ssr_disable(); // for (; n < N; n++) { - // char c = (*ALPHA) ? C[m * ldC + n] : 0.0; + // char c = (*BETA) ? C[m * ldC + n] : 0.0; // for (uint32_t k = 0; k < K; k++) { // c += A[k + m * ldA] * B[k + n * ldB]; // } @@ -886,12 +886,12 @@ void gemm_fp8_ex_opt(uint32_t M, uint32_t N, uint32_t K, char* A, uint32_t ldA, // BLAS compliant GEMM kernel, with some additional arguments at the beginning // to specify Snitch implementation details. Matrix sizes and pointers are for // the whole cluster computation -// TODO: alpha (and beta) should be of floating-point type (same precision as +// TODO: beta (and alpha) should be of floating-point type (same precision as // operands) void gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr, uint32_t transa, uint32_t transb, uint32_t m, uint32_t n, uint32_t k, - uint32_t alpha, void* a, uint32_t lda, void* b, uint32_t ldb, - double beta, void* c, uint32_t ldc) { + double alpha, void* a, uint32_t lda, void* b, uint32_t ldb, + uint32_t beta, void* c, uint32_t ldc) { const uint32_t compute_num = snrt_cluster_compute_core_num(); const uint32_t compute_id = snrt_cluster_core_idx(); @@ -910,27 +910,27 @@ void gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr, case FP64: gemm_fp64_opt(frac_m, n, k, (double*)a + offsetA, lda_strided, transa, (double*)b, ldb, transb, (double*)c + offsetC, - ldc_strided, &alpha, setup_ssr); + ldc_strided, &beta, setup_ssr); break; case FP32: gemm_fp32_opt(frac_m, n, k, (float*)a + offsetA, lda_strided, (float*)b, ldb, (float*)c + offsetC, ldc_strided, - &alpha, setup_ssr); + &beta, setup_ssr); break; case FP16: if (expand) { gemm_fp16_ex_opt( frac_m, n, k, (__fp16*)a + offsetA, lda_strided, (__fp16*)b, - ldb, (__fp16*)c + offsetC, ldc_strided, &alpha, setup_ssr); + ldb, (__fp16*)c + offsetC, ldc_strided, &beta, setup_ssr); } else { gemm_fp16_opt(frac_m, n, k, (__fp16*)a + offsetA, lda_strided, (__fp16*)b, ldb, (__fp16*)c + offsetC, - ldc_strided, &alpha, setup_ssr); + ldc_strided, &beta, setup_ssr); } break; case FP8: gemm_fp8_ex_opt(frac_m, n, k, (char*)a + offsetA, lda, (char*)b, - ldb, (char*)c + offsetC, ldc_strided, &alpha, + ldb, (char*)c + offsetC, ldc_strided, &beta, setup_ssr); break; } diff --git a/sw/blas/gemm/src/main.c b/sw/blas/gemm/src/main.c index 3da55c2ab0..33779abf73 100644 --- a/sw/blas/gemm/src/main.c +++ b/sw/blas/gemm/src/main.c @@ -61,8 +61,8 @@ int main() { ldb = K; } - gemm(dtype_size, expand, setup_ssr, TA, TB, frac_m, N, K, ALPHA, - local_a, lda, local_b, ldb, 1, local_c, ldc); + gemm(dtype_size, expand, setup_ssr, TA, TB, frac_m, N, K, 1, local_a, + lda, local_b, ldb, BETA, local_c, ldc); uint32_t end_cycle = snrt_mcycle(); } diff --git a/sw/blas/gemm/verify.py b/sw/blas/gemm/verify.py index 3bae7f8015..03c6ab8131 100755 --- a/sw/blas/gemm/verify.py +++ b/sw/blas/gemm/verify.py @@ -37,7 +37,7 @@ def main(): a = np.array(bytes_to_doubles(elf.get_symbol_contents('a'))) b = np.array(bytes_to_doubles(elf.get_symbol_contents('b'))) c = np.array(bytes_to_doubles(elf.get_symbol_contents('c'))) - alpha = bytes_to_uint32s(elf.get_symbol_contents('ALPHA'))[0] + beta = bytes_to_uint32s(elf.get_symbol_contents('BETA'))[0] m = bytes_to_uint32s(elf.get_symbol_contents('M'))[0] n = bytes_to_uint32s(elf.get_symbol_contents('N'))[0] k = bytes_to_uint32s(elf.get_symbol_contents('K'))[0] @@ -49,7 +49,7 @@ def main(): c = np.reshape(c, (m, n)) # Verify results - c_golden = golden_model(a, b, alpha, c).flatten() + c_golden = golden_model(1, a, b, beta, c).flatten() absolute_err = np.absolute(c_golden - c_actual) fail = np.any(absolute_err > ERR_THRESHOLD) diff --git a/util/sim/data_utils.py b/util/sim/data_utils.py index 664e2624b3..2ed260d3f1 100644 --- a/util/sim/data_utils.py +++ b/util/sim/data_utils.py @@ -9,7 +9,7 @@ def emit_license(): - s = (f"// Copyright {datetime.now().year} ETH Zurich and University of Bologna." + s = (f"// Copyright {datetime.now().year} ETH Zurich and University of Bologna.\n" f"// Licensed under the Apache License, Version 2.0, see LICENSE for details.\n" f"// SPDX-License-Identifier: Apache-2.0\n\n") return s