-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add quick double-buffered GEMM in gemm_v2
- Loading branch information
Showing
13 changed files
with
1,963 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright 2024 ETH Zurich and University of Bologna. | ||
# Licensed under the Apache License, Version 2.0, see LICENSE for details. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Luca Colagrande <[email protected]> | ||
|
||
from .scripts.datagen import GemmDataGen | ||
|
||
__all__ = ['GemmDataGen'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
// Copyright 2024 ETH Zurich and University of Bologna. | ||
// Licensed under the Apache License, Version 2.0, see LICENSE for details. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
{ | ||
setup_ssr: 1, | ||
m_tiles: 2, // number of tiles in M dimension | ||
transa: false, | ||
transb: true, // must be true for SIMD | ||
M: 96, | ||
N: 48, | ||
K: 48, | ||
alpha: 1, | ||
beta: 0, | ||
gemm_fp: "gemm_fp64_opt" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
[ | ||
<% DOUBLE_BUFFER = 1 %> | ||
<% N_TILES = 2 %> | ||
|
||
// Compute cores | ||
% for j in range(0, 8): | ||
{ | ||
"thread": "${f'hart_{j}'}", | ||
"roi": [ | ||
% for i in range(0, N_TILES): | ||
{"idx": ${2 * i + 1}, "label": "${f'tile_{i}'}"}, | ||
% endfor | ||
] | ||
}, | ||
% endfor | ||
|
||
// DMA core | ||
{ | ||
"thread": "hart_8", | ||
"roi": [ | ||
% if not DOUBLE_BUFFER: | ||
% for i in range(0, N_TILES): | ||
{"idx": ${4 * i + 1}, "label": "${f'tile_{i}_in'}"}, | ||
{"idx": ${4 * i + 3}, "label": "${f'tile_{i}_out'}"}, | ||
% endfor | ||
% else: | ||
{"idx": 1, "label": "tile_0_in"}, | ||
% for i in range(1, N_TILES): | ||
{"idx": ${4 * (i - 1) + 3}, "label": "${f'tile_{i}_in'}"}, | ||
{"idx": ${4 * (i - 1) + 5}, "label": "${f'tile_{i-1}_out'}"}, | ||
% endfor | ||
{"idx": ${4 * (i - 1) + 7}, "label": "tile_15_out"}, | ||
% endif | ||
] | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2022 ETH Zurich and University of Bologna. | ||
# Licensed under the Apache License, Version 2.0, see LICENSE for details. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Authors: Tim Fischer <[email protected]> | ||
# Luca Bertaccini <[email protected]> | ||
# Viviane Potocnik <[email protected]> | ||
# Luca Colagrande <[email protected]> | ||
|
||
import numpy as np | ||
import re | ||
import pyflexfloat as ff | ||
import sys | ||
|
||
from snitch.util.sim import data_utils | ||
from snitch.util.sim.data_utils import DataGen, format_array_declaration, \ | ||
format_struct_definition, format_array_definition, format_ifdef_wrapper | ||
|
||
|
||
np.random.seed(42) | ||
|
||
|
||
class GemmDataGen(DataGen): | ||
|
||
# AXI splits bursts crossing 4KB address boundaries. To minimize | ||
# the occurrence of these splits the data should be aligned to 4KB | ||
BURST_ALIGNMENT = 4096 | ||
|
||
def golden_model(self, alpha, a, b, beta, c): | ||
return alpha * np.matmul(a, b) + beta * c | ||
|
||
def exact_golden_model(self, alpha, a, b, beta, c): | ||
M, N, K = a.shape[0], b.shape[1], b.shape[0] | ||
result = beta * c | ||
for m in range(M): | ||
for n in range(N): | ||
for k in range(K): | ||
result[m][n] += a[m][k] * b[k][n] | ||
return result | ||
|
||
def infer_implementation(self, gemm_fp): | ||
# gemm_fp: "gemm_fp64_opt" | ||
# create a regex with fp_<type>_<implementation> | ||
prec, impl = re.search(r'gemm_fp(\d+)_(\w+)', gemm_fp).group(1, 2) | ||
return (int(prec) / 8), impl | ||
|
||
def validate_config(self, gemm_fp, | ||
m_tiles, transa, | ||
transb, M, N, K, beta, **kwargs): | ||
frac_m = M / m_tiles | ||
frac_n = N / 1 | ||
frac_k = K / 1 | ||
|
||
dtype, impl = self.infer_implementation(gemm_fp) | ||
|
||
# Calculate total TCDM occupation | ||
prec = data_utils.size_from_precision_t(dtype) | ||
a_size = frac_m * frac_k * prec | ||
b_size = frac_k * frac_n * prec | ||
c_size = frac_m * frac_n * prec | ||
total_size = a_size | ||
total_size += b_size | ||
total_size += c_size | ||
data_utils.validate_tcdm_footprint(2*total_size) | ||
|
||
assert (M % m_tiles) == 0, 'M is not an integer multiple of tile size' | ||
assert (N % 1) == 0, 'N is not an integer multiple of tile size' | ||
assert (K % 1) == 0, 'K is not an integer multiple of tile size' | ||
assert not transa, 'SIMD kernels don\'t support transposed A matrix' | ||
assert (dtype == 8) or (impl == 'baseline') or (impl == 'naive') \ | ||
or transb, 'Optimized SIMD kernels only support transposed B matrix' | ||
assert not transb or 1 == 1, 'Tiling in the N dimension not supported' \ | ||
' if B is transposed' | ||
assert not transb or 1 == 1, 'Tiling in the K dimension not supported' \ | ||
' if B is transposed' | ||
assert (impl == 'baseline') or (impl == 'naive') or frac_n >= 8, \ | ||
'N dimension of tile size must be greater or equal to the unrolling factor (8) ' \ | ||
'when using optimized kernels' | ||
assert beta == 0 or beta == 1, 'Only values of 0 or 1 supported for beta' | ||
assert not (dtype == 8 and impl == "baseline"), 'No baseline implemented' \ | ||
' for FP64 (switch to NAIVE)' | ||
assert not (((dtype == 8) or (dtype == 4)) and impl == "opt_ex"), \ | ||
'Expanding GEMM kernels' \ | ||
' not supported for FP64 and FP32' | ||
assert not (dtype == 1 and impl == "opt"), 'FP8 not supported in' \ | ||
' optimized implementation' \ | ||
' (switch to opt_ex)' | ||
|
||
def emit_header(self, **kwargs): | ||
header = [super().emit_header()] | ||
|
||
# Validate parameters | ||
self.validate_config(**kwargs) | ||
|
||
M, N, K = kwargs['M'], kwargs['N'], kwargs['K'] | ||
|
||
prec, _ = self.infer_implementation(kwargs['gemm_fp']) | ||
|
||
ff_desc = data_utils.ff_desc_from_precision_t(prec) | ||
ctype = data_utils.ctype_from_precision_t(prec) | ||
|
||
a = ff.array(np.random.rand(M, K), ff_desc) | ||
b = ff.array(np.random.rand(K, N), ff_desc) | ||
c = ff.array(np.random.rand(M, N), ff_desc) | ||
result = self.exact_golden_model(1, a, b, kwargs['beta'], c) | ||
|
||
# Store matrices in transposed form if requested | ||
a = a.T if kwargs['transa'] else a | ||
b = b.T if kwargs['transb'] else b | ||
|
||
a_uid = 'a' | ||
b_uid = 'b' | ||
c_uid = 'c' | ||
|
||
cfg = { | ||
'prec': prec, | ||
**kwargs, | ||
'a': a_uid, | ||
'b': b_uid, | ||
'c': c_uid, | ||
} | ||
|
||
a = a.flatten() | ||
b = b.flatten() | ||
c = c.flatten() | ||
|
||
header += [format_array_declaration(ctype, a_uid, a.shape)] | ||
header += [format_array_declaration(ctype, b_uid, b.shape)] | ||
header += [format_array_declaration(ctype, c_uid, c.shape)] | ||
header += [format_struct_definition('gemm_args_t', 'args', cfg)] | ||
header += [format_array_definition(ctype, a_uid, a, | ||
section=kwargs['section'])] | ||
header += [format_array_definition(ctype, b_uid, b, | ||
section=kwargs['section'])] | ||
header += [format_array_definition(ctype, c_uid, c, | ||
section=kwargs['section'])] | ||
result_def = format_array_definition(ctype, 'result', result.flatten()) | ||
header += [format_ifdef_wrapper('BIST', result_def)] | ||
header = '\n\n'.join(header) | ||
|
||
return header | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(GemmDataGen().main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2023 ETH Zurich and University of Bologna. | ||
# Licensed under the Apache License, Version 2.0, see LICENSE for details. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Luca Colagrande <[email protected]> | ||
|
||
import numpy as np | ||
import sys | ||
from datagen import GemmDataGen | ||
|
||
from snitch.util.sim.verif_utils import Verifier | ||
from snitch.util.sim.data_utils import ctype_from_precision_t | ||
|
||
|
||
class GemmVerifier(Verifier): | ||
|
||
OUTPUT_UIDS = ['c'] | ||
ERR_THRESHOLD = { | ||
1: 1e-4, | ||
2: 1e-2, | ||
4: 1e-6, | ||
8: 1e-6 | ||
} | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.func_args = { | ||
'alpha': 'd', | ||
'prec': 'I', | ||
'setup_ssr': 'I', | ||
'parallelize_m': 'I', | ||
'parallelize_k': 'I', | ||
'm_tiles': 'I', | ||
'n_tiles': 'I', | ||
'k_tiles': 'I', | ||
'load_a': 'I', | ||
'load_b': 'I', | ||
'load_c': 'I', | ||
'transa': 'I', | ||
'transb': 'I', | ||
'M': 'I', | ||
'N': 'I', | ||
'K': 'I', | ||
'a': 'I', | ||
'b': 'I', | ||
'beta': 'I', | ||
'c': 'I', | ||
'gemm_fp': 'I' | ||
} | ||
self.func_args = self.get_input_from_symbol('args', self.func_args) | ||
|
||
def get_actual_results(self): | ||
prec = self.func_args['prec'] | ||
return self.get_output_from_symbol(self.OUTPUT_UIDS[0], ctype_from_precision_t(prec)) | ||
|
||
def get_expected_results(self): | ||
prec = self.func_args['prec'] | ||
a = self.get_input_from_symbol('a', ctype_from_precision_t(prec)) | ||
b = self.get_input_from_symbol('b', ctype_from_precision_t(prec)) | ||
c = self.get_input_from_symbol('c', ctype_from_precision_t(prec)) | ||
beta = self.func_args['beta'] | ||
m = self.func_args['M'] | ||
n = self.func_args['N'] | ||
k = self.func_args['K'] | ||
tb = self.func_args['transb'] | ||
|
||
a = np.reshape(a, (m, k)) | ||
if tb: | ||
b = np.reshape(b, (n, k)) | ||
b = b.transpose() | ||
else: | ||
b = np.reshape(b, (k, n)) | ||
c = np.reshape(c, (m, n)) | ||
return GemmDataGen().exact_golden_model(1, a, b, beta, c).flatten() | ||
|
||
def check_results(self, *args): | ||
prec = self.func_args['prec'] | ||
return super().check_results(*args, rtol=self.ERR_THRESHOLD[prec]) | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(GemmVerifier().main()) |
Oops, something went wrong.