diff --git a/sw/blas/gemm_v2/__init__.py b/sw/blas/gemm_v2/__init__.py new file mode 100644 index 0000000000..8432d35c70 --- /dev/null +++ b/sw/blas/gemm_v2/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2024 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Luca Colagrande + +from .scripts.datagen import GemmDataGen + +__all__ = ['GemmDataGen'] diff --git a/sw/blas/gemm_v2/data/params.json b/sw/blas/gemm_v2/data/params.json new file mode 100644 index 0000000000..e60e323664 --- /dev/null +++ b/sw/blas/gemm_v2/data/params.json @@ -0,0 +1,16 @@ +// Copyright 2024 ETH Zurich and University of Bologna. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 + +{ + setup_ssr: 1, + m_tiles: 2, // number of tiles in M dimension + transa: false, + transb: true, // must be true for SIMD + M: 96, + N: 48, + K: 48, + alpha: 1, + beta: 0, + gemm_fp: "gemm_fp64_opt" +} diff --git a/sw/blas/gemm_v2/roi.json b/sw/blas/gemm_v2/roi.json new file mode 100644 index 0000000000..1fd37dc360 --- /dev/null +++ b/sw/blas/gemm_v2/roi.json @@ -0,0 +1,36 @@ +[ + <% DOUBLE_BUFFER = 1 %> + <% N_TILES = 2 %> + + // Compute cores + % for j in range(0, 8): + { + "thread": "${f'hart_{j}'}", + "roi": [ + % for i in range(0, N_TILES): + {"idx": ${2 * i + 1}, "label": "${f'tile_{i}'}"}, + % endfor + ] + }, + % endfor + + // DMA core + { + "thread": "hart_8", + "roi": [ + % if not DOUBLE_BUFFER: + % for i in range(0, N_TILES): + {"idx": ${4 * i + 1}, "label": "${f'tile_{i}_in'}"}, + {"idx": ${4 * i + 3}, "label": "${f'tile_{i}_out'}"}, + % endfor + % else: + {"idx": 1, "label": "tile_0_in"}, + % for i in range(1, N_TILES): + {"idx": ${4 * (i - 1) + 3}, "label": "${f'tile_{i}_in'}"}, + {"idx": ${4 * (i - 1) + 5}, "label": "${f'tile_{i-1}_out'}"}, + % endfor + {"idx": ${4 * (i - 1) + 7}, "label": "tile_15_out"}, + % endif + ] + } +] \ No newline at end of file diff --git a/sw/blas/gemm_v2/scripts/datagen.py b/sw/blas/gemm_v2/scripts/datagen.py new file mode 100755 index 0000000000..32c872928a --- /dev/null +++ b/sw/blas/gemm_v2/scripts/datagen.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 +# Copyright 2022 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Authors: Tim Fischer +# Luca Bertaccini +# Viviane Potocnik +# Luca Colagrande + +import numpy as np +import re +import pyflexfloat as ff +import sys + +from snitch.util.sim import data_utils +from snitch.util.sim.data_utils import DataGen, format_array_declaration, \ + format_struct_definition, format_array_definition, format_ifdef_wrapper + + +np.random.seed(42) + + +class GemmDataGen(DataGen): + + # AXI splits bursts crossing 4KB address boundaries. To minimize + # the occurrence of these splits the data should be aligned to 4KB + BURST_ALIGNMENT = 4096 + + def golden_model(self, alpha, a, b, beta, c): + return alpha * np.matmul(a, b) + beta * c + + def exact_golden_model(self, alpha, a, b, beta, c): + M, N, K = a.shape[0], b.shape[1], b.shape[0] + result = beta * c + for m in range(M): + for n in range(N): + for k in range(K): + result[m][n] += a[m][k] * b[k][n] + return result + + def infer_implementation(self, gemm_fp): + # gemm_fp: "gemm_fp64_opt" + # create a regex with fp__ + prec, impl = re.search(r'gemm_fp(\d+)_(\w+)', gemm_fp).group(1, 2) + return (int(prec) / 8), impl + + def validate_config(self, gemm_fp, + m_tiles, transa, + transb, M, N, K, beta, **kwargs): + frac_m = M / m_tiles + frac_n = N / 1 + frac_k = K / 1 + + dtype, impl = self.infer_implementation(gemm_fp) + + # Calculate total TCDM occupation + prec = data_utils.size_from_precision_t(dtype) + a_size = frac_m * frac_k * prec + b_size = frac_k * frac_n * prec + c_size = frac_m * frac_n * prec + total_size = a_size + total_size += b_size + total_size += c_size + data_utils.validate_tcdm_footprint(2*total_size) + + assert (M % m_tiles) == 0, 'M is not an integer multiple of tile size' + assert (N % 1) == 0, 'N is not an integer multiple of tile size' + assert (K % 1) == 0, 'K is not an integer multiple of tile size' + assert not transa, 'SIMD kernels don\'t support transposed A matrix' + assert (dtype == 8) or (impl == 'baseline') or (impl == 'naive') \ + or transb, 'Optimized SIMD kernels only support transposed B matrix' + assert not transb or 1 == 1, 'Tiling in the N dimension not supported' \ + ' if B is transposed' + assert not transb or 1 == 1, 'Tiling in the K dimension not supported' \ + ' if B is transposed' + assert (impl == 'baseline') or (impl == 'naive') or frac_n >= 8, \ + 'N dimension of tile size must be greater or equal to the unrolling factor (8) ' \ + 'when using optimized kernels' + assert beta == 0 or beta == 1, 'Only values of 0 or 1 supported for beta' + assert not (dtype == 8 and impl == "baseline"), 'No baseline implemented' \ + ' for FP64 (switch to NAIVE)' + assert not (((dtype == 8) or (dtype == 4)) and impl == "opt_ex"), \ + 'Expanding GEMM kernels' \ + ' not supported for FP64 and FP32' + assert not (dtype == 1 and impl == "opt"), 'FP8 not supported in' \ + ' optimized implementation' \ + ' (switch to opt_ex)' + + def emit_header(self, **kwargs): + header = [super().emit_header()] + + # Validate parameters + self.validate_config(**kwargs) + + M, N, K = kwargs['M'], kwargs['N'], kwargs['K'] + + prec, _ = self.infer_implementation(kwargs['gemm_fp']) + + ff_desc = data_utils.ff_desc_from_precision_t(prec) + ctype = data_utils.ctype_from_precision_t(prec) + + a = ff.array(np.random.rand(M, K), ff_desc) + b = ff.array(np.random.rand(K, N), ff_desc) + c = ff.array(np.random.rand(M, N), ff_desc) + result = self.exact_golden_model(1, a, b, kwargs['beta'], c) + + # Store matrices in transposed form if requested + a = a.T if kwargs['transa'] else a + b = b.T if kwargs['transb'] else b + + a_uid = 'a' + b_uid = 'b' + c_uid = 'c' + + cfg = { + 'prec': prec, + **kwargs, + 'a': a_uid, + 'b': b_uid, + 'c': c_uid, + } + + a = a.flatten() + b = b.flatten() + c = c.flatten() + + header += [format_array_declaration(ctype, a_uid, a.shape)] + header += [format_array_declaration(ctype, b_uid, b.shape)] + header += [format_array_declaration(ctype, c_uid, c.shape)] + header += [format_struct_definition('gemm_args_t', 'args', cfg)] + header += [format_array_definition(ctype, a_uid, a, + section=kwargs['section'])] + header += [format_array_definition(ctype, b_uid, b, + section=kwargs['section'])] + header += [format_array_definition(ctype, c_uid, c, + section=kwargs['section'])] + result_def = format_array_definition(ctype, 'result', result.flatten()) + header += [format_ifdef_wrapper('BIST', result_def)] + header = '\n\n'.join(header) + + return header + + +if __name__ == "__main__": + sys.exit(GemmDataGen().main()) diff --git a/sw/blas/gemm_v2/scripts/verify.py b/sw/blas/gemm_v2/scripts/verify.py new file mode 100755 index 0000000000..40840b327b --- /dev/null +++ b/sw/blas/gemm_v2/scripts/verify.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# Copyright 2023 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Luca Colagrande + +import numpy as np +import sys +from datagen import GemmDataGen + +from snitch.util.sim.verif_utils import Verifier +from snitch.util.sim.data_utils import ctype_from_precision_t + + +class GemmVerifier(Verifier): + + OUTPUT_UIDS = ['c'] + ERR_THRESHOLD = { + 1: 1e-4, + 2: 1e-2, + 4: 1e-6, + 8: 1e-6 + } + + def __init__(self): + super().__init__() + self.func_args = { + 'alpha': 'd', + 'prec': 'I', + 'setup_ssr': 'I', + 'parallelize_m': 'I', + 'parallelize_k': 'I', + 'm_tiles': 'I', + 'n_tiles': 'I', + 'k_tiles': 'I', + 'load_a': 'I', + 'load_b': 'I', + 'load_c': 'I', + 'transa': 'I', + 'transb': 'I', + 'M': 'I', + 'N': 'I', + 'K': 'I', + 'a': 'I', + 'b': 'I', + 'beta': 'I', + 'c': 'I', + 'gemm_fp': 'I' + } + self.func_args = self.get_input_from_symbol('args', self.func_args) + + def get_actual_results(self): + prec = self.func_args['prec'] + return self.get_output_from_symbol(self.OUTPUT_UIDS[0], ctype_from_precision_t(prec)) + + def get_expected_results(self): + prec = self.func_args['prec'] + a = self.get_input_from_symbol('a', ctype_from_precision_t(prec)) + b = self.get_input_from_symbol('b', ctype_from_precision_t(prec)) + c = self.get_input_from_symbol('c', ctype_from_precision_t(prec)) + beta = self.func_args['beta'] + m = self.func_args['M'] + n = self.func_args['N'] + k = self.func_args['K'] + tb = self.func_args['transb'] + + a = np.reshape(a, (m, k)) + if tb: + b = np.reshape(b, (n, k)) + b = b.transpose() + else: + b = np.reshape(b, (k, n)) + c = np.reshape(c, (m, n)) + return GemmDataGen().exact_golden_model(1, a, b, beta, c).flatten() + + def check_results(self, *args): + prec = self.func_args['prec'] + return super().check_results(*args, rtol=self.ERR_THRESHOLD[prec]) + + +if __name__ == "__main__": + sys.exit(GemmVerifier().main()) diff --git a/sw/blas/gemm_v2/src/gemm_fp16.h b/sw/blas/gemm_v2/src/gemm_fp16.h new file mode 100644 index 0000000000..26caa4e047 --- /dev/null +++ b/sw/blas/gemm_v2/src/gemm_fp16.h @@ -0,0 +1,458 @@ +// Copyright 2023 ETH Zurich and University of Bologna. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Author: Tim Fischer +// Luca Bertaccini +// Luca Colagrande +// Viviane Potocnik + +void gemm_fp16_naive(uint32_t M, uint32_t N, uint32_t K, void* A_p, + uint32_t ldA, uint32_t ta, void* B_p, uint32_t ldB, + uint32_t tb, void* C_p, uint32_t ldC, uint32_t BETA, + uint32_t setup_SSR) { + __fp16* A = (__fp16*)A_p; + __fp16* B = (__fp16*)B_p; + __fp16* C = (__fp16*)C_p; + __fp16 beta = (__fp16)BETA; + + for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { + __fp16 c; + if (beta != 0) { + c = C[m * ldC + n] * beta; + } else { + c = 0.0; + } + for (uint32_t k = 0; k < K; k++) { + __fp16 a = A[m * ldA + k]; + __fp16 b; + if (tb) + b = B[n * ldB + k]; + else + b = B[k * ldB + n]; + asm volatile( + "fmv.h.x ft3, %[a]\n" + "fmv.h.x ft4, %[b]\n" + "fmv.h.x ft5, %[c]\n" + "fmul.h ft6, ft3, ft4 \n" + "fadd.h ft5, ft5, ft6 \n" + "fmv.x.h %[c], ft5\n" + : [ c ] "+r"(c) + : [ a ] "r"(a), [ b ] "r"(b)); + } + C[m * ldC + n] = c; + } + } +} + +void gemm_fp16_baseline(uint32_t M, uint32_t N, uint32_t K, void* A_p, + uint32_t ldA, uint32_t ta, void* B_p, uint32_t ldB, + uint32_t tb, void* C_p, uint32_t ldC, uint32_t BETA, + uint32_t setup_SSR) { + __fp16* A = (__fp16*)A_p; + __fp16* B = (__fp16*)B_p; + __fp16* C = (__fp16*)C_p; + + for (uint32_t m = 0; m < M; m++) { + uint32_t n = 0; + for (; n < N; n++) { + volatile register v4f16 *a_ptr, *b_ptr; + register v4f16 a, b; + volatile __fp16* c_ptr; + const register float zero = 0.0; + double c = 0.0; + v4f16 reduce_reg; + + a_ptr = (v4f16*)(&A[m * ldA]); + b_ptr = (v4f16*)(&B[n * ldB]); + c_ptr = &C[m * ldC + n]; + // Don't accumulate in first iteration + asm volatile( + "beqz %[beta], 1f \n" + // Load intermediate results + "flh ft2, 0(%[C]) \n" + "vfcvt.s.h ft2, ft2\n" + "vfcpka.s.s ft2, ft2, %[zero]\n" + // or initialize with zero + "j 2f \n" + "1: \n" + "vfcpka.s.s ft2, %[zero], %[zero]\n" + "2: \n" + // loop over the MACs + "li t0, 0 \n" + "3: \n" + "fld ft0, 0(%[a_ptr]) \n" + "fld ft1, 0(%[b_ptr]) \n" + "add %[a_ptr], %[a_ptr], 8 \n" + "add %[b_ptr], %[b_ptr], 8 \n" + "vfdotpex.s.h ft2, ft0, ft1 \n" + "addi t0, t0, 4 \n" + "blt t0, %[K], 3b \n" + // Sum reduce vector + "vfcpka.s.s ft3, %[zero], %[zero]\n" + "vfsum.s ft3, ft2 \n" + "vfcvt.h.s ft3, ft3\n" + // Store results + "fsh ft3, 0(%[C]) \n" + : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr) + : [ c ] "f"(c), [ reduce_reg ] "f"(reduce_reg), + [ C ] "r"(c_ptr), [ beta ] "r"(BETA), [ K ] "r"(K), + [ zero ] "f"(zero) + : "ft0", "ft1", "ft2", "ft3", "ft4", "t0"); + } + } +} + +void gemm_fp16_opt(uint32_t M, uint32_t N, uint32_t K, void* A_p, uint32_t ldA, + uint32_t ta, void* B_p, uint32_t ldB, uint32_t tb, void* C_p, + uint32_t ldC, uint32_t BETA, uint32_t setup_SSR) { + __fp16* A = (__fp16*)A_p; + __fp16* B = (__fp16*)B_p; + __fp16* C = (__fp16*)C_p; + + // Unrolling factor of most inner loop. + // Should be at least as high as the FMA delay + // for maximum utilization + const uint32_t unroll = 8; + + // SSR strides and bounds only have to be configured + // once in the beginning + if (setup_SSR) { + uint32_t ssr0_b[4] = {unroll, K / 4, N / unroll, M}; + uint32_t ssr0_i[4] = {0, sizeof(__fp16) * 4, 0, sizeof(__fp16) * ldA}; + + uint32_t ssr1_b[4] = {unroll, K / 4, N / unroll, M}; + uint32_t ssr1_i[4] = {sizeof(__fp16) * ldB, sizeof(__fp16) * 4, + sizeof(__fp16) * unroll * ldB, 0}; + + snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3], + ssr0_i[1], ssr0_i[2], ssr0_i[3]); + snrt_ssr_repeat(SNRT_SSR_DM0, unroll); + + snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2], + ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2], ssr1_i[3]); + } + + // SSR start address need to be configured each time + snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, A); + snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_4D, B); + snrt_ssr_enable(); + + // Kernel progresses by 4 values each step + const uint32_t n_frep = K / 4 - 1; + + for (uint32_t m = 0; m < M; m++) { + uint32_t n = 0; + for (uint32_t n0 = 0; n0 < N / unroll; n0++) { + __fp16* _C = &C[m * ldC + n]; + const register float zero = 0.0; + v4f16 c[unroll]; + v2f32 reduce_reg[unroll]; + + asm volatile( + "beqz %[beta], 1f \n" + // Load intermediate results + "flh %[reduce_reg0], 0(%[C]) \n" + "flh %[reduce_reg1], 2(%[C]) \n" + "flh %[reduce_reg2], 4(%[C]) \n" + "flh %[reduce_reg3], 6(%[C]) \n" + "flh %[reduce_reg4], 8(%[C]) \n" + "flh %[reduce_reg5], 10(%[C]) \n" + "flh %[reduce_reg6], 12(%[C]) \n" + "flh %[reduce_reg7], 14(%[C]) \n" + // Convert intermediate results before packing + "vfcvt.s.h %[reduce_reg0], %[reduce_reg0]\n" + "vfcvt.s.h %[reduce_reg1], %[reduce_reg1]\n" + "vfcvt.s.h %[reduce_reg2], %[reduce_reg2]\n" + "vfcvt.s.h %[reduce_reg3], %[reduce_reg3]\n" + "vfcvt.s.h %[reduce_reg4], %[reduce_reg4]\n" + "vfcvt.s.h %[reduce_reg5], %[reduce_reg5]\n" + "vfcvt.s.h %[reduce_reg6], %[reduce_reg6]\n" + "vfcvt.s.h %[reduce_reg7], %[reduce_reg7]\n" + // Initialize reduce register to zero + "vfcpka.s.s %[c0], %[zero], %[zero]\n" + "vfcpka.s.s %[c1], %[zero], %[zero]\n" + "vfcpka.s.s %[c2], %[zero], %[zero]\n" + "vfcpka.s.s %[c3], %[zero], %[zero]\n" + "vfcpka.s.s %[c4], %[zero], %[zero]\n" + "vfcpka.s.s %[c5], %[zero], %[zero]\n" + "vfcpka.s.s %[c6], %[zero], %[zero]\n" + "vfcpka.s.s %[c7], %[zero], %[zero]\n" + // Pack intermediate results into SIMD vector + "vfcpka.h.s %[c0], %[reduce_reg0], %[zero]\n" + "vfcpka.h.s %[c1], %[reduce_reg1], %[zero]\n" + "vfcpka.h.s %[c2], %[reduce_reg2], %[zero]\n" + "vfcpka.h.s %[c3], %[reduce_reg3], %[zero]\n" + "vfcpka.h.s %[c4], %[reduce_reg4], %[zero]\n" + "vfcpka.h.s %[c5], %[reduce_reg5], %[zero]\n" + "vfcpka.h.s %[c6], %[reduce_reg6], %[zero]\n" + "vfcpka.h.s %[c7], %[reduce_reg7], %[zero]\n" + "j 2f \n" + "1: \n" + // Initialize SIMD vector with zeros + "vfcpka.s.s %[c0], %[zero], %[zero]\n" + "vfcpka.s.s %[c1], %[zero], %[zero]\n" + "vfcpka.s.s %[c2], %[zero], %[zero]\n" + "vfcpka.s.s %[c3], %[zero], %[zero]\n" + "vfcpka.s.s %[c4], %[zero], %[zero]\n" + "vfcpka.s.s %[c5], %[zero], %[zero]\n" + "vfcpka.s.s %[c6], %[zero], %[zero]\n" + "vfcpka.s.s %[c7], %[zero], %[zero]\n" + "2: \n" + // Perform non-expanding MACs + "frep.o %[n_frep], 8, 0, 0 \n" + "vfmac.h %[c0], ft1, ft0 \n" + "vfmac.h %[c1], ft1, ft0 \n" + "vfmac.h %[c2], ft1, ft0 \n" + "vfmac.h %[c3], ft1, ft0 \n" + "vfmac.h %[c4], ft1, ft0 \n" + "vfmac.h %[c5], ft1, ft0 \n" + "vfmac.h %[c6], ft1, ft0 \n" + "vfmac.h %[c7], ft1, ft0 \n" + // Initialize reduce register to zero + "vfcpka.s.s %[reduce_reg0], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg1], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg2], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg3], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg4], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg5], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg6], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg7], %[zero], %[zero]\n" + // Sum-reduce vector + // EXVSUM is used for the sake of packing afterwards + "vfsumex.s.h %[reduce_reg0], %[c0] \n" + "vfsumex.s.h %[reduce_reg1], %[c1] \n" + "vfsumex.s.h %[reduce_reg2], %[c2] \n" + "vfsumex.s.h %[reduce_reg3], %[c3] \n" + "vfsumex.s.h %[reduce_reg4], %[c4] \n" + "vfsumex.s.h %[reduce_reg5], %[c5] \n" + "vfsumex.s.h %[reduce_reg6], %[c6] \n" + "vfsumex.s.h %[reduce_reg7], %[c7] \n" + // Initialize reduce register to zero + "vfcpka.s.s %[c0], %[zero], %[zero] \n" + "vfcpka.s.s %[c1], %[zero], %[zero] \n" + "vfcpka.s.s %[c2], %[zero], %[zero] \n" + "vfcpka.s.s %[c3], %[zero], %[zero] \n" + "vfcpka.s.s %[c4], %[zero], %[zero] \n" + "vfcpka.s.s %[c5], %[zero], %[zero] \n" + "vfcpka.s.s %[c6], %[zero], %[zero] \n" + "vfcpka.s.s %[c7], %[zero], %[zero] \n" + // Sum-reduce vector + "vfsum.s %[c0], %[reduce_reg0] \n" + "vfsum.s %[c1], %[reduce_reg1] \n" + "vfsum.s %[c2], %[reduce_reg2] \n" + "vfsum.s %[c3], %[reduce_reg3] \n" + "vfsum.s %[c4], %[reduce_reg4] \n" + "vfsum.s %[c5], %[reduce_reg5] \n" + "vfsum.s %[c6], %[reduce_reg6] \n" + "vfsum.s %[c7], %[reduce_reg7] \n" + // Pack results to FP16 vectors + "vfcpka.h.s %[c0], %[c0], %[c1] \n" + "vfcpkb.h.s %[c0], %[c2], %[c3] \n" + "vfcpka.h.s %[c1], %[c4], %[c5] \n" + "vfcpkb.h.s %[c1], %[c6], %[c7] \n" + : [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]), + [ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]), + [ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), + [ reduce_reg0 ] "+f"(reduce_reg[0]), + [ reduce_reg1 ] "+f"(reduce_reg[1]), + [ reduce_reg2 ] "+f"(reduce_reg[2]), + [ reduce_reg3 ] "+f"(reduce_reg[3]), + [ reduce_reg4 ] "+f"(reduce_reg[4]), + [ reduce_reg5 ] "+f"(reduce_reg[5]), + [ reduce_reg6 ] "+f"(reduce_reg[6]), + [ reduce_reg7 ] "+f"(reduce_reg[7]) + : [ C ] "r"(_C), [ zero ] "f"(zero), [ n_frep ] "r"(n_frep), + [ beta ] "r"(BETA) + : "ft0", "ft1", "ft2"); + + // Store results back + ((v4f16*)_C)[0] = c[0]; + ((v4f16*)_C)[1] = c[1]; + n += unroll; + } + + // Clean up left over column + // snrt_ssr_disable(); + + // for (; n < N; n++) { + // __fp16 c = (*BETA) ? C[m * ldC + n] : 0.0; + // for (uint32_t k = 0; k < K; k++) { + // c += A[k + m * ldA] * B[k + n * ldB]; + // } + // C[m * ldC + n] = c; + // } + + // snrt_ssr_enable(); + } + + snrt_ssr_disable(); +} + +void gemm_fp16_opt_ex(uint32_t M, uint32_t N, uint32_t K, void* A_p, + uint32_t ldA, uint32_t ta, void* B_p, uint32_t ldB, + uint32_t tb, void* C_p, uint32_t ldC, uint32_t BETA, + uint32_t setup_SSR) { + __fp16* A = (__fp16*)A_p; + __fp16* B = (__fp16*)B_p; + __fp16* C = (__fp16*)C_p; + + // Unrolling factor of most inner loop. + // Should be at least as high as the FMA delay + // for maximum utilization + const uint32_t unroll = 8; + + // SSR strides and bounds only have to be configured + // once in the beginning + if (setup_SSR) { + uint32_t ssr0_b[4] = {unroll, K / 4, N / unroll, M}; + uint32_t ssr0_i[4] = {0, sizeof(__fp16) * 4, 0, sizeof(__fp16) * ldA}; + + uint32_t ssr1_b[4] = {unroll, K / 4, N / unroll, M}; + uint32_t ssr1_i[4] = {sizeof(__fp16) * ldB, sizeof(__fp16) * 4, + sizeof(__fp16) * unroll * ldB, 0}; + + snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3], + ssr0_i[1], ssr0_i[2], ssr0_i[3]); + snrt_ssr_repeat(SNRT_SSR_DM0, unroll); + + snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2], + ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2], ssr1_i[3]); + } + + // SSR start address need to be configured each time + snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, A); + snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_4D, B); + snrt_ssr_enable(); + + // Kernel progresses by 4 values each step + const uint32_t n_frep = K / 4 - 1; + + for (uint32_t m = 0; m < M; m++) { + uint32_t n = 0; + for (uint32_t n0 = 0; n0 < N / unroll; n0++) { + __fp16* _C = &C[m * ldC + n]; + const register float zero = 0.0; + v4f16 c[unroll]; + v2f32 reduce_reg[unroll]; + + asm volatile( + "beqz %[beta], 1f \n" + "flh %[reduce_reg0], 0(%[C]) \n" + "flh %[reduce_reg1], 2(%[C]) \n" + "flh %[reduce_reg2], 4(%[C]) \n" + "flh %[reduce_reg3], 6(%[C]) \n" + "flh %[reduce_reg4], 8(%[C]) \n" + "flh %[reduce_reg5], 10(%[C]) \n" + "flh %[reduce_reg6], 12(%[C]) \n" + "flh %[reduce_reg7], 14(%[C]) \n" + // Convert intermediate results before packing + "vfcvt.s.h %[reduce_reg0], %[reduce_reg0]\n" + "vfcvt.s.h %[reduce_reg1], %[reduce_reg1]\n" + "vfcvt.s.h %[reduce_reg2], %[reduce_reg2]\n" + "vfcvt.s.h %[reduce_reg3], %[reduce_reg3]\n" + "vfcvt.s.h %[reduce_reg4], %[reduce_reg4]\n" + "vfcvt.s.h %[reduce_reg5], %[reduce_reg5]\n" + "vfcvt.s.h %[reduce_reg6], %[reduce_reg6]\n" + "vfcvt.s.h %[reduce_reg7], %[reduce_reg7]\n" + // Initialize reduce register to zero + "vfcpka.s.s %[c0], %[zero], %[zero]\n" + "vfcpka.s.s %[c1], %[zero], %[zero]\n" + "vfcpka.s.s %[c2], %[zero], %[zero]\n" + "vfcpka.s.s %[c3], %[zero], %[zero]\n" + "vfcpka.s.s %[c4], %[zero], %[zero]\n" + "vfcpka.s.s %[c5], %[zero], %[zero]\n" + "vfcpka.s.s %[c6], %[zero], %[zero]\n" + "vfcpka.s.s %[c7], %[zero], %[zero]\n" + // Pack intermediate results into SIMD vector + "vfcpka.s.s %[c0], %[reduce_reg0], %[zero]\n" + "vfcpka.s.s %[c1], %[reduce_reg1], %[zero]\n" + "vfcpka.s.s %[c2], %[reduce_reg2], %[zero]\n" + "vfcpka.s.s %[c3], %[reduce_reg3], %[zero]\n" + "vfcpka.s.s %[c4], %[reduce_reg4], %[zero]\n" + "vfcpka.s.s %[c5], %[reduce_reg5], %[zero]\n" + "vfcpka.s.s %[c6], %[reduce_reg6], %[zero]\n" + "vfcpka.s.s %[c7], %[reduce_reg7], %[zero]\n" + "j 2f \n" + "1: \n" + // Initialize SIMD vector with zeros + "vfcpka.s.s %[c0], %[zero], %[zero]\n" + "vfcpka.s.s %[c1], %[zero], %[zero]\n" + "vfcpka.s.s %[c2], %[zero], %[zero]\n" + "vfcpka.s.s %[c3], %[zero], %[zero]\n" + "vfcpka.s.s %[c4], %[zero], %[zero]\n" + "vfcpka.s.s %[c5], %[zero], %[zero]\n" + "vfcpka.s.s %[c6], %[zero], %[zero]\n" + "vfcpka.s.s %[c7], %[zero], %[zero]\n" + "2: \n" + // Perform expanding sum-dotproducts + "frep.o %[n_frep], %[unroll], 0, 0 \n" + "vfdotpex.s.h %[c0], ft1, ft0 \n" + "vfdotpex.s.h %[c1], ft1, ft0 \n" + "vfdotpex.s.h %[c2], ft1, ft0 \n" + "vfdotpex.s.h %[c3], ft1, ft0 \n" + "vfdotpex.s.h %[c4], ft1, ft0 \n" + "vfdotpex.s.h %[c5], ft1, ft0 \n" + "vfdotpex.s.h %[c6], ft1, ft0 \n" + "vfdotpex.s.h %[c7], ft1, ft0 \n" + // Initialize reduce register to zero + "vfcpka.s.s %[reduce_reg0], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg1], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg2], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg3], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg4], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg5], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg6], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg7], %[zero], %[zero]\n" + // Sum-reduce vector + "vfsum.s %[reduce_reg0], %[c0] \n" + "vfsum.s %[reduce_reg1], %[c1] \n" + "vfsum.s %[reduce_reg2], %[c2] \n" + "vfsum.s %[reduce_reg3], %[c3] \n" + "vfsum.s %[reduce_reg4], %[c4] \n" + "vfsum.s %[reduce_reg5], %[c5] \n" + "vfsum.s %[reduce_reg6], %[c6] \n" + "vfsum.s %[reduce_reg7], %[c7] \n" + // Pack and convert results to FP16 vectors + "vfcpka.h.s %[c0], %[reduce_reg0], %[reduce_reg1] \n" + "vfcpkb.h.s %[c0], %[reduce_reg2], %[reduce_reg3] \n" + "vfcpka.h.s %[c1], %[reduce_reg4], %[reduce_reg5] \n" + "vfcpkb.h.s %[c1], %[reduce_reg6], %[reduce_reg7] \n" + : [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]), + [ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]), + [ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), + [ reduce_reg0 ] "+f"(reduce_reg[0]), + [ reduce_reg1 ] "+f"(reduce_reg[1]), + [ reduce_reg2 ] "+f"(reduce_reg[2]), + [ reduce_reg3 ] "+f"(reduce_reg[3]), + [ reduce_reg4 ] "+f"(reduce_reg[4]), + [ reduce_reg5 ] "+f"(reduce_reg[5]), + [ reduce_reg6 ] "+f"(reduce_reg[6]), + [ reduce_reg7 ] "+f"(reduce_reg[7]) + : [ C ] "r"(_C), [ zero ] "f"(zero), [ n_frep ] "r"(n_frep), + [ unroll ] "i"(unroll), [ beta ] "r"(BETA) + : "ft0", "ft1", "ft2"); + + // Store results back + ((v4f16*)_C)[0] = c[0]; + ((v4f16*)_C)[1] = c[1]; + n += unroll; + } + + // Clean up left over column + // snrt_ssr_disable(); + + // for (; n < N; n++) { + // __fp16 c = (*BETA) ? C[m * ldC + n] : 0.0; + // for (uint32_t k = 0; k < K; k++) { + // c += A[k + m * ldA] * B[k + n * ldB]; + // } + // C[m * ldC + n] = c; + // } + + // snrt_ssr_enable(); + } + + snrt_ssr_disable(); +} diff --git a/sw/blas/gemm_v2/src/gemm_fp32.h b/sw/blas/gemm_v2/src/gemm_fp32.h new file mode 100644 index 0000000000..e79182b628 --- /dev/null +++ b/sw/blas/gemm_v2/src/gemm_fp32.h @@ -0,0 +1,358 @@ +// Copyright 2023 ETH Zurich and University of Bologna. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Author: Tim Fischer +// Luca Bertaccini +// Luca Colagrande +// Viviane Potocnik + +void gemm_fp32_naive(uint32_t M, uint32_t N, uint32_t K, void* A_p, + uint32_t ldA, uint32_t ta, void* B_p, uint32_t ldB, + uint32_t tb, void* C_p, uint32_t ldC, uint32_t BETA, + uint32_t setup_SSR) { + float* A = (float*)A_p; + float* B = (float*)B_p; + float* C = (float*)C_p; + + if (!ta && !tb) { + for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { + float c0 = multiply_opt(C[m * ldC + n], BETA); + for (uint32_t k = 0; k < K; k++) { + c0 += A[k + m * ldA] * B[k * ldB + n]; + } + C[m * ldC + n] = c0; + } + } + } else if (ta && !tb) { + for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { + float c0 = multiply_opt(C[m * ldC + n], BETA); + for (uint32_t k = 0; k < K; k++) { + c0 += A[k * M * ldA + m * ldA] * B[k * ldB + n]; + } + C[m * ldC + n] = c0; + } + } + } else if (!ta && tb) { + for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { + float c0 = multiply_opt(C[m * ldC + n], BETA); + for (uint32_t k = 0; k < K; k++) { + c0 += A[k + m * ldA] * B[k + n * ldB]; + } + C[m * ldC + n] = c0; + } + } + } else { + for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { + float c0 = multiply_opt(C[m * ldC + n], BETA); + for (uint32_t k = 0; k < K; k++) { + c0 += A[k * M * ldA + m * ldA] * B[k + n * ldB]; + } + C[m * ldC + n] = c0; + } + } + } +} + +void gemm_fp32_naive_unrolled(uint32_t M, uint32_t N, uint32_t K, void* A_p, + uint32_t ldA, uint32_t ta, void* B_p, + uint32_t ldB, uint32_t tb, void* C_p, + uint32_t ldC, uint32_t BETA, uint32_t setup_SSR) { + float* A = (float*)A_p; + float* B = (float*)B_p; + float* C = (float*)C_p; + + float c0 = 0.0f; + float c1 = 0.0f; + float c2 = 0.0f; + float c3 = 0.0f; + if (!ta && !tb) { + for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { + if (BETA == 0) { + c0 = 0.0f; + } else { + c0 = BETA * C[m * ldC + n]; + } + c1 = 0.0f; + c2 = 0.0f; + c3 = 0.0f; + for (uint32_t k = 0; k < K; k += 4) { + c0 += A[(k + 0) + m * ldA] * B[(k + 0) * ldB + n]; + c1 += A[(k + 1) + m * ldA] * B[(k + 1) * ldB + n]; + c2 += A[(k + 2) + m * ldA] * B[(k + 2) * ldB + n]; + c3 += A[(k + 3) + m * ldA] * B[(k + 3) * ldB + n]; + } + C[m * ldC + n] = c0 + c1 + c2 + c3; + } + } + } else if (ta && !tb) { + for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { + if (BETA == 0) { + c0 = 0.0f; + } else { + c0 = BETA * C[m * ldC + n]; + } + c1 = 0.0f; + c2 = 0.0f; + c3 = 0.0f; + for (uint32_t k = 0; k < K; k += 4) { + c0 += A[(k + 0) * M * ldA + m * ldA] * B[(k + 0) * ldB + n]; + c1 += A[(k + 1) * M * ldA + m * ldA] * B[(k + 1) * ldB + n]; + c2 += A[(k + 2) * M * ldA + m * ldA] * B[(k + 2) * ldB + n]; + c3 += A[(k + 3) * M * ldA + m * ldA] * B[(k + 3) * ldB + n]; + } + C[m * ldC + n] = c0 + c1 + c2 + c3; + } + } + } else if (!ta && tb) { + for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { + if (BETA == 0) { + c0 = 0.0f; + } else { + c0 = BETA * C[m * ldC + n]; + } + c1 = 0.0f; + c2 = 0.0f; + c3 = 0.0f; + for (uint32_t k = 0; k < K; k += 4) { + c0 += A[(k + 0) + m * ldA] * B[(k + 0) + n * ldB]; + c1 += A[(k + 1) + m * ldA] * B[(k + 1) + n * ldB]; + c2 += A[(k + 2) + m * ldA] * B[(k + 2) + n * ldB]; + c3 += A[(k + 3) + m * ldA] * B[(k + 3) + n * ldB]; + } + C[m * ldC + n] = c0 + c1 + c2 + c3; + } + } + } else { + for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { + register float c0 = BETA * C[m * ldC + n]; + for (uint32_t k = 0; k < K; k++) { + c0 += A[k * M * ldA + m * ldA] * B[k + n * ldB]; + } + C[m * ldC + n] = c0; + } + } + } +} + +void gemm_fp32_baseline(uint32_t M, uint32_t N, uint32_t K, void* A_p, + uint32_t ldA, uint32_t ta, void* B_p, uint32_t ldB, + uint32_t tb, void* C_p, uint32_t ldC, uint32_t BETA, + uint32_t setup_SSR) { + float* A = (float*)A_p; + float* B = (float*)B_p; + float* C = (float*)C_p; + + for (uint32_t m = 0; m < M; m++) { + uint32_t n = 0; + for (; n < N; n++) { + volatile register v2f32 *a_ptr, *b_ptr; + register v2f32 a, b; + volatile float* c_ptr; + const register float zero = 0.0; + double c = 0.0; + v2f32 reduce_reg; + + a_ptr = (v2f32*)(&A[m * ldA]); + b_ptr = (v2f32*)(&B[n * ldB]); + c_ptr = &C[m * ldC + n]; + // Don't accumulate in first iteration + asm volatile( + "beqz %[BETA], 1f \n" + // Load intermediate results + "flw ft2, 0(%[C]) \n" + "vfcpka.s.s ft2, ft2, %[zero]\n" + // or initialize with zero + "j 2f \n" + "1: \n" + "vfcpka.s.s ft2, %[zero], %[zero]\n" + // Don't accumulate in first iteration + "2: \n" + "fld ft0, 0(%[a_ptr]) \n" + "fld ft1, 0(%[b_ptr]) \n" + "add %[a_ptr], %[a_ptr], 8 \n" + "add %[b_ptr], %[b_ptr], 8 \n" + "vfmul.s ft3, ft0, ft1 \n" + // loop over the MACs + "li t0, 2 \n" + "3: \n" + "fld ft0, 0(%[a_ptr]) \n" + "fld ft1, 0(%[b_ptr]) \n" + "vfmac.s ft3, ft0, ft1 \n" + "add %[a_ptr], %[a_ptr], 8 \n" + "add %[b_ptr], %[b_ptr], 8 \n" + "addi t0, t0, 2 \n" + "blt t0, %[K], 3b \n" + // Sum reduce vector + "vfsum.s ft2, ft3 \n" + // Store results + "fsw ft2, 0(%[C]) \n" + : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr) + : [ c ] "f"(c), [ reduce_reg ] "f"(reduce_reg), + [ C ] "r"(c_ptr), [ BETA ] "r"(BETA), [ K ] "r"(K), + [ zero ] "f"(zero) + : "ft0", "ft1", "ft2", "ft3", "ft4", "t0"); + } + } +} + +void gemm_fp32_opt(uint32_t M, uint32_t N, uint32_t K, void* A_p, uint32_t ldA, + uint32_t ta, void* B_p, uint32_t ldB, uint32_t tb, void* C_p, + uint32_t ldC, uint32_t BETA, uint32_t setup_SSR) { + // cast void pointers to float pointers + float* A = (float*)A_p; + float* B = (float*)B_p; + float* C = (float*)C_p; + // Unrolling factor of most inner loop. + // Should be at least as high as the FMA delay + // for maximum utilization + const uint32_t unroll = 8; + + // SSR strides and bounds only have to be configured + // once in the beginning + if (setup_SSR) { + uint32_t ssr0_b[4] = {unroll, K / 2, N / unroll, M}; + uint32_t ssr0_i[4] = {0, sizeof(float) * 2, 0, sizeof(float) * ldA}; + + uint32_t ssr1_b[4] = {unroll, K / 2, N / unroll, M}; + uint32_t ssr1_i[4] = {sizeof(float) * ldB, sizeof(float) * 2, + sizeof(float) * unroll * ldB, 0}; + + snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3], + ssr0_i[1], ssr0_i[2], ssr0_i[3]); + snrt_ssr_repeat(SNRT_SSR_DM0, unroll); + + snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2], + ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2], ssr1_i[3]); + } + + // SSR start address need to be configured each time + snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, A); + snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_4D, B); + snrt_ssr_enable(); + + // Kernel progresses by 2 values each step + const uint32_t n_frep = K / 2 - 1; + + for (uint32_t m = 0; m < M; m++) { + uint32_t n = 0; + for (uint32_t n0 = 0; n0 < N / unroll; n0++) { + float* _C = &C[m * ldC + n / 2]; + const register float zero = 0.0; + v2f32 c[unroll], reduce_reg[unroll]; + + asm volatile( + "beqz %[BETA], 1f \n" + // Load intermediate results + "flw %[reduce_reg0], 0(%[C]) \n" + "flw %[reduce_reg1], 4(%[C]) \n" + "flw %[reduce_reg2], 8(%[C]) \n" + "flw %[reduce_reg3], 12(%[C]) \n" + "flw %[reduce_reg4], 16(%[C]) \n" + "flw %[reduce_reg5], 20(%[C]) \n" + "flw %[reduce_reg6], 24(%[C]) \n" + "flw %[reduce_reg7], 28(%[C]) \n" + // Pack intermediate results into SIMD vector + "vfcpka.s.s %[reduce_reg0], %[reduce_reg0], %[zero]\n" + "vfcpka.s.s %[reduce_reg1], %[reduce_reg1], %[zero]\n" + "vfcpka.s.s %[reduce_reg2], %[reduce_reg2], %[zero]\n" + "vfcpka.s.s %[reduce_reg3], %[reduce_reg3], %[zero]\n" + "vfcpka.s.s %[reduce_reg4], %[reduce_reg4], %[zero]\n" + "vfcpka.s.s %[reduce_reg5], %[reduce_reg5], %[zero]\n" + "vfcpka.s.s %[reduce_reg6], %[reduce_reg6], %[zero]\n" + "vfcpka.s.s %[reduce_reg7], %[reduce_reg7], %[zero]\n" + "j 2f \n" + "1: \n" + // Initialize SIMD vector with zeros + "vfcpka.s.s %[reduce_reg0], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg1], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg2], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg3], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg4], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg5], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg6], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg7], %[zero], %[zero]\n" + + "2: \n" + // Don't accumulate in first iteration + "vfmul.s %[c0], ft1, ft0 \n" + "vfmul.s %[c1], ft1, ft0 \n" + "vfmul.s %[c2], ft1, ft0 \n" + "vfmul.s %[c3], ft1, ft0 \n" + "vfmul.s %[c4], ft1, ft0 \n" + "vfmul.s %[c5], ft1, ft0 \n" + "vfmul.s %[c6], ft1, ft0 \n" + "vfmul.s %[c7], ft1, ft0 \n" + // frep over MACs + "frep.o %[n_frep], %[unroll], 0, 0 \n" + "vfmac.s %[c0], ft1, ft0 \n" + "vfmac.s %[c1], ft1, ft0 \n" + "vfmac.s %[c2], ft1, ft0 \n" + "vfmac.s %[c3], ft1, ft0 \n" + "vfmac.s %[c4], ft1, ft0 \n" + "vfmac.s %[c5], ft1, ft0 \n" + "vfmac.s %[c6], ft1, ft0 \n" + "vfmac.s %[c7], ft1, ft0 \n" + // Sum-reduce vector + "vfsum.s %[reduce_reg0], %[c0] \n" + "vfsum.s %[reduce_reg1], %[c1] \n" + "vfsum.s %[reduce_reg2], %[c2] \n" + "vfsum.s %[reduce_reg3], %[c3] \n" + "vfsum.s %[reduce_reg4], %[c4] \n" + "vfsum.s %[reduce_reg5], %[c5] \n" + "vfsum.s %[reduce_reg6], %[c6] \n" + "vfsum.s %[reduce_reg7], %[c7] \n" + // Pack results together again into vectors + "vfcpka.s.s %[c0], %[reduce_reg0], %[reduce_reg1] \n" + "vfcpka.s.s %[c1], %[reduce_reg2], %[reduce_reg3] \n" + "vfcpka.s.s %[c2], %[reduce_reg4], %[reduce_reg5] \n" + "vfcpka.s.s %[c3], %[reduce_reg6], %[reduce_reg7] \n" + : [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]), + [ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]), + [ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), + [ reduce_reg0 ] "+f"(reduce_reg[0]), + [ reduce_reg1 ] "+f"(reduce_reg[1]), + [ reduce_reg2 ] "+f"(reduce_reg[2]), + [ reduce_reg3 ] "+f"(reduce_reg[3]), + [ reduce_reg4 ] "+f"(reduce_reg[4]), + [ reduce_reg5 ] "+f"(reduce_reg[5]), + [ reduce_reg6 ] "+f"(reduce_reg[6]), + [ reduce_reg7 ] "+f"(reduce_reg[7]) + : [ C ] "r"(_C), [ zero ] "f"(zero), [ n_frep ] "r"(n_frep - 1), + [ unroll ] "i"(unroll), [ BETA ] "r"(BETA) + : "ft0", "ft1", "ft2"); + + // Store results + ((v2f32*)_C)[0] = c[0]; + ((v2f32*)_C)[1] = c[1]; + ((v2f32*)_C)[2] = c[2]; + ((v2f32*)_C)[3] = c[3]; + + // progress by 2 columns each iteration of the loop + n += unroll * 2; + } + + // Clean up of leftover columns + snrt_ssr_disable(); + + for (; n < N; n++) { + float c = BETA ? C[m * ldC + n] : 0.0; + for (uint32_t k = 0; k < K; k++) { + c += A[k + m * ldA] * B[k + n * ldB]; + } + C[m * ldC + n] = c; + } + + snrt_ssr_enable(); + } + + snrt_ssr_disable(); +} diff --git a/sw/blas/gemm_v2/src/gemm_fp64.h b/sw/blas/gemm_v2/src/gemm_fp64.h new file mode 100644 index 0000000000..66b6436064 --- /dev/null +++ b/sw/blas/gemm_v2/src/gemm_fp64.h @@ -0,0 +1,189 @@ +// Copyright 2023 ETH Zurich and University of Bologna. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Author: Tim Fischer +// Luca Bertaccini +// Luca Colagrande +// Viviane Potocnik + +void gemm_fp64_naive(uint32_t M, uint32_t N, uint32_t K, void* A_p, + uint32_t ldA, uint32_t ta, void* B_p, uint32_t ldB, + uint32_t tb, void* C_p, uint32_t ldC, uint32_t BETA, + uint32_t setup_SSR) { + double* A = (double*)A_p; + double* B = (double*)B_p; + double* C = (double*)C_p; + + if (!ta && !tb) { + for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { + double c0 = multiply_opt(C[m * ldC + n], BETA); + for (uint32_t k = 0; k < K; k++) { + c0 += A[k + m * ldA] * B[k * ldB + n]; + } + C[m * ldC + n] = c0; + } + } + } else if (ta && !tb) { + for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { + double c0 = multiply_opt(C[m * ldC + n], BETA); + for (uint32_t k = 0; k < K; k++) { + c0 += A[k * M * ldA + m * ldA] * B[k * ldB + n]; + } + C[m * ldC + n] = c0; + } + } + } else if (!ta && tb) { + for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { + double c0 = multiply_opt(C[m * ldC + n], BETA); + for (uint32_t k = 0; k < K; k++) { + c0 += A[k + m * ldA] * B[k + n * ldB]; + } + C[m * ldC + n] = c0; + } + } + } else { + for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { + double c0 = multiply_opt(C[m * ldC + n], BETA); + for (uint32_t k = 0; k < K; k++) { + c0 += A[k * M * ldA + m * ldA] * B[k + n * ldB]; + } + C[m * ldC + n] = c0; + } + } + } +} + +void gemm_fp64_opt(uint32_t M, uint32_t N, uint32_t K, void* A_p, uint32_t ldA, + uint32_t ta, void* B_p, uint32_t ldB, uint32_t tb, void* C_p, + uint32_t ldC, uint32_t BETA, uint32_t setup_SSR) { + double* A = (double*)A_p; + double* B = (double*)B_p; + double* C = (double*)C_p; + + // Unrolling factor of most inner loop. + // Should be at least as high as the FMA delay + // for maximum utilization + const uint32_t unroll = 8; + + // SSR strides and bounds only have to be configured + // once in the beginning + if (setup_SSR) { + // First matrix is stored in transposed format + if (ta) { + const uint32_t ssr0_b[4] = {unroll, K, N / unroll, M}; + const uint32_t ssr0_i[4] = {0, 8 * ldA, 0, 8 * 8}; + + snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3], + ssr0_i[1], ssr0_i[2], ssr0_i[3]); + snrt_ssr_repeat(SNRT_SSR_DM0, unroll); + } else { + const uint32_t ssr0_b[4] = {unroll, K, N / unroll, M}; + const uint32_t ssr0_i[4] = {0, 8, 0, 8 * ldA}; + + snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3], + ssr0_i[1], ssr0_i[2], ssr0_i[3]); + snrt_ssr_repeat(SNRT_SSR_DM0, unroll); + } + + // Second matrix is stored in transposed format + if (tb) { + const uint32_t ssr1_b[4] = {unroll, K, N / unroll, M}; + const uint32_t ssr1_i[4] = {8 * ldB, 8, 8 * ldB * unroll, 0}; + + snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2], + ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2], + ssr1_i[3]); + } else { + const uint32_t ssr1_b[4] = {unroll, K, N / unroll, M}; + const uint32_t ssr1_i[4] = {8, 8 * ldB, 8 * unroll, 0}; + + snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2], + ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2], + ssr1_i[3]); + } + } + + // SSR start address need to be configured each time + snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, A); + snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_4D, B); + snrt_ssr_enable(); + + for (uint32_t m = 0; m < M; m++) { + uint32_t n = 0; + for (uint32_t n0 = 0; n0 < N / unroll; n0++) { + double c[unroll]; + + // Load intermediate result + if (BETA != 0) { + c[0] = C[m * ldC + n + 0]; + c[1] = C[m * ldC + n + 1]; + c[2] = C[m * ldC + n + 2]; + c[3] = C[m * ldC + n + 3]; + c[4] = C[m * ldC + n + 4]; + c[5] = C[m * ldC + n + 5]; + c[6] = C[m * ldC + n + 6]; + c[7] = C[m * ldC + n + 7]; + } else { + c[0] = 0.0; + c[1] = 0.0; + c[2] = 0.0; + c[3] = 0.0; + c[4] = 0.0; + c[5] = 0.0; + c[6] = 0.0; + c[7] = 0.0; + } + asm volatile( + "frep.o %[n_frep], %[unroll], 0, 0 \n" + "fmadd.d %[c0], ft0, ft1, %[c0] \n" + "fmadd.d %[c1], ft0, ft1, %[c1] \n" + "fmadd.d %[c2], ft0, ft1, %[c2] \n" + "fmadd.d %[c3], ft0, ft1, %[c3] \n" + "fmadd.d %[c4], ft0, ft1, %[c4] \n" + "fmadd.d %[c5], ft0, ft1, %[c5] \n" + "fmadd.d %[c6], ft0, ft1, %[c6] \n" + "fmadd.d %[c7], ft0, ft1, %[c7] \n" + : [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]), + [ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]), + [ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]) + : [ n_frep ] "r"(K - 1), [ unroll ] "i"(unroll) + : "ft0", "ft1", "ft2"); + + // Store results back + C[m * ldC + n + 0] = c[0]; + C[m * ldC + n + 1] = c[1]; + C[m * ldC + n + 2] = c[2]; + C[m * ldC + n + 3] = c[3]; + C[m * ldC + n + 4] = c[4]; + C[m * ldC + n + 5] = c[5]; + C[m * ldC + n + 6] = c[6]; + C[m * ldC + n + 7] = c[7]; + n += unroll; + } + + // Clean up of leftover columns + snrt_ssr_disable(); + + for (; n < N; n++) { + double c; + if (BETA != 0) { + c = C[m * ldC + n]; + } else { + c = 0.0; + } + for (uint32_t k = 0; k < K; k++) { + c += A[k + m * ldA] * B[k + n * ldB]; + } + C[m * ldC + n] = c; + } + + snrt_ssr_enable(); + } + + snrt_ssr_disable(); +} diff --git a/sw/blas/gemm_v2/src/gemm_fp8.h b/sw/blas/gemm_v2/src/gemm_fp8.h new file mode 100644 index 0000000000..2ec5e05103 --- /dev/null +++ b/sw/blas/gemm_v2/src/gemm_fp8.h @@ -0,0 +1,304 @@ +// Copyright 2023 ETH Zurich and University of Bologna. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Author: Tim Fischer +// Luca Bertaccini +// Luca Colagrande +// Viviane Potocnik + +void gemm_fp8_naive(uint32_t M, uint32_t N, uint32_t K, void* A_p, uint32_t ldA, + uint32_t ta, void* B_p, uint32_t ldB, uint32_t tb, + void* C_p, uint32_t ldC, uint32_t BETA, + uint32_t setup_SSR) { + char* A = (char*)A_p; + char* B = (char*)B_p; + char* C = (char*)C_p; + + for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { + char c; + if (BETA != 0) { + c = C[m * ldC + n]; + // FIXME: get the correct beta value + asm volatile( + // "fmv.b.x ft0, %[beta]\n" + "fcvt.b.s ft0, %[beta]\n" + "fmv.b.x ft1, %[c]\n" + "fmul.b ft2, ft0, ft1\n" + "fmv.x.b %[c], ft2\n" + : [ c ] "+r"(c) + : [ beta ] "f"(1.0f) + : "ft0", "ft1", "ft2"); + } else { + c = 0.0; + } + for (uint32_t k = 0; k < K; k++) { + char a = A[k + m * ldA]; + char b; + if (tb) + b = B[n * ldB + k]; + else + b = B[k * ldB + n]; + asm volatile( + "fmv.b.x ft3, %[a]\n" + "fmv.b.x ft4, %[b]\n" + "fmv.b.x ft5, %[c]\n" + "fmul.b ft6, ft3, ft4 \n" + "fadd.b ft5, ft5, ft6 \n" + "fmv.x.b %[c], ft5\n" + : [ c ] "+r"(c) + : [ a ] "r"(a), [ b ] "r"(b)); + } + C[m * ldC + n] = c; + } + } +} + +void gemm_fp8_baseline(uint32_t M, uint32_t N, uint32_t K, void* A_p, + uint32_t ldA, uint32_t ta, void* B_p, uint32_t ldB, + uint32_t tb, void* C_p, uint32_t ldC, uint32_t BETA, + uint32_t setup_SSR) { + char* A = (char*)A_p; + char* B = (char*)B_p; + char* C = (char*)C_p; + + for (uint32_t m = 0; m < M; m++) { + uint32_t n = 0; + for (; n < N; n++) { + volatile register v8f8 *a_ptr, *b_ptr; + register v8f8 a, b; + volatile char* c_ptr; + const register float zero = 0.0; + double c = 0.0; + v8f8 reduce_reg; + + a_ptr = (v8f8*)(&A[m * ldA]); + b_ptr = (v8f8*)(&B[n * ldB]); + c_ptr = &C[m * ldC + n]; + asm volatile( + "beqz %[beta], 1f \n" + // Load intermediate results + "flb ft2, 0(%[C]) \n" + "vfcvt.s.b ft2, ft2\n" + "vfcpka.h.s ft2, ft2, %[zero]\n" + // or initialize with zero + "j 2f \n" + "1: \n" + "vfcpka.s.s ft2, %[zero], %[zero]\n" + "2: \n" + // loop over the MACs + "li t0, 0 \n" + "3: \n" + "fld ft0, 0(%[a_ptr]) \n" + "fld ft1, 0(%[b_ptr]) \n" + "add %[a_ptr], %[a_ptr], 8 \n" + "add %[b_ptr], %[b_ptr], 8 \n" + "vfdotpex.h.b ft2, ft0, ft1 \n" + "addi t0, t0, 8 \n" + "blt t0, %[K], 3b \n" + // Sum reduce vector + "vfcpka.s.s ft3, %[zero], %[zero]\n" + "vfsumex.s.h ft3, ft2 \n" + "vfcpka.s.s ft2, %[zero], %[zero]\n" + "vfsum.s ft2, ft3 \n" + "vfcvt.b.s ft2, ft2\n" + // Store results + "fsb ft2, 0(%[C]) \n" + : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr) + : [ c ] "f"(c), [ reduce_reg ] "f"(reduce_reg), + [ C ] "r"(c_ptr), [ beta ] "r"(BETA), [ K ] "r"(K), + [ zero ] "f"(zero) + : "ft0", "ft1", "ft2", "ft3", "ft4", "t0"); + } + } +} + +void gemm_fp8_opt_ex(uint32_t M, uint32_t N, uint32_t K, void* A_p, + uint32_t ldA, uint32_t ta, void* B_p, uint32_t ldB, + uint32_t tb, void* C_p, uint32_t ldC, uint32_t BETA, + uint32_t setup_SSR) { + char* A = (char*)A_p; + char* B = (char*)B_p; + char* C = (char*)C_p; + // Unrolling factor of most inner loop. + // Should be at least as high as the FMA delay + // for maximum utilization + const uint32_t unroll = 8; + + // SSR strides and bounds only have to be configured + // once in the beginning + if (setup_SSR) { + uint32_t ssr0_b[4] = {unroll, K / 8, N / unroll, M}; + uint32_t ssr0_i[4] = {0, sizeof(char) * 8, 0, sizeof(char) * ldA}; + + uint32_t ssr1_b[4] = {unroll, K / 8, N / unroll, M}; + uint32_t ssr1_i[4] = {sizeof(char) * ldB, sizeof(char) * 8, + sizeof(char) * unroll * ldB, 0}; + + snrt_ssr_loop_3d(SNRT_SSR_DM0, ssr0_b[1], ssr0_b[2], ssr0_b[3], + ssr0_i[1], ssr0_i[2], ssr0_i[3]); + snrt_ssr_repeat(SNRT_SSR_DM0, unroll); + + snrt_ssr_loop_4d(SNRT_SSR_DM1, ssr1_b[0], ssr1_b[1], ssr1_b[2], + ssr1_b[3], ssr1_i[0], ssr1_i[1], ssr1_i[2], ssr1_i[3]); + } + + // SSR start address need to be configured each time + snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_4D, A); + snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_4D, B); + snrt_ssr_enable(); + + // Kernel progresses by 8 values each step + const uint32_t n_frep = K / 8 - 1; + + for (uint32_t m = 0; m < M; m++) { + uint32_t n = 0; + for (uint32_t n0 = 0; n0 < N / unroll; n0++) { + char* _C = &C[m * ldC + n]; + const register float zero = 0.0; + v8f8 c[unroll]; + v4f16 reduce_reg[unroll]; + + asm volatile( + "beqz %[beta], 1f \n" + "flb %[reduce_reg0], 0(%[C]) \n" + "flb %[reduce_reg1], 1(%[C]) \n" + "flb %[reduce_reg2], 2(%[C]) \n" + "flb %[reduce_reg3], 3(%[C]) \n" + "flb %[reduce_reg4], 4(%[C]) \n" + "flb %[reduce_reg5], 5(%[C]) \n" + "flb %[reduce_reg6], 6(%[C]) \n" + "flb %[reduce_reg7], 7(%[C]) \n" + // Convert intermediate results before packing + "vfcvt.s.b %[reduce_reg0], %[reduce_reg0]\n" + "vfcvt.s.b %[reduce_reg1], %[reduce_reg1]\n" + "vfcvt.s.b %[reduce_reg2], %[reduce_reg2]\n" + "vfcvt.s.b %[reduce_reg3], %[reduce_reg3]\n" + "vfcvt.s.b %[reduce_reg4], %[reduce_reg4]\n" + "vfcvt.s.b %[reduce_reg5], %[reduce_reg5]\n" + "vfcvt.s.b %[reduce_reg6], %[reduce_reg6]\n" + "vfcvt.s.b %[reduce_reg7], %[reduce_reg7]\n" + // Initialize reduce register to zero + "vfcpka.s.s %[c0], %[zero], %[zero]\n" + "vfcpka.s.s %[c1], %[zero], %[zero]\n" + "vfcpka.s.s %[c2], %[zero], %[zero]\n" + "vfcpka.s.s %[c3], %[zero], %[zero]\n" + "vfcpka.s.s %[c4], %[zero], %[zero]\n" + "vfcpka.s.s %[c5], %[zero], %[zero]\n" + "vfcpka.s.s %[c6], %[zero], %[zero]\n" + "vfcpka.s.s %[c7], %[zero], %[zero]\n" + // Pack intermediate results into SIMD vector + "vfcpka.h.s %[c0], %[reduce_reg0], %[zero]\n" + "vfcpka.h.s %[c1], %[reduce_reg1], %[zero]\n" + "vfcpka.h.s %[c2], %[reduce_reg2], %[zero]\n" + "vfcpka.h.s %[c3], %[reduce_reg3], %[zero]\n" + "vfcpka.h.s %[c4], %[reduce_reg4], %[zero]\n" + "vfcpka.h.s %[c5], %[reduce_reg5], %[zero]\n" + "vfcpka.h.s %[c6], %[reduce_reg6], %[zero]\n" + "vfcpka.h.s %[c7], %[reduce_reg7], %[zero]\n" + "j 2f \n" + "1: \n" + // Initialize SIMD vector with zeros + "vfcpka.s.s %[c0], %[zero], %[zero]\n" + "vfcpka.s.s %[c1], %[zero], %[zero]\n" + "vfcpka.s.s %[c2], %[zero], %[zero]\n" + "vfcpka.s.s %[c3], %[zero], %[zero]\n" + "vfcpka.s.s %[c4], %[zero], %[zero]\n" + "vfcpka.s.s %[c5], %[zero], %[zero]\n" + "vfcpka.s.s %[c6], %[zero], %[zero]\n" + "vfcpka.s.s %[c7], %[zero], %[zero]\n" + "2: \n" + // Perform expanding sum-dotproducts + "frep.o %[n_frep], %[unroll], 0, 0 \n" + "vfdotpex.h.b %[c0], ft1, ft0 \n" + "vfdotpex.h.b %[c1], ft1, ft0 \n" + "vfdotpex.h.b %[c2], ft1, ft0 \n" + "vfdotpex.h.b %[c3], ft1, ft0 \n" + "vfdotpex.h.b %[c4], ft1, ft0 \n" + "vfdotpex.h.b %[c5], ft1, ft0 \n" + "vfdotpex.h.b %[c6], ft1, ft0 \n" + "vfdotpex.h.b %[c7], ft1, ft0 \n" + // Initialize reduce register to zero + "vfcpka.s.s %[reduce_reg0], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg1], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg2], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg3], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg4], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg5], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg6], %[zero], %[zero]\n" + "vfcpka.s.s %[reduce_reg7], %[zero], %[zero]\n" + // Sum-reduce vector + "vfsumex.s.h %[reduce_reg0], %[c0] \n" + "vfsumex.s.h %[reduce_reg1], %[c1] \n" + "vfsumex.s.h %[reduce_reg2], %[c2] \n" + "vfsumex.s.h %[reduce_reg3], %[c3] \n" + "vfsumex.s.h %[reduce_reg4], %[c4] \n" + "vfsumex.s.h %[reduce_reg5], %[c5] \n" + "vfsumex.s.h %[reduce_reg6], %[c6] \n" + "vfsumex.s.h %[reduce_reg7], %[c7] \n" + // + // Initialize reduce register to zero + "vfcpka.s.s %[c0], %[zero], %[zero] \n" + "vfcpka.s.s %[c1], %[zero], %[zero] \n" + "vfcpka.s.s %[c2], %[zero], %[zero] \n" + "vfcpka.s.s %[c3], %[zero], %[zero] \n" + "vfcpka.s.s %[c4], %[zero], %[zero] \n" + "vfcpka.s.s %[c5], %[zero], %[zero] \n" + "vfcpka.s.s %[c6], %[zero], %[zero] \n" + "vfcpka.s.s %[c7], %[zero], %[zero] \n" + // Sum-reduce vector + "vfsum.s %[c0], %[reduce_reg0] \n" + "vfsum.s %[c1], %[reduce_reg1] \n" + "vfsum.s %[c2], %[reduce_reg2] \n" + "vfsum.s %[c3], %[reduce_reg3] \n" + "vfsum.s %[c4], %[reduce_reg4] \n" + "vfsum.s %[c5], %[reduce_reg5] \n" + "vfsum.s %[c6], %[reduce_reg6] \n" + "vfsum.s %[c7], %[reduce_reg7] \n" + // Pack and convert results to FP8 vectors + "vfcpka.b.s %[c0], %[c0], %[c1] \n" + "vfcpkb.b.s %[c0], %[c2], %[c3] \n" + "vfcpkc.b.s %[c0], %[c4], %[c5] \n" + "vfcpkd.b.s %[c0], %[c6], %[c7] \n" + // // // Pack and convert results to FP8 vectors + // "vfcpka.b.s %[c0], %[reduce_reg0], %[reduce_reg1] \n" + // "vfcpkb.b.s %[c0], %[reduce_reg2], %[reduce_reg3] \n" + // "vfcpkc.b.s %[c0], %[reduce_reg4], %[reduce_reg5] \n" + // "vfcpkd.b.s %[c0], %[reduce_reg6], %[reduce_reg7] \n" + : [ c0 ] "+f"(c[0]), [ c1 ] "+f"(c[1]), [ c2 ] "+f"(c[2]), + [ c3 ] "+f"(c[3]), [ c4 ] "+f"(c[4]), [ c5 ] "+f"(c[5]), + [ c6 ] "+f"(c[6]), [ c7 ] "+f"(c[7]), + [ reduce_reg0 ] "+f"(reduce_reg[0]), + [ reduce_reg1 ] "+f"(reduce_reg[1]), + [ reduce_reg2 ] "+f"(reduce_reg[2]), + [ reduce_reg3 ] "+f"(reduce_reg[3]), + [ reduce_reg4 ] "+f"(reduce_reg[4]), + [ reduce_reg5 ] "+f"(reduce_reg[5]), + [ reduce_reg6 ] "+f"(reduce_reg[6]), + [ reduce_reg7 ] "+f"(reduce_reg[7]) + : [ C ] "r"(_C), [ n_frep ] "r"(n_frep), [ beta ] "r"(BETA), + [ unroll ] "i"(unroll), [ zero ] "f"(zero) + : "ft0", "ft1", "ft2"); + + // Store results back + ((v8f8*)_C)[0] = c[0]; + n += unroll; + } + + // Clean up left over column + // snrt_ssr_disable(); + + // for (; n < N; n++) { + // char c = (*BETA) ? C[m * ldC + n] : 0.0; + // for (uint32_t k = 0; k < K; k++) { + // c += A[k + m * ldA] * B[k + n * ldB]; + // } + // C[m * ldC + n] = c; + // } + + // snrt_ssr_enable(); + } + + snrt_ssr_disable(); +} diff --git a/sw/blas/gemm_v2/src/gemm_v2.h b/sw/blas/gemm_v2/src/gemm_v2.h new file mode 100644 index 0000000000..02f8ebe4ed --- /dev/null +++ b/sw/blas/gemm_v2/src/gemm_v2.h @@ -0,0 +1,262 @@ +// Copyright 2023 ETH Zurich and University of Bologna. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Author: Tim Fischer +// Luca Bertaccini +// Luca Colagrande +// Viviane Potocnik + +#include + +#include "snrt.h" + +#pragma once + +// Guard to avoid conflict with DNN header file +// TODO: move this definition to Snitch math library to solve problem +#ifndef PRECISION_T +#define PRECISION_T +typedef enum { FP64 = 8, FP32 = 4, FP16 = 2, FP8 = 1 } precision_t; +#endif + +typedef float v2f32 __attribute__((vector_size(8))); +typedef __fp16 v4f16 __attribute__((vector_size(8))); +typedef char v8f8 __attribute__((vector_size(8))); + +// Floating-point multiplications by zero cannot be optimized as in some +// edge cases they do not yield zero: +// - 0f * NaN = NaN +// - 0f * INFINITY == NaN +// Thus in order to optimize it, we need to test for zero. You can use this +// function for free when `multiplier` is a constant. +static inline double multiply_opt(double multiplicand, double multiplier) { + if (multiplier) + return multiplicand * multiplier; + else + return 0; +} + +#include "gemm_fp16.h" +#include "gemm_fp32.h" +#include "gemm_fp64.h" +#include "gemm_fp8.h" + +// define the gemm_fp function pointer +typedef void (*gemm_fp_t)(uint32_t m, uint32_t n, uint32_t k, void* a, + uint32_t lda, uint32_t transa, void* b, + uint32_t transb, uint32_t ldb, void* c, uint32_t ldc, + uint32_t beta, uint32_t setup_ssr); + +typedef struct { + double alpha; + uint32_t prec; + uint32_t setup_ssr; + uint32_t parallelize_m; + uint32_t parallelize_k; + uint32_t m_tiles; + uint32_t n_tiles; + uint32_t k_tiles; + uint32_t load_a; + uint32_t load_b; + uint32_t load_c; + uint32_t transa; + uint32_t transb; + uint32_t M; + uint32_t N; + uint32_t K; + void* a; + void* b; + uint32_t beta; + void* c; + void* gemm_fp; +} gemm_args_t; + +// BLAS compliant single-cluster single-tile GEMM kernel, with some additional +// arguments at the beginning to specify Snitch implementation details. Matrix +// sizes and pointers are for the whole cluster computation. Within a cluster +// the computation is parallelized by assigning distinct output rows to +// distinct cores. +// TODO: beta (and alpha) should be of floating-point type (same precision as +// operands) +void sc_st_gemm(gemm_args_t* gemm_args, void* a, void* b, uint32_t beta, + void* c) { + gemm_fp_t impl = (gemm_fp_t)gemm_args->gemm_fp; + precision_t prec = gemm_args->prec; + uint32_t setup_ssr = gemm_args->setup_ssr; + uint32_t transa = gemm_args->transa; + uint32_t transb = gemm_args->transb; + + uint32_t m = gemm_args->M / gemm_args->m_tiles; + uint32_t n = gemm_args->N; + uint32_t k = gemm_args->K; + + uint32_t lda = k; + uint32_t ldb; + if (transb) { + ldb = k; + } else { + ldb = n; + } + uint32_t ldc = n; + + double alpha = gemm_args->alpha; + + if (snrt_is_compute_core()) { + const uint32_t compute_num = snrt_cluster_compute_core_num(); + const uint32_t compute_id = snrt_cluster_core_idx(); + + // Compute cores work not on contiguous blocks but on strided rows + uint32_t lda_strided = compute_num * lda; + uint32_t ldc_strided = compute_num * ldc; + + // Compute cores access A and C at offsets of one row from each other + uint32_t offsetA = compute_id * lda * prec; + uint32_t offsetC = compute_id * ldc * prec; + + // Compute fraction of C rows every core computes + uint32_t frac_m = m / compute_num; + uint32_t rem_m = m % compute_num; + if (snrt_cluster_core_idx() < rem_m) frac_m++; + + if (frac_m > 0) + impl(frac_m, n, k, a + offsetA, lda_strided, transa, b, ldb, transb, + c + offsetC, ldc_strided, (float)beta, setup_ssr); + } +} + +// Multiple-cluster multiple-tile GEMM implementation. +// If parallelize_m, assigns a distinct subset of M-tiles to distinct clusters. +// If parallelize_k, then K-tiles are distributed to distinct clusters; a +// binary reduction tree is implemented to accumulate these tiles together. +// Note: in the current implementation, parallelize_m and parallelize_k +// should be mutually-exclusive. The load_* options allow to bypass the DMA +// transfers and operate directly on the a, b and c inputs. +// m_tiles: number of tiles in M dimension +// k_tiles: number of tiles in K dimension +// n_tiles: number of tiles in N dimension +int gemm(gemm_args_t* args) { + gemm_args_t* local_args = snrt_l1_next(); + + // Copy the arguments to local memory + if (snrt_is_dm_core()) { + snrt_dma_start_1d(local_args, args, sizeof(gemm_args_t)); + snrt_dma_wait_all(); + } + snrt_cluster_hw_barrier(); + + uint32_t m = local_args->M; + uint32_t n = local_args->N; + uint32_t k = local_args->K; + precision_t prec = (precision_t)local_args->prec; + uint32_t setup_ssr = local_args->setup_ssr; + uint32_t m_tiles = local_args->m_tiles; + uint32_t transa = local_args->transa; + uint32_t transb = local_args->transb; + double alpha = local_args->alpha; + void* a = local_args->a; + void* b = local_args->b; + uint32_t beta = local_args->beta; + void* c = local_args->c; + + // Calculate tile sizes + uint32_t frac_m = m / m_tiles; + uint32_t frac_n = n; + uint32_t frac_k = k; + uint32_t frac_a = frac_m * frac_k; + uint32_t frac_c = frac_m * frac_n; + uint32_t size_frac_a = frac_a * prec; + uint32_t size_frac_b = frac_k * frac_n * prec; + uint32_t size_frac_c = frac_c * prec; + + // Allocate space in TCDM + void *local_a[2]; + void *local_b[2]; + void *local_c[2]; + void* heap_ptr = (void*)local_args + sizeof(gemm_args_t); + local_a[0] = heap_ptr; + heap_ptr += size_frac_a; + local_b[0] = heap_ptr; + heap_ptr += size_frac_b; + local_c[0] = heap_ptr; + heap_ptr += size_frac_c; + local_a[1] = heap_ptr; + heap_ptr += size_frac_a; + local_b[1] = heap_ptr; + heap_ptr += size_frac_b; + local_c[1] = heap_ptr; + + // Calculate number of iterations + int n_tiles = args->m_tiles; + int iterations = n_tiles + 2; + int buff_idx; + int i, i_dma_out, i_dma_in, i_compute; + + // Iterate over all tiles + for (i = 0; i < iterations; i++) { + if (snrt_is_dm_core()) { + // DMA out + // (out before in to avoid overwriting data) + if (i > 1) { + snrt_mcycle(); + + // Compute tile and buffer indices + i_dma_out = i - 2; + buff_idx = i_dma_out % 2; + + // Copy job outputs from TCDM + snrt_dma_store_2d_tile(c, local_c[buff_idx], i_dma_out, 0, + frac_m, frac_n, n, prec); + snrt_dma_wait_all(); + + snrt_mcycle(); + } + + // DMA in + if (i < n_tiles) { + snrt_mcycle(); + + // Compute tile and buffer indices + i_dma_in = i; + buff_idx = i_dma_in % 2; + + // Copy job operands in TCDM + snrt_dma_load_2d_tile(local_a[buff_idx], a, i_dma_in, + 0, frac_m, frac_k, k, + prec); + snrt_dma_start_1d(local_b[buff_idx], b, frac_k * frac_n * prec); + snrt_dma_wait_all(); + + snrt_mcycle(); + } + } + + // Compute + if (snrt_is_compute_core()) { + if (i > 0 && i < (n_tiles + 1)) { + snrt_mcycle(); + + // Compute tile and buffer indices + i_compute = i - 1; + buff_idx = i_compute % 2; + + // Perform tile computation + volatile uint32_t ldb = frac_n; + volatile uint32_t ldc = frac_n; + if (transb) { + ldb = frac_k; + } + sc_st_gemm(local_args, local_a[buff_idx], local_b[buff_idx], beta, + local_c[buff_idx]); + + snrt_mcycle(); + } + } + + // Synchronize cores after every iteration + snrt_cluster_hw_barrier(); + } + + + return 0; +} diff --git a/sw/blas/gemm_v2/src/main.c b/sw/blas/gemm_v2/src/main.c new file mode 100644 index 0000000000..00bf414e7b --- /dev/null +++ b/sw/blas/gemm_v2/src/main.c @@ -0,0 +1,87 @@ +// Copyright 2023 ETH Zurich and University of Bologna. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Author: Tim Fischer +// Luca Colagrande +// Viviane Potocnik + +#include +#include + +#include "gemm_v2.h" + +#include "data.h" +#include "snrt.h" + +int main() { + int retcode = gemm(&args); + + snrt_cluster_hw_barrier(); + +// TODO: currently only works for single cluster otherwise need to +// synchronize all cores here +#ifdef BIST + void *local_a, *local_b, *local_c; + void *remote_a, *remote_b, *remote_c; + + // Aliases + uint32_t M = args.M; + uint32_t N = args.N; + uint32_t K = args.K; + uint32_t dtype_size = args.prec; + + // Calculate size and pointers for each cluster + uint32_t frac_m = M / snrt_cluster_num(); + uint32_t frac_a = frac_m * K; + uint32_t frac_c = frac_m * N; + uint32_t size_frac_a = frac_a * dtype_size; + uint32_t size_b = K * N * dtype_size; + uint32_t size_frac_c = frac_c * dtype_size; + uint32_t offset_a = frac_a * snrt_cluster_idx(); + uint32_t offset_c = frac_c * snrt_cluster_idx(); + remote_a = a + offset_a; + remote_b = b; + remote_c = c + offset_c; + + // Allocate space in TCDM + local_a = (void *)snrt_l1_next(); + local_b = local_a + size_frac_a; + local_c = local_b + size_b; + + uint32_t errors = M * N; + + if (snrt_cluster_core_idx() == 0) { + for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { + uint32_t idx = m * N + n; + switch (dtype_size) { + case FP64: + if (fabs(result[idx] - ((double *)local_c)[idx]) < + fabs(result[idx] * 0.00001)) + errors--; + break; + case FP32: + if (fabs(result[idx] - ((float *)local_c)[idx]) < + fabs(result[idx] * 0.0001)) + errors--; + break; + case FP16: + if (fabs(result[idx] - ((__fp16 *)local_c)[idx]) < + fabs(result[idx] * 0.005)) + errors--; + case FP8: + printf("No golden model yet for fp8!\n"); + return -1; + break; + } + } + } + printf("%d/%d Errors\n", errors, M * N); + } + + return errors; +#endif + + return retcode; +} diff --git a/target/snitch_cluster/sw.mk b/target/snitch_cluster/sw.mk index e4456fdfc3..46cf7a3bb6 100644 --- a/target/snitch_cluster/sw.mk +++ b/target/snitch_cluster/sw.mk @@ -50,6 +50,7 @@ include sw/tests/tests.mk APPS = sw/apps/nop APPS += sw/apps/blas/axpy APPS += sw/apps/blas/gemm +APPS += sw/apps/blas/gemm_v2 APPS += sw/apps/blas/dot APPS += sw/apps/blas/syrk APPS += sw/apps/dnn/batchnorm diff --git a/target/snitch_cluster/sw/apps/blas/gemm_v2/app.mk b/target/snitch_cluster/sw/apps/blas/gemm_v2/app.mk new file mode 100644 index 0000000000..f22121b8b6 --- /dev/null +++ b/target/snitch_cluster/sw/apps/blas/gemm_v2/app.mk @@ -0,0 +1,14 @@ +# Copyright 2023 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Luca Colagrande + +APP := gemm_v2 +$(APP)_BUILD_DIR ?= $(ROOT)/target/snitch_cluster/sw/apps/blas/$(APP)/build +SRC_DIR := $(ROOT)/sw/blas/$(APP)/src +SRCS := $(SRC_DIR)/main.c +$(APP)_INCDIRS := $(ROOT)/sw/blas + +include $(ROOT)/sw/apps/common.mk +include $(ROOT)/target/snitch_cluster/sw/apps/common.mk