From 6547da086af3281782383918c05458add0a65d74 Mon Sep 17 00:00:00 2001 From: Tim Fischer Date: Fri, 19 Jan 2024 14:13:18 +0100 Subject: [PATCH] gemm: Verify fp32 kernels --- sw/blas/gemm/verify.py | 29 ++++++++++++++++++++++++----- util/sim/data_utils.py | 12 ++++++++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/sw/blas/gemm/verify.py b/sw/blas/gemm/verify.py index b6f886b7b0..3fad658d7d 100755 --- a/sw/blas/gemm/verify.py +++ b/sw/blas/gemm/verify.py @@ -13,7 +13,7 @@ sys.path.append(str(Path(__file__).parent / '../../../util/sim/')) import verification # noqa: E402 from elf import Elf # noqa: E402 -from data_utils import bytes_to_doubles, bytes_to_uint32s # noqa: E402 +from data_utils import bytes_to_doubles, bytes_to_floats, bytes_to_uint32s # noqa: E402 ERR_THRESHOLD = 0.001 @@ -27,21 +27,40 @@ def main(): symbols_bin=args.symbols_bin, log=args.log, output_uids=['c']) - c_actual = np.array(bytes_to_doubles(raw_results['c'])) # Extract input operands from ELF file if args.symbols_bin: elf = Elf(args.symbols_bin) else: elf = Elf(args.snitch_bin) - a = np.array(bytes_to_doubles(elf.get_symbol_contents('a'))) - b = np.array(bytes_to_doubles(elf.get_symbol_contents('b'))) - c = np.array(bytes_to_doubles(elf.get_symbol_contents('c'))) beta = bytes_to_uint32s(elf.get_symbol_contents('BETA'))[0] m = bytes_to_uint32s(elf.get_symbol_contents('M'))[0] n = bytes_to_uint32s(elf.get_symbol_contents('N'))[0] k = bytes_to_uint32s(elf.get_symbol_contents('K'))[0] tb = bytes_to_uint32s(elf.get_symbol_contents('TB'))[0] + if elf.get_symbol_size('a') / (m * k) == 8: + a = np.array(bytes_to_doubles(elf.get_symbol_contents('a'))) + elif elf.get_symbol_size('a') / (m * k) == 4: + a = np.array(bytes_to_floats(elf.get_symbol_contents('a'))) + else: + raise ValueError('Unknown data type for a') + + if elf.get_symbol_size('b') / (n * k) == 8: + b = np.array(bytes_to_doubles(elf.get_symbol_contents('b'))) + elif elf.get_symbol_size('b') / (n * k) == 4: + b = np.array(bytes_to_floats(elf.get_symbol_contents('b'))) + else: + raise ValueError('Unknown data type for b') + + if elf.get_symbol_size('c') / (m * n) == 8: + c = np.array(bytes_to_doubles(elf.get_symbol_contents('c'))) + c_actual = np.array(bytes_to_doubles(raw_results['c'])) + elif elf.get_symbol_size('c') / (m * n) == 4: + c = np.array(bytes_to_floats(elf.get_symbol_contents('c'))) + c_actual = np.array(bytes_to_floats(raw_results['c'])) + else: + raise ValueError('Unknown data type for c') + a = np.reshape(a, (m, k)) if tb: b = np.reshape(b, (n, k)) diff --git a/util/sim/data_utils.py b/util/sim/data_utils.py index 2ed260d3f1..11c6dc9a60 100644 --- a/util/sim/data_utils.py +++ b/util/sim/data_utils.py @@ -68,6 +68,18 @@ def bytes_to_doubles(byte_array): doubles.append(double) return doubles +def bytes_to_floats(byte_array): + float_size = struct.calcsize('f') # Size of a float in bytes + num_floats = len(byte_array) // float_size + + # Unpack the byte array into a list of doubles + floats = [] + for i in range(num_floats): + float_bytes = byte_array[i * float_size:(i + 1) * float_size] + float = struct.unpack('