diff --git a/hw/chisel_acc/src/main/scala/snax_acc/gemm_simd/BlockGemmRescaleSIMD.scala b/hw/chisel_acc/src/main/scala/snax_acc/gemm_simd/BlockGemmRescaleSIMD.scala index 0db9c534b..b5e4a3bb8 100644 --- a/hw/chisel_acc/src/main/scala/snax_acc/gemm_simd/BlockGemmRescaleSIMD.scala +++ b/hw/chisel_acc/src/main/scala/snax_acc/gemm_simd/BlockGemmRescaleSIMD.scala @@ -122,3 +122,62 @@ object BlockGemmRescaleSIMD extends App { Array("--target-dir", "generated/gemmx") ) } + +object BlockGemmRescaleSIMDGen { + def main(args: Array[String]): Unit = { + // Helper function to parse command-line arguments into a Map + def parseArgs(args: Array[String]): Map[String, String] = { + val parsed_args = args + .sliding(2, 2) + .collect { + case Array(key, value) if key.startsWith("--") => key.drop(2) -> value + } + .toMap + if (parsed_args.size != 4) { + throw new Exception( + "Please provide the meshRow, meshCol, tileSize, and withPipeline. Example usage: sbt 'runMain snax_acc.gemmx.BlockGemmRescaleSIMDGen --meshRow 2 --meshCol 2 --tileSize 16 --withPipeline true'" + ) + } + parsed_args + } + + // Parse the arguments + val argMap = parseArgs(args) + + // Retrieve the specific values, providing defaults or error handling + val meshRow = argMap("meshRow") + val meshCol = argMap("meshCol") + val tileSize = argMap("tileSize") + + // set the parameters for the gemm module + // other parameters are set to default values + val gemmParams = GemmParams( + GemmConstant.dataWidthA, + GemmConstant.dataWidthB, + GemmConstant.dataWidthMul, + GemmConstant.dataWidthC, + GemmConstant.dataWidthAccum, + GemmConstant.subtractionCfgWidth, + meshRow.toInt, + tileSize.toInt, + meshCol.toInt, + GemmConstant.addrWidth, + GemmConstant.sizeConfigWidth + ) + + val withPipeline = argMap("withPipeline").toBoolean + + emitVerilog( + new BlockGemmRescaleSIMD( + BlockGemmRescaleSIMDParams( + gemmParams, + (if (withPipeline == true) + snax_acc.simd.PipelinedConfig.rescaleSIMDConfig + else snax_acc.simd.DefaultConfig.rescaleSIMDConfig), + false + ) + ), + Array("--target-dir", "generated/gemmx") + ) + } +} diff --git a/target/snitch_cluster/cfg/snax-kul-cluster-mixed-narrow-wide-xdma.hjson b/target/snitch_cluster/cfg/snax-kul-cluster-mixed-narrow-wide-xdma.hjson index 1e67437ad..d9159bf75 100644 --- a/target/snitch_cluster/cfg/snax-kul-cluster-mixed-narrow-wide-xdma.hjson +++ b/target/snitch_cluster/cfg/snax-kul-cluster-mixed-narrow-wide-xdma.hjson @@ -105,6 +105,10 @@ snax_wide_tcdm_ports: 48, snax_num_rw_csr: 10, snax_num_ro_csr: 2, + snax_gemmx_mesh_row: 8, + snax_gemmx_tile_size: 8, + snax_gemmx_mesh_col: 8, + with_pipeline: true, snax_streamer_cfg: {$ref: "#/snax_streamer_gemmX_streamer_template" } }, snax_use_custom_ports: false, diff --git a/target/snitch_cluster/cfg/snax-kul-cluster-mixed-narrow-wide.hjson b/target/snitch_cluster/cfg/snax-kul-cluster-mixed-narrow-wide.hjson index a0fab8ef2..db098dd43 100644 --- a/target/snitch_cluster/cfg/snax-kul-cluster-mixed-narrow-wide.hjson +++ b/target/snitch_cluster/cfg/snax-kul-cluster-mixed-narrow-wide.hjson @@ -106,6 +106,10 @@ snax_wide_tcdm_ports: 48, snax_num_rw_csr: 10, snax_num_ro_csr: 2, + snax_gemmx_mesh_row: 8, + snax_gemmx_tile_size: 8, + snax_gemmx_mesh_col: 8, + with_pipeline: false, snax_streamer_cfg: {$ref: "#/snax_streamer_gemmX_streamer_template" } }, snax_use_custom_ports: false, diff --git a/target/snitch_cluster/sw/apps/snax-gemmx/data/datagen.py b/target/snitch_cluster/sw/apps/snax-gemmx/data/datagen.py index 38d1bb882..088cf2f1a 100755 --- a/target/snitch_cluster/sw/apps/snax-gemmx/data/datagen.py +++ b/target/snitch_cluster/sw/apps/snax-gemmx/data/datagen.py @@ -40,6 +40,11 @@ def emit_header_file(**kwargs): MIN = -128 MAX = 127 +bankWidth = 64 +input_data_width = 8 +output_data_width = 32 +quantized_output_data_width = 8 + def emit_conv_data(**kwargs): Cin = kwargs["Cin"] @@ -553,6 +558,11 @@ def emit_conv_data(**kwargs): def emit_matmul_data(**kwargs): + + meshRow = kwargs["meshRow"] + tileSize = kwargs["tileSize"] + meshCol = kwargs["meshCol"] + # matmul settings data_str = [] @@ -562,13 +572,23 @@ def emit_matmul_data(**kwargs): data_str += [format_scalar_definition("int", "N", kwargs["N"])] data_str += [format_scalar_definition("int32_t", "Aslstride0", 1)] - data_str += [format_scalar_definition("int32_t", "Aslstride1", 8)] + data_str += [format_scalar_definition("int32_t", "Aslstride1", bankWidth / 8)] data_str += [format_scalar_definition("int32_t", "Atlbound0", kwargs["K"])] - data_str += [format_scalar_definition("int32_t", "Atlstride0", 64)] + data_str += [ + format_scalar_definition( + "int32_t", "Atlstride0", input_data_width * tileSize * meshRow / 8 + ) + ] data_str += [format_scalar_definition("int32_t", "Atlbound1", kwargs["N"])] data_str += [format_scalar_definition("int32_t", "Atlstride1", 0)] data_str += [format_scalar_definition("int32_t", "Atlbound2", kwargs["M"])] - data_str += [format_scalar_definition("int32_t", "Atlstride2", kwargs["K"] * 64)] + data_str += [ + format_scalar_definition( + "int32_t", + "Atlstride2", + kwargs["K"] * input_data_width * tileSize * meshRow / 8, + ) + ] data_str += [format_scalar_definition("int32_t", "Atlbound3", 1)] data_str += [format_scalar_definition("int32_t", "Atlstride3", 0)] data_str += [format_scalar_definition("int32_t", "Atlbound4", 1)] @@ -577,70 +597,115 @@ def emit_matmul_data(**kwargs): data_str += [format_scalar_definition("int32_t", "Atlstride5", 0)] data_str += [format_scalar_definition("int32_t", "Bslstride0", 1)] - data_str += [format_scalar_definition("int32_t", "Bslstride1", 8)] + data_str += [format_scalar_definition("int32_t", "Bslstride1", bankWidth / 8)] data_str += [format_scalar_definition("int32_t", "Btlbound0", kwargs["K"])] - data_str += [format_scalar_definition("int32_t", "Btlstride0", 64)] + data_str += [ + format_scalar_definition( + "int32_t", "Btlstride0", input_data_width * tileSize * meshCol / 8 + ) + ] data_str += [format_scalar_definition("int32_t", "Btlbound1", kwargs["N"])] - data_str += [format_scalar_definition("int32_t", "Btlstride1", 64 * kwargs["K"])] + data_str += [ + format_scalar_definition( + "int32_t", + "Btlstride1", + kwargs["K"] * input_data_width * tileSize * meshCol / 8, + ) + ] data_str += [format_scalar_definition("int32_t", "Btlbound2", kwargs["M"])] data_str += [format_scalar_definition("int32_t", "Btlstride2", 0)] data_str += [format_scalar_definition("int32_t", "Cslstride0", 4)] - data_str += [format_scalar_definition("int32_t", "Cslstride1", 8)] + data_str += [format_scalar_definition("int32_t", "Cslstride1", bankWidth / 8)] data_str += [format_scalar_definition("int32_t", "Ctlbound0", kwargs["N"])] - data_str += [format_scalar_definition("int32_t", "Ctlstride0", 256)] + data_str += [ + format_scalar_definition( + "int32_t", "Ctlstride0", output_data_width * meshRow * meshCol / 8 + ) + ] data_str += [format_scalar_definition("int32_t", "Ctlbound1", kwargs["M"])] - data_str += [format_scalar_definition("int32_t", "Ctlstride1", 256 * kwargs["N"])] + data_str += [ + format_scalar_definition( + "int32_t", + "Ctlstride1", + kwargs["N"] * output_data_width * meshRow * meshCol / 8, + ) + ] data_str += [format_scalar_definition("int32_t", "Ctlbound2", 1)] data_str += [format_scalar_definition("int32_t", "Ctlstride2", 0)] data_str += [format_scalar_definition("int32_t", "D32slstride0", 4)] - data_str += [format_scalar_definition("int32_t", "D32slstride1", 8)] + data_str += [format_scalar_definition("int32_t", "D32slstride1", bankWidth / 8)] data_str += [format_scalar_definition("int32_t", "D32tlbound0", kwargs["N"])] - data_str += [format_scalar_definition("int32_t", "D32tlstride0", 256)] + data_str += [ + format_scalar_definition( + "int32_t", "D32tlstride0", output_data_width * meshRow * meshCol / 8 + ) + ] data_str += [format_scalar_definition("int32_t", "D32tlbound1", kwargs["M"])] - data_str += [format_scalar_definition("int32_t", "D32tlstride1", 256 * kwargs["N"])] + data_str += [ + format_scalar_definition( + "int32_t", + "D32tlstride1", + kwargs["N"] * output_data_width * meshRow * meshCol / 8, + ) + ] data_str += [format_scalar_definition("int32_t", "D32tlbound2", 1)] data_str += [format_scalar_definition("int32_t", "D32tlstride2", 0)] data_str += [format_scalar_definition("int32_t", "D8slstride0", 1)] - data_str += [format_scalar_definition("int32_t", "D8slstride1", 8)] + data_str += [format_scalar_definition("int32_t", "D8slstride1", bankWidth / 8)] data_str += [format_scalar_definition("int32_t", "D8tlbound0", kwargs["N"])] - data_str += [format_scalar_definition("int32_t", "D8tlstride0", 64)] + data_str += [ + format_scalar_definition( + "int32_t", + "D8tlstride0", + quantized_output_data_width * meshRow * meshCol / 8, + ) + ] data_str += [format_scalar_definition("int32_t", "D8tlbound1", kwargs["M"])] - data_str += [format_scalar_definition("int32_t", "D8tlstride1", 64 * kwargs["N"])] - data_str += [format_scalar_definition("int32_t", "D8tlbound2", 1)] - data_str += [format_scalar_definition("int32_t", "D8tlstride2", 0)] - - data_str += [format_scalar_definition("int32_t", "delta_local_a", 0)] data_str += [ format_scalar_definition( - "int32_t", "delta_local_b", 64 * kwargs["K"] * kwargs["M"] + "int32_t", + "D8tlstride1", + kwargs["N"] * quantized_output_data_width * meshRow * meshCol / 8, ) ] + data_str += [format_scalar_definition("int32_t", "D8tlbound2", 1)] + data_str += [format_scalar_definition("int32_t", "D8tlstride2", 0)] + + delta_local_a = 0 + delta_local_b = ( + kwargs["K"] * kwargs["M"] * (meshRow * tileSize * input_data_width / 8) + ) + delta_local_c = delta_local_b + kwargs["K"] * kwargs["N"] * ( + meshCol * tileSize * input_data_width / 8 + ) + delta_local_d32 = delta_local_c + kwargs["M"] * kwargs["N"] * ( + meshRow * meshCol * output_data_width / 8 + ) + delta_local_d8 = delta_local_d32 + data_str += [format_scalar_definition("int32_t", "delta_local_a", delta_local_a)] + data_str += [format_scalar_definition("int32_t", "delta_local_b", delta_local_b)] data_str += [ format_scalar_definition( "int32_t", "delta_local_c", - 64 * kwargs["K"] * kwargs["M"] + 64 * kwargs["K"] * kwargs["N"], + delta_local_c, ) ] data_str += [ format_scalar_definition( "int32_t", "delta_local_d32", - 64 * kwargs["K"] * kwargs["M"] - + 64 * kwargs["K"] * kwargs["N"] - + 256 * kwargs["M"] * kwargs["N"], + delta_local_d32, ) ] data_str += [ format_scalar_definition( "int32_t", "delta_local_d8", - 64 * kwargs["K"] * kwargs["M"] - + 64 * kwargs["K"] * kwargs["N"] - + 256 * kwargs["M"] * kwargs["N"], + delta_local_d8, ) ] @@ -652,18 +717,24 @@ def emit_matmul_data(**kwargs): data_str += [format_scalar_definition("int8_t", "subtraction_a", subtraction_a)] data_str += [format_scalar_definition("int8_t", "subtraction_b", subtraction_b)] - A = np.random.randint(MIN, MAX, size=(kwargs["M"], kwargs["K"], 8, 8)).reshape(-1) + A = np.random.randint( + MIN, MAX, size=(kwargs["M"], kwargs["K"], meshRow, tileSize) + ).reshape(-1) data_str += [format_vector_definition("int8_t", "A", A)] - B = np.random.randint(MIN, MAX, size=(kwargs["K"], kwargs["N"], 8, 8)).reshape(-1) + B = np.random.randint( + MIN, MAX, size=(kwargs["K"], kwargs["N"], tileSize, meshCol) + ).reshape(-1) data_str += [format_vector_definition("int8_t", "B", B)] - C = np.random.randint(MIN, MAX, size=(kwargs["M"], kwargs["N"], 8, 8)).reshape(-1) + C = np.random.randint( + MIN, MAX, size=(kwargs["M"], kwargs["N"], meshRow, meshCol) + ).reshape(-1) data_str += [format_vector_definition("int32_t", "C", C)] if kwargs["transposed_A"] == 1: - A = A.reshape(kwargs["M"], kwargs["K"], 8, 8) + A = A.reshape(kwargs["M"], kwargs["K"], meshRow, tileSize) A = A.transpose(0, 1, 3, 2).reshape(-1) if kwargs["transposed_B"] == 1: - B = B.reshape(kwargs["K"], kwargs["N"], 8, 8) + B = B.reshape(kwargs["K"], kwargs["N"], tileSize, meshCol) B = B.transpose(0, 1, 3, 2).reshape(-1) data_str += [ @@ -677,9 +748,9 @@ def emit_matmul_data(**kwargs): kwargs["M"], kwargs["K"], kwargs["N"], - 8, - 8, - 8, + meshRow, + tileSize, + meshCol, A, B, subtraction_a, diff --git a/target/snitch_cluster/sw/apps/snax-gemmx/data/params.hjson b/target/snitch_cluster/sw/apps/snax-gemmx/data/params.hjson index 450f06df1..fef6b1b79 100644 --- a/target/snitch_cluster/sw/apps/snax-gemmx/data/params.hjson +++ b/target/snitch_cluster/sw/apps/snax-gemmx/data/params.hjson @@ -24,4 +24,9 @@ M: 1, bypassSIMD: 1, ifTestMatmul: 1, + + // hardware parameters + meshRow : 8, + tileSize: 8, + meshCol : 8, } diff --git a/target/snitch_cluster/sw/snax/gemmx/include/snax-gemmx-params.h b/target/snitch_cluster/sw/snax/gemmx/include/snax-gemmx-params.h index 7066333b8..cea91a4b9 100644 --- a/target/snitch_cluster/sw/snax/gemmx/include/snax-gemmx-params.h +++ b/target/snitch_cluster/sw/snax/gemmx/include/snax-gemmx-params.h @@ -6,6 +6,6 @@ #pragma once -#define tileSize 8 #define meshRow 8 +#define tileSize 8 #define meshCol 8 diff --git a/target/snitch_cluster/sw/snax/gemmx/src/snax-gemmx-lib.c b/target/snitch_cluster/sw/snax/gemmx/src/snax-gemmx-lib.c index d3a48970f..5c2361c5a 100644 --- a/target/snitch_cluster/sw/snax/gemmx/src/snax-gemmx-lib.c +++ b/target/snitch_cluster/sw/snax/gemmx/src/snax-gemmx-lib.c @@ -6,6 +6,7 @@ #include "snax-gemmx-lib.h" #include +#include "snax-gemmx-params.h" #include "snrt.h" #include "stdint.h" #include "streamer_csr_addr_map.h" @@ -156,8 +157,10 @@ void set_gemmx_streamer_csr( csrw_ss(T_STRIDE_READER_WRITER_1_2, D32tlstride2); // set the transpose +#ifdef TRANSPOSE_EXTENSION_ENABLE csrw_ss(TRANSPOSE_CSR_READER_0, transpose_A == 0 ? 1 : 0); csrw_ss(TRANSPOSE_CSR_READER_1, transpose_B == 0 ? 1 : 0); +#endif } // Set CSR to start STREAMER @@ -234,7 +237,7 @@ uint32_t check_gemmx_result_D8(int8_t* output, int8_t* output_golden, int32_t Batch, int32_t M, int32_t N) { uint32_t err = 0; uint32_t size = 0; - size = Batch * M * N * 8 * 8; + size = Batch * M * N * meshRow * meshCol; for (int i = 0; i < size; i++) { if (output[i] != output_golden[i]) { @@ -248,7 +251,7 @@ uint32_t check_gemmx_result_D32(int32_t* output, int32_t* output_golden, int32_t Batch, int32_t M, int32_t N) { uint32_t err = 0; uint32_t size = 0; - size = Batch * M * N * 8 * 8; + size = Batch * M * N * meshRow * meshCol; for (int i = 0; i < size; i++) { if (output[i] != output_golden[i]) { diff --git a/util/snaxgen/snaxgen.py b/util/snaxgen/snaxgen.py index 342d33d32..e7c1524dd 100755 --- a/util/snaxgen/snaxgen.py +++ b/util/snaxgen/snaxgen.py @@ -70,7 +70,7 @@ def gen_chisel_file(chisel_path, chisel_param, gen_path): mill Snax.runMain {chisel_param} {gen_path}" print(f"Running command: {cmd}") if os.system(cmd) != 0: - raise ChildProcessError('Chisel generation error. ') + raise ChildProcessError("Chisel generation error. ") return @@ -80,20 +80,20 @@ def streamer_csr_num(acc_cfgs): # Regardless if shared or not, it is the same total # This is the total number of loop dimension registers num_loop_dim = 0 - if ("data_reader_params" in acc_cfgs["snax_streamer_cfg"]): + if "data_reader_params" in acc_cfgs["snax_streamer_cfg"]: num_loop_dim += sum( acc_cfgs["snax_streamer_cfg"]["data_reader_params"]["temporal_dim"] - ) + ) - if ("data_writer_params" in acc_cfgs["snax_streamer_cfg"]): + if "data_writer_params" in acc_cfgs["snax_streamer_cfg"]: num_loop_dim += sum( acc_cfgs["snax_streamer_cfg"]["data_writer_params"]["temporal_dim"] - ) + ) - if ("data_reader_writer_params" in acc_cfgs["snax_streamer_cfg"]): + if "data_reader_writer_params" in acc_cfgs["snax_streamer_cfg"]: num_loop_dim += sum( acc_cfgs["snax_streamer_cfg"]["data_reader_writer_params"]["temporal_dim"] - ) + ) # Calculation of data movers num_data_reader = 0 @@ -103,36 +103,36 @@ def streamer_csr_num(acc_cfgs): if "data_reader_params" in acc_cfgs["snax_streamer_cfg"]: num_data_reader = len( - acc_cfgs["snax_streamer_cfg"]["data_reader_params"]["num_channel"] # noqa: E501 + acc_cfgs["snax_streamer_cfg"]["data_reader_params"][ + "num_channel" + ] # noqa: E501 ) if "data_writer_params" in acc_cfgs["snax_streamer_cfg"]: num_data_writer = len( - acc_cfgs["snax_streamer_cfg"]["data_writer_params"]["num_channel"] # noqa: E501 + acc_cfgs["snax_streamer_cfg"]["data_writer_params"][ + "num_channel" + ] # noqa: E501 ) if "data_reader_writer_params" in acc_cfgs["snax_streamer_cfg"]: num_data_reader_writer = len( - acc_cfgs["snax_streamer_cfg"]["data_reader_writer_params"]["num_channel"] # noqa: E501 + acc_cfgs["snax_streamer_cfg"]["data_reader_writer_params"][ + "num_channel" + ] # noqa: E501 ) # This sets the total number of base pointers - num_data_mover = num_data_reader + num_data_writer \ - + num_data_reader_writer + num_data_mover = num_data_reader + num_data_writer + num_data_reader_writer streamer_csr_num = ( - # Total temporal loop dimensions and strides - 2 * num_loop_dim + \ - # Number of spatial strides - num_data_mover + \ - # Number of base pointers - 2 * num_data_mover + \ - # Start register - 1 + \ - # Performance counter - 1 + \ - # Busy register - 1 + # Total temporal loop dimensions and strides + 2 * num_loop_dim # Number of spatial strides + + num_data_mover # Number of base pointers + + 2 * num_data_mover # Start register + + 1 # Performance counter + + 1 # Busy register + + 1 ) # transpose csr @@ -181,8 +181,7 @@ def main(): help="Bypass default accelerator generation", ) parser.add_argument( - "--gen_path", type=str, default="./", - help="Points to the output directory" + "--gen_path", type=str, default="./", help="Points to the output directory" ) parser.add_argument( "--get_bender_targets", @@ -207,6 +206,7 @@ def main(): # For generating all bender targets if args.get_bender_targets: + def get_bender_targets(cfg): targets = [] # If cfg is dictionary, then first check if it has @@ -234,7 +234,7 @@ def get_bender_targets(cfg): print(" Generating accelerator specific wrappers ") print("------------------------------------------------") - if (args.bypass_accgen == "false"): + if args.bypass_accgen == "false": for i in range(num_cores): if "snax_acc_cfg" in cfg_cores[i]: num_core_w_acc += 1 @@ -246,8 +246,7 @@ def get_bender_targets(cfg): # TCDM configurations tcdm_data_width = cfg["cluster"]["data_width"] acc_cfgs[i]["tcdm_data_width"] = tcdm_data_width - acc_cfgs[i]["tcdm_dma_data_width"] = \ - cfg["cluster"]["dma_data_width"] + acc_cfgs[i]["tcdm_dma_data_width"] = cfg["cluster"]["dma_data_width"] tcdm_depth = ( cfg["cluster"]["tcdm"]["size"] * 1024 @@ -257,8 +256,7 @@ def get_bender_targets(cfg): acc_cfgs[i]["tcdm_depth"] = tcdm_depth tcdm_num_banks = cfg["cluster"]["tcdm"]["banks"] acc_cfgs[i]["tcdm_num_banks"] = tcdm_num_banks - tcdm_addr_width = tcdm_num_banks * \ - tcdm_depth * (tcdm_data_width // 8) + tcdm_addr_width = tcdm_num_banks * tcdm_depth * (tcdm_data_width // 8) tcdm_addr_width = int(math.log2(tcdm_addr_width)) acc_cfgs[i]["tcdm_addr_width"] = tcdm_addr_width # Chisel parameter tag names @@ -270,8 +268,7 @@ def get_bender_targets(cfg): for i in range(len(acc_cfgs)): # First part is for chisel generation # Generate the parameter files for chisel streamer generation - chisel_target_path = args.chisel_path + \ - "src/main/scala/snax/streamer/" + chisel_target_path = args.chisel_path + "src/main/scala/snax/streamer/" file_name = "StreamParamGen.scala" tpl_scala_param_file = args.tpl_path + "stream_param_gen.scala.tpl" tpl_scala_param = get_template(tpl_scala_param_file) @@ -284,11 +281,11 @@ def get_bender_targets(cfg): # CSR manager scala parameter generation if not acc_cfgs[i].get("snax_disable_csr_manager", False): - chisel_target_path = args.chisel_path + \ - "src/main/scala/snax/csr_manager/" + chisel_target_path = ( + args.chisel_path + "src/main/scala/snax/csr_manager/" + ) file_name = "CsrManParamGen.scala" - tpl_scala_param_file = args.tpl_path + \ - "csrman_param_gen.scala.tpl" + tpl_scala_param_file = args.tpl_path + "csrman_param_gen.scala.tpl" tpl_scala_param = get_template(tpl_scala_param_file) gen_file( cfg=acc_cfgs[i], @@ -297,15 +294,13 @@ def get_bender_targets(cfg): file_name=file_name, ) - rtl_target_path = args.gen_path + \ - acc_cfgs[i]["snax_acc_name"] + "/" + rtl_target_path = args.gen_path + acc_cfgs[i]["snax_acc_name"] + "/" # This is for RTL wrapper and chisel generation # This first one generates the CSR manager wrapper if not acc_cfgs[i].get("snax_disable_csr_manager", False): file_name = acc_cfgs[i]["snax_acc_name"] + "_csrman_wrapper.sv" - tpl_csrman_wrapper_file = args.tpl_path + \ - "snax_csrman_wrapper.sv.tpl" + tpl_csrman_wrapper_file = args.tpl_path + "snax_csrman_wrapper.sv.tpl" tpl_csrman_wrapper = get_template(tpl_csrman_wrapper_file) gen_file( cfg=acc_cfgs[i], @@ -316,7 +311,9 @@ def get_bender_targets(cfg): # This first one generates the streamer wrapper file_name = acc_cfgs[i]["snax_acc_name"] + "_streamer_wrapper.sv" - tpl_streamer_wrapper_file = args.tpl_path + "snax_streamer_wrapper.sv.tpl" # noqa: E501 + tpl_streamer_wrapper_file = ( + args.tpl_path + "snax_streamer_wrapper.sv.tpl" + ) # noqa: E501 tpl_streamer_wrapper = get_template(tpl_streamer_wrapper_file) gen_file( cfg=acc_cfgs[i], @@ -376,24 +373,43 @@ def get_bender_targets(cfg): chisel_acc_path = args.chisel_path + "../chisel_acc" rtl_target_path = args.gen_path + acc_cfgs[i]["snax_acc_name"] + "/" - if (acc_cfgs[i]["snax_acc_name"] == "snax_streamer_gemmX"): + if acc_cfgs[i]["snax_acc_name"] == "snax_streamer_gemmX": + if not ( + "snax_gemmx_tile_size" in acc_cfgs[i] + and "snax_gemmx_mesh_row" in acc_cfgs[i] + and "snax_gemmx_mesh_col" in acc_cfgs[i] + and "with_pipeline" in acc_cfgs[i] + ): + raise ValueError( + "Missing gemmX configuration. \n" + "Please set snax_gemmx_mesh_row, snax_gemmx_mesh_col, " + "snax_gemmx_tile_size, with_pipeline" + ) gen_chisel_file( - chisel_path=chisel_acc_path, - chisel_param="snax_acc.gemmx.BlockGemmRescaleSIMD", - gen_path=rtl_target_path, - ) - elif (acc_cfgs[i]["snax_acc_name"] == "snax_streamer_gemm_add_c"): + chisel_path=chisel_acc_path, + chisel_param="snax_acc.gemmx.BlockGemmRescaleSIMDGen " + + " --meshRow " + + str(acc_cfgs[i]["snax_gemmx_mesh_row"]) + + " --meshCol " + + str(acc_cfgs[i]["snax_gemmx_mesh_col"]) + + " --tileSize " + + str(acc_cfgs[i]["snax_gemmx_tile_size"]) + + " --withPipeline " + + str(acc_cfgs[i]["with_pipeline"]), + gen_path=rtl_target_path, + ) + elif acc_cfgs[i]["snax_acc_name"] == "snax_streamer_gemm_add_c": gen_chisel_file( - chisel_path=chisel_acc_path, - chisel_param="snax_acc.gemm.BlockGemm", - gen_path=rtl_target_path, - ) - elif (acc_cfgs[i]["snax_acc_name"] == "snax_data_reshuffler"): + chisel_path=chisel_acc_path, + chisel_param="snax_acc.gemm.BlockGemm", + gen_path=rtl_target_path, + ) + elif acc_cfgs[i]["snax_acc_name"] == "snax_data_reshuffler": gen_chisel_file( - chisel_path=chisel_acc_path, - chisel_param="snax_acc.reshuffle.Reshuffler", - gen_path=rtl_target_path, - ) + chisel_path=chisel_acc_path, + chisel_param="snax_acc.reshuffle.Reshuffler", + gen_path=rtl_target_path, + ) else: print("Nothing to generate ") @@ -407,7 +423,7 @@ def get_bender_targets(cfg): for i in range(num_cores): if "snax_xdma_cfg" in cfg_cores[i]: snax_xdma_cfg = cfg_cores[i]["snax_xdma_cfg"] - if (snax_xdma_cfg is not None): + if snax_xdma_cfg is not None: tpl_rtl_wrapper_file = args.tpl_path + "snax_xdma_wrapper.sv.tpl" tpl_rtl_wrapper = get_template(tpl_rtl_wrapper_file) @@ -427,20 +443,36 @@ def get_bender_targets(cfg): gen_chisel_file( chisel_path=args.chisel_path, chisel_param="snax.xdma.xdmaTop.xdmaTopGen", - gen_path=" --clusterName " + str(cfg["cluster"]["name"]) + - " --tcdmDataWidth " + str(cfg["cluster"]["data_width"]) + - " --axiDataWidth " + str(cfg["cluster"]["dma_data_width"]) + - " --axiAddrWidth " + str(cfg["cluster"]["addr_width"]) + - " --tcdmSize " + str(cfg["cluster"]["tcdm"]["size"]) + - " --readerSpatialBounds " + str(snax_xdma_cfg["reader_agu_spatial_bounds"]) + - " --readerTemporalDimension " + str(snax_xdma_cfg["reader_agu_temporal_dimension"]) + - " --writerSpatialBounds " + str(snax_xdma_cfg["writer_agu_spatial_bounds"]) + - " --writerTemporalDimension " + str(snax_xdma_cfg["writer_agu_temporal_dimension"]) + - " --readerBufferDepth " + str(snax_xdma_cfg["reader_buffer"]) + - " --writerBufferDepth " + str(snax_xdma_cfg["writer_buffer"]) + - xdma_extension_arg + " --hw-target-dir " + args.gen_path + - cfg["cluster"]["name"] + "_xdma/" + - " --sw-target-dir " + args.gen_path + "../sw/snax/xdma" + gen_path=" --clusterName " + + str(cfg["cluster"]["name"]) + + " --tcdmDataWidth " + + str(cfg["cluster"]["data_width"]) + + " --axiDataWidth " + + str(cfg["cluster"]["dma_data_width"]) + + " --axiAddrWidth " + + str(cfg["cluster"]["addr_width"]) + + " --tcdmSize " + + str(cfg["cluster"]["tcdm"]["size"]) + + " --readerSpatialBounds " + + str(snax_xdma_cfg["reader_agu_spatial_bounds"]) + + " --readerTemporalDimension " + + str(snax_xdma_cfg["reader_agu_temporal_dimension"]) + + " --writerSpatialBounds " + + str(snax_xdma_cfg["writer_agu_spatial_bounds"]) + + " --writerTemporalDimension " + + str(snax_xdma_cfg["writer_agu_temporal_dimension"]) + + " --readerBufferDepth " + + str(snax_xdma_cfg["reader_buffer"]) + + " --writerBufferDepth " + + str(snax_xdma_cfg["writer_buffer"]) + + xdma_extension_arg + + " --hw-target-dir " + + args.gen_path + + cfg["cluster"]["name"] + + "_xdma/" + + " --sw-target-dir " + + args.gen_path + + "../sw/snax/xdma", ) # --------------------------------------- @@ -449,17 +481,22 @@ def get_bender_targets(cfg): cluster_schema_path = "../../docs/schema/snitch_cluster.schema.json" harness_cfg = read_schema(cluster_schema_path) - if ("enable_debug" not in cfg["cluster"]): - cfg["cluster"]["enable_debug"] = \ - harness_cfg["properties"]["enable_debug"]["default"] - - if ("iso_crossings" not in cfg["cluster"]["timing"]): - cfg["cluster"]["timing"]["iso_crossings"] = \ - harness_cfg["properties"]["timing"]["properties"]["iso_crossings"]["default"] # noqa: E501 - - if ("sram_cfg_expose" not in cfg["cluster"]): - cfg["cluster"]["sram_cfg_expose"] = \ - harness_cfg["properties"]["sram_cfg_expose"]["default"] + if "enable_debug" not in cfg["cluster"]: + cfg["cluster"]["enable_debug"] = harness_cfg["properties"]["enable_debug"][ + "default" + ] + + if "iso_crossings" not in cfg["cluster"]["timing"]: + cfg["cluster"]["timing"]["iso_crossings"] = harness_cfg["properties"]["timing"][ + "properties" + ]["iso_crossings"][ + "default" + ] # noqa: E501 + + if "sram_cfg_expose" not in cfg["cluster"]: + cfg["cluster"]["sram_cfg_expose"] = harness_cfg["properties"][ + "sram_cfg_expose" + ]["default"] test_target_path = args.test_path file_name = "testharness.sv"