Skip to content

Commit

Permalink
Fix Vivi's initial PR proposal
Browse files Browse the repository at this point in the history
  • Loading branch information
colluca committed Aug 10, 2024
1 parent b0a1bfb commit 56e4a0f
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 154 deletions.
92 changes: 41 additions & 51 deletions sw/blas/gemm/scripts/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,49 +45,55 @@ def infer_implementation(self, gemm_fp):
prec, impl = re.search(r'gemm_fp(\d+)_(\w+)', gemm_fp).group(1, 2)
return (int(prec) / 8), impl

def load_params(self, params):
self.M = params.get('M')
self.N = params.get('N')
self.K = params.get('K')
self.m_tiles = params.get('m_tiles')
self.n_tiles = params.get('n_tiles')
self.k_tiles = params.get('k_tiles')
self.load_a = params.get('load_a')
self.load_b = params.get('load_b')
self.load_c = params.get('load_c')
self.setup_ssr = params.get('setup_ssr')
self.parallelize_m = params.get('parallelize_m')
self.parallelize_k = params.get('parallelize_k')
self.gemm_fp = params.get('gemm_fp')
self.transa = params.get('transa')
self.transb = params.get('transb')
self.alpha = params.get('alpha', 1)
self.beta = params.get('beta')
self.section = params.get('section')
self.dtype, self.impl = self.infer_implementation(self.gemm_fp)
self.prec = data_utils.size_from_precision_t(self.dtype)
self.ff_desc = data_utils.ff_desc_from_precision_t(self.dtype)
self.ctype = data_utils.ctype_from_precision_t(self.dtype)

def validate(self):
frac_m = self.M / self.m_tiles
frac_n = self.N / self.n_tiles
frac_k = self.K / self.k_tiles

a_size = frac_m * frac_k * self.prec
b_size = frac_k * frac_n * self.prec
c_size = frac_m * frac_n * self.prec
def validate_config(self, gemm_fp, parallelize_m,
parallelize_k, m_tiles, n_tiles, k_tiles, transa,
transb, M, N, K, beta, **kwargs):
frac_m = M / m_tiles
frac_n = N / n_tiles
frac_k = K / k_tiles

dtype, impl = self.infer_implementation(gemm_fp)

# Calculate total TCDM occupation
# Note: doesn't account for double buffering
prec = data_utils.size_from_precision_t(dtype)
a_size = frac_m * frac_k * prec
b_size = frac_k * frac_n * prec
c_size = frac_m * frac_n * prec
total_size = a_size
total_size += b_size
total_size += c_size
data_utils.validate_tcdm_footprint(total_size)

assert (M % m_tiles) == 0, 'M is not an integer multiple of tile size'
assert (N % n_tiles) == 0, 'N is not an integer multiple of tile size'
assert (K % k_tiles) == 0, 'K is not an integer multiple of tile size'
assert not (parallelize_m and parallelize_k), 'Cannot parallelize K and M simultaneously'
assert not transa, 'SIMD kernels don\'t support transposed A matrix'
assert (dtype == 8) or (impl == 'baseline') or (impl == 'naive') \
or transb, 'Optimized SIMD kernels only support transposed B matrix'
assert not transb or n_tiles == 1, 'Tiling in the N dimension not supported' \
' if B is transposed'
assert not transb or k_tiles == 1, 'Tiling in the K dimension not supported' \
' if B is transposed'
assert (impl == 'baseline') or (impl == 'naive') or frac_n >= 8, \
'N dimension of tile size must be greater or equal to the unrolling factor (8) ' \
'when using optimized kernels'
assert beta == 0 or beta == 1, 'Only values of 0 or 1 supported for beta'
assert not (dtype == 8 and impl == "baseline"), 'No baseline implemented' \
' for FP64 (switch to NAIVE)'
assert not (((dtype == 8) or (dtype == 4)) and impl == "opt_ex"), \
'Expanding GEMM kernels' \
' not supported for FP64 and FP32'
assert not (dtype == 1 and impl == "opt"), 'FP8 not supported in' \
' optimized implementation' \
' (switch to opt_ex)'

