Skip to content

Commit

Permalink
Add quick double-buffered GEMM in gemm_v2
Browse files Browse the repository at this point in the history
  • Loading branch information
colluca committed Aug 21, 2024
1 parent 14dad2a commit 515e2ca
Show file tree
Hide file tree
Showing 13 changed files with 1,963 additions and 0 deletions.
9 changes: 9 additions & 0 deletions sw/blas/gemm_v2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright 2024 ETH Zurich and University of Bologna.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
#
# Luca Colagrande <[email protected]>

from .scripts.datagen import GemmDataGen

__all__ = ['GemmDataGen']
16 changes: 16 additions & 0 deletions sw/blas/gemm_v2/data/params.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright 2024 ETH Zurich and University of Bologna.
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

{
setup_ssr: 1,
m_tiles: 2, // number of tiles in M dimension
transa: false,
transb: true, // must be true for SIMD
M: 96,
N: 48,
K: 48,
alpha: 1,
beta: 0,
gemm_fp: "gemm_fp64_opt"
}
36 changes: 36 additions & 0 deletions sw/blas/gemm_v2/roi.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[
<% DOUBLE_BUFFER = 1 %>
<% N_TILES = 2 %>

// Compute cores
% for j in range(0, 8):
{
"thread": "${f'hart_{j}'}",
"roi": [
% for i in range(0, N_TILES):
{"idx": ${2 * i + 1}, "label": "${f'tile_{i}'}"},
% endfor
]
},
% endfor

// DMA core
{
"thread": "hart_8",
"roi": [
% if not DOUBLE_BUFFER:
% for i in range(0, N_TILES):
{"idx": ${4 * i + 1}, "label": "${f'tile_{i}_in'}"},
{"idx": ${4 * i + 3}, "label": "${f'tile_{i}_out'}"},
% endfor
% else:
{"idx": 1, "label": "tile_0_in"},
% for i in range(1, N_TILES):
{"idx": ${4 * (i - 1) + 3}, "label": "${f'tile_{i}_in'}"},
{"idx": ${4 * (i - 1) + 5}, "label": "${f'tile_{i-1}_out'}"},
% endfor
{"idx": ${4 * (i - 1) + 7}, "label": "tile_15_out"},
% endif
]
}
]
146 changes: 146 additions & 0 deletions sw/blas/gemm_v2/scripts/datagen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#!/usr/bin/env python3
# Copyright 2022 ETH Zurich and University of Bologna.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
#
# Authors: Tim Fischer <[email protected]>
# Luca Bertaccini <[email protected]>
# Viviane Potocnik <[email protected]>
# Luca Colagrande <[email protected]>

import numpy as np
import re
import pyflexfloat as ff
import sys

from snitch.util.sim import data_utils
from snitch.util.sim.data_utils import DataGen, format_array_declaration, \
format_struct_definition, format_array_definition, format_ifdef_wrapper


np.random.seed(42)


class GemmDataGen(DataGen):

# AXI splits bursts crossing 4KB address boundaries. To minimize
# the occurrence of these splits the data should be aligned to 4KB
BURST_ALIGNMENT = 4096

def golden_model(self, alpha, a, b, beta, c):
return alpha * np.matmul(a, b) + beta * c

def exact_golden_model(self, alpha, a, b, beta, c):
M, N, K = a.shape[0], b.shape[1], b.shape[0]
result = beta * c
for m in range(M):
for n in range(N):
for k in range(K):
result[m][n] += a[m][k] * b[k][n]
return result

def infer_implementation(self, gemm_fp):
# gemm_fp: "gemm_fp64_opt"
# create a regex with fp_<type>_<implementation>
prec, impl = re.search(r'gemm_fp(\d+)_(\w+)', gemm_fp).group(1, 2)
return (int(prec) / 8), impl

def validate_config(self, gemm_fp,
m_tiles, transa,
transb, M, N, K, beta, **kwargs):
frac_m = M / m_tiles
frac_n = N / 1
frac_k = K / 1

dtype, impl = self.infer_implementation(gemm_fp)

# Calculate total TCDM occupation
prec = data_utils.size_from_precision_t(dtype)
a_size = frac_m * frac_k * prec
b_size = frac_k * frac_n * prec
c_size = frac_m * frac_n * prec
total_size = a_size
total_size += b_size
total_size += c_size
data_utils.validate_tcdm_footprint(2*total_size)

