diff --git a/target/sim/sw/device/apps/snax/snax-data-reshuffler/Makefile b/target/sim/sw/device/apps/snax/snax-data-reshuffler/Makefile index 62b0e3c62..8a8b4951b 100644 --- a/target/sim/sw/device/apps/snax/snax-data-reshuffler/Makefile +++ b/target/sim/sw/device/apps/snax/snax-data-reshuffler/Makefile @@ -4,7 +4,11 @@ # # Xiaoling Yi +<<<<<<<< HEAD:target/sim/sw/device/apps/snax/snax-data-reshuffler/Makefile APP = snax-data-reshuffler +======== +APP = snax-gemmx-conv +>>>>>>>> 32685b4 (update gemmx sw (#69)):target/sim/sw/device/apps/snax/snax-gemmx-conv/Makefile INCDIRS = data @@ -13,7 +17,11 @@ INCDIRS += ../../../snax/data-reshuffler/include # Include this binary in the final build RISCV_LDFLAGS += ../../../snax/data-reshuffler/build/snax-data-reshuffler-lib.o +<<<<<<<< HEAD:target/sim/sw/device/apps/snax/snax-data-reshuffler/Makefile SRCS = src/snax-data-reshuffler.c +======== +SRCS = src/snax-gemmx-conv.c +>>>>>>>> 32685b4 (update gemmx sw (#69)):target/sim/sw/device/apps/snax/snax-gemmx-conv/Makefile include ./data/Makefile include ../../common.mk diff --git a/target/sim/sw/device/apps/snax/snax-gemmx-conv/Makefile b/target/sim/sw/device/apps/snax/snax-gemmx-conv/Makefile new file mode 100644 index 000000000..8a8b4951b --- /dev/null +++ b/target/sim/sw/device/apps/snax/snax-gemmx-conv/Makefile @@ -0,0 +1,29 @@ +# Copyright 2023 KU Leuven. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Xiaoling Yi + +<<<<<<<< HEAD:target/sim/sw/device/apps/snax/snax-data-reshuffler/Makefile +APP = snax-data-reshuffler +======== +APP = snax-gemmx-conv +>>>>>>>> 32685b4 (update gemmx sw (#69)):target/sim/sw/device/apps/snax/snax-gemmx-conv/Makefile + +INCDIRS = data + +INCDIRS += ../../../snax/data-reshuffler/include + +# Include this binary in the final build +RISCV_LDFLAGS += ../../../snax/data-reshuffler/build/snax-data-reshuffler-lib.o + +<<<<<<<< HEAD:target/sim/sw/device/apps/snax/snax-data-reshuffler/Makefile +SRCS = src/snax-data-reshuffler.c +======== +SRCS = src/snax-gemmx-conv.c +>>>>>>>> 32685b4 (update gemmx sw (#69)):target/sim/sw/device/apps/snax/snax-gemmx-conv/Makefile + +include ./data/Makefile +include ../../common.mk + +$(DEP): $(DATA_H) diff --git a/target/sim/sw/device/apps/snax/snax-gemmx-conv/data/Makefile b/target/sim/sw/device/apps/snax/snax-gemmx-conv/data/Makefile new file mode 100644 index 000000000..18006cbf7 --- /dev/null +++ b/target/sim/sw/device/apps/snax/snax-gemmx-conv/data/Makefile @@ -0,0 +1,23 @@ +# Copyright 2023 KU Leuven. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Xiaoling Yi + +# Usage of absolute paths is required to externally include this Makefile +MK_DIR := $(dir $(realpath $(lastword $(MAKEFILE_LIST)))) +DATA_DIR := $(realpath $(MK_DIR)) + +DATA_CFG ?= $(DATA_DIR)/params.hjson + +DATA_H = $(DATA_DIR)/data.h + +$(DATA_H): $(DATA_DIR)/datagen.py $(DATA_CFG) + $< -c $(DATA_CFG) > $@ + +.PHONY: clean-data clean + +clean-data: + rm -f $(DATA_H) + +clean: clean-data diff --git a/target/sim/sw/device/apps/snax/snax-gemmx-conv/data/datagen.py b/target/sim/sw/device/apps/snax/snax-gemmx-conv/data/datagen.py new file mode 100755 index 000000000..9c474c708 --- /dev/null +++ b/target/sim/sw/device/apps/snax/snax-gemmx-conv/data/datagen.py @@ -0,0 +1,646 @@ +#!/usr/bin/env python3 + +# Copyright 2024 KU Leuven. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Xiaoling Yi + +import numpy as np +import argparse +import pathlib +import hjson +import sys +import os + +import subprocess + +# Add data utility path +sys.path.append(os.path.join(os.path.dirname(__file__), + "../../../../../../../../util/sim/")) +from data_utils import format_scalar_definition, format_vector_definition # noqa E402 + +# Add golden model path +bender_command = subprocess.run(['bender', 'path', 'snitch_cluster'], + capture_output=True, text=True) +snax_utils_path = bender_command.stdout.strip() + +sys.path.append(snax_utils_path + "/util/sim/") + +from snax_utils import ( # noqa E402 + conv2d, + im2col, + postprocessing_simd_golden_model, + align_wide_addr, +) # noqa E402 + +np.random.seed(42) + + +# Add stdint.h header +def emit_header_file(**kwargs): + emit_str = "#include \n\n" + emit_str += emit_gemmx_data(**kwargs) + return emit_str + + +MIN = -128 +MAX = 127 +MIN_BIAS = -(2**30) +MAX_BIAS = 2**30 - 1 + +bankWidth = 64 +input_data_width = 8 +output_data_width = 32 +quantized_output_data_width = 8 + + +def emit_conv_data(**kwargs): + # size extraction + Cin = kwargs["Cin"] + Cout = kwargs["Cout"] + + Nbatch, Cin8, H, W, _ = ( + kwargs["Nbatch"], + kwargs["Cin"] // 8, + kwargs["H"], + kwargs["W"], + 8, + ) + Cout8, Cin8, Kh, Kw, _, _ = ( + kwargs["Cout"] // 8, + kwargs["Cin"] // 8, + kwargs["Kh"], + kwargs["Kw"], + 8, + 8, + ) + + stride_h, stride_w = (kwargs["stride_h"], kwargs["stride_w"]) + pad_h, pad_w = (kwargs["pad_h"], kwargs["pad_w"]) + + # make sure the output width is multiple of 8 + if W // stride_w % 8 != 0: + W = W + (stride_w * (8 - (W // stride_w) % 8)) % (stride_w * 8) + + # test data generation + input_data = np.random.randint(-10, 10, size=(Nbatch, Cin8, H, W, 8)) + kernel = np.random.randint(-10, 10, size=(Cout8, Cin8, Kh, Kw, 8, 8)) + + # inferred config from the input data and kernel + padding = pad_h, pad_w + stride = stride_h, stride_w + + # Padding the input data + input_padding = np.pad( + input_data, + ((0, 0), (0, 0), (pad_h, pad_h), (pad_w, pad_w), (0, 0)), + mode="constant", + ) + + # Calculate the size of the output feature map + out_height = (H + 2 * pad_h - Kh) // stride_h + 1 + out_width = (W + 2 * pad_w - Kw) // stride_w + 1 + + assert out_width % 8 == 0, "out_width must be multiple of 8" + + M = out_height * out_width // 8 + K = Cin // 8 * Kh * Kw + N = Cout // 8 + + length_c = M * N * 8 * 8 + + broadcast_C = kwargs["broadcast_C"] == 1 and kwargs["channel_en_C"] == 1 + disable_C = kwargs["broadcast_C"] == 0 and kwargs["channel_en_C"] == 0 + enable_full_C = kwargs["broadcast_C"] == 0 and kwargs["channel_en_C"] == 1 + + assert broadcast_C or disable_C or enable_full_C, "Invalid C settings" + if broadcast_C == 1: + bias = np.random.randint(MIN_BIAS, MAX_BIAS, size=(int(length_c / 8 / 8), 8)) + bias = np.repeat(bias, repeats=8, axis=0).reshape(-1) + elif enable_full_C == 1: + bias = np.random.randint(MIN_BIAS, MAX_BIAS, size=length_c).reshape(-1) + else: + bias = np.random.randint(0, 1, size=length_c).reshape(-1) + + data_str = [] + + data_str += [ + format_scalar_definition("int32_t", "broadcast_C", kwargs["broadcast_C"]) + ] + if broadcast_C == 1: + data_str += [format_scalar_definition("int32_t", "channel_en_C", 0b11111111)] + elif enable_full_C == 1: + data_str += [ + format_scalar_definition("int32_t", "channel_en_C", ((1 << 32) - 1)) + ] + else: + data_str += [format_scalar_definition("int32_t", "channel_en_C", 0)] + + # Generating conv2d settings + data_str += [ + format_scalar_definition("int", "Nbatch", Nbatch), + format_scalar_definition("int", "H", H), + format_scalar_definition("int", "W", W), + format_scalar_definition("int", "Cin", Cin), + format_scalar_definition("int", "Cout", Cout), + format_scalar_definition("int", "Kh", Kh), + format_scalar_definition("int", "Kw", Kw), + format_scalar_definition("int", "stride_h", stride_h), + format_scalar_definition("int", "stride_w", stride_w), + format_scalar_definition("int", "pad_h", pad_h), + format_scalar_definition("int", "pad_w", pad_w), + ] + + # Generating matrix size settings + data_str += [ + format_scalar_definition("int", "Batch", Nbatch), + format_scalar_definition("int", "M", M), + format_scalar_definition("int", "K", K), + format_scalar_definition("int", "N", N), + ] + + # Generating base pointer settings + + if kwargs["interleaved_address"] == 1: + # Generating base pointer settings, interleaved memory + delta_local_a = 0 + + delta_local_b = input_padding.size + assert input_padding.size == ( + Nbatch * Cin8 * (H + 2 * pad_h) * (W + 2 * pad_w) * 8 + ) + + delta_local_b = align_wide_addr(delta_local_b, 64) + assert delta_local_b % 64 == 0 + + delta_local_c = delta_local_b + kernel.size + assert kernel.size == (Cout8 * Cin8 * Kh * Kw * 8 * 8) + delta_local_c = align_wide_addr(delta_local_c, 64) + assert delta_local_c % 64 == 0 + + delta_local_d8 = delta_local_c + length_c * 4 + delta_local_d8 = align_wide_addr(delta_local_d8, 64) + assert delta_local_d8 % 64 == 0 + + delta_local_d32 = delta_local_d8 + + # logical address is the same as physical address + delta_physical_a = delta_local_a + delta_physical_b = delta_local_b + delta_physical_c = delta_local_c + delta_physical_d8 = delta_local_d8 + delta_physical_d32 = delta_local_d32 + + assert ( + input_padding.size + kernel.size + length_c * 4 * 2 + < kwargs["memory_size"] * 1024 + ) + else: + # Generating base pointer settings + base_logical_addr_delta = kwargs["memory_size"] / 4 * 1024 + delta_local_a = 0 + delta_local_b = base_logical_addr_delta + delta_local_c = base_logical_addr_delta * 2 + delta_local_d32 = base_logical_addr_delta * 3 + delta_local_d8 = base_logical_addr_delta * 3 + + base_pyhsical_addr_delta = 64 + delta_physical_a = 0 + delta_physical_b = base_pyhsical_addr_delta + delta_physical_c = base_pyhsical_addr_delta * 2 + delta_physical_d32 = base_pyhsical_addr_delta * 3 + delta_physical_d8 = base_pyhsical_addr_delta * 3 + + assert ( + input_padding.size < base_logical_addr_delta + and kernel.size < base_logical_addr_delta + and M * N * 8 * 8 * 4 < base_logical_addr_delta + ) + + if kwargs["interleaved_address"] == 1: + # logical address is the same as physical address + data_str += [format_scalar_definition("int32_t", "set_addr_remap_index_A", 0)] + data_str += [format_scalar_definition("int32_t", "set_addr_remap_index_B", 0)] + data_str += [format_scalar_definition("int32_t", "set_addr_remap_index_C", 0)] + data_str += [format_scalar_definition("int32_t", "set_addr_remap_index_D32", 0)] + data_str += [format_scalar_definition("int32_t", "set_addr_remap_index_D8", 0)] + else: + # open the address remap + data_str += [format_scalar_definition("int32_t", "set_addr_remap_index_A", 2)] + data_str += [format_scalar_definition("int32_t", "set_addr_remap_index_B", 2)] + data_str += [format_scalar_definition("int32_t", "set_addr_remap_index_C", 2)] + data_str += [format_scalar_definition("int32_t", "set_addr_remap_index_D32", 2)] + data_str += [format_scalar_definition("int32_t", "set_addr_remap_index_D8", 2)] + + data_str += [ + format_scalar_definition( + "int32_t", "interleaved_address", kwargs["interleaved_address"] + ) + ] + + data_str += [ + format_scalar_definition("int32_t", "delta_physical_a", delta_physical_a), + format_scalar_definition("int32_t", "delta_physical_b", delta_physical_b), + format_scalar_definition("int32_t", "delta_physical_d8", delta_physical_d8), + format_scalar_definition("int32_t", "delta_physical_c", delta_physical_c), + format_scalar_definition("int32_t", "delta_physical_d32", delta_physical_d32), + ] + + data_str += [ + format_scalar_definition("int32_t", "delta_local_a", delta_local_a), + format_scalar_definition("int32_t", "delta_local_b", delta_local_b), + format_scalar_definition("int32_t", "delta_local_d8", delta_local_d8), + format_scalar_definition("int32_t", "delta_local_c", delta_local_c), + format_scalar_definition("int32_t", "delta_local_d32", delta_local_d32), + ] + + # for streamer cfg + # streamer setting for data mover A + # NC8HW8 + Aslstride0 = 1 + Aslstride1 = 8 * stride_w + + # K dim + Atlbound0 = Kw + Atlstride0 = 8 + + Atlbound1 = Kh + Atlstride1 = 8 * (W + 2 * pad_w) + + Atlbound2 = Cin8 + Atlstride2 = 8 * (W + 2 * pad_w) * (H + 2 * pad_h) + + # N dim + Atlbound3 = Cout // 8 + Atlstride3 = 0 + + # M dim + Atlbound4 = out_width // 8 + Atlstride4 = 8 * 8 * stride_w + + Atlbound5 = out_height + Atlstride5 = 8 * (W + 2 * pad_w) * stride_h + + # Batch dim + Atlbound6 = Nbatch + Atlstride6 = 8 * Cin8 * (H + 2 * pad_h) * (W + 2 * pad_w) + + assert ( + Atlstride0 % 8 == 0 + and Atlstride1 % 8 == 0 + and Atlstride2 % 8 == 0 + and Atlstride3 % 8 == 0 + and Atlstride4 % 8 == 0 + and Atlstride5 % 8 == 0 + and Atlstride6 % 8 == 0 + ) + + assert ( + M * K * N + == Atlbound0 + * Atlbound1 + * Atlbound2 + * Atlbound3 + * Atlbound4 + * Atlbound5 + * Atlbound6 + ) + + data_str += [ + format_scalar_definition("int32_t", "Aslstride0", Aslstride0), + format_scalar_definition("int32_t", "Aslstride1", Aslstride1), + format_scalar_definition("int32_t", "Atlbound0", Atlbound0), + format_scalar_definition("int32_t", "Atlstride0", Atlstride0), + format_scalar_definition("int32_t", "Atlbound1", Atlbound1), + format_scalar_definition("int32_t", "Atlstride1", Atlstride1), + format_scalar_definition("int32_t", "Atlbound2", Atlbound2), + format_scalar_definition("int32_t", "Atlstride2", Atlstride2), + format_scalar_definition("int32_t", "Atlbound3", Atlbound3), + format_scalar_definition("int32_t", "Atlstride3", Atlstride3), + format_scalar_definition("int32_t", "Atlbound4", Atlbound4), + format_scalar_definition("int32_t", "Atlstride4", Atlstride4), + format_scalar_definition("int32_t", "Atlbound5", Atlbound5), + format_scalar_definition("int32_t", "Atlstride5", Atlstride5), + format_scalar_definition("int32_t", "Atlbound6", Atlbound6), + format_scalar_definition("int32_t", "Atlstride6", Atlstride6), + ] + + # Cout8Cin8FyFx88 + # streamer setting for data mover B + Bslstride0 = 1 + Bslstride1 = 8 + + # K dim + Btlbound0 = Kw * Kh * Cin8 + Btlstride0 = 8 * 8 + + # N dim + Btlbound1 = Cout // 8 + Btlstride1 = 8 * 8 * Kw * Kh * Cin8 + + # M dim + Btlbound2 = out_width * out_height // 8 + Btlstride2 = 0 + + # Batch dim + Btlbound3 = Nbatch + Btlstride3 = 0 + + assert ( + Btlstride0 % 64 == 0 + and Btlstride1 % 64 == 0 + and Btlstride2 % 64 == 0 + and Btlstride3 % 64 == 0 + ) + + assert K * N * M == Btlbound0 * Btlbound1 * Btlbound2 * Btlbound3, ( + "K * N * M", + K * N * M, + "Loopbounds multipliers ", + Btlbound0 * Btlbound1 * Btlbound2 * Btlbound3, + ) + + data_str += [ + format_scalar_definition("int32_t", "Bslstride0", Bslstride0), + format_scalar_definition("int32_t", "Bslstride1", Bslstride1), + format_scalar_definition("int32_t", "Btlbound0", Btlbound0), + format_scalar_definition("int32_t", "Btlstride0", Btlstride0), + format_scalar_definition("int32_t", "Btlbound1", Btlbound1), + format_scalar_definition("int32_t", "Btlstride1", Btlstride1), + format_scalar_definition("int32_t", "Btlbound2", Btlbound2), + format_scalar_definition("int32_t", "Btlstride2", Btlstride2), + format_scalar_definition("int32_t", "Btlbound3", Btlbound3), + format_scalar_definition("int32_t", "Btlstride3", Btlstride3), + ] + + # streamer setting for data mover C + # C is int32_t so the stride is 4 times of the int8_t + # NHWC + Cslstride0 = 8 + Cslstride1 = 8 * 8 + + # N dim + Ctlbound0 = Cout // 8 + Ctlstride0 = out_height * out_width // 8 * 8 * 8 * 4 + + # M dim + # K is merged because of the block gemm output stationarity + Ctlbound1 = out_width // 8 + Ctlstride1 = 8 * 8 * 4 + + Ctlbound2 = out_height + Ctlstride2 = out_width // 8 * 8 * 8 * 4 + + # Batch dim + Ctlbound3 = Nbatch + Ctlstride3 = Cout * out_height * out_width * 4 + + assert ( + Ctlstride0 % 64 == 0 + and Ctlstride1 % 64 == 0 + and Ctlstride2 % 64 == 0 + and Ctlstride3 % 64 == 0 + ) + assert M * N == Ctlbound0 * Ctlbound1 * Ctlbound2 * Ctlbound3 + + data_str += [ + format_scalar_definition("int32_t", "Cslstride0", Cslstride0), + format_scalar_definition("int32_t", "Cslstride1", Cslstride1), + format_scalar_definition("int32_t", "Ctlbound0", Ctlbound0), + format_scalar_definition("int32_t", "Ctlstride0", Ctlstride0), + format_scalar_definition("int32_t", "Ctlbound1", Ctlbound1), + format_scalar_definition("int32_t", "Ctlstride1", Ctlstride1), + format_scalar_definition("int32_t", "Ctlbound2", Ctlbound2), + format_scalar_definition("int32_t", "Ctlstride2", Ctlstride2), + format_scalar_definition("int32_t", "Ctlbound3", Ctlbound3), + format_scalar_definition("int32_t", "Ctlstride3", Ctlstride3), + ] + + D32slstride0 = 8 + D32slstride1 = 8 * 8 + + # N dim + D32tlbound0 = Cout // 8 + D32tlstride0 = out_height * out_width // 8 * 8 * 8 * 4 + + # M dim + # K is merged because of the block gemm output stationarity + D32tlbound1 = out_width // 8 + D32tlstride1 = 8 * 8 * 4 + + D32tlbound2 = out_height + D32tlstride2 = out_width // 8 * 8 * 8 * 4 + + # Batch dim + D32tlbound3 = Nbatch + D32tlstride3 = Cout * out_height * out_width * 4 + + assert ( + D32tlstride0 % 64 == 0 + and D32tlstride1 % 64 == 0 + and D32tlstride2 % 64 == 0 + and D32tlstride3 % 64 == 0 + ) + + data_str += [ + format_scalar_definition("int32_t", "D32slstride0", D32slstride0), + format_scalar_definition("int32_t", "D32slstride1", D32slstride1), + format_scalar_definition("int32_t", "D32tlbound0", D32tlbound0), + format_scalar_definition("int32_t", "D32tlstride0", D32tlstride0), + format_scalar_definition("int32_t", "D32tlbound1", D32tlbound1), + format_scalar_definition("int32_t", "D32tlstride1", D32tlstride1), + format_scalar_definition("int32_t", "D32tlbound2", D32tlbound2), + format_scalar_definition("int32_t", "D32tlstride2", D32tlstride2), + format_scalar_definition("int32_t", "D32tlbound3", D32tlbound3), + format_scalar_definition("int32_t", "D32tlstride3", D32tlstride3), + ] + + # postprocessing D8 settings + D8slstride0 = 1 + D8slstride1 = 8 + + # N dim + D8tlbound0 = Cout // 8 + D8tlstride0 = out_height * out_width // 8 * 8 * 8 + + # M dim + # K is merged because of the block gemm output stationarity + D8tlbound1 = out_width // 8 + D8tlstride1 = 8 * 8 + + D8tlbound2 = out_height + D8tlstride2 = out_width // 8 * 8 * 8 + + # Batch dim + D8tlbound3 = Nbatch + D8tlstride3 = Cout * out_height * out_width + + assert ( + D8tlstride0 % 64 == 0 + and D8tlstride1 % 64 == 0 + and D8tlstride2 % 64 == 0 + and D8tlstride3 % 64 == 0 + ) + data_str += [ + format_scalar_definition("int32_t", "D8slstride0", D8slstride0), + format_scalar_definition("int32_t", "D8slstride1", D8slstride1), + format_scalar_definition("int32_t", "D8tlbound0", D8tlbound0), + format_scalar_definition("int32_t", "D8tlstride0", D8tlstride0), + format_scalar_definition("int32_t", "D8tlbound1", D8tlbound1), + format_scalar_definition("int32_t", "D8tlstride1", D8tlstride1), + format_scalar_definition("int32_t", "D8tlbound2", D8tlbound2), + format_scalar_definition("int32_t", "D8tlstride2", D8tlstride2), + format_scalar_definition("int32_t", "D8tlbound3", D8tlbound3), + format_scalar_definition("int32_t", "D8tlstride3", D8tlstride3), + ] + + # Generating random 8 integer a and b for subtraction + subtraction_a = 0 + subtraction_b = 0 + + # Writing the subtraction value to data.h + data_str += [ + format_scalar_definition("int8_t", "subtraction_a", subtraction_a), + format_scalar_definition("int8_t", "subtraction_b", subtraction_b), + ] + + # direct conv2d + direct_conv2d_res = conv2d( + input_data, kernel, stride=stride, padding=padding, mode="C8HW8" + ) + + # output in NHWC format + direct_conv2d_res = np.add(direct_conv2d_res.reshape(-1), bias) + + # Writing testing data and golden data into data.h + # implicit im2col matrix and kernel, store original input data and kernel + data_str += [format_vector_definition("int8_t", "A", input_padding.reshape(-1))] + data_str += [format_vector_definition("int8_t", "B", kernel.reshape(-1))] + data_str += [format_vector_definition("int32_t", "C", bias.reshape(-1))] + + data_str += [format_scalar_definition("int32_t", "transposed_A", 0)] + data_str += [format_scalar_definition("int32_t", "transposed_B", 0)] + + return data_str, direct_conv2d_res + + +def emit_gemmx_data(**kwargs): + + data_str, D32 = emit_conv_data(**kwargs) + + data_str += [format_vector_definition("int32_t", "D32", D32)] + + # ----------------------------------------------------------- + # Postprocessing + # ----------------------------------------------------------- + + bypassSIMD = kwargs["bypassSIMD"] + data_str += [format_scalar_definition("int32_t", "bypassSIMD", bypassSIMD)] + + # Generating random constant values + group_num = 8 + input_zp_i = np.random.randint(MIN, MAX) + output_zp_i = np.random.randint(MIN, MAX) + max_int_i = MAX + min_int_i = MIN + double_round_i = np.random.randint(0, 1) + + shift_i = np.random.randint(0, 63, size=group_num) # values between 0-63 + multiplier_i = np.random.randint(-(2**31), 2**31 - 1, size=group_num) + + # Writing the constant values to data.h + data_str += [ + format_scalar_definition("int8_t", "input_zp_i", input_zp_i), + format_scalar_definition("int8_t", "output_zp_i", output_zp_i), + format_scalar_definition("int8_t", "max_int_i", max_int_i), + format_scalar_definition("int8_t", "min_int_i", min_int_i), + format_scalar_definition("int8_t", "double_round_i", double_round_i), + ] + + shared_bitpacked_shift0 = ( + (shift_i[3] << 24) | (shift_i[2] << 16) | (shift_i[1] << 8) | shift_i[0] + ) + shared_bitpacked_shift1 = ( + (shift_i[7] << 24) | (shift_i[6] << 16) | (shift_i[5] << 8) | shift_i[4] + ) + data_str += [ + format_scalar_definition( + "int32_t", "shared_bitpacked_shift0", shared_bitpacked_shift0 + ) + ] + data_str += [ + format_scalar_definition( + "int32_t", "shared_bitpacked_shift1", shared_bitpacked_shift1 + ) + ] + + data_str += [ + format_scalar_definition("int32_t", "shared_multiplier0", multiplier_i[0]) + ] + data_str += [ + format_scalar_definition("int32_t", "shared_multiplier1", multiplier_i[1]) + ] + data_str += [ + format_scalar_definition("int32_t", "shared_multiplier2", multiplier_i[2]) + ] + data_str += [ + format_scalar_definition("int32_t", "shared_multiplier3", multiplier_i[3]) + ] + data_str += [ + format_scalar_definition("int32_t", "shared_multiplier4", multiplier_i[4]) + ] + data_str += [ + format_scalar_definition("int32_t", "shared_multiplier5", multiplier_i[5]) + ] + data_str += [ + format_scalar_definition("int32_t", "shared_multiplier6", multiplier_i[6]) + ] + data_str += [ + format_scalar_definition("int32_t", "shared_multiplier7", multiplier_i[7]) + ] + + D8 = np.zeros_like(D32, dtype=np.uint8) + # output channel (innermost dim) has a different scale factor + for i in range(group_num): + D8[i::group_num] = postprocessing_simd_golden_model( + D32[i::group_num], + input_zp_i, + output_zp_i, + shift_i[i], + max_int_i, + min_int_i, + double_round_i, + multiplier_i[i], + ) + + data_str += [format_vector_definition("int8_t", "D8", D8)] + + data_str = "\n\n".join(data_str) + + return data_str + + +def main(): + # Parsing cmd args + parser = argparse.ArgumentParser(description="Generate data for kernels") + parser.add_argument( + "-c", + "--cfg", + type=pathlib.Path, + required=True, + help="Select param config file kernel", + ) + args = parser.parse_args() + + # Load param config file + with args.cfg.open() as f: + param = hjson.loads(f.read()) + + # Emit header file + print(emit_header_file(**param)) + + +if __name__ == "__main__": + + main() diff --git a/target/sim/sw/device/apps/snax/snax-gemmx-conv/src/snax-gemmx-conv.c b/target/sim/sw/device/apps/snax/snax-gemmx-conv/src/snax-gemmx-conv.c new file mode 100644 index 000000000..8699c3eb6 --- /dev/null +++ b/target/sim/sw/device/apps/snax/snax-gemmx-conv/src/snax-gemmx-conv.c @@ -0,0 +1,149 @@ +// Copyright 2024 KU Leuven. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Xiaoling Yi + +#include "data.h" + +#include "snax-gemmx-params.h" + +#include "snax-gemmx-lib.h" + +// This is the main function for the SNAX GEMM for Conv2d +// We use several nested loops to iterate over the input data and weights, +// achieving implicit im2col +int main() { + // Set err value for checking + int err = 0; + + // Prepare addresses pointers in TCDM for DMA + int8_t *local_a_dma, *local_b_dma; + int32_t *local_c_dma, *local_d32_dma; + int8_t *local_d8_dma; + + // Allocate space in TCDM for DMA + local_a_dma = (int8_t *)(snrt_l1_next() + delta_physical_a); + local_b_dma = (int8_t *)(snrt_l1_next() + delta_physical_b); + local_c_dma = (int32_t *)(snrt_l1_next() + delta_physical_c); + local_d32_dma = (int32_t *)(snrt_l1_next() + delta_physical_d32); + local_d8_dma = (int8_t *)(snrt_l1_next() + delta_physical_d8); + + // Prepare addresses pointers in TCDM for streamer + int8_t *local_a, *local_b; + int32_t *local_c, *local_d32; + int8_t *local_d8; + + // Allocate space in TCDM for streamer + local_a = (int8_t *)(snrt_l1_next() + delta_local_a); + local_b = (int8_t *)(snrt_l1_next() + delta_local_b); + local_c = (int32_t *)(snrt_l1_next() + delta_local_c); + local_d32 = (int32_t *)(snrt_l1_next() + delta_local_d32); + local_d8 = (int8_t *)(snrt_l1_next() + delta_local_d8); + + // Transfer data from L3 to L1 + // Using DMA only + if (snrt_is_dm_core()) { + if (interleaved_address == 1) { + snrt_dma_start_1d(local_a, A, + Nbatch * (H + 2 * pad_h) * (W + 2 * pad_w) * Cin * + sizeof(int8_t)); + snrt_dma_start_1d(local_b, B, + Cout * Kh * Kw * Cin * sizeof(int8_t)); + } else { + snrt_dma_start_2d( + local_a_dma, A, 64 * sizeof(int8_t), 256, 64, + Nbatch * (H + 2 * pad_h) * (W + 2 * pad_w) * Cin / 64); + snrt_dma_start_2d(local_b_dma, B, 64 * sizeof(int8_t), 256, 64, + Cout * Kh * Kw * Cin / 64); + } + snrt_dma_wait_all(); + } + + // Wait for DMA to finish + snrt_cluster_hw_barrier(); + if (snrt_is_dm_core()) { + if (interleaved_address == 1) { + snrt_dma_start_1d(local_c, C, + M * N * meshRow * meshCol * sizeof(int32_t)); + } else { + snrt_dma_start_2d(local_c_dma, C, 16 * sizeof(int32_t), 256, + 16 * sizeof(int32_t), + M * N * meshRow * meshCol / 16); + } + snrt_dma_wait_all(); + } + + snrt_cluster_hw_barrier(); + + if (snrt_global_core_idx() == 0) { + // Set Streamer configuration CSR for conv2d + set_gemmx_streamer_csr( + Aslstride0, Aslstride1, Atlbound0, Atlstride0, Atlbound1, + Atlstride1, Atlbound2, Atlstride2, Atlbound3, Atlstride3, Atlbound4, + Atlstride4, Atlbound5, Atlstride5, set_addr_remap_index_A, + + Bslstride0, Bslstride1, Btlbound0, Btlstride0, Btlbound1, + Btlstride1, Btlbound2, Btlstride2, set_addr_remap_index_B, + + D8slstride0, D8slstride1, D8tlbound0, D8tlstride0, D8tlbound1, + D8tlstride1, D8tlbound2, D8tlstride2, set_addr_remap_index_D8, + + Cslstride0, Cslstride1, Ctlbound0, Ctlstride0, Ctlbound1, + Ctlstride1, Ctlbound2, Ctlstride2, set_addr_remap_index_C, + + D32slstride0, D32slstride1, D32tlbound0, D32tlstride0, D32tlbound1, + D32tlstride1, D32tlbound2, D32tlstride2, set_addr_remap_index_D32, + + delta_local_a, delta_local_b, delta_local_d8, delta_local_c, + delta_local_d32, bypassSIMD, transposed_A, transposed_B, + channel_en_C, broadcast_C); + + // Set GEMMX configuration CSR + uint32_t subtraction_setting = + gen_subtraction_config(subtraction_a, subtraction_b); + + uint32_t csr0 = + gen_csr0_config(input_zp_i, output_zp_i, max_int_i, min_int_i); + uint32_t csr1 = gen_csr1_config(double_round_i); + + set_gemmx_csr( + K, N, M, subtraction_setting, csr0, csr1, shared_bitpacked_shift0, + shared_bitpacked_shift1, shared_multiplier0, shared_multiplier1, + shared_multiplier2, shared_multiplier3, shared_multiplier4, + shared_multiplier5, shared_multiplier6, shared_multiplier7, M * N, + bypassSIMD); + + // Set CSR to start Streamer for conv2d + set_gemmx_streamer_start(); + + // Set CSR to start GEMM + set_gemmx_start(); + + // Poll until Streamer and GEMM accelerator finish + wait_gemmx_and_streamer(); + + // check the result of the implicit im2col convolution + if (interleaved_address == 1) { + if (!bypassSIMD) { + err += check_gemmx_result_D8(local_d8, D8, Batch, M, N, false); + } else { + err += + check_gemmx_result_D32(local_d32, D32, Batch, M, N, false); + } + } else { + if (!bypassSIMD) { + err += + check_gemmx_result_D8(local_d8_dma, D8, Batch, M, N, true); + } else { + err += check_gemmx_result_D32(local_d32_dma, D32, Batch, M, N, + true); + } + } + + printf("SNAX GEMM Conv2d: %s, Error: %d . bypassSIMD = %d .\n", + err ? "FAIL" : "PASS", err, bypassSIMD); + }; + + return err; +} diff --git a/target/sim/sw/device/apps/snax/snax-gemmx-matmul/Makefile b/target/sim/sw/device/apps/snax/snax-gemmx-matmul/Makefile new file mode 100644 index 000000000..0159f899e --- /dev/null +++ b/target/sim/sw/device/apps/snax/snax-gemmx-matmul/Makefile @@ -0,0 +1,21 @@ +# Copyright 2023 KU Leuven. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Xiaoling Yi + +APP = snax-gemmx-matmul + +INCDIRS = data + +INCDIRS += ../../../snax/gemmx/include + +# Include this binary in the final build +RISCV_LDFLAGS += ../../../snax/gemmx/build/snax-gemmx-lib.o + +SRCS = src/snax-gemmx-matmul.c + +include ./data/Makefile +include ../../common.mk + +$(DEP): $(DATA_H) diff --git a/target/sim/sw/device/apps/snax/snax-gemmx-matmul/data/Makefile b/target/sim/sw/device/apps/snax/snax-gemmx-matmul/data/Makefile new file mode 100644 index 000000000..18006cbf7 --- /dev/null +++ b/target/sim/sw/device/apps/snax/snax-gemmx-matmul/data/Makefile @@ -0,0 +1,23 @@ +# Copyright 2023 KU Leuven. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Xiaoling Yi + +# Usage of absolute paths is required to externally include this Makefile +MK_DIR := $(dir $(realpath $(lastword $(MAKEFILE_LIST)))) +DATA_DIR := $(realpath $(MK_DIR)) + +DATA_CFG ?= $(DATA_DIR)/params.hjson + +DATA_H = $(DATA_DIR)/data.h + +$(DATA_H): $(DATA_DIR)/datagen.py $(DATA_CFG) + $< -c $(DATA_CFG) > $@ + +.PHONY: clean-data clean + +clean-data: + rm -f $(DATA_H) + +clean: clean-data diff --git a/target/sim/sw/device/apps/snax/snax-gemmx-matmul/data/datagen.py b/target/sim/sw/device/apps/snax/snax-gemmx-matmul/data/datagen.py new file mode 100755 index 000000000..e525f740d --- /dev/null +++ b/target/sim/sw/device/apps/snax/snax-gemmx-matmul/data/datagen.py @@ -0,0 +1,423 @@ +#!/usr/bin/env python3 + +# Copyright 2024 KU Leuven. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Xiaoling Yi + +import numpy as np +import argparse +import pathlib +import hjson +import sys +import os + +import subprocess + +# Add data utility path +sys.path.append(os.path.join(os.path.dirname(__file__), + "../../../../../../../../util/sim/")) +from data_utils import format_scalar_definition, format_vector_definition # noqa E402 + +# Add golden model path +bender_command = subprocess.run(['bender', 'path', 'snitch_cluster'], + capture_output=True, text=True) +snax_utils_path = bender_command.stdout.strip() + +sys.path.append(snax_utils_path + "/util/sim/") +from snax_utils import ( # noqa E402 + conv2d, + im2col, + block_gemm_golden_model, + data_reshuffler_golden_model, + postprocessing_simd_golden_model, + align_wide_addr, +) # noqa E402 + +np.random.seed(42) + + +# Add stdint.h header +def emit_header_file(**kwargs): + emit_str = "#include \n\n" + emit_str += emit_gemmx_data(**kwargs) + return emit_str + + +MIN = -128 +MAX = 127 + +bankWidth = 64 +input_data_width = 8 +output_data_width = 32 +quantized_output_data_width = 8 + + +def emit_matmul_data(**kwargs): + + meshRow = kwargs["meshRow"] + tileSize = kwargs["tileSize"] + meshCol = kwargs["meshCol"] + + # matmul settings + data_str = [] + + data_str += [format_scalar_definition("int", "Batch", 1)] + data_str += [format_scalar_definition("int", "M", kwargs["M"])] + data_str += [format_scalar_definition("int", "K", kwargs["K"])] + data_str += [format_scalar_definition("int", "N", kwargs["N"])] + + data_str += [format_scalar_definition("int32_t", "Aslstride0", 1)] + data_str += [format_scalar_definition("int32_t", "Aslstride1", bankWidth / 8)] + data_str += [format_scalar_definition("int32_t", "Atlbound0", kwargs["K"])] + data_str += [ + format_scalar_definition( + "int32_t", "Atlstride0", input_data_width * tileSize * meshRow / 8 + ) + ] + data_str += [format_scalar_definition("int32_t", "Atlbound1", kwargs["N"])] + data_str += [format_scalar_definition("int32_t", "Atlstride1", 0)] + data_str += [format_scalar_definition("int32_t", "Atlbound2", kwargs["M"])] + data_str += [ + format_scalar_definition( + "int32_t", + "Atlstride2", + kwargs["K"] * input_data_width * tileSize * meshRow / 8, + ) + ] + data_str += [format_scalar_definition("int32_t", "Atlbound3", 1)] + data_str += [format_scalar_definition("int32_t", "Atlstride3", 0)] + data_str += [format_scalar_definition("int32_t", "Atlbound4", 1)] + data_str += [format_scalar_definition("int32_t", "Atlstride4", 0)] + data_str += [format_scalar_definition("int32_t", "Atlbound5", 1)] + data_str += [format_scalar_definition("int32_t", "Atlstride5", 0)] + + data_str += [format_scalar_definition("int32_t", "Bslstride0", 1)] + data_str += [format_scalar_definition("int32_t", "Bslstride1", bankWidth / 8)] + data_str += [format_scalar_definition("int32_t", "Btlbound0", kwargs["K"])] + data_str += [ + format_scalar_definition( + "int32_t", "Btlstride0", input_data_width * tileSize * meshCol / 8 + ) + ] + data_str += [format_scalar_definition("int32_t", "Btlbound1", kwargs["N"])] + data_str += [ + format_scalar_definition( + "int32_t", + "Btlstride1", + kwargs["K"] * input_data_width * tileSize * meshCol / 8, + ) + ] + data_str += [format_scalar_definition("int32_t", "Btlbound2", kwargs["M"])] + data_str += [format_scalar_definition("int32_t", "Btlstride2", 0)] + + data_str += [format_scalar_definition("int32_t", "Cslstride0", bankWidth / 8)] + c32_spatial_bound_0 = 8 + data_str += [ + format_scalar_definition( + "int32_t", "Cslstride1", c32_spatial_bound_0 * (bankWidth / 8) + ) + ] + data_str += [format_scalar_definition("int32_t", "Ctlbound0", kwargs["N"])] + data_str += [ + format_scalar_definition( + "int32_t", "Ctlstride0", output_data_width * meshRow * meshCol / 8 + ) + ] + data_str += [format_scalar_definition("int32_t", "Ctlbound1", kwargs["M"])] + data_str += [ + format_scalar_definition( + "int32_t", + "Ctlstride1", + kwargs["N"] * output_data_width * meshRow * meshCol / 8, + ) + ] + data_str += [format_scalar_definition("int32_t", "Ctlbound2", 1)] + data_str += [format_scalar_definition("int32_t", "Ctlstride2", 0)] + + data_str += [format_scalar_definition("int32_t", "D32slstride0", bankWidth / 8)] + d32_spatial_bound_0 = 8 + data_str += [ + format_scalar_definition( + "int32_t", "D32slstride1", d32_spatial_bound_0 * (bankWidth / 8) + ) + ] + data_str += [format_scalar_definition("int32_t", "D32tlbound0", kwargs["N"])] + data_str += [ + format_scalar_definition( + "int32_t", "D32tlstride0", output_data_width * meshRow * meshCol / 8 + ) + ] + data_str += [format_scalar_definition("int32_t", "D32tlbound1", kwargs["M"])] + data_str += [ + format_scalar_definition( + "int32_t", + "D32tlstride1", + kwargs["N"] * output_data_width * meshRow * meshCol / 8, + ) + ] + data_str += [format_scalar_definition("int32_t", "D32tlbound2", 1)] + data_str += [format_scalar_definition("int32_t", "D32tlstride2", 0)] + + data_str += [format_scalar_definition("int32_t", "D8slstride0", 1)] + data_str += [format_scalar_definition("int32_t", "D8slstride1", bankWidth / 8)] + data_str += [format_scalar_definition("int32_t", "D8tlbound0", kwargs["N"])] + data_str += [ + format_scalar_definition( + "int32_t", + "D8tlstride0", + quantized_output_data_width * meshRow * meshCol / 8, + ) + ] + data_str += [format_scalar_definition("int32_t", "D8tlbound1", kwargs["M"])] + data_str += [ + format_scalar_definition( + "int32_t", + "D8tlstride1", + kwargs["N"] * quantized_output_data_width * meshRow * meshCol / 8, + ) + ] + data_str += [format_scalar_definition("int32_t", "D8tlbound2", 1)] + data_str += [format_scalar_definition("int32_t", "D8tlstride2", 0)] + + delta_local_a = 0 + delta_local_b = ( + kwargs["K"] * kwargs["M"] * (meshRow * tileSize * input_data_width / 8) + ) + delta_local_c = delta_local_b + kwargs["K"] * kwargs["N"] * ( + meshCol * tileSize * input_data_width / 8 + ) + delta_local_d32 = delta_local_c + kwargs["M"] * kwargs["N"] * ( + meshRow * meshCol * output_data_width / 8 + ) + delta_local_d8 = delta_local_d32 + data_str += [format_scalar_definition("int32_t", "delta_local_a", delta_local_a)] + data_str += [format_scalar_definition("int32_t", "delta_local_b", delta_local_b)] + data_str += [ + format_scalar_definition( + "int32_t", + "delta_local_c", + delta_local_c, + ) + ] + data_str += [ + format_scalar_definition( + "int32_t", + "delta_local_d32", + delta_local_d32, + ) + ] + data_str += [ + format_scalar_definition( + "int32_t", + "delta_local_d8", + delta_local_d8, + ) + ] + + # Generating random 8 integer a and b for subtraction + subtraction_a = np.random.randint(MIN, MAX) + subtraction_b = np.random.randint(MIN, MAX) + + # Writing the subtraction value to data.h + data_str += [format_scalar_definition("int8_t", "subtraction_a", subtraction_a)] + data_str += [format_scalar_definition("int8_t", "subtraction_b", subtraction_b)] + + A = np.random.randint( + MIN, MAX, size=(kwargs["M"], kwargs["K"], meshRow, tileSize) + ).reshape(-1) + data_str += [format_vector_definition("int8_t", "A", A)] + + B = np.random.randint( + MIN, MAX, size=(kwargs["K"], kwargs["N"], tileSize, meshCol) + ).reshape(-1) + data_str += [format_vector_definition("int8_t", "B", B)] + + broadcast_C = kwargs["broadcast_C"] == 1 and kwargs["channel_en_C"] == 1 + disable_C = kwargs["broadcast_C"] == 0 and kwargs["channel_en_C"] == 0 + enable_full_C = kwargs["broadcast_C"] == 0 and kwargs["channel_en_C"] == 1 + + assert broadcast_C or disable_C or enable_full_C, "Invalid C settings" + + if broadcast_C == 1: + C = np.random.randint(MIN, MAX, size=(kwargs["M"], kwargs["N"], 1, meshCol)) + C = np.repeat(C, repeats=8, axis=1).reshape(-1) + elif enable_full_C == 1: + C = np.random.randint( + MIN, MAX, size=(kwargs["M"], kwargs["N"], meshRow, meshCol) + ).reshape(-1) + else: + C = np.random.randint( + 0, 1, size=(kwargs["M"], kwargs["N"], meshRow, meshCol) + ).reshape(-1) + + if broadcast_C == 1: + data_str += [format_scalar_definition("int32_t", "channel_en_C", 0b11111111)] + elif enable_full_C == 1: + data_str += [ + format_scalar_definition("int32_t", "channel_en_C", ((1 << 32) - 1)) + ] + else: + data_str += [format_scalar_definition("int32_t", "channel_en_C", 0)] + + data_str += [ + format_scalar_definition("int32_t", "broadcast_C", kwargs["broadcast_C"]) + ] + data_str += [format_vector_definition("int32_t", "C", C)] + + if kwargs["transposed_A"] == 1: + A = A.reshape(kwargs["M"], kwargs["K"], meshRow, tileSize) + A = A.transpose(0, 1, 3, 2).reshape(-1) + if kwargs["transposed_B"] == 1: + B = B.reshape(kwargs["K"], kwargs["N"], tileSize, meshCol) + B = B.transpose(0, 1, 3, 2).reshape(-1) + + data_str += [ + format_scalar_definition("int32_t", "transposed_A", kwargs["transposed_A"]) + ] + data_str += [ + format_scalar_definition("int32_t", "transposed_B", kwargs["transposed_B"]) + ] + + D32 = block_gemm_golden_model( + kwargs["M"], + kwargs["K"], + kwargs["N"], + meshRow, + tileSize, + meshCol, + A, + B, + subtraction_a, + subtraction_b, + C, + ) + + return data_str, D32 + + +def emit_gemmx_data(**kwargs): + data_str, D32 = emit_matmul_data(**kwargs) + + data_str += [format_vector_definition("int32_t", "D32", D32)] + + # ----------------------------------------------------------- + # Postprocessing + # ----------------------------------------------------------- + + bypassSIMD = kwargs["bypassSIMD"] + data_str += [format_scalar_definition("int32_t", "bypassSIMD", bypassSIMD)] + + # Generating random constant values + group_num = 8 + input_zp_i = np.random.randint(MIN, MAX) + output_zp_i = np.random.randint(MIN, MAX) + max_int_i = MAX + min_int_i = MIN + double_round_i = np.random.randint(0, 1) + + shift_i = np.random.randint(0, 63, size=group_num) # values between 0-63 + multiplier_i = np.random.randint(-(2**31), 2**31 - 1, size=group_num) + + # Writing the constant values to data.h + data_str += [ + format_scalar_definition("int8_t", "input_zp_i", input_zp_i), + format_scalar_definition("int8_t", "output_zp_i", output_zp_i), + format_scalar_definition("int8_t", "max_int_i", max_int_i), + format_scalar_definition("int8_t", "min_int_i", min_int_i), + format_scalar_definition("int8_t", "double_round_i", double_round_i), + ] + + shared_bitpacked_shift0 = ( + (shift_i[3] << 24) | (shift_i[2] << 16) | (shift_i[1] << 8) | shift_i[0] + ) + shared_bitpacked_shift1 = ( + (shift_i[7] << 24) | (shift_i[6] << 16) | (shift_i[5] << 8) | shift_i[4] + ) + data_str += [ + format_scalar_definition( + "int32_t", "shared_bitpacked_shift0", shared_bitpacked_shift0 + ) + ] + data_str += [ + format_scalar_definition( + "int32_t", "shared_bitpacked_shift1", shared_bitpacked_shift1 + ) + ] + + data_str += [ + format_scalar_definition("int32_t", "shared_multiplier0", multiplier_i[0]) + ] + data_str += [ + format_scalar_definition("int32_t", "shared_multiplier1", multiplier_i[1]) + ] + data_str += [ + format_scalar_definition("int32_t", "shared_multiplier2", multiplier_i[2]) + ] + data_str += [ + format_scalar_definition("int32_t", "shared_multiplier3", multiplier_i[3]) + ] + data_str += [ + format_scalar_definition("int32_t", "shared_multiplier4", multiplier_i[4]) + ] + data_str += [ + format_scalar_definition("int32_t", "shared_multiplier5", multiplier_i[5]) + ] + data_str += [ + format_scalar_definition("int32_t", "shared_multiplier6", multiplier_i[6]) + ] + data_str += [ + format_scalar_definition("int32_t", "shared_multiplier7", multiplier_i[7]) + ] + + D8 = np.zeros_like(D32, dtype=np.uint8) + # output channel (innermost dim) has a different scale factor + for i in range(group_num): + D8[i::group_num] = postprocessing_simd_golden_model( + D32[i::group_num], + input_zp_i, + output_zp_i, + shift_i[i], + max_int_i, + min_int_i, + double_round_i, + multiplier_i[i], + ) + + data_str += [format_vector_definition("int8_t", "D8", D8)] + + data_str += [format_scalar_definition("int32_t", "set_addr_remap_index_A", 0)] + data_str += [format_scalar_definition("int32_t", "set_addr_remap_index_B", 0)] + data_str += [format_scalar_definition("int32_t", "set_addr_remap_index_C", 0)] + data_str += [format_scalar_definition("int32_t", "set_addr_remap_index_D32", 0)] + data_str += [format_scalar_definition("int32_t", "set_addr_remap_index_D8", 0)] + + data_str = "\n\n".join(data_str) + + return data_str + + +def main(): + # Parsing cmd args + parser = argparse.ArgumentParser(description="Generate data for kernels") + parser.add_argument( + "-c", + "--cfg", + type=pathlib.Path, + required=True, + help="Select param config file kernel", + ) + args = parser.parse_args() + + # Load param config file + with args.cfg.open() as f: + param = hjson.loads(f.read()) + + # Emit header file + print(emit_header_file(**param)) + + +if __name__ == "__main__": + + main() diff --git a/target/sim/sw/device/apps/snax/snax-gemmx-matmul/data/params.hjson b/target/sim/sw/device/apps/snax/snax-gemmx-matmul/data/params.hjson new file mode 100644 index 000000000..fa184c055 --- /dev/null +++ b/target/sim/sw/device/apps/snax/snax-gemmx-matmul/data/params.hjson @@ -0,0 +1,21 @@ +// Copyright 2024 KU Leuven. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 + +// Cluster configuration for a matmul testbench system. +{ + // gemm configurations + transposed_A: 1, + transposed_B: 0, + K: 18, + N: 2, + M: 1, + bypassSIMD: 0, + broadcast_C: 1, + channel_en_C: 1, + + // hardware parameters + meshRow : 8, + tileSize: 8, + meshCol : 8, +} diff --git a/target/sim/sw/device/apps/snax/snax-gemmx-matmul/src/snax-gemmx-matmul.c b/target/sim/sw/device/apps/snax/snax-gemmx-matmul/src/snax-gemmx-matmul.c new file mode 100644 index 000000000..efc0c088d --- /dev/null +++ b/target/sim/sw/device/apps/snax/snax-gemmx-matmul/src/snax-gemmx-matmul.c @@ -0,0 +1,112 @@ +// Copyright 2024 KU Leuven. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Xiaoling Yi + +#include "data.h" + +#include "snax-gemmx-params.h" + +#include "snax-gemmx-lib.h" + +// This is the main function for the SNAX GEMM for Conv2d +// We use several nested loops to iterate over the input data and weights, +// achieving implicit im2col +int main() { + // Set err value for checking + int err = 0; + + // Prepare addresses in TCDM + int8_t *local_a, *local_b; + int32_t *local_c, *local_d32; + int8_t *local_d8; + + // Allocate space in TCDM + local_a = (int8_t *)(snrt_l1_next() + delta_local_a); + local_b = (int8_t *)(snrt_l1_next() + delta_local_b); + local_c = (int32_t *)(snrt_l1_next() + delta_local_c); + local_d32 = (int32_t *)(snrt_l1_next() + delta_local_d32); + local_d8 = (int8_t *)(snrt_l1_next() + delta_local_d8); + + // Transfer data from L3 to L1 + // Using DMA only + if (snrt_is_dm_core()) { + snrt_dma_start_1d(local_a, A, + M * K * meshRow * tileSize * sizeof(int8_t)); + snrt_dma_start_1d(local_b, B, + N * K * tileSize * meshCol * sizeof(int8_t)); + + snrt_dma_wait_all(); + } + + // Wait for DMA to finish + snrt_cluster_hw_barrier(); + if (snrt_is_dm_core()) { + snrt_dma_start_1d(local_c, C, + M * N * meshRow * meshCol * sizeof(int32_t)); + snrt_dma_wait_all(); + } + + snrt_cluster_hw_barrier(); + + if (snrt_global_core_idx() == 0) { + // Set Streamer configuration CSR for conv2d + set_gemmx_streamer_csr( + Aslstride0, Aslstride1, Atlbound0, Atlstride0, Atlbound1, + Atlstride1, Atlbound2, Atlstride2, Atlbound3, Atlstride3, Atlbound4, + Atlstride4, Atlbound5, Atlstride5, set_addr_remap_index_A, + + Bslstride0, Bslstride1, Btlbound0, Btlstride0, Btlbound1, + Btlstride1, Btlbound2, Btlstride2, set_addr_remap_index_B, + + D8slstride0, D8slstride1, D8tlbound0, D8tlstride0, D8tlbound1, + D8tlstride1, D8tlbound2, D8tlstride2, set_addr_remap_index_D8, + + Cslstride0, Cslstride1, Ctlbound0, Ctlstride0, Ctlbound1, + Ctlstride1, Ctlbound2, Ctlstride2, set_addr_remap_index_C, + + D32slstride0, D32slstride1, D32tlbound0, D32tlstride0, D32tlbound1, + D32tlstride1, D32tlbound2, D32tlstride2, set_addr_remap_index_D32, + + delta_local_a, delta_local_b, delta_local_d8, delta_local_c, + delta_local_d32, bypassSIMD, transposed_A, transposed_B, + channel_en_C, broadcast_C); + + // Set GEMMX configuration CSR + uint32_t subtraction_setting = + gen_subtraction_config(subtraction_a, subtraction_b); + + uint32_t csr0 = + gen_csr0_config(input_zp_i, output_zp_i, max_int_i, min_int_i); + uint32_t csr1 = gen_csr1_config(double_round_i); + + set_gemmx_csr( + K, N, M, subtraction_setting, csr0, csr1, shared_bitpacked_shift0, + shared_bitpacked_shift1, shared_multiplier0, shared_multiplier1, + shared_multiplier2, shared_multiplier3, shared_multiplier4, + shared_multiplier5, shared_multiplier6, shared_multiplier7, M * N, + bypassSIMD); + + // Set CSR to start Streamer for conv2d + set_gemmx_streamer_start(); + + // Set CSR to start GEMM + set_gemmx_start(); + + // Poll until Streamer and GEMM accelerator finish + wait_gemmx_and_streamer(); + + // check the result of the implicit im2col convolution + if (!bypassSIMD) { + err += check_gemmx_result_D8(local_d8, D8, Batch, M, N, false); + } else { + err += check_gemmx_result_D32(local_d32, D32, Batch, M, N, false); + } + + printf("SNAX GEMM Matmul: %s, Error: %d . bypassSIMD = %d .\n", + err ? "FAIL" : "PASS", err, bypassSIMD); + }; + + return err; +} diff --git a/target/sim/sw/device/snax/streamer-gemm-conv-simd/include/snax-streamer-gemm-conv-simd-lib.h b/target/sim/sw/device/snax/streamer-gemm-conv-simd/include/snax-streamer-gemm-conv-simd-lib.h index 63242c35d..66bacf58e 100644 --- a/target/sim/sw/device/snax/streamer-gemm-conv-simd/include/snax-streamer-gemm-conv-simd-lib.h +++ b/target/sim/sw/device/snax/streamer-gemm-conv-simd/include/snax-streamer-gemm-conv-simd-lib.h @@ -58,23 +58,27 @@ void set_gemmx_streamer_csr( int Aslstride0, int Aslstride1, int Atlbound0, int Atlstride0, int Atlbound1, int Atlstride1, int Atlbound2, int Atlstride2, int Atlbound3, int Atlstride3, int Atlbound4, int Atlstride4, int Atlbound5, - int Atlstride5, + int Atlstride5, int set_addr_remap_index_A, int Bslstride0, int Bslstride1, int Btlbound0, int Btlstride0, int Btlbound1, int Btlstride1, int Btlbound2, int Btlstride2, + int set_addr_remap_index_B, int D8slstride0, int D8slstride1, int D8tlbound0, int D8tlstride0, int D8tlbound1, int D8tlstride1, int D8tlbound2, int D8tlstride2, + int set_addr_remap_index_D8, int Cslstride0, int Cslstride1, int Ctlbound0, int Ctlstride0, int Ctlbound1, int Ctlstride1, int Ctlbound2, int Ctlstride2, + int set_addr_remap_index_C, int D32slstride0, int D32slstride1, int D32tlbound0, int D32tlstride0, int D32tlbound1, int D32tlstride1, int D32tlbound2, int D32tlstride2, + int set_addr_remap_index_D32, int delta_local_a, int delta_local_b, int delta_local_d8, int delta_local_c, int delta_local_d32, int bypassSIMD, int32_t transpose_A, - int32_t transpose_B, int32_t channel_en_C); + int32_t transpose_B, int32_t channel_en_C, int32_t broadcast_C); // Set CSR to start STREAMER inline void set_gemmx_streamer_start() { csrw_ss(STREAMER_START_CSR, 1); } @@ -103,7 +107,9 @@ uint32_t read_gemmx_perf_counter(); // Check the result of the implicit im2col convolution uint32_t check_gemmx_result_D8(int8_t* output, int8_t* output_golden, - int32_t Batch, int32_t M, int32_t N); + int32_t Batch, int32_t M, int32_t N, + bool banked_data_layout); uint32_t check_gemmx_result_D32(int32_t* output, int32_t* output_golden, - int32_t Batch, int32_t M, int32_t N); + int32_t Batch, int32_t M, int32_t N, + bool banked_data_layout); diff --git a/target/sim/sw/host/apps/offload/Makefile b/target/sim/sw/host/apps/offload/Makefile index 7e74faef4..38eeb69d9 100644 --- a/target/sim/sw/host/apps/offload/Makefile +++ b/target/sim/sw/host/apps/offload/Makefile @@ -32,9 +32,8 @@ RUNTIME_DIR = $(abspath $(HOST_DIR)/runtime) DEVICE_DIR = $(abspath $(HOST_DIR)/../device) # now we only include the snax app -DEVICE_APPS += snax/snax-data-reshuffler -DEVICE_APPS += snax/snax-streamer-gemm-conv -DEVICE_APPS += snax/snax-streamer-gemm-conv-simd +DEVICE_APPS += snax/snax-gemmx-matmul +DEVICE_APPS += snax/snax-gemmx-conv DEVICE_APPS += snax/snax-test-integration DEVICE_APPS += snax/snax-hypercorex-test-csr DEVICE_APPS += snax/snax-hypercorex-char-recog diff --git a/target/sim/sw/sim_elf.yaml b/target/sim/sw/sim_elf.yaml index bdc2c3fc0..0f8aa4b07 100644 --- a/target/sim/sw/sim_elf.yaml +++ b/target/sim/sw/sim_elf.yaml @@ -4,3 +4,5 @@ runs: - elf: host/apps/offload/build/offload-snax-test-integration.elf + - elf: host/apps/offload/build/offload-snax-gemmx-matmul.elf + - elf: host/apps/offload/build/offload-snax-gemmx-conv.elf