def emit_header(self, **kwargs):
header = [super().emit_header()]
self.load_params(kwargs)

# Validate parameters
self.validate()
self.validate_config(**kwargs)

M, N, K = kwargs['M'], kwargs['N'], kwargs['K']

Expand All @@ -111,23 +117,7 @@ def emit_header(self, **kwargs):

cfg = {
'prec': prec,
'setup_ssr': kwargs['setup_ssr'],
'parallelize_m': kwargs['parallelize_m'],
'parallelize_k': kwargs['parallelize_k'],
'm_tiles': kwargs['m_tiles'],
'n_tiles': kwargs['n_tiles'],
'k_tiles': kwargs['k_tiles'],
'load_a': kwargs['load_a'],
'load_b': kwargs['load_b'],
'load_c': kwargs['load_c'],
'transa': kwargs['transa'],
'transb': kwargs['transb'],
'M': M,
'N': N,
'K': K,
'alpha': kwargs['alpha'],
'beta': kwargs['beta'],
'gemm_fp': kwargs['gemm_fp'],
**kwargs,
'a': a_uid,
'b': b_uid,
'c': c_uid,
Expand Down
13 changes: 0 additions & 13 deletions sw/dnn/flashattention_2/__init__.py

This file was deleted.

99 changes: 30 additions & 69 deletions sw/dnn/flashattention_2/scripts/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
np.random.seed(42)
torch.manual_seed(42)

np.set_printoptions(formatter={'object': str})
# AXI splits bursts crossing 4KB address boundaries. To minimize
# the occurrence of these splits the data should be aligned to 4KB
BURST_ALIGNMENT = 4096


def torch_golden_model(Q, K, V):
Expand Down Expand Up @@ -135,38 +137,22 @@ def exact_flexfloat_golden_model(Q, K, V, B_r, B_c, desc):
return np.concatenate(O_tiles, 0)


def load_params(self, params):
self.L = params['L']
self.S = params['S']
self.d = params['d']
self.B_r = params['B_r']
self.B_c = params['B_c']
self.dtype = params['dtype']
self.baseline = params['baseline']
self.use_mask = params['use_mask']
self.gemm_impl = self.get_gemm_implementation(params)
# self.torch_type = data_utils.torch_type_from_precision_t(self.dtype)
self.ff_desc = data_utils.ff_desc_from_precision_t(self.dtype)
self.ctype = data_utils.ctype_from_precision_t(self.dtype)
self.prec = data_utils.size_from_precision_t(self.dtype)


# Verify layer parameters are valid
def validate(self):
assert (self.L % self.B_r) == 0, 'L is not an integer multiple of B_r'
assert (self.S % self.B_c) == 0, 'S is not an integer multiple of B_c'
assert self.dtype != 'FP64', 'FP64 precision is not supported yet'
def validate_config(L, S, d, B_r, B_c, dtype, baseline, gemm_impl):
assert (L % B_r) == 0, 'L is not an integer multiple of B_r'
assert (S % B_c) == 0, 'S is not an integer multiple of B_c'
assert dtype != 'FP64', 'FP64 precision is not supported yet'

# Calculate total TCDM occupation
q_fa_size = self.B_r * self.d * self.prec
k_fa_size = self.B_c * self.d * self.prec
v_fa_size = self.B_c * self.d * self.prec
s_fa_size = self.B_r * self.B_c * self.prec
p_fa_size = self.B_r * self.B_c * self.prec
o_fa_size = self.B_r * self.d * self.prec
m_i_size = self.B_r * self.prec
l_i_size = self.B_r * self.prec
mask_size = self.B_r * self.B_c * self.prec if self.use_mask else 0
prec = data_utils.size_from_precision_t(dtype)
q_fa_size = B_r * d * prec
k_fa_size = B_c * d * prec
v_fa_size = B_c * d * prec
s_fa_size = B_r * B_c * prec
p_fa_size = B_r * B_c * prec
o_fa_size = B_r * d * prec
m_i_size = B_r * prec
l_i_size = B_r * prec
total_size = q_fa_size
total_size += k_fa_size
total_size += v_fa_size * 2 # V and V^t
Expand All @@ -175,51 +161,26 @@ def validate(self):
total_size += o_fa_size
total_size += m_i_size * 2 # m_i and m_i_prev
total_size += l_i_size
total_size += mask_size
data_utils.validate_tcdm_footprint(total_size)

# Q*K^t
gemm_gen = gemm.GemmDataGen()
gemm_params = {
'gemm_fp': self.gemm_impl,
'parallelize_m': 0,
'parallelize_k': 0,
'm_tiles': 1,
'n_tiles': 1,
'k_tiles': 1,
'transa': 0,
'transb': 1,
'M': self.B_r,
'N': self.B_c,
'K': self.d,
'beta': 0
}
gemm_gen.load_params(gemm_params)
gemm_gen.validate()
gemm.GemmDataGen().validate_config(
gemm_fp=gemm_impl, parallelize_m=0, parallelize_k=0, m_tiles=1, n_tiles=1,
k_tiles=1, transa=0, transb=1, M=B_r, N=B_c, K=d, beta=0
)

# P*V
gemm_params = {
'gemm_fp': self.gemm_impl,
'parallelize_m': 0,
'parallelize_k': 0,
'm_tiles': 1,
'n_tiles': 1,
'k_tiles': 1,
'transa': 0,
'M': self.B_r,
'N': self.d,
'K': self.B_c,
'beta': 1
}
if self.baseline:
gemm_params['transb'] = 0
gemm_gen.load_params(gemm_params)
gemm_gen.validate()
if baseline:
gemm.GemmDataGen().validate_config(
gemm_fp=gemm_impl, parallelize_m=0, parallelize_k=0, m_tiles=1, n_tiles=1,
k_tiles=1, transa=0, transb=0, M=B_r, N=d, K=B_c, beta=1
)
else:
# P*(V^t)^t
gemm_params['transb'] = 1
gemm_gen.load_params(gemm_params)
gemm_gen.validate()
gemm.GemmDataGen().validate_config(
gemm_fp=gemm_impl, parallelize_m=0, parallelize_k=0, m_tiles=1, n_tiles=1,
k_tiles=1, transa=0, transb=1, M=B_r, N=d, K=B_c, beta=1
)


def get_gemm_implementation(params):
Expand All @@ -243,7 +204,7 @@ def emit_header(section, params):
prec = params['dtype']
gemm_impl = get_gemm_implementation(params)

# TODO: Add validation (vivianep)
validate_config(gemm_impl=gemm_impl, **params)

# torch_type = data_utils.torch_type_from_precision_t(prec)
ff_desc = data_utils.ff_desc_from_precision_t(prec)
Expand Down
9 changes: 0 additions & 9 deletions sw/dnn/fused_concat_linear/__init__.py

This file was deleted.

1 change: 0 additions & 1 deletion sw/dnn/fused_concat_linear/scripts/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# SPDX-License-Identifier: Apache-2.0
#
# Luca Colagrande <[email protected]>
# Tim Fischer <[email protected]>
# Viviane Potocnik <[email protected]>

import argparse
Expand Down
1 change: 0 additions & 1 deletion sw/dnn/fused_concat_linear/scripts/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# SPDX-License-Identifier: Apache-2.0
#
# Luca Colagrande <[email protected]>
# Tim Fischer <[email protected]>
# Viviane Potocnik <[email protected]>

import sys
Expand Down
10 changes: 0 additions & 10 deletions sw/dnn/layernorm/__init__.py

This file was deleted.

1 change: 1 addition & 0 deletions util/container/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
COPY . /tmp/snitch_cluster
RUN pip install -r /tmp/snitch_cluster/python-requirements.txt
RUN pip install /tmp/snitch_cluster/
RUN rm -rf /tmp/snitch_cluster

# Add Verilator to PATH
ENV PATH "/tools/verilator/bin:${PATH}"
Expand Down

0 comments on commit 56e4a0f

Please sign in to comment.