Skip to content

Commit

Permalink
gemm: Verify fp32 kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
fischeti committed Jan 19, 2024
1 parent e53f1fa commit 6547da0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
29 changes: 24 additions & 5 deletions sw/blas/gemm/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
sys.path.append(str(Path(__file__).parent / '../../../util/sim/'))
import verification # noqa: E402
from elf import Elf # noqa: E402
from data_utils import bytes_to_doubles, bytes_to_uint32s # noqa: E402
from data_utils import bytes_to_doubles, bytes_to_floats, bytes_to_uint32s # noqa: E402


ERR_THRESHOLD = 0.001
Expand All @@ -27,21 +27,40 @@ def main():
symbols_bin=args.symbols_bin,
log=args.log,
output_uids=['c'])
c_actual = np.array(bytes_to_doubles(raw_results['c']))

# Extract input operands from ELF file
if args.symbols_bin:
elf = Elf(args.symbols_bin)
else:
elf = Elf(args.snitch_bin)
a = np.array(bytes_to_doubles(elf.get_symbol_contents('a')))
b = np.array(bytes_to_doubles(elf.get_symbol_contents('b')))
c = np.array(bytes_to_doubles(elf.get_symbol_contents('c')))
beta = bytes_to_uint32s(elf.get_symbol_contents('BETA'))[0]
m = bytes_to_uint32s(elf.get_symbol_contents('M'))[0]
n = bytes_to_uint32s(elf.get_symbol_contents('N'))[0]
k = bytes_to_uint32s(elf.get_symbol_contents('K'))[0]
tb = bytes_to_uint32s(elf.get_symbol_contents('TB'))[0]
if elf.get_symbol_size('a') / (m * k) == 8:
a = np.array(bytes_to_doubles(elf.get_symbol_contents('a')))
elif elf.get_symbol_size('a') / (m * k) == 4:
a = np.array(bytes_to_floats(elf.get_symbol_contents('a')))
else:
raise ValueError('Unknown data type for a')

if elf.get_symbol_size('b') / (n * k) == 8:
b = np.array(bytes_to_doubles(elf.get_symbol_contents('b')))
elif elf.get_symbol_size('b') / (n * k) == 4:
b = np.array(bytes_to_floats(elf.get_symbol_contents('b')))
else:
raise ValueError('Unknown data type for b')

if elf.get_symbol_size('c') / (m * n) == 8:
c = np.array(bytes_to_doubles(elf.get_symbol_contents('c')))
c_actual = np.array(bytes_to_doubles(raw_results['c']))
elif elf.get_symbol_size('c') / (m * n) == 4:
c = np.array(bytes_to_floats(elf.get_symbol_contents('c')))
c_actual = np.array(bytes_to_floats(raw_results['c']))
else:
raise ValueError('Unknown data type for c')

a = np.reshape(a, (m, k))
if tb:
b = np.reshape(b, (n, k))
Expand Down
12 changes: 12 additions & 0 deletions util/sim/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ def bytes_to_doubles(byte_array):
doubles.append(double)
return doubles

def bytes_to_floats(byte_array):
float_size = struct.calcsize('f') # Size of a float in bytes
num_floats = len(byte_array) // float_size

# Unpack the byte array into a list of doubles
floats = []
for i in range(num_floats):
float_bytes = byte_array[i * float_size:(i + 1) * float_size]
float = struct.unpack('<f', float_bytes)[0]
floats.append(float)
return floats


def bytes_to_uint32s(byte_array):
uint32_size = struct.calcsize('I') # Size of a uint32 in bytes
Expand Down

0 comments on commit 6547da0

Please sign in to comment.