From 8c1f0e73e90e069d9820dadc9c26a1b47fa63081 Mon Sep 17 00:00:00 2001 From: Luca Colagrande Date: Thu, 28 Sep 2023 17:17:05 +0200 Subject: [PATCH] gemm: Add `--section` flag and burst size alignment --- sw/blas/gemm/Makefile | 6 ++++-- sw/blas/gemm/data/datagen.py | 17 ++++++++++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/sw/blas/gemm/Makefile b/sw/blas/gemm/Makefile index 604556ed11..9605f07d7c 100644 --- a/sw/blas/gemm/Makefile +++ b/sw/blas/gemm/Makefile @@ -9,16 +9,18 @@ MK_DIR := $(dir $(realpath $(lastword $(MAKEFILE_LIST)))) DATA_DIR := $(realpath $(MK_DIR)/data) SRC_DIR := $(realpath $(MK_DIR)/src) +DATA_CFG ?= $(DATA_DIR)/params.hjson +SECTION ?= + APP ?= gemm SRCS ?= $(realpath $(SRC_DIR)/main.c) INCDIRS ?= $(DATA_DIR) $(SRC_DIR) -DATA_CFG ?= $(DATA_DIR)/params.hjson DATAGEN_PY = $(DATA_DIR)/datagen.py DATA_H = $(DATA_DIR)/data.h $(DATA_H): $(DATAGEN_PY) $(DATA_CFG) - $< -c $(DATA_CFG) > $@ + $< -c $(DATA_CFG) --section="$(SECTION)" > $@ .PHONY: clean-data clean diff --git a/sw/blas/gemm/data/datagen.py b/sw/blas/gemm/data/datagen.py index 3848653f1b..44b71ffad4 100755 --- a/sw/blas/gemm/data/datagen.py +++ b/sw/blas/gemm/data/datagen.py @@ -39,6 +39,9 @@ 'fp8alt': {'exp': 4, 'mant': 3} } +# AXI splits bursts crossing 4KB address boundaries. To minimize +# the occurrence of these splits the data should be aligned to 4KB +BURST_ALIGNMENT = 4096 def golden_model(alpha, a, b, beta, c): return alpha * np.matmul(a, b) + beta * c @@ -96,9 +99,12 @@ def emit_header(**kwargs): data_str += [format_scalar_definition('uint32_t', 'BETA', kwargs['beta'])] data_str += [format_scalar_definition('uint32_t', 'dtype_size', kwargs['prec']//8)] data_str += [format_scalar_definition('uint32_t', 'expand', kwargs['expand'])] - data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'a', a.flatten())] - data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'b', b.flatten())] - data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'c', c.flatten())] + data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'a', a.flatten(), + alignment=BURST_ALIGNMENT, section=kwargs['section'])] + data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'b', b.flatten(), + alignment=BURST_ALIGNMENT, section=kwargs['section'])] + data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'c', c.flatten(), + alignment=BURST_ALIGNMENT, section=kwargs['section'])] if kwargs['prec'] == 8: result_def = format_vector_definition(C_TYPES['64'], 'result', result.flatten()) else: @@ -120,11 +126,16 @@ def main(): required=True, help='Select param config file kernel' ) + parser.add_argument( + '--section', + type=str, + help='Section to store matrices in') args = parser.parse_args() # Load param config file with args.cfg.open() as f: param = hjson.loads(f.read()) + param['section'] = args.section # Emit header file print(emit_header(**param))