assert (M % m_tiles) == 0, 'M is not an integer multiple of tile size'
assert (N % 1) == 0, 'N is not an integer multiple of tile size'
assert (K % 1) == 0, 'K is not an integer multiple of tile size'
assert not transa, 'SIMD kernels don\'t support transposed A matrix'
assert (dtype == 8) or (impl == 'baseline') or (impl == 'naive') \
or transb, 'Optimized SIMD kernels only support transposed B matrix'
assert not transb or 1 == 1, 'Tiling in the N dimension not supported' \
' if B is transposed'
assert not transb or 1 == 1, 'Tiling in the K dimension not supported' \
' if B is transposed'
assert (impl == 'baseline') or (impl == 'naive') or frac_n >= 8, \
'N dimension of tile size must be greater or equal to the unrolling factor (8) ' \
'when using optimized kernels'
assert beta == 0 or beta == 1, 'Only values of 0 or 1 supported for beta'
assert not (dtype == 8 and impl == "baseline"), 'No baseline implemented' \
' for FP64 (switch to NAIVE)'
assert not (((dtype == 8) or (dtype == 4)) and impl == "opt_ex"), \
'Expanding GEMM kernels' \
' not supported for FP64 and FP32'
assert not (dtype == 1 and impl == "opt"), 'FP8 not supported in' \
' optimized implementation' \
' (switch to opt_ex)'

def emit_header(self, **kwargs):
header = [super().emit_header()]

# Validate parameters
self.validate_config(**kwargs)

M, N, K = kwargs['M'], kwargs['N'], kwargs['K']

prec, _ = self.infer_implementation(kwargs['gemm_fp'])

ff_desc = data_utils.ff_desc_from_precision_t(prec)
ctype = data_utils.ctype_from_precision_t(prec)

a = ff.array(np.random.rand(M, K), ff_desc)
b = ff.array(np.random.rand(K, N), ff_desc)
c = ff.array(np.random.rand(M, N), ff_desc)
result = self.exact_golden_model(1, a, b, kwargs['beta'], c)

# Store matrices in transposed form if requested
a = a.T if kwargs['transa'] else a
b = b.T if kwargs['transb'] else b

a_uid = 'a'
b_uid = 'b'
c_uid = 'c'

cfg = {
'prec': prec,
**kwargs,
'a': a_uid,
'b': b_uid,
'c': c_uid,
}

a = a.flatten()
b = b.flatten()
c = c.flatten()

header += [format_array_declaration(ctype, a_uid, a.shape)]
header += [format_array_declaration(ctype, b_uid, b.shape)]
header += [format_array_declaration(ctype, c_uid, c.shape)]
header += [format_struct_definition('gemm_args_t', 'args', cfg)]
header += [format_array_definition(ctype, a_uid, a,
section=kwargs['section'])]
header += [format_array_definition(ctype, b_uid, b,
section=kwargs['section'])]
header += [format_array_definition(ctype, c_uid, c,
section=kwargs['section'])]
result_def = format_array_definition(ctype, 'result', result.flatten())
header += [format_ifdef_wrapper('BIST', result_def)]
header = '\n\n'.join(header)

return header


if __name__ == "__main__":
sys.exit(GemmDataGen().main())
83 changes: 83 additions & 0 deletions sw/blas/gemm_v2/scripts/verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/usr/bin/env python3
# Copyright 2023 ETH Zurich and University of Bologna.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
#
# Luca Colagrande <[email protected]>

import numpy as np
import sys
from datagen import GemmDataGen

from snitch.util.sim.verif_utils import Verifier
from snitch.util.sim.data_utils import ctype_from_precision_t


class GemmVerifier(Verifier):

OUTPUT_UIDS = ['c']
ERR_THRESHOLD = {
1: 1e-4,
2: 1e-2,
4: 1e-6,
8: 1e-6
}

def __init__(self):
super().__init__()
self.func_args = {
'alpha': 'd',
'prec': 'I',
'setup_ssr': 'I',
'parallelize_m': 'I',
'parallelize_k': 'I',
'm_tiles': 'I',
'n_tiles': 'I',
'k_tiles': 'I',
'load_a': 'I',
'load_b': 'I',
'load_c': 'I',
'transa': 'I',
'transb': 'I',
'M': 'I',
'N': 'I',
'K': 'I',
'a': 'I',
'b': 'I',
'beta': 'I',
'c': 'I',
'gemm_fp': 'I'
}
self.func_args = self.get_input_from_symbol('args', self.func_args)

def get_actual_results(self):
prec = self.func_args['prec']
return self.get_output_from_symbol(self.OUTPUT_UIDS[0], ctype_from_precision_t(prec))

def get_expected_results(self):
prec = self.func_args['prec']
a = self.get_input_from_symbol('a', ctype_from_precision_t(prec))
b = self.get_input_from_symbol('b', ctype_from_precision_t(prec))
c = self.get_input_from_symbol('c', ctype_from_precision_t(prec))
beta = self.func_args['beta']
m = self.func_args['M']
n = self.func_args['N']
k = self.func_args['K']
tb = self.func_args['transb']

a = np.reshape(a, (m, k))
if tb:
b = np.reshape(b, (n, k))
b = b.transpose()
else:
b = np.reshape(b, (k, n))
c = np.reshape(c, (m, n))
return GemmDataGen().exact_golden_model(1, a, b, beta, c).flatten()

def check_results(self, *args):
prec = self.func_args['prec']
return super().check_results(*args, rtol=self.ERR_THRESHOLD[prec])


if __name__ == "__main__":
sys.exit(GemmVerifier().main())
Loading

0 comments on commit 515e2ca

Please sign in to comment.