From 56e4a0f1b1c59210060565523f1bc2b7b497bf83 Mon Sep 17 00:00:00 2001 From: Luca Colagrande Date: Sat, 10 Aug 2024 06:37:38 +0200 Subject: [PATCH] Fix Vivi's initial PR proposal --- sw/blas/gemm/scripts/datagen.py | 92 ++++++++--------- sw/dnn/flashattention_2/__init__.py | 13 --- sw/dnn/flashattention_2/scripts/datagen.py | 99 ++++++------------- sw/dnn/fused_concat_linear/__init__.py | 9 -- sw/dnn/fused_concat_linear/scripts/datagen.py | 1 - sw/dnn/fused_concat_linear/scripts/verify.py | 1 - sw/dnn/layernorm/__init__.py | 10 -- util/container/Dockerfile | 1 + 8 files changed, 72 insertions(+), 154 deletions(-) delete mode 100644 sw/dnn/flashattention_2/__init__.py delete mode 100644 sw/dnn/fused_concat_linear/__init__.py delete mode 100644 sw/dnn/layernorm/__init__.py diff --git a/sw/blas/gemm/scripts/datagen.py b/sw/blas/gemm/scripts/datagen.py index e1c130ce50..da7f8ba578 100755 --- a/sw/blas/gemm/scripts/datagen.py +++ b/sw/blas/gemm/scripts/datagen.py @@ -45,49 +45,55 @@ def infer_implementation(self, gemm_fp): prec, impl = re.search(r'gemm_fp(\d+)_(\w+)', gemm_fp).group(1, 2) return (int(prec) / 8), impl - def load_params(self, params): - self.M = params.get('M') - self.N = params.get('N') - self.K = params.get('K') - self.m_tiles = params.get('m_tiles') - self.n_tiles = params.get('n_tiles') - self.k_tiles = params.get('k_tiles') - self.load_a = params.get('load_a') - self.load_b = params.get('load_b') - self.load_c = params.get('load_c') - self.setup_ssr = params.get('setup_ssr') - self.parallelize_m = params.get('parallelize_m') - self.parallelize_k = params.get('parallelize_k') - self.gemm_fp = params.get('gemm_fp') - self.transa = params.get('transa') - self.transb = params.get('transb') - self.alpha = params.get('alpha', 1) - self.beta = params.get('beta') - self.section = params.get('section') - self.dtype, self.impl = self.infer_implementation(self.gemm_fp) - self.prec = data_utils.size_from_precision_t(self.dtype) - self.ff_desc = data_utils.ff_desc_from_precision_t(self.dtype) - self.ctype = data_utils.ctype_from_precision_t(self.dtype) - - def validate(self): - frac_m = self.M / self.m_tiles - frac_n = self.N / self.n_tiles - frac_k = self.K / self.k_tiles - - a_size = frac_m * frac_k * self.prec - b_size = frac_k * frac_n * self.prec - c_size = frac_m * frac_n * self.prec + def validate_config(self, gemm_fp, parallelize_m, + parallelize_k, m_tiles, n_tiles, k_tiles, transa, + transb, M, N, K, beta, **kwargs): + frac_m = M / m_tiles + frac_n = N / n_tiles + frac_k = K / k_tiles + + dtype, impl = self.infer_implementation(gemm_fp) + + # Calculate total TCDM occupation + # Note: doesn't account for double buffering + prec = data_utils.size_from_precision_t(dtype) + a_size = frac_m * frac_k * prec + b_size = frac_k * frac_n * prec + c_size = frac_m * frac_n * prec total_size = a_size total_size += b_size total_size += c_size data_utils.validate_tcdm_footprint(total_size) + assert (M % m_tiles) == 0, 'M is not an integer multiple of tile size' + assert (N % n_tiles) == 0, 'N is not an integer multiple of tile size' + assert (K % k_tiles) == 0, 'K is not an integer multiple of tile size' + assert not (parallelize_m and parallelize_k), 'Cannot parallelize K and M simultaneously' + assert not transa, 'SIMD kernels don\'t support transposed A matrix' + assert (dtype == 8) or (impl == 'baseline') or (impl == 'naive') \ + or transb, 'Optimized SIMD kernels only support transposed B matrix' + assert not transb or n_tiles == 1, 'Tiling in the N dimension not supported' \ + ' if B is transposed' + assert not transb or k_tiles == 1, 'Tiling in the K dimension not supported' \ + ' if B is transposed' + assert (impl == 'baseline') or (impl == 'naive') or frac_n >= 8, \ + 'N dimension of tile size must be greater or equal to the unrolling factor (8) ' \ + 'when using optimized kernels' + assert beta == 0 or beta == 1, 'Only values of 0 or 1 supported for beta' + assert not (dtype == 8 and impl == "baseline"), 'No baseline implemented' \ + ' for FP64 (switch to NAIVE)' + assert not (((dtype == 8) or (dtype == 4)) and impl == "opt_ex"), \ + 'Expanding GEMM kernels' \ + ' not supported for FP64 and FP32' + assert not (dtype == 1 and impl == "opt"), 'FP8 not supported in' \ + ' optimized implementation' \ + ' (switch to opt_ex)' + def emit_header(self, **kwargs): header = [super().emit_header()] - self.load_params(kwargs) # Validate parameters - self.validate() + self.validate_config(**kwargs) M, N, K = kwargs['M'], kwargs['N'], kwargs['K'] @@ -111,23 +117,7 @@ def emit_header(self, **kwargs): cfg = { 'prec': prec, - 'setup_ssr': kwargs['setup_ssr'], - 'parallelize_m': kwargs['parallelize_m'], - 'parallelize_k': kwargs['parallelize_k'], - 'm_tiles': kwargs['m_tiles'], - 'n_tiles': kwargs['n_tiles'], - 'k_tiles': kwargs['k_tiles'], - 'load_a': kwargs['load_a'], - 'load_b': kwargs['load_b'], - 'load_c': kwargs['load_c'], - 'transa': kwargs['transa'], - 'transb': kwargs['transb'], - 'M': M, - 'N': N, - 'K': K, - 'alpha': kwargs['alpha'], - 'beta': kwargs['beta'], - 'gemm_fp': kwargs['gemm_fp'], + **kwargs, 'a': a_uid, 'b': b_uid, 'c': c_uid, diff --git a/sw/dnn/flashattention_2/__init__.py b/sw/dnn/flashattention_2/__init__.py deleted file mode 100644 index 319745fc7d..0000000000 --- a/sw/dnn/flashattention_2/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Luca Colagrande -# Viviane Potocnik - -from .scripts.datagen import exact_golden_model, exact_flexfloat_golden_model, \ - get_gemm_implementation, load_params, \ - validate, emit_header - -__all__ = ['exact_golden_model', 'exact_flexfloat_golden_model', - 'get_gemm_implementation', 'load_params', 'validate', 'emit_header'] diff --git a/sw/dnn/flashattention_2/scripts/datagen.py b/sw/dnn/flashattention_2/scripts/datagen.py index 39103b8559..f54ff7a08f 100755 --- a/sw/dnn/flashattention_2/scripts/datagen.py +++ b/sw/dnn/flashattention_2/scripts/datagen.py @@ -21,7 +21,9 @@ np.random.seed(42) torch.manual_seed(42) -np.set_printoptions(formatter={'object': str}) +# AXI splits bursts crossing 4KB address boundaries. To minimize +# the occurrence of these splits the data should be aligned to 4KB +BURST_ALIGNMENT = 4096 def torch_golden_model(Q, K, V): @@ -135,38 +137,22 @@ def exact_flexfloat_golden_model(Q, K, V, B_r, B_c, desc): return np.concatenate(O_tiles, 0) -def load_params(self, params): - self.L = params['L'] - self.S = params['S'] - self.d = params['d'] - self.B_r = params['B_r'] - self.B_c = params['B_c'] - self.dtype = params['dtype'] - self.baseline = params['baseline'] - self.use_mask = params['use_mask'] - self.gemm_impl = self.get_gemm_implementation(params) - # self.torch_type = data_utils.torch_type_from_precision_t(self.dtype) - self.ff_desc = data_utils.ff_desc_from_precision_t(self.dtype) - self.ctype = data_utils.ctype_from_precision_t(self.dtype) - self.prec = data_utils.size_from_precision_t(self.dtype) - - # Verify layer parameters are valid -def validate(self): - assert (self.L % self.B_r) == 0, 'L is not an integer multiple of B_r' - assert (self.S % self.B_c) == 0, 'S is not an integer multiple of B_c' - assert self.dtype != 'FP64', 'FP64 precision is not supported yet' +def validate_config(L, S, d, B_r, B_c, dtype, baseline, gemm_impl): + assert (L % B_r) == 0, 'L is not an integer multiple of B_r' + assert (S % B_c) == 0, 'S is not an integer multiple of B_c' + assert dtype != 'FP64', 'FP64 precision is not supported yet' # Calculate total TCDM occupation - q_fa_size = self.B_r * self.d * self.prec - k_fa_size = self.B_c * self.d * self.prec - v_fa_size = self.B_c * self.d * self.prec - s_fa_size = self.B_r * self.B_c * self.prec - p_fa_size = self.B_r * self.B_c * self.prec - o_fa_size = self.B_r * self.d * self.prec - m_i_size = self.B_r * self.prec - l_i_size = self.B_r * self.prec - mask_size = self.B_r * self.B_c * self.prec if self.use_mask else 0 + prec = data_utils.size_from_precision_t(dtype) + q_fa_size = B_r * d * prec + k_fa_size = B_c * d * prec + v_fa_size = B_c * d * prec + s_fa_size = B_r * B_c * prec + p_fa_size = B_r * B_c * prec + o_fa_size = B_r * d * prec + m_i_size = B_r * prec + l_i_size = B_r * prec total_size = q_fa_size total_size += k_fa_size total_size += v_fa_size * 2 # V and V^t @@ -175,51 +161,26 @@ def validate(self): total_size += o_fa_size total_size += m_i_size * 2 # m_i and m_i_prev total_size += l_i_size - total_size += mask_size data_utils.validate_tcdm_footprint(total_size) # Q*K^t - gemm_gen = gemm.GemmDataGen() - gemm_params = { - 'gemm_fp': self.gemm_impl, - 'parallelize_m': 0, - 'parallelize_k': 0, - 'm_tiles': 1, - 'n_tiles': 1, - 'k_tiles': 1, - 'transa': 0, - 'transb': 1, - 'M': self.B_r, - 'N': self.B_c, - 'K': self.d, - 'beta': 0 - } - gemm_gen.load_params(gemm_params) - gemm_gen.validate() + gemm.GemmDataGen().validate_config( + gemm_fp=gemm_impl, parallelize_m=0, parallelize_k=0, m_tiles=1, n_tiles=1, + k_tiles=1, transa=0, transb=1, M=B_r, N=B_c, K=d, beta=0 + ) # P*V - gemm_params = { - 'gemm_fp': self.gemm_impl, - 'parallelize_m': 0, - 'parallelize_k': 0, - 'm_tiles': 1, - 'n_tiles': 1, - 'k_tiles': 1, - 'transa': 0, - 'M': self.B_r, - 'N': self.d, - 'K': self.B_c, - 'beta': 1 - } - if self.baseline: - gemm_params['transb'] = 0 - gemm_gen.load_params(gemm_params) - gemm_gen.validate() + if baseline: + gemm.GemmDataGen().validate_config( + gemm_fp=gemm_impl, parallelize_m=0, parallelize_k=0, m_tiles=1, n_tiles=1, + k_tiles=1, transa=0, transb=0, M=B_r, N=d, K=B_c, beta=1 + ) else: # P*(V^t)^t - gemm_params['transb'] = 1 - gemm_gen.load_params(gemm_params) - gemm_gen.validate() + gemm.GemmDataGen().validate_config( + gemm_fp=gemm_impl, parallelize_m=0, parallelize_k=0, m_tiles=1, n_tiles=1, + k_tiles=1, transa=0, transb=1, M=B_r, N=d, K=B_c, beta=1 + ) def get_gemm_implementation(params): @@ -243,7 +204,7 @@ def emit_header(section, params): prec = params['dtype'] gemm_impl = get_gemm_implementation(params) - # TODO: Add validation (vivianep) + validate_config(gemm_impl=gemm_impl, **params) # torch_type = data_utils.torch_type_from_precision_t(prec) ff_desc = data_utils.ff_desc_from_precision_t(prec) diff --git a/sw/dnn/fused_concat_linear/__init__.py b/sw/dnn/fused_concat_linear/__init__.py deleted file mode 100644 index 4a88f99695..0000000000 --- a/sw/dnn/fused_concat_linear/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2024 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Luca Colagrande - -from .scripts.datagen import golden_model, emit_header - -__all__ = ['golden_model', 'emit_header'] diff --git a/sw/dnn/fused_concat_linear/scripts/datagen.py b/sw/dnn/fused_concat_linear/scripts/datagen.py index bffd80a2a1..5c7fdc62ef 100755 --- a/sw/dnn/fused_concat_linear/scripts/datagen.py +++ b/sw/dnn/fused_concat_linear/scripts/datagen.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: Apache-2.0 # # Luca Colagrande -# Tim Fischer # Viviane Potocnik import argparse diff --git a/sw/dnn/fused_concat_linear/scripts/verify.py b/sw/dnn/fused_concat_linear/scripts/verify.py index 2bd0f36114..0bb898a0cc 100755 --- a/sw/dnn/fused_concat_linear/scripts/verify.py +++ b/sw/dnn/fused_concat_linear/scripts/verify.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: Apache-2.0 # # Luca Colagrande -# Tim Fischer # Viviane Potocnik import sys diff --git a/sw/dnn/layernorm/__init__.py b/sw/dnn/layernorm/__init__.py deleted file mode 100644 index 65f0cf5a3f..0000000000 --- a/sw/dnn/layernorm/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Luca Colagrande - -from .scripts.datagen import golden_model, golden_model_torch, \ - validate_config, emit_header - -__all__ = ['golden_model', 'golden_model_torch', 'validate_config', 'emit_header'] diff --git a/util/container/Dockerfile b/util/container/Dockerfile index d9b52238e9..0bbd03730e 100644 --- a/util/container/Dockerfile +++ b/util/container/Dockerfile @@ -158,6 +158,7 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" COPY . /tmp/snitch_cluster RUN pip install -r /tmp/snitch_cluster/python-requirements.txt RUN pip install /tmp/snitch_cluster/ +RUN rm -rf /tmp/snitch_cluster # Add Verilator to PATH ENV PATH "/tools/verilator/bin:${PATH